nautilus_architect_ax/websocket/data/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Market data WebSocket client for Ax.
17
18use std::{
19    fmt::Debug,
20    sync::{
21        Arc,
22        atomic::{AtomicBool, AtomicI64, AtomicU8, Ordering},
23    },
24    time::Duration,
25};
26
27use arc_swap::ArcSwap;
28use dashmap::DashMap;
29use nautilus_common::live::get_runtime;
30use nautilus_core::consts::NAUTILUS_USER_AGENT;
31use nautilus_model::instruments::{Instrument, InstrumentAny};
32use nautilus_network::{
33    backoff::ExponentialBackoff,
34    mode::ConnectionMode,
35    websocket::{
36        PingHandler, SubscriptionState, WebSocketClient, WebSocketConfig, channel_message_handler,
37    },
38};
39use ustr::Ustr;
40
41use super::handler::{FeedHandler, HandlerCommand};
42use crate::{
43    common::enums::{AxCandleWidth, AxMarketDataLevel},
44    websocket::messages::NautilusWsMessage,
45};
46
47/// Default heartbeat interval in seconds.
48const DEFAULT_HEARTBEAT_SECS: u64 = 30;
49
50/// Subscription topic delimiter for Ax.
51const AX_TOPIC_DELIMITER: char = ':';
52
53/// Result type for Ax WebSocket operations.
54pub type AxWsResult<T> = Result<T, AxWsClientError>;
55
56/// Error type for the Ax WebSocket client.
57#[derive(Debug, Clone)]
58pub enum AxWsClientError {
59    /// Transport/connection error.
60    Transport(String),
61    /// Channel send error.
62    ChannelError(String),
63}
64
65impl core::fmt::Display for AxWsClientError {
66    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
67        match self {
68            Self::Transport(msg) => write!(f, "Transport error: {msg}"),
69            Self::ChannelError(msg) => write!(f, "Channel error: {msg}"),
70        }
71    }
72}
73
74impl std::error::Error for AxWsClientError {}
75
76/// Market data WebSocket client for Ax.
77///
78/// Provides streaming market data including tickers, trades, order books, and candles.
79/// Requires Bearer token authentication obtained via the HTTP `/api/authenticate` endpoint.
80#[cfg_attr(
81    feature = "python",
82    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.architect")
83)]
84pub struct AxMdWebSocketClient {
85    url: String,
86    heartbeat: Option<u64>,
87    auth_token: Option<String>,
88    connection_mode: Arc<ArcSwap<AtomicU8>>,
89    cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
90    out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>,
91    signal: Arc<AtomicBool>,
92    task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
93    subscriptions: SubscriptionState,
94    instruments_cache: Arc<DashMap<Ustr, InstrumentAny>>,
95    request_id_counter: Arc<AtomicI64>,
96}
97
98impl Debug for AxMdWebSocketClient {
99    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
100        f.debug_struct(stringify!(AxMdWebSocketClient))
101            .field("url", &self.url)
102            .field("heartbeat", &self.heartbeat)
103            .field("confirmed_subscriptions", &self.subscriptions.len())
104            .finish()
105    }
106}
107
108impl Clone for AxMdWebSocketClient {
109    fn clone(&self) -> Self {
110        Self {
111            url: self.url.clone(),
112            heartbeat: self.heartbeat,
113            auth_token: self.auth_token.clone(),
114            connection_mode: Arc::clone(&self.connection_mode),
115            cmd_tx: Arc::clone(&self.cmd_tx),
116            out_rx: None, // Each clone gets its own receiver
117            signal: Arc::clone(&self.signal),
118            task_handle: None, // Each clone gets its own task handle
119            subscriptions: self.subscriptions.clone(),
120            instruments_cache: Arc::clone(&self.instruments_cache),
121            request_id_counter: Arc::clone(&self.request_id_counter),
122        }
123    }
124}
125
126impl AxMdWebSocketClient {
127    /// Creates a new Ax market data WebSocket client.
128    ///
129    /// The `auth_token` is a Bearer token obtained from the HTTP `/api/authenticate` endpoint.
130    #[must_use]
131    pub fn new(url: String, auth_token: String, heartbeat: Option<u64>) -> Self {
132        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
133
134        let initial_mode = AtomicU8::new(ConnectionMode::Closed.as_u8());
135        let connection_mode = Arc::new(ArcSwap::from_pointee(initial_mode));
136
137        Self {
138            url,
139            heartbeat: heartbeat.or(Some(DEFAULT_HEARTBEAT_SECS)),
140            auth_token: Some(auth_token),
141            connection_mode,
142            cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
143            out_rx: None,
144            signal: Arc::new(AtomicBool::new(false)),
145            task_handle: None,
146            subscriptions: SubscriptionState::new(AX_TOPIC_DELIMITER),
147            instruments_cache: Arc::new(DashMap::new()),
148            request_id_counter: Arc::new(AtomicI64::new(1)),
149        }
150    }
151
152    /// Returns the WebSocket URL.
153    #[must_use]
154    pub fn url(&self) -> &str {
155        &self.url
156    }
157
158    /// Sets the authentication token for subsequent connections.
159    ///
160    /// This should be called before `connect()` if authentication is required.
161    pub fn set_auth_token(&mut self, token: String) {
162        self.auth_token = Some(token);
163    }
164
165    /// Returns whether the client is currently connected and active.
166    #[must_use]
167    pub fn is_active(&self) -> bool {
168        let connection_mode_arc = self.connection_mode.load();
169        ConnectionMode::from_atomic(&connection_mode_arc).is_active()
170            && !self.signal.load(Ordering::Acquire)
171    }
172
173    /// Returns whether the client is closed.
174    #[must_use]
175    pub fn is_closed(&self) -> bool {
176        let connection_mode_arc = self.connection_mode.load();
177        ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
178            || self.signal.load(Ordering::Acquire)
179    }
180
181    /// Returns the number of confirmed subscriptions.
182    #[must_use]
183    pub fn subscription_count(&self) -> usize {
184        self.subscriptions.len()
185    }
186
187    /// Generates a unique request ID.
188    fn next_request_id(&self) -> i64 {
189        self.request_id_counter.fetch_add(1, Ordering::Relaxed)
190    }
191
192    /// Caches an instrument for use during message parsing.
193    pub fn cache_instrument(&self, instrument: InstrumentAny) {
194        let symbol = instrument.symbol().inner();
195        self.instruments_cache.insert(symbol, instrument.clone());
196
197        if self.is_active() {
198            let cmd = HandlerCommand::UpdateInstrument(Box::new(instrument));
199            let cmd_tx = self.cmd_tx.clone();
200            get_runtime().spawn(async move {
201                let guard = cmd_tx.read().await;
202                let _ = guard.send(cmd);
203            });
204        }
205    }
206
207    /// Returns a cached instrument by symbol.
208    #[must_use]
209    pub fn get_cached_instrument(&self, symbol: &Ustr) -> Option<InstrumentAny> {
210        self.instruments_cache.get(symbol).map(|r| r.clone())
211    }
212
213    /// Establishes the WebSocket connection.
214    ///
215    /// # Errors
216    ///
217    /// Returns an error if the connection cannot be established.
218    pub async fn connect(&mut self) -> AxWsResult<()> {
219        const MAX_RETRIES: u32 = 5;
220        const CONNECTION_TIMEOUT_SECS: u64 = 10;
221
222        self.signal.store(false, Ordering::Relaxed);
223
224        let (raw_handler, raw_rx) = channel_message_handler();
225
226        // No-op: ping responses are handled internally by the WebSocketClient
227        let ping_handler: PingHandler = Arc::new(move |_payload: Vec<u8>| {});
228
229        let mut headers = vec![("User-Agent".to_string(), NAUTILUS_USER_AGENT.to_string())];
230        if let Some(ref token) = self.auth_token {
231            headers.push(("Authorization".to_string(), format!("Bearer {token}")));
232        }
233
234        let config = WebSocketConfig {
235            url: self.url.clone(),
236            headers,
237            heartbeat: self.heartbeat,
238            heartbeat_msg: None, // Ax server sends heartbeats
239            reconnect_timeout_ms: Some(5_000),
240            reconnect_delay_initial_ms: Some(500),
241            reconnect_delay_max_ms: Some(5_000),
242            reconnect_backoff_factor: Some(1.5),
243            reconnect_jitter_ms: Some(250),
244            reconnect_max_attempts: None,
245        };
246
247        // Retry initial connection with exponential backoff
248        let mut backoff = ExponentialBackoff::new(
249            Duration::from_millis(500),
250            Duration::from_millis(5000),
251            2.0,
252            250,
253            false,
254        )
255        .map_err(|e| AxWsClientError::Transport(e.to_string()))?;
256
257        let mut last_error: String;
258        let mut attempt = 0;
259
260        let client = loop {
261            attempt += 1;
262
263            match tokio::time::timeout(
264                Duration::from_secs(CONNECTION_TIMEOUT_SECS),
265                WebSocketClient::connect(
266                    config.clone(),
267                    Some(raw_handler.clone()),
268                    Some(ping_handler.clone()),
269                    None,
270                    vec![],
271                    None,
272                ),
273            )
274            .await
275            {
276                Ok(Ok(client)) => {
277                    if attempt > 1 {
278                        log::info!("WebSocket connection established after {attempt} attempts");
279                    }
280                    break client;
281                }
282                Ok(Err(e)) => {
283                    last_error = e.to_string();
284                    log::warn!(
285                        "WebSocket connection attempt failed: attempt={attempt}/{MAX_RETRIES}, url={}, error={last_error}",
286                        self.url
287                    );
288                }
289                Err(_) => {
290                    last_error = format!("Connection timeout after {CONNECTION_TIMEOUT_SECS}s");
291                    log::warn!(
292                        "WebSocket connection attempt timed out: attempt={attempt}/{MAX_RETRIES}, url={}",
293                        self.url
294                    );
295                }
296            }
297
298            if attempt >= MAX_RETRIES {
299                return Err(AxWsClientError::Transport(format!(
300                    "Failed to connect to {} after {MAX_RETRIES} attempts: {}",
301                    self.url,
302                    if last_error.is_empty() {
303                        "unknown error"
304                    } else {
305                        &last_error
306                    }
307                )));
308            }
309
310            let delay = backoff.next_duration();
311            log::debug!(
312                "Retrying in {delay:?} (attempt {}/{MAX_RETRIES})",
313                attempt + 1
314            );
315            tokio::time::sleep(delay).await;
316        };
317
318        self.connection_mode.store(client.connection_mode_atomic());
319
320        let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<NautilusWsMessage>();
321        self.out_rx = Some(Arc::new(out_rx));
322
323        let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
324        *self.cmd_tx.write().await = cmd_tx.clone();
325
326        self.send_cmd(HandlerCommand::SetClient(client)).await?;
327
328        if !self.instruments_cache.is_empty() {
329            let cached_instruments: Vec<InstrumentAny> = self
330                .instruments_cache
331                .iter()
332                .map(|entry| entry.value().clone())
333                .collect();
334            self.send_cmd(HandlerCommand::InitializeInstruments(cached_instruments))
335                .await?;
336        }
337
338        let signal = Arc::clone(&self.signal);
339        let subscriptions = self.subscriptions.clone();
340
341        let stream_handle = get_runtime().spawn(async move {
342            let mut handler = FeedHandler::new(
343                signal.clone(),
344                cmd_rx,
345                raw_rx,
346                out_tx.clone(),
347                subscriptions.clone(),
348            );
349
350            while let Some(msg) = handler.next().await {
351                if matches!(msg, NautilusWsMessage::Reconnected) {
352                    log::info!("WebSocket reconnected, resubscribing...");
353                    // TODO: Replay subscriptions on reconnect
354                }
355
356                if out_tx.send(msg).is_err() {
357                    log::debug!("Output channel closed");
358                    break;
359                }
360            }
361
362            log::debug!("Handler loop exited");
363        });
364
365        self.task_handle = Some(Arc::new(stream_handle));
366
367        Ok(())
368    }
369
370    /// Subscribes to market data for a symbol at the specified level.
371    ///
372    /// # Errors
373    ///
374    /// Returns an error if the subscription command cannot be sent.
375    pub async fn subscribe(&self, symbol: &str, level: AxMarketDataLevel) -> AxWsResult<()> {
376        let request_id = self.next_request_id();
377        let topic = format!("{symbol}:{level:?}");
378
379        self.subscriptions.mark_subscribe(&topic);
380
381        self.send_cmd(HandlerCommand::Subscribe {
382            request_id,
383            symbol: symbol.to_string(),
384            level,
385        })
386        .await
387    }
388
389    /// Unsubscribes from market data for a symbol.
390    ///
391    /// # Errors
392    ///
393    /// Returns an error if the unsubscribe command cannot be sent.
394    pub async fn unsubscribe(&self, symbol: &str) -> AxWsResult<()> {
395        let request_id = self.next_request_id();
396
397        for level in [
398            AxMarketDataLevel::Level1,
399            AxMarketDataLevel::Level2,
400            AxMarketDataLevel::Level3,
401        ] {
402            let topic = format!("{symbol}:{level:?}");
403            self.subscriptions.mark_unsubscribe(&topic);
404        }
405
406        self.send_cmd(HandlerCommand::Unsubscribe {
407            request_id,
408            symbol: symbol.to_string(),
409        })
410        .await
411    }
412
413    /// Subscribes to candle data for a symbol.
414    ///
415    /// # Errors
416    ///
417    /// Returns an error if the subscription command cannot be sent.
418    pub async fn subscribe_candles(&self, symbol: &str, width: AxCandleWidth) -> AxWsResult<()> {
419        let request_id = self.next_request_id();
420        let topic = format!("candles:{symbol}:{width:?}");
421
422        self.subscriptions.mark_subscribe(&topic);
423
424        self.send_cmd(HandlerCommand::SubscribeCandles {
425            request_id,
426            symbol: symbol.to_string(),
427            width,
428        })
429        .await
430    }
431
432    /// Unsubscribes from candle data for a symbol.
433    ///
434    /// # Errors
435    ///
436    /// Returns an error if the unsubscribe command cannot be sent.
437    pub async fn unsubscribe_candles(&self, symbol: &str, width: AxCandleWidth) -> AxWsResult<()> {
438        let request_id = self.next_request_id();
439        let topic = format!("candles:{symbol}:{width:?}");
440
441        self.subscriptions.mark_unsubscribe(&topic);
442
443        self.send_cmd(HandlerCommand::UnsubscribeCandles {
444            request_id,
445            symbol: symbol.to_string(),
446            width,
447        })
448        .await
449    }
450
451    /// Returns a stream of messages from the WebSocket.
452    ///
453    /// # Panics
454    ///
455    /// Panics if called more than once or before connecting.
456    pub fn stream(&mut self) -> impl futures_util::Stream<Item = NautilusWsMessage> + use<'_> {
457        let rx = self
458            .out_rx
459            .take()
460            .expect("Stream receiver already taken or client not connected");
461        let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
462        async_stream::stream! {
463            while let Some(msg) = rx.recv().await {
464                yield msg;
465            }
466        }
467    }
468
469    /// Disconnects the WebSocket connection gracefully.
470    pub async fn disconnect(&self) {
471        log::debug!("Disconnecting WebSocket");
472        let _ = self.send_cmd(HandlerCommand::Disconnect).await;
473    }
474
475    /// Closes the WebSocket connection and cleans up resources.
476    pub async fn close(&mut self) {
477        log::debug!("Closing WebSocket client");
478        self.signal.store(true, Ordering::Relaxed);
479
480        let _ = self.send_cmd(HandlerCommand::Disconnect).await;
481
482        if let Some(handle) = self.task_handle.take() {
483            const CLOSE_TIMEOUT: Duration = Duration::from_secs(2);
484
485            match tokio::time::timeout(CLOSE_TIMEOUT, async {
486                loop {
487                    if Arc::strong_count(&handle) == 1 {
488                        break;
489                    }
490                    tokio::time::sleep(Duration::from_millis(50)).await;
491                }
492            })
493            .await
494            {
495                Ok(()) => log::debug!("Handler task completed gracefully"),
496                Err(_) => {
497                    log::warn!("Handler task did not complete within timeout, aborting");
498                    handle.abort();
499                }
500            }
501        }
502    }
503
504    async fn send_cmd(&self, cmd: HandlerCommand) -> AxWsResult<()> {
505        let guard = self.cmd_tx.read().await;
506        guard
507            .send(cmd)
508            .map_err(|e| AxWsClientError::ChannelError(e.to_string()))
509    }
510}