nautilus_okx/websocket/
subscription.rs1use std::sync::Arc;
19
20use ahash::AHashSet;
21use dashmap::DashMap;
22use ustr::Ustr;
23
24use crate::{
25 common::enums::OKXInstrumentType,
26 websocket::{
27 enums::OKXWsChannel,
28 messages::{OKXSubscriptionArg, OKXWebSocketArg},
29 },
30};
31
32fn topic_from_parts(
33 channel: &OKXWsChannel,
34 inst_id: Option<&Ustr>,
35 inst_family: Option<&Ustr>,
36 inst_type: Option<&OKXInstrumentType>,
37 bar: Option<&Ustr>,
38) -> String {
39 let base = channel.as_ref();
40
41 if let Some(inst_id) = inst_id {
42 let inst_id = inst_id.as_str();
43 if let Some(bar) = bar {
44 format!("{base}:{inst_id}:{}", bar.as_str())
45 } else {
46 format!("{base}:{inst_id}")
47 }
48 } else if let Some(inst_family) = inst_family {
49 format!("{base}:{}", inst_family.as_str())
50 } else if let Some(inst_type) = inst_type {
51 format!("{base}:{}", inst_type.as_ref())
52 } else {
53 base.to_string()
54 }
55}
56
57pub(crate) fn topic_from_subscription_arg(arg: &OKXSubscriptionArg) -> String {
58 topic_from_parts(
59 &arg.channel,
60 arg.inst_id.as_ref(),
61 arg.inst_family.as_ref(),
62 arg.inst_type.as_ref(),
63 None,
64 )
65}
66
67pub(crate) fn topic_from_websocket_arg(arg: &OKXWebSocketArg) -> String {
68 topic_from_parts(
69 &arg.channel,
70 arg.inst_id.as_ref(),
71 arg.inst_family.as_ref(),
72 arg.inst_type.as_ref(),
73 arg.bar.as_ref(),
74 )
75}
76
77pub(crate) fn split_topic(topic: &str) -> (&str, Option<&str>) {
78 topic
79 .split_once(':')
80 .map_or((topic, None), |(channel, symbol)| (channel, Some(symbol)))
81}
82
83pub(crate) fn track_topic(
84 map: &DashMap<String, AHashSet<Ustr>>,
85 channel: &str,
86 symbol: Option<&str>,
87) {
88 if let Some(symbol) = symbol {
89 let mut entry = map.entry(channel.to_string()).or_default();
90 entry.insert(Ustr::from(symbol));
91 } else {
92 map.entry(channel.to_string()).or_default();
93 }
94}
95
96pub(crate) fn untrack_topic(
97 map: &DashMap<String, AHashSet<Ustr>>,
98 channel: &str,
99 symbol: Option<&str>,
100) {
101 if let Some(symbol) = symbol {
102 let symbol_ustr = Ustr::from(symbol);
103 let mut remove_channel = false;
104 if let Some(mut entry) = map.get_mut(channel) {
105 entry.remove(&symbol_ustr);
106 remove_channel = entry.is_empty();
107 }
108 if remove_channel {
109 map.remove(channel);
110 }
111 } else {
112 map.remove(channel);
113 }
114}
115
116#[derive(Clone, Debug)]
117pub(crate) struct SubscriptionState {
118 confirmed: Arc<DashMap<String, AHashSet<Ustr>>>,
119 pending: Arc<DashMap<String, AHashSet<Ustr>>>,
120}
121
122impl SubscriptionState {
123 pub(crate) fn new() -> Self {
124 Self {
125 confirmed: Arc::new(DashMap::new()),
126 pending: Arc::new(DashMap::new()),
127 }
128 }
129
130 pub(crate) fn confirmed(&self) -> Arc<DashMap<String, AHashSet<Ustr>>> {
131 Arc::clone(&self.confirmed)
132 }
133
134 pub(crate) fn pending(&self) -> Arc<DashMap<String, AHashSet<Ustr>>> {
135 Arc::clone(&self.pending)
136 }
137
138 pub(crate) fn len(&self) -> usize {
139 self.confirmed.len()
140 }
141
142 pub(crate) fn mark_subscribe(&self, topic: &str) {
143 let (channel, symbol) = split_topic(topic);
144 track_topic(&self.pending, channel, symbol);
145 }
146
147 pub(crate) fn mark_unsubscribe(&self, topic: &str) {
148 let (channel, symbol) = split_topic(topic);
149 track_topic(&self.pending, channel, symbol);
150 untrack_topic(&self.confirmed, channel, symbol);
151 }
152
153 pub(crate) fn confirm(&self, topic: &str) {
154 let (channel, symbol) = split_topic(topic);
155 untrack_topic(&self.pending, channel, symbol);
156 track_topic(&self.confirmed, channel, symbol);
157 }
158
159 pub(crate) fn mark_failure(&self, topic: &str) {
160 let (channel, symbol) = split_topic(topic);
161 untrack_topic(&self.confirmed, channel, symbol);
162 track_topic(&self.pending, channel, symbol);
163 }
164
165 pub(crate) fn clear_pending(&self, topic: &str) {
166 let (channel, symbol) = split_topic(topic);
167 untrack_topic(&self.pending, channel, symbol);
168 }
169}