nautilus_architect_ax/websocket/data/
client.rs1use std::{
19 fmt::Debug,
20 sync::{
21 Arc,
22 atomic::{AtomicBool, AtomicI64, AtomicU8, Ordering},
23 },
24 time::Duration,
25};
26
27use arc_swap::ArcSwap;
28use dashmap::DashMap;
29use nautilus_common::live::get_runtime;
30use nautilus_core::consts::NAUTILUS_USER_AGENT;
31use nautilus_model::instruments::{Instrument, InstrumentAny};
32use nautilus_network::{
33 backoff::ExponentialBackoff,
34 mode::ConnectionMode,
35 websocket::{
36 PingHandler, SubscriptionState, WebSocketClient, WebSocketConfig, channel_message_handler,
37 },
38};
39use ustr::Ustr;
40
41use super::handler::{FeedHandler, HandlerCommand};
42use crate::{
43 common::enums::{AxCandleWidth, AxMarketDataLevel},
44 websocket::messages::NautilusWsMessage,
45};
46
47const DEFAULT_HEARTBEAT_SECS: u64 = 30;
49
50const AX_TOPIC_DELIMITER: char = ':';
52
53pub type AxWsResult<T> = Result<T, AxWsClientError>;
55
56#[derive(Debug, Clone)]
58pub enum AxWsClientError {
59 Transport(String),
61 ChannelError(String),
63}
64
65impl core::fmt::Display for AxWsClientError {
66 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
67 match self {
68 Self::Transport(msg) => write!(f, "Transport error: {msg}"),
69 Self::ChannelError(msg) => write!(f, "Channel error: {msg}"),
70 }
71 }
72}
73
74impl std::error::Error for AxWsClientError {}
75
76#[cfg_attr(
81 feature = "python",
82 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.architect")
83)]
84pub struct AxMdWebSocketClient {
85 url: String,
86 heartbeat: Option<u64>,
87 auth_token: Option<String>,
88 connection_mode: Arc<ArcSwap<AtomicU8>>,
89 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
90 out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>,
91 signal: Arc<AtomicBool>,
92 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
93 subscriptions: SubscriptionState,
94 instruments_cache: Arc<DashMap<Ustr, InstrumentAny>>,
95 request_id_counter: Arc<AtomicI64>,
96}
97
98impl Debug for AxMdWebSocketClient {
99 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
100 f.debug_struct(stringify!(AxMdWebSocketClient))
101 .field("url", &self.url)
102 .field("heartbeat", &self.heartbeat)
103 .field("confirmed_subscriptions", &self.subscriptions.len())
104 .finish()
105 }
106}
107
108impl Clone for AxMdWebSocketClient {
109 fn clone(&self) -> Self {
110 Self {
111 url: self.url.clone(),
112 heartbeat: self.heartbeat,
113 auth_token: self.auth_token.clone(),
114 connection_mode: Arc::clone(&self.connection_mode),
115 cmd_tx: Arc::clone(&self.cmd_tx),
116 out_rx: None, signal: Arc::clone(&self.signal),
118 task_handle: None, subscriptions: self.subscriptions.clone(),
120 instruments_cache: Arc::clone(&self.instruments_cache),
121 request_id_counter: Arc::clone(&self.request_id_counter),
122 }
123 }
124}
125
126impl AxMdWebSocketClient {
127 #[must_use]
131 pub fn new(url: String, auth_token: String, heartbeat: Option<u64>) -> Self {
132 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
133
134 let initial_mode = AtomicU8::new(ConnectionMode::Closed.as_u8());
135 let connection_mode = Arc::new(ArcSwap::from_pointee(initial_mode));
136
137 Self {
138 url,
139 heartbeat: heartbeat.or(Some(DEFAULT_HEARTBEAT_SECS)),
140 auth_token: Some(auth_token),
141 connection_mode,
142 cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
143 out_rx: None,
144 signal: Arc::new(AtomicBool::new(false)),
145 task_handle: None,
146 subscriptions: SubscriptionState::new(AX_TOPIC_DELIMITER),
147 instruments_cache: Arc::new(DashMap::new()),
148 request_id_counter: Arc::new(AtomicI64::new(1)),
149 }
150 }
151
152 #[must_use]
154 pub fn url(&self) -> &str {
155 &self.url
156 }
157
158 pub fn set_auth_token(&mut self, token: String) {
162 self.auth_token = Some(token);
163 }
164
165 #[must_use]
167 pub fn is_active(&self) -> bool {
168 let connection_mode_arc = self.connection_mode.load();
169 ConnectionMode::from_atomic(&connection_mode_arc).is_active()
170 && !self.signal.load(Ordering::Acquire)
171 }
172
173 #[must_use]
175 pub fn is_closed(&self) -> bool {
176 let connection_mode_arc = self.connection_mode.load();
177 ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
178 || self.signal.load(Ordering::Acquire)
179 }
180
181 #[must_use]
183 pub fn subscription_count(&self) -> usize {
184 self.subscriptions.len()
185 }
186
187 fn next_request_id(&self) -> i64 {
189 self.request_id_counter.fetch_add(1, Ordering::Relaxed)
190 }
191
192 pub fn cache_instrument(&self, instrument: InstrumentAny) {
194 let symbol = instrument.symbol().inner();
195 self.instruments_cache.insert(symbol, instrument.clone());
196
197 if self.is_active() {
198 let cmd = HandlerCommand::UpdateInstrument(Box::new(instrument));
199 let cmd_tx = self.cmd_tx.clone();
200 get_runtime().spawn(async move {
201 let guard = cmd_tx.read().await;
202 let _ = guard.send(cmd);
203 });
204 }
205 }
206
207 #[must_use]
209 pub fn get_cached_instrument(&self, symbol: &Ustr) -> Option<InstrumentAny> {
210 self.instruments_cache.get(symbol).map(|r| r.clone())
211 }
212
213 pub async fn connect(&mut self) -> AxWsResult<()> {
219 const MAX_RETRIES: u32 = 5;
220 const CONNECTION_TIMEOUT_SECS: u64 = 10;
221
222 self.signal.store(false, Ordering::Relaxed);
223
224 let (raw_handler, raw_rx) = channel_message_handler();
225
226 let ping_handler: PingHandler = Arc::new(move |_payload: Vec<u8>| {});
228
229 let mut headers = vec![("User-Agent".to_string(), NAUTILUS_USER_AGENT.to_string())];
230 if let Some(ref token) = self.auth_token {
231 headers.push(("Authorization".to_string(), format!("Bearer {token}")));
232 }
233
234 let config = WebSocketConfig {
235 url: self.url.clone(),
236 headers,
237 heartbeat: self.heartbeat,
238 heartbeat_msg: None, reconnect_timeout_ms: Some(5_000),
240 reconnect_delay_initial_ms: Some(500),
241 reconnect_delay_max_ms: Some(5_000),
242 reconnect_backoff_factor: Some(1.5),
243 reconnect_jitter_ms: Some(250),
244 reconnect_max_attempts: None,
245 };
246
247 let mut backoff = ExponentialBackoff::new(
249 Duration::from_millis(500),
250 Duration::from_millis(5000),
251 2.0,
252 250,
253 false,
254 )
255 .map_err(|e| AxWsClientError::Transport(e.to_string()))?;
256
257 let mut last_error: String;
258 let mut attempt = 0;
259
260 let client = loop {
261 attempt += 1;
262
263 match tokio::time::timeout(
264 Duration::from_secs(CONNECTION_TIMEOUT_SECS),
265 WebSocketClient::connect(
266 config.clone(),
267 Some(raw_handler.clone()),
268 Some(ping_handler.clone()),
269 None,
270 vec![],
271 None,
272 ),
273 )
274 .await
275 {
276 Ok(Ok(client)) => {
277 if attempt > 1 {
278 log::info!("WebSocket connection established after {attempt} attempts");
279 }
280 break client;
281 }
282 Ok(Err(e)) => {
283 last_error = e.to_string();
284 log::warn!(
285 "WebSocket connection attempt failed: attempt={attempt}/{MAX_RETRIES}, url={}, error={last_error}",
286 self.url
287 );
288 }
289 Err(_) => {
290 last_error = format!("Connection timeout after {CONNECTION_TIMEOUT_SECS}s");
291 log::warn!(
292 "WebSocket connection attempt timed out: attempt={attempt}/{MAX_RETRIES}, url={}",
293 self.url
294 );
295 }
296 }
297
298 if attempt >= MAX_RETRIES {
299 return Err(AxWsClientError::Transport(format!(
300 "Failed to connect to {} after {MAX_RETRIES} attempts: {}",
301 self.url,
302 if last_error.is_empty() {
303 "unknown error"
304 } else {
305 &last_error
306 }
307 )));
308 }
309
310 let delay = backoff.next_duration();
311 log::debug!(
312 "Retrying in {delay:?} (attempt {}/{MAX_RETRIES})",
313 attempt + 1
314 );
315 tokio::time::sleep(delay).await;
316 };
317
318 self.connection_mode.store(client.connection_mode_atomic());
319
320 let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<NautilusWsMessage>();
321 self.out_rx = Some(Arc::new(out_rx));
322
323 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
324 *self.cmd_tx.write().await = cmd_tx.clone();
325
326 self.send_cmd(HandlerCommand::SetClient(client)).await?;
327
328 if !self.instruments_cache.is_empty() {
329 let cached_instruments: Vec<InstrumentAny> = self
330 .instruments_cache
331 .iter()
332 .map(|entry| entry.value().clone())
333 .collect();
334 self.send_cmd(HandlerCommand::InitializeInstruments(cached_instruments))
335 .await?;
336 }
337
338 let signal = Arc::clone(&self.signal);
339 let subscriptions = self.subscriptions.clone();
340
341 let stream_handle = get_runtime().spawn(async move {
342 let mut handler = FeedHandler::new(
343 signal.clone(),
344 cmd_rx,
345 raw_rx,
346 out_tx.clone(),
347 subscriptions.clone(),
348 );
349
350 while let Some(msg) = handler.next().await {
351 if matches!(msg, NautilusWsMessage::Reconnected) {
352 log::info!("WebSocket reconnected, resubscribing...");
353 }
355
356 if out_tx.send(msg).is_err() {
357 log::debug!("Output channel closed");
358 break;
359 }
360 }
361
362 log::debug!("Handler loop exited");
363 });
364
365 self.task_handle = Some(Arc::new(stream_handle));
366
367 Ok(())
368 }
369
370 pub async fn subscribe(&self, symbol: &str, level: AxMarketDataLevel) -> AxWsResult<()> {
376 let request_id = self.next_request_id();
377 let topic = format!("{symbol}:{level:?}");
378
379 self.subscriptions.mark_subscribe(&topic);
380
381 self.send_cmd(HandlerCommand::Subscribe {
382 request_id,
383 symbol: symbol.to_string(),
384 level,
385 })
386 .await
387 }
388
389 pub async fn unsubscribe(&self, symbol: &str) -> AxWsResult<()> {
395 let request_id = self.next_request_id();
396
397 for level in [
398 AxMarketDataLevel::Level1,
399 AxMarketDataLevel::Level2,
400 AxMarketDataLevel::Level3,
401 ] {
402 let topic = format!("{symbol}:{level:?}");
403 self.subscriptions.mark_unsubscribe(&topic);
404 }
405
406 self.send_cmd(HandlerCommand::Unsubscribe {
407 request_id,
408 symbol: symbol.to_string(),
409 })
410 .await
411 }
412
413 pub async fn subscribe_candles(&self, symbol: &str, width: AxCandleWidth) -> AxWsResult<()> {
419 let request_id = self.next_request_id();
420 let topic = format!("candles:{symbol}:{width:?}");
421
422 self.subscriptions.mark_subscribe(&topic);
423
424 self.send_cmd(HandlerCommand::SubscribeCandles {
425 request_id,
426 symbol: symbol.to_string(),
427 width,
428 })
429 .await
430 }
431
432 pub async fn unsubscribe_candles(&self, symbol: &str, width: AxCandleWidth) -> AxWsResult<()> {
438 let request_id = self.next_request_id();
439 let topic = format!("candles:{symbol}:{width:?}");
440
441 self.subscriptions.mark_unsubscribe(&topic);
442
443 self.send_cmd(HandlerCommand::UnsubscribeCandles {
444 request_id,
445 symbol: symbol.to_string(),
446 width,
447 })
448 .await
449 }
450
451 pub fn stream(&mut self) -> impl futures_util::Stream<Item = NautilusWsMessage> + use<'_> {
457 let rx = self
458 .out_rx
459 .take()
460 .expect("Stream receiver already taken or client not connected");
461 let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
462 async_stream::stream! {
463 while let Some(msg) = rx.recv().await {
464 yield msg;
465 }
466 }
467 }
468
469 pub async fn disconnect(&self) {
471 log::debug!("Disconnecting WebSocket");
472 let _ = self.send_cmd(HandlerCommand::Disconnect).await;
473 }
474
475 pub async fn close(&mut self) {
477 log::debug!("Closing WebSocket client");
478 self.signal.store(true, Ordering::Relaxed);
479
480 let _ = self.send_cmd(HandlerCommand::Disconnect).await;
481
482 if let Some(handle) = self.task_handle.take() {
483 const CLOSE_TIMEOUT: Duration = Duration::from_secs(2);
484
485 match tokio::time::timeout(CLOSE_TIMEOUT, async {
486 loop {
487 if Arc::strong_count(&handle) == 1 {
488 break;
489 }
490 tokio::time::sleep(Duration::from_millis(50)).await;
491 }
492 })
493 .await
494 {
495 Ok(()) => log::debug!("Handler task completed gracefully"),
496 Err(_) => {
497 log::warn!("Handler task did not complete within timeout, aborting");
498 handle.abort();
499 }
500 }
501 }
502 }
503
504 async fn send_cmd(&self, cmd: HandlerCommand) -> AxWsResult<()> {
505 let guard = self.cmd_tx.read().await;
506 guard
507 .send(cmd)
508 .map_err(|e| AxWsClientError::ChannelError(e.to_string()))
509 }
510}