nautilus_okx/websocket/
subscription.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Subscription tracking helpers for the OKX WebSocket client.
17
18use std::sync::Arc;
19
20use ahash::AHashSet;
21use dashmap::DashMap;
22use ustr::Ustr;
23
24use crate::{
25    common::enums::OKXInstrumentType,
26    websocket::{
27        enums::OKXWsChannel,
28        messages::{OKXSubscriptionArg, OKXWebSocketArg},
29    },
30};
31
32fn topic_from_parts(
33    channel: &OKXWsChannel,
34    inst_id: Option<&Ustr>,
35    inst_family: Option<&Ustr>,
36    inst_type: Option<&OKXInstrumentType>,
37    bar: Option<&Ustr>,
38) -> String {
39    let base = channel.as_ref();
40
41    if let Some(inst_id) = inst_id {
42        let inst_id = inst_id.as_str();
43        if let Some(bar) = bar {
44            format!("{base}:{inst_id}:{}", bar.as_str())
45        } else {
46            format!("{base}:{inst_id}")
47        }
48    } else if let Some(inst_family) = inst_family {
49        format!("{base}:{}", inst_family.as_str())
50    } else if let Some(inst_type) = inst_type {
51        format!("{base}:{}", inst_type.as_ref())
52    } else {
53        base.to_string()
54    }
55}
56
57pub(crate) fn topic_from_subscription_arg(arg: &OKXSubscriptionArg) -> String {
58    topic_from_parts(
59        &arg.channel,
60        arg.inst_id.as_ref(),
61        arg.inst_family.as_ref(),
62        arg.inst_type.as_ref(),
63        None,
64    )
65}
66
67pub(crate) fn topic_from_websocket_arg(arg: &OKXWebSocketArg) -> String {
68    topic_from_parts(
69        &arg.channel,
70        arg.inst_id.as_ref(),
71        arg.inst_family.as_ref(),
72        arg.inst_type.as_ref(),
73        arg.bar.as_ref(),
74    )
75}
76
77pub(crate) fn split_topic(topic: &str) -> (&str, Option<&str>) {
78    topic
79        .split_once(':')
80        .map_or((topic, None), |(channel, symbol)| (channel, Some(symbol)))
81}
82
83pub(crate) fn track_topic(
84    map: &DashMap<String, AHashSet<Ustr>>,
85    channel: &str,
86    symbol: Option<&str>,
87) {
88    if let Some(symbol) = symbol {
89        let mut entry = map.entry(channel.to_string()).or_default();
90        entry.insert(Ustr::from(symbol));
91    } else {
92        map.entry(channel.to_string()).or_default();
93    }
94}
95
96pub(crate) fn untrack_topic(
97    map: &DashMap<String, AHashSet<Ustr>>,
98    channel: &str,
99    symbol: Option<&str>,
100) {
101    if let Some(symbol) = symbol {
102        let symbol_ustr = Ustr::from(symbol);
103        let mut remove_channel = false;
104        if let Some(mut entry) = map.get_mut(channel) {
105            entry.remove(&symbol_ustr);
106            remove_channel = entry.is_empty();
107        }
108        if remove_channel {
109            map.remove(channel);
110        }
111    } else {
112        map.remove(channel);
113    }
114}
115
116#[derive(Clone, Debug)]
117pub(crate) struct SubscriptionState {
118    confirmed: Arc<DashMap<String, AHashSet<Ustr>>>,
119    pending: Arc<DashMap<String, AHashSet<Ustr>>>,
120}
121
122impl SubscriptionState {
123    pub(crate) fn new() -> Self {
124        Self {
125            confirmed: Arc::new(DashMap::new()),
126            pending: Arc::new(DashMap::new()),
127        }
128    }
129
130    pub(crate) fn confirmed(&self) -> Arc<DashMap<String, AHashSet<Ustr>>> {
131        Arc::clone(&self.confirmed)
132    }
133
134    pub(crate) fn pending(&self) -> Arc<DashMap<String, AHashSet<Ustr>>> {
135        Arc::clone(&self.pending)
136    }
137
138    pub(crate) fn len(&self) -> usize {
139        self.confirmed.len()
140    }
141
142    pub(crate) fn mark_subscribe(&self, topic: &str) {
143        let (channel, symbol) = split_topic(topic);
144        track_topic(&self.pending, channel, symbol);
145    }
146
147    pub(crate) fn mark_unsubscribe(&self, topic: &str) {
148        let (channel, symbol) = split_topic(topic);
149        track_topic(&self.pending, channel, symbol);
150        untrack_topic(&self.confirmed, channel, symbol);
151    }
152
153    pub(crate) fn confirm(&self, topic: &str) {
154        let (channel, symbol) = split_topic(topic);
155        untrack_topic(&self.pending, channel, symbol);
156        track_topic(&self.confirmed, channel, symbol);
157    }
158
159    pub(crate) fn mark_failure(&self, topic: &str) {
160        let (channel, symbol) = split_topic(topic);
161        untrack_topic(&self.confirmed, channel, symbol);
162        track_topic(&self.pending, channel, symbol);
163    }
164
165    pub(crate) fn clear_pending(&self, topic: &str) {
166        let (channel, symbol) = split_topic(topic);
167        untrack_topic(&self.pending, channel, symbol);
168    }
169}