nautilus_architect_ax/websocket/orders/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Orders WebSocket client for Ax.
17
18use std::{
19    fmt::Debug,
20    sync::{
21        Arc,
22        atomic::{AtomicBool, AtomicI64, AtomicU8, Ordering},
23    },
24    time::Duration,
25};
26
27use arc_swap::ArcSwap;
28use dashmap::DashMap;
29use nautilus_common::live::get_runtime;
30use nautilus_core::consts::NAUTILUS_USER_AGENT;
31use nautilus_model::{
32    identifiers::{AccountId, ClientOrderId},
33    instruments::{Instrument, InstrumentAny},
34};
35use nautilus_network::{
36    backoff::ExponentialBackoff,
37    mode::ConnectionMode,
38    websocket::{
39        AuthTracker, PingHandler, WebSocketClient, WebSocketConfig, channel_message_handler,
40    },
41};
42use rust_decimal::Decimal;
43use ustr::Ustr;
44
45use super::handler::{FeedHandler, HandlerCommand};
46use crate::{
47    common::enums::{AxOrderSide, AxTimeInForce},
48    websocket::messages::{AxOrdersWsMessage, AxWsPlaceOrder, OrderMetadata},
49};
50
51/// Default heartbeat interval in seconds.
52const DEFAULT_HEARTBEAT_SECS: u64 = 30;
53
54/// Result type for Ax orders WebSocket operations.
55pub type AxOrdersWsResult<T> = Result<T, AxOrdersWsClientError>;
56
57/// Error type for the Ax orders WebSocket client.
58#[derive(Debug, Clone)]
59pub enum AxOrdersWsClientError {
60    /// Transport/connection error.
61    Transport(String),
62    /// Channel send error.
63    ChannelError(String),
64    /// Authentication error.
65    AuthenticationError(String),
66}
67
68impl core::fmt::Display for AxOrdersWsClientError {
69    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
70        match self {
71            Self::Transport(msg) => write!(f, "Transport error: {msg}"),
72            Self::ChannelError(msg) => write!(f, "Channel error: {msg}"),
73            Self::AuthenticationError(msg) => write!(f, "Authentication error: {msg}"),
74        }
75    }
76}
77
78impl std::error::Error for AxOrdersWsClientError {}
79
80/// Orders WebSocket client for Ax.
81///
82/// Provides authenticated order management including placing, canceling,
83/// and monitoring order status via WebSocket.
84pub struct AxOrdersWebSocketClient {
85    url: String,
86    heartbeat: Option<u64>,
87    connection_mode: Arc<ArcSwap<AtomicU8>>,
88    cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
89    out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<AxOrdersWsMessage>>>,
90    signal: Arc<AtomicBool>,
91    task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
92    auth_tracker: AuthTracker,
93    instruments_cache: Arc<DashMap<Ustr, InstrumentAny>>,
94    request_id_counter: Arc<AtomicI64>,
95    account_id: AccountId,
96}
97
98impl Debug for AxOrdersWebSocketClient {
99    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
100        f.debug_struct(stringify!(AxOrdersWebSocketClient))
101            .field("url", &self.url)
102            .field("heartbeat", &self.heartbeat)
103            .field("account_id", &self.account_id)
104            .finish()
105    }
106}
107
108impl Clone for AxOrdersWebSocketClient {
109    fn clone(&self) -> Self {
110        Self {
111            url: self.url.clone(),
112            heartbeat: self.heartbeat,
113            connection_mode: Arc::clone(&self.connection_mode),
114            cmd_tx: Arc::clone(&self.cmd_tx),
115            out_rx: None, // Each clone gets its own receiver
116            signal: Arc::clone(&self.signal),
117            task_handle: None, // Each clone gets its own task handle
118            auth_tracker: self.auth_tracker.clone(),
119            instruments_cache: Arc::clone(&self.instruments_cache),
120            request_id_counter: Arc::clone(&self.request_id_counter),
121            account_id: self.account_id,
122        }
123    }
124}
125
126impl AxOrdersWebSocketClient {
127    /// Creates a new Ax orders WebSocket client.
128    #[must_use]
129    pub fn new(url: String, account_id: AccountId, heartbeat: Option<u64>) -> Self {
130        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
131
132        let initial_mode = AtomicU8::new(ConnectionMode::Closed.as_u8());
133        let connection_mode = Arc::new(ArcSwap::from_pointee(initial_mode));
134
135        Self {
136            url,
137            heartbeat: heartbeat.or(Some(DEFAULT_HEARTBEAT_SECS)),
138            connection_mode,
139            cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
140            out_rx: None,
141            signal: Arc::new(AtomicBool::new(false)),
142            task_handle: None,
143            auth_tracker: AuthTracker::default(),
144            instruments_cache: Arc::new(DashMap::new()),
145            request_id_counter: Arc::new(AtomicI64::new(1)),
146            account_id,
147        }
148    }
149
150    /// Returns the WebSocket URL.
151    #[must_use]
152    pub fn url(&self) -> &str {
153        &self.url
154    }
155
156    /// Returns the account ID.
157    #[must_use]
158    pub fn account_id(&self) -> AccountId {
159        self.account_id
160    }
161
162    /// Returns whether the client is currently connected and active.
163    #[must_use]
164    pub fn is_active(&self) -> bool {
165        let connection_mode_arc = self.connection_mode.load();
166        ConnectionMode::from_atomic(&connection_mode_arc).is_active()
167            && !self.signal.load(Ordering::Acquire)
168    }
169
170    /// Returns whether the client is closed.
171    #[must_use]
172    pub fn is_closed(&self) -> bool {
173        let connection_mode_arc = self.connection_mode.load();
174        ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
175            || self.signal.load(Ordering::Acquire)
176    }
177
178    /// Generates a unique request ID.
179    fn next_request_id(&self) -> i64 {
180        self.request_id_counter.fetch_add(1, Ordering::Relaxed)
181    }
182
183    /// Caches an instrument for use during message parsing.
184    pub fn cache_instrument(&self, instrument: InstrumentAny) {
185        let symbol = instrument.symbol().inner();
186        self.instruments_cache.insert(symbol, instrument.clone());
187
188        // If connected, also send to handler
189        if self.is_active() {
190            let cmd = HandlerCommand::UpdateInstrument(Box::new(instrument));
191            let cmd_tx = self.cmd_tx.clone();
192            get_runtime().spawn(async move {
193                let guard = cmd_tx.read().await;
194                let _ = guard.send(cmd);
195            });
196        }
197    }
198
199    /// Returns a cached instrument by symbol.
200    #[must_use]
201    pub fn get_cached_instrument(&self, symbol: &Ustr) -> Option<InstrumentAny> {
202        self.instruments_cache.get(symbol).map(|r| r.clone())
203    }
204
205    /// Establishes the WebSocket connection with authentication.
206    ///
207    /// # Arguments
208    ///
209    /// * `bearer_token` - The bearer token for authentication.
210    ///
211    /// # Errors
212    ///
213    /// Returns an error if the connection cannot be established.
214    pub async fn connect(&mut self, bearer_token: &str) -> AxOrdersWsResult<()> {
215        const MAX_RETRIES: u32 = 5;
216        const CONNECTION_TIMEOUT_SECS: u64 = 10;
217
218        self.signal.store(false, Ordering::Relaxed);
219
220        let (raw_handler, raw_rx) = channel_message_handler();
221
222        // No-op ping handler: handler owns the WebSocketClient and responds to pings directly
223        let ping_handler: PingHandler = Arc::new(move |_payload: Vec<u8>| {
224            // Handler responds to pings internally via select! loop
225        });
226
227        let config = WebSocketConfig {
228            url: self.url.clone(),
229            headers: vec![
230                ("User-Agent".to_string(), NAUTILUS_USER_AGENT.to_string()),
231                (
232                    "Authorization".to_string(),
233                    format!("Bearer {bearer_token}"),
234                ),
235            ],
236            heartbeat: self.heartbeat,
237            heartbeat_msg: None, // Ax server sends heartbeats
238            reconnect_timeout_ms: Some(5_000),
239            reconnect_delay_initial_ms: Some(500),
240            reconnect_delay_max_ms: Some(5_000),
241            reconnect_backoff_factor: Some(1.5),
242            reconnect_jitter_ms: Some(250),
243            reconnect_max_attempts: None,
244        };
245
246        // Retry initial connection with exponential backoff
247        let mut backoff = ExponentialBackoff::new(
248            Duration::from_millis(500),
249            Duration::from_millis(5000),
250            2.0,
251            250,
252            false,
253        )
254        .map_err(|e| AxOrdersWsClientError::Transport(e.to_string()))?;
255
256        let mut last_error: String;
257        let mut attempt = 0;
258
259        let client = loop {
260            attempt += 1;
261
262            match tokio::time::timeout(
263                Duration::from_secs(CONNECTION_TIMEOUT_SECS),
264                WebSocketClient::connect(
265                    config.clone(),
266                    Some(raw_handler.clone()),
267                    Some(ping_handler.clone()),
268                    None,
269                    vec![],
270                    None,
271                ),
272            )
273            .await
274            {
275                Ok(Ok(client)) => {
276                    if attempt > 1 {
277                        log::info!("WebSocket connection established after {attempt} attempts");
278                    }
279                    break client;
280                }
281                Ok(Err(e)) => {
282                    last_error = e.to_string();
283                    log::warn!(
284                        "WebSocket connection attempt failed: attempt={attempt}, max_retries={MAX_RETRIES}, url={}, error={last_error}",
285                        self.url
286                    );
287                }
288                Err(_) => {
289                    last_error = format!("Connection timeout after {CONNECTION_TIMEOUT_SECS}s");
290                    log::warn!(
291                        "WebSocket connection attempt timed out: attempt={attempt}, max_retries={MAX_RETRIES}, url={}",
292                        self.url
293                    );
294                }
295            }
296
297            if attempt >= MAX_RETRIES {
298                return Err(AxOrdersWsClientError::Transport(format!(
299                    "Failed to connect to {} after {MAX_RETRIES} attempts: {}",
300                    self.url,
301                    if last_error.is_empty() {
302                        "unknown error"
303                    } else {
304                        &last_error
305                    }
306                )));
307            }
308
309            let delay = backoff.next_duration();
310            log::debug!(
311                "Retrying in {delay:?} (attempt {}/{MAX_RETRIES})",
312                attempt + 1
313            );
314            tokio::time::sleep(delay).await;
315        };
316
317        self.connection_mode.store(client.connection_mode_atomic());
318
319        let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<AxOrdersWsMessage>();
320        self.out_rx = Some(Arc::new(out_rx));
321
322        let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
323        *self.cmd_tx.write().await = cmd_tx.clone();
324
325        self.send_cmd(HandlerCommand::SetClient(client)).await?;
326
327        if !self.instruments_cache.is_empty() {
328            let cached_instruments: Vec<InstrumentAny> = self
329                .instruments_cache
330                .iter()
331                .map(|entry| entry.value().clone())
332                .collect();
333            self.send_cmd(HandlerCommand::InitializeInstruments(cached_instruments))
334                .await?;
335        }
336
337        // Bearer token is passed in connection headers
338        self.send_cmd(HandlerCommand::Authenticate {
339            token: bearer_token.to_string(),
340        })
341        .await?;
342
343        let signal = Arc::clone(&self.signal);
344        let auth_tracker = self.auth_tracker.clone();
345
346        let stream_handle = get_runtime().spawn(async move {
347            let mut handler = FeedHandler::new(
348                signal.clone(),
349                cmd_rx,
350                raw_rx,
351                out_tx.clone(),
352                auth_tracker.clone(),
353            );
354
355            while let Some(msg) = handler.next().await {
356                if matches!(msg, AxOrdersWsMessage::Reconnected) {
357                    log::info!("WebSocket reconnected");
358                    // TODO: Re-authenticate on reconnect if needed
359                }
360
361                if out_tx.send(msg).is_err() {
362                    log::debug!("Output channel closed");
363                    break;
364                }
365            }
366
367            log::debug!("Handler loop exited");
368        });
369
370        self.task_handle = Some(Arc::new(stream_handle));
371
372        Ok(())
373    }
374
375    /// Places an order via WebSocket.
376    ///
377    /// # Errors
378    ///
379    /// Returns an error if the order command cannot be sent.
380    #[allow(clippy::too_many_arguments)]
381    pub async fn place_order(
382        &self,
383        client_order_id: ClientOrderId,
384        symbol: Ustr,
385        side: AxOrderSide,
386        quantity: i64,
387        price: Decimal,
388        time_in_force: AxTimeInForce,
389        post_only: bool,
390        tag: Option<String>,
391    ) -> AxOrdersWsResult<i64> {
392        let request_id = self.next_request_id();
393
394        let order = AxWsPlaceOrder {
395            rid: request_id,
396            t: "p".to_string(),
397            s: symbol.to_string(),
398            d: side,
399            q: quantity,
400            p: price,
401            tif: time_in_force,
402            po: post_only,
403            tag,
404        };
405
406        let metadata = OrderMetadata {
407            client_order_id,
408            symbol,
409        };
410
411        self.send_cmd(HandlerCommand::PlaceOrder {
412            request_id,
413            order,
414            metadata,
415        })
416        .await?;
417
418        Ok(request_id)
419    }
420
421    /// Cancels an order via WebSocket.
422    ///
423    /// # Errors
424    ///
425    /// Returns an error if the cancel command cannot be sent.
426    pub async fn cancel_order(&self, order_id: &str) -> AxOrdersWsResult<i64> {
427        let request_id = self.next_request_id();
428
429        self.send_cmd(HandlerCommand::CancelOrder {
430            request_id,
431            order_id: order_id.to_string(),
432        })
433        .await?;
434
435        Ok(request_id)
436    }
437
438    /// Requests open orders via WebSocket.
439    ///
440    /// # Errors
441    ///
442    /// Returns an error if the request command cannot be sent.
443    pub async fn get_open_orders(&self) -> AxOrdersWsResult<i64> {
444        let request_id = self.next_request_id();
445
446        self.send_cmd(HandlerCommand::GetOpenOrders { request_id })
447            .await?;
448
449        Ok(request_id)
450    }
451
452    /// Returns a stream of messages from the WebSocket.
453    ///
454    /// # Panics
455    ///
456    /// Panics if called more than once or before connecting.
457    pub fn stream(&mut self) -> impl futures_util::Stream<Item = AxOrdersWsMessage> + use<'_> {
458        let rx = self
459            .out_rx
460            .take()
461            .expect("Stream receiver already taken or client not connected");
462        let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
463        async_stream::stream! {
464            while let Some(msg) = rx.recv().await {
465                yield msg;
466            }
467        }
468    }
469
470    /// Disconnects the WebSocket connection gracefully.
471    pub async fn disconnect(&self) {
472        log::debug!("Disconnecting WebSocket");
473        let _ = self.send_cmd(HandlerCommand::Disconnect).await;
474    }
475
476    /// Closes the WebSocket connection and cleans up resources.
477    pub async fn close(&mut self) {
478        log::debug!("Closing WebSocket client");
479        self.signal.store(true, Ordering::Relaxed);
480
481        let _ = self.send_cmd(HandlerCommand::Disconnect).await;
482
483        if let Some(handle) = self.task_handle.take() {
484            const CLOSE_TIMEOUT: Duration = Duration::from_secs(2);
485
486            match tokio::time::timeout(CLOSE_TIMEOUT, async {
487                loop {
488                    if Arc::strong_count(&handle) == 1 {
489                        break;
490                    }
491                    tokio::time::sleep(Duration::from_millis(50)).await;
492                }
493            })
494            .await
495            {
496                Ok(()) => log::debug!("Handler task completed gracefully"),
497                Err(_) => {
498                    log::warn!("Handler task did not complete within timeout, aborting");
499                    handle.abort();
500                }
501            }
502        }
503    }
504
505    async fn send_cmd(&self, cmd: HandlerCommand) -> AxOrdersWsResult<()> {
506        let guard = self.cmd_tx.read().await;
507        guard
508            .send(cmd)
509            .map_err(|e| AxOrdersWsClientError::ChannelError(e.to_string()))
510    }
511}