1use std::{
19 fmt::Debug,
20 sync::{
21 Arc,
22 atomic::{AtomicBool, AtomicI64, AtomicU8, Ordering},
23 },
24 time::Duration,
25};
26
27use arc_swap::ArcSwap;
28use dashmap::DashMap;
29use nautilus_common::live::get_runtime;
30use nautilus_core::consts::NAUTILUS_USER_AGENT;
31use nautilus_model::instruments::{Instrument, InstrumentAny};
32use nautilus_network::{
33 backoff::ExponentialBackoff,
34 mode::ConnectionMode,
35 websocket::{
36 PingHandler, SubscriptionState, WebSocketClient, WebSocketConfig, channel_message_handler,
37 },
38};
39use ustr::Ustr;
40
41use super::handler::{FeedHandler, HandlerCommand};
42use crate::{
43 common::enums::{AxCandleWidth, AxMarketDataLevel},
44 websocket::messages::NautilusDataWsMessage,
45};
46
47const DEFAULT_HEARTBEAT_SECS: u64 = 30;
49
50const AX_TOPIC_DELIMITER: char = ':';
52
53pub type AxWsResult<T> = Result<T, AxWsClientError>;
55
56#[derive(Debug, Clone)]
58pub enum AxWsClientError {
59 Transport(String),
61 ChannelError(String),
63}
64
65impl core::fmt::Display for AxWsClientError {
66 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
67 match self {
68 Self::Transport(msg) => write!(f, "Transport error: {msg}"),
69 Self::ChannelError(msg) => write!(f, "Channel error: {msg}"),
70 }
71 }
72}
73
74impl std::error::Error for AxWsClientError {}
75
76#[derive(Debug, Default, Clone)]
77pub(crate) struct SymbolDataTypes {
78 pub(crate) quotes: bool,
79 pub(crate) trades: bool,
80 pub(crate) book_level: Option<AxMarketDataLevel>,
81}
82
83impl SymbolDataTypes {
84 pub(crate) fn effective_level(&self) -> Option<AxMarketDataLevel> {
85 if let Some(level) = self.book_level {
86 return Some(level);
87 }
88 if self.quotes || self.trades {
89 return Some(AxMarketDataLevel::Level1);
90 }
91 None
92 }
93
94 fn is_empty(&self) -> bool {
95 !self.quotes && !self.trades && self.book_level.is_none()
96 }
97}
98
99#[cfg_attr(
104 feature = "python",
105 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.architect")
106)]
107pub struct AxMdWebSocketClient {
108 url: String,
109 heartbeat: Option<u64>,
110 auth_token: Option<String>,
111 connection_mode: Arc<ArcSwap<AtomicU8>>,
112 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
113 out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusDataWsMessage>>>,
114 signal: Arc<AtomicBool>,
115 task_handle: Option<tokio::task::JoinHandle<()>>,
116 subscriptions: SubscriptionState,
117 instruments_cache: Arc<DashMap<Ustr, InstrumentAny>>,
118 request_id_counter: Arc<AtomicI64>,
119 subscribe_lock: Arc<tokio::sync::Mutex<()>>,
120 symbol_data_types: Arc<DashMap<String, SymbolDataTypes>>,
121}
122
123impl Debug for AxMdWebSocketClient {
124 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
125 f.debug_struct(stringify!(AxMdWebSocketClient))
126 .field("url", &self.url)
127 .field("heartbeat", &self.heartbeat)
128 .field("confirmed_subscriptions", &self.subscriptions.len())
129 .finish()
130 }
131}
132
133impl Clone for AxMdWebSocketClient {
134 fn clone(&self) -> Self {
135 Self {
136 url: self.url.clone(),
137 heartbeat: self.heartbeat,
138 auth_token: self.auth_token.clone(),
139 connection_mode: Arc::clone(&self.connection_mode),
140 cmd_tx: Arc::clone(&self.cmd_tx),
141 out_rx: None,
142 signal: Arc::clone(&self.signal),
143 task_handle: None,
144 subscriptions: self.subscriptions.clone(),
145 subscribe_lock: Arc::clone(&self.subscribe_lock),
146 instruments_cache: Arc::clone(&self.instruments_cache),
147 request_id_counter: Arc::clone(&self.request_id_counter),
148 symbol_data_types: Arc::clone(&self.symbol_data_types),
149 }
150 }
151}
152
153impl AxMdWebSocketClient {
154 #[must_use]
158 pub fn new(url: String, auth_token: String, heartbeat: Option<u64>) -> Self {
159 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
160
161 let initial_mode = AtomicU8::new(ConnectionMode::Closed.as_u8());
162 let connection_mode = Arc::new(ArcSwap::from_pointee(initial_mode));
163
164 Self {
165 url,
166 heartbeat: heartbeat.or(Some(DEFAULT_HEARTBEAT_SECS)),
167 auth_token: Some(auth_token),
168 connection_mode,
169 cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
170 out_rx: None,
171 signal: Arc::new(AtomicBool::new(false)),
172 task_handle: None,
173 subscriptions: SubscriptionState::new(AX_TOPIC_DELIMITER),
174 instruments_cache: Arc::new(DashMap::new()),
175 request_id_counter: Arc::new(AtomicI64::new(1)),
176 subscribe_lock: Arc::new(tokio::sync::Mutex::new(())),
177 symbol_data_types: Arc::new(DashMap::new()),
178 }
179 }
180
181 #[must_use]
185 pub fn without_auth(url: String, heartbeat: Option<u64>) -> Self {
186 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
187
188 let initial_mode = AtomicU8::new(ConnectionMode::Closed.as_u8());
189 let connection_mode = Arc::new(ArcSwap::from_pointee(initial_mode));
190
191 Self {
192 url,
193 heartbeat: heartbeat.or(Some(DEFAULT_HEARTBEAT_SECS)),
194 auth_token: None,
195 connection_mode,
196 cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
197 out_rx: None,
198 signal: Arc::new(AtomicBool::new(false)),
199 task_handle: None,
200 subscriptions: SubscriptionState::new(AX_TOPIC_DELIMITER),
201 instruments_cache: Arc::new(DashMap::new()),
202 request_id_counter: Arc::new(AtomicI64::new(1)),
203 subscribe_lock: Arc::new(tokio::sync::Mutex::new(())),
204 symbol_data_types: Arc::new(DashMap::new()),
205 }
206 }
207
208 #[must_use]
210 pub fn url(&self) -> &str {
211 &self.url
212 }
213
214 pub fn set_auth_token(&mut self, token: String) {
218 self.auth_token = Some(token);
219 }
220
221 #[must_use]
223 pub fn is_active(&self) -> bool {
224 let connection_mode_arc = self.connection_mode.load();
225 ConnectionMode::from_atomic(&connection_mode_arc).is_active()
226 && !self.signal.load(Ordering::Acquire)
227 }
228
229 #[must_use]
231 pub fn is_closed(&self) -> bool {
232 let connection_mode_arc = self.connection_mode.load();
233 ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
234 || self.signal.load(Ordering::Acquire)
235 }
236
237 #[must_use]
239 pub fn subscription_count(&self) -> usize {
240 self.subscriptions.len()
241 }
242
243 fn next_request_id(&self) -> i64 {
244 self.request_id_counter.fetch_add(1, Ordering::Relaxed)
245 }
246
247 fn is_subscribed_topic(&self, topic: &str) -> bool {
248 let (channel, symbol) = topic
249 .split_once(AX_TOPIC_DELIMITER)
250 .map_or((topic, None), |(c, s)| (c, Some(s)));
251 let channel_ustr = Ustr::from(channel);
252 let symbol_ustr = symbol.map_or_else(|| Ustr::from(""), Ustr::from);
253 self.subscriptions
254 .is_subscribed(&channel_ustr, &symbol_ustr)
255 }
256
257 pub fn cache_instrument(&self, instrument: InstrumentAny) {
259 let symbol = instrument.symbol().inner();
260 self.instruments_cache.insert(symbol, instrument.clone());
261
262 if self.is_active() {
263 let cmd = HandlerCommand::UpdateInstrument(Box::new(instrument));
264 let cmd_tx = self.cmd_tx.clone();
265 get_runtime().spawn(async move {
266 let guard = cmd_tx.read().await;
267 let _ = guard.send(cmd);
268 });
269 }
270 }
271
272 #[must_use]
274 pub fn get_cached_instrument(&self, symbol: &Ustr) -> Option<InstrumentAny> {
275 self.instruments_cache.get(symbol).map(|r| r.clone())
276 }
277
278 pub async fn connect(&mut self) -> AxWsResult<()> {
284 const MAX_RETRIES: u32 = 5;
285 const CONNECTION_TIMEOUT_SECS: u64 = 10;
286
287 self.signal.store(false, Ordering::Release);
288
289 let (raw_handler, raw_rx) = channel_message_handler();
290
291 let ping_handler: PingHandler = Arc::new(move |_payload: Vec<u8>| {});
293
294 let mut headers = vec![("User-Agent".to_string(), NAUTILUS_USER_AGENT.to_string())];
295 if let Some(ref token) = self.auth_token {
296 headers.push(("Authorization".to_string(), format!("Bearer {token}")));
297 }
298
299 let config = WebSocketConfig {
300 url: self.url.clone(),
301 headers,
302 heartbeat: self.heartbeat,
303 heartbeat_msg: None, reconnect_timeout_ms: Some(5_000),
305 reconnect_delay_initial_ms: Some(500),
306 reconnect_delay_max_ms: Some(5_000),
307 reconnect_backoff_factor: Some(1.5),
308 reconnect_jitter_ms: Some(250),
309 reconnect_max_attempts: None,
310 };
311
312 let mut backoff = ExponentialBackoff::new(
314 Duration::from_millis(500),
315 Duration::from_millis(5000),
316 2.0,
317 250,
318 false,
319 )
320 .map_err(|e| AxWsClientError::Transport(e.to_string()))?;
321
322 let mut last_error: String;
323 let mut attempt = 0;
324
325 let client = loop {
326 attempt += 1;
327
328 match tokio::time::timeout(
329 Duration::from_secs(CONNECTION_TIMEOUT_SECS),
330 WebSocketClient::connect(
331 config.clone(),
332 Some(raw_handler.clone()),
333 Some(ping_handler.clone()),
334 None,
335 vec![],
336 None,
337 ),
338 )
339 .await
340 {
341 Ok(Ok(client)) => {
342 if attempt > 1 {
343 log::info!("WebSocket connection established after {attempt} attempts");
344 }
345 break client;
346 }
347 Ok(Err(e)) => {
348 last_error = e.to_string();
349 log::warn!(
350 "WebSocket connection attempt failed: attempt={attempt}/{MAX_RETRIES}, url={}, error={last_error}",
351 self.url
352 );
353 }
354 Err(_) => {
355 last_error = format!("Connection timeout after {CONNECTION_TIMEOUT_SECS}s");
356 log::warn!(
357 "WebSocket connection attempt timed out: attempt={attempt}/{MAX_RETRIES}, url={}",
358 self.url
359 );
360 }
361 }
362
363 if attempt >= MAX_RETRIES {
364 return Err(AxWsClientError::Transport(format!(
365 "Failed to connect to {} after {MAX_RETRIES} attempts: {}",
366 self.url,
367 if last_error.is_empty() {
368 "unknown error"
369 } else {
370 &last_error
371 }
372 )));
373 }
374
375 let delay = backoff.next_duration();
376 log::debug!(
377 "Retrying in {delay:?} (attempt {}/{MAX_RETRIES})",
378 attempt + 1
379 );
380 tokio::time::sleep(delay).await;
381 };
382
383 self.connection_mode.store(client.connection_mode_atomic());
384
385 let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<NautilusDataWsMessage>();
386 self.out_rx = Some(Arc::new(out_rx));
387
388 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
389 *self.cmd_tx.write().await = cmd_tx.clone();
390
391 self.send_cmd(HandlerCommand::SetClient(client)).await?;
392
393 if !self.instruments_cache.is_empty() {
394 let cached_instruments: Vec<InstrumentAny> = self
395 .instruments_cache
396 .iter()
397 .map(|entry| entry.value().clone())
398 .collect();
399 self.send_cmd(HandlerCommand::InitializeInstruments(cached_instruments))
400 .await?;
401 }
402
403 let signal = Arc::clone(&self.signal);
404 let subscriptions = self.subscriptions.clone();
405 let symbol_data_types = Arc::clone(&self.symbol_data_types);
406
407 let stream_handle = get_runtime().spawn(async move {
408 let mut handler = FeedHandler::new(
409 signal.clone(),
410 cmd_rx,
411 raw_rx,
412 out_tx.clone(),
413 subscriptions.clone(),
414 symbol_data_types,
415 );
416
417 while let Some(msg) = handler.next().await {
418 if matches!(msg, NautilusDataWsMessage::Reconnected) {
419 log::info!("WebSocket reconnected, subscriptions will be replayed");
420 }
421
422 if out_tx.send(msg).is_err() {
423 log::debug!("Output channel closed");
424 break;
425 }
426 }
427
428 log::debug!("Handler loop exited");
429 });
430
431 self.task_handle = Some(stream_handle);
432
433 Ok(())
434 }
435
436 pub async fn subscribe_book_deltas(
445 &self,
446 symbol: &str,
447 level: AxMarketDataLevel,
448 ) -> AxWsResult<()> {
449 let _guard = self.subscribe_lock.lock().await;
450
451 let entry = self
452 .symbol_data_types
453 .entry(symbol.to_string())
454 .or_default();
455
456 if entry.book_level.is_some() {
458 log::debug!("Book deltas already subscribed for {symbol}, skipping");
459 return Ok(());
460 }
461
462 let old_level = entry.effective_level();
463 let mut next = entry.clone();
464 next.book_level = Some(level);
465 let new_level = next.effective_level();
466 drop(entry);
467
468 self.update_data_subscription(symbol, old_level, new_level)
469 .await?;
470
471 self.symbol_data_types
472 .entry(symbol.to_string())
473 .or_default()
474 .book_level = Some(level);
475
476 Ok(())
477 }
478
479 pub async fn subscribe_quotes(&self, symbol: &str) -> AxWsResult<()> {
488 let _guard = self.subscribe_lock.lock().await;
489
490 let entry = self
491 .symbol_data_types
492 .entry(symbol.to_string())
493 .or_default();
494 let old_level = entry.effective_level();
495 let mut next = entry.clone();
496 next.quotes = true;
497 let new_level = next.effective_level();
498 drop(entry);
499
500 self.update_data_subscription(symbol, old_level, new_level)
501 .await?;
502
503 self.symbol_data_types
504 .entry(symbol.to_string())
505 .or_default()
506 .quotes = true;
507
508 Ok(())
509 }
510
511 pub async fn subscribe_trades(&self, symbol: &str) -> AxWsResult<()> {
520 let _guard = self.subscribe_lock.lock().await;
521
522 let entry = self
523 .symbol_data_types
524 .entry(symbol.to_string())
525 .or_default();
526 let old_level = entry.effective_level();
527 let mut next = entry.clone();
528 next.trades = true;
529 let new_level = next.effective_level();
530 drop(entry);
531
532 self.update_data_subscription(symbol, old_level, new_level)
533 .await?;
534
535 self.symbol_data_types
536 .entry(symbol.to_string())
537 .or_default()
538 .trades = true;
539
540 Ok(())
541 }
542
543 pub async fn unsubscribe_book_deltas(&self, symbol: &str) -> AxWsResult<()> {
552 let _guard = self.subscribe_lock.lock().await;
553
554 let Some(entry) = self.symbol_data_types.get(symbol) else {
555 log::debug!("Symbol {symbol} not subscribed, skipping unsubscribe book deltas");
556 return Ok(());
557 };
558 let old_level = entry.effective_level();
559 let mut next = entry.clone();
560 next.book_level = None;
561 let new_level = next.effective_level();
562 drop(entry);
563
564 self.update_data_subscription(symbol, old_level, new_level)
565 .await?;
566
567 if let Some(mut entry) = self.symbol_data_types.get_mut(symbol) {
568 entry.book_level = None;
569 if entry.is_empty() {
570 drop(entry);
571 self.symbol_data_types.remove(symbol);
572 }
573 }
574
575 Ok(())
576 }
577
578 pub async fn unsubscribe_quotes(&self, symbol: &str) -> AxWsResult<()> {
587 let _guard = self.subscribe_lock.lock().await;
588
589 let Some(entry) = self.symbol_data_types.get(symbol) else {
590 log::debug!("Symbol {symbol} not subscribed, skipping unsubscribe quotes");
591 return Ok(());
592 };
593 let old_level = entry.effective_level();
594 let mut next = entry.clone();
595 next.quotes = false;
596 let new_level = next.effective_level();
597 drop(entry);
598
599 self.update_data_subscription(symbol, old_level, new_level)
600 .await?;
601
602 if let Some(mut entry) = self.symbol_data_types.get_mut(symbol) {
603 entry.quotes = false;
604 if entry.is_empty() {
605 drop(entry);
606 self.symbol_data_types.remove(symbol);
607 }
608 }
609
610 Ok(())
611 }
612
613 pub async fn unsubscribe_trades(&self, symbol: &str) -> AxWsResult<()> {
622 let _guard = self.subscribe_lock.lock().await;
623
624 let Some(entry) = self.symbol_data_types.get(symbol) else {
625 log::debug!("Symbol {symbol} not subscribed, skipping unsubscribe trades");
626 return Ok(());
627 };
628 let old_level = entry.effective_level();
629 let mut next = entry.clone();
630 next.trades = false;
631 let new_level = next.effective_level();
632 drop(entry);
633
634 self.update_data_subscription(symbol, old_level, new_level)
635 .await?;
636
637 if let Some(mut entry) = self.symbol_data_types.get_mut(symbol) {
638 entry.trades = false;
639 if entry.is_empty() {
640 drop(entry);
641 self.symbol_data_types.remove(symbol);
642 }
643 }
644
645 Ok(())
646 }
647
648 async fn update_data_subscription(
649 &self,
650 symbol: &str,
651 old_level: Option<AxMarketDataLevel>,
652 new_level: Option<AxMarketDataLevel>,
653 ) -> AxWsResult<()> {
654 if old_level == new_level {
655 return Ok(());
656 }
657
658 match (old_level, new_level) {
659 (None, Some(level)) => {
660 log::debug!("Subscribing {symbol} at {level:?}");
661 self.send_subscribe(symbol, level).await
662 }
663 (Some(_), None) => {
664 log::debug!("Unsubscribing {symbol} (no remaining data types)");
665 self.send_unsubscribe(symbol).await
666 }
667 (Some(old), Some(new)) => {
668 log::debug!("Resubscribing {symbol}: {old:?} -> {new:?}");
669 self.send_unsubscribe(symbol).await?;
670 if let Err(e) = self.send_subscribe(symbol, new).await {
671 log::warn!("Resubscribe failed for {symbol} at {new:?}: {e}");
672 if let Err(restore_err) = self.send_subscribe(symbol, old).await {
673 log::error!(
675 "Failed to restore {symbol} at {old:?}: {restore_err}, \
676 reconnection required"
677 );
678 let old_topic = format!("{symbol}:{old:?}");
679 self.subscriptions.mark_subscribe(&old_topic);
680 }
681 return Err(e);
682 }
683 Ok(())
684 }
685 (None, None) => Ok(()),
686 }
687 }
688
689 async fn send_subscribe(&self, symbol: &str, level: AxMarketDataLevel) -> AxWsResult<()> {
690 let topic = format!("{symbol}:{level:?}");
691 let request_id = self.next_request_id();
692
693 self.subscriptions.mark_subscribe(&topic);
694
695 if let Err(e) = self
696 .send_cmd(HandlerCommand::Subscribe {
697 request_id,
698 symbol: Ustr::from(symbol),
699 level,
700 })
701 .await
702 {
703 self.subscriptions.mark_unsubscribe(&topic);
704 return Err(e);
705 }
706
707 Ok(())
708 }
709
710 async fn send_unsubscribe(&self, symbol: &str) -> AxWsResult<()> {
711 let request_id = self.next_request_id();
712
713 self.send_cmd(HandlerCommand::Unsubscribe {
714 request_id,
715 symbol: Ustr::from(symbol),
716 })
717 .await?;
718
719 for level in [
720 AxMarketDataLevel::Level1,
721 AxMarketDataLevel::Level2,
722 AxMarketDataLevel::Level3,
723 ] {
724 let topic = format!("{symbol}:{level:?}");
725 self.subscriptions.mark_unsubscribe(&topic);
726 }
727
728 Ok(())
729 }
730
731 pub async fn subscribe_candles(&self, symbol: &str, width: AxCandleWidth) -> AxWsResult<()> {
739 let _guard = self.subscribe_lock.lock().await;
740 let topic = format!("candles:{symbol}:{width:?}");
741
742 if self.is_subscribed_topic(&topic) {
744 log::debug!("Already subscribed to {topic}, skipping");
745 return Ok(());
746 }
747
748 let request_id = self.next_request_id();
749
750 self.subscriptions.mark_subscribe(&topic);
752
753 if let Err(e) = self
754 .send_cmd(HandlerCommand::SubscribeCandles {
755 request_id,
756 symbol: Ustr::from(symbol),
757 width,
758 })
759 .await
760 {
761 self.subscriptions.mark_unsubscribe(&topic);
763 return Err(e);
764 }
765
766 Ok(())
767 }
768
769 pub async fn unsubscribe_candles(&self, symbol: &str, width: AxCandleWidth) -> AxWsResult<()> {
775 let request_id = self.next_request_id();
776 let topic = format!("candles:{symbol}:{width:?}");
777
778 self.subscriptions.mark_unsubscribe(&topic);
779
780 self.send_cmd(HandlerCommand::UnsubscribeCandles {
781 request_id,
782 symbol: Ustr::from(symbol),
783 width,
784 })
785 .await
786 }
787
788 pub fn stream(&mut self) -> impl futures_util::Stream<Item = NautilusDataWsMessage> + 'static {
794 let rx = self
795 .out_rx
796 .take()
797 .expect("Stream receiver already taken or client not connected - stream() can only be called once");
798 let mut rx = Arc::try_unwrap(rx).expect(
799 "Cannot take ownership of stream - client was cloned and other references exist",
800 );
801 async_stream::stream! {
802 while let Some(msg) = rx.recv().await {
803 yield msg;
804 }
805 }
806 }
807
808 pub async fn disconnect(&self) {
810 log::debug!("Disconnecting WebSocket");
811 let _ = self.send_cmd(HandlerCommand::Disconnect).await;
812 }
813
814 pub async fn close(&mut self) {
816 log::debug!("Closing WebSocket client");
817
818 let _ = self.send_cmd(HandlerCommand::Disconnect).await;
820 tokio::time::sleep(Duration::from_millis(50)).await;
821 self.signal.store(true, Ordering::Release);
822
823 if let Some(handle) = self.task_handle.take() {
824 const CLOSE_TIMEOUT: Duration = Duration::from_secs(2);
825 let abort_handle = handle.abort_handle();
826
827 match tokio::time::timeout(CLOSE_TIMEOUT, handle).await {
828 Ok(Ok(())) => log::debug!("Handler task completed gracefully"),
829 Ok(Err(e)) => log::warn!("Handler task panicked: {e}"),
830 Err(_) => {
831 log::warn!("Handler task did not complete within timeout, aborting");
832 abort_handle.abort();
833 }
834 }
835 }
836 }
837
838 async fn send_cmd(&self, cmd: HandlerCommand) -> AxWsResult<()> {
839 let guard = self.cmd_tx.read().await;
840 guard
841 .send(cmd)
842 .map_err(|e| AxWsClientError::ChannelError(e.to_string()))
843 }
844}