Skip to main content

nautilus_architect_ax/websocket/orders/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Orders WebSocket client for Ax.
17
18use std::{
19    fmt::Debug,
20    sync::{
21        Arc,
22        atomic::{AtomicBool, AtomicI64, AtomicU8, Ordering},
23    },
24    time::Duration,
25};
26
27use arc_swap::ArcSwap;
28use dashmap::DashMap;
29use nautilus_common::live::get_runtime;
30use nautilus_core::{
31    consts::NAUTILUS_USER_AGENT,
32    nanos::UnixNanos,
33    time::{AtomicTime, get_atomic_clock_realtime},
34};
35use nautilus_model::{
36    enums::{OrderSide, OrderType, TimeInForce},
37    identifiers::{AccountId, ClientOrderId, InstrumentId, StrategyId, TraderId, VenueOrderId},
38    instruments::{Instrument, InstrumentAny},
39    types::{Price, Quantity},
40};
41use nautilus_network::{
42    backoff::ExponentialBackoff,
43    mode::ConnectionMode,
44    websocket::{
45        AuthTracker, PingHandler, WebSocketClient, WebSocketConfig, channel_message_handler,
46    },
47};
48use ustr::Ustr;
49
50use super::handler::{FeedHandler, HandlerCommand, WsOrderInfo};
51use crate::{
52    common::{
53        consts::AX_NAUTILUS_TAG,
54        enums::{AxOrderRequestType, AxOrderSide, AxOrderType, AxTimeInForce},
55        parse::{client_order_id_to_cid, quantity_to_contracts},
56    },
57    websocket::messages::{AxOrdersWsMessage, AxWsPlaceOrder, OrderMetadata},
58};
59
60/// Default heartbeat interval in seconds.
61const DEFAULT_HEARTBEAT_SECS: u64 = 30;
62
63/// Result type for Ax orders WebSocket operations.
64pub type AxOrdersWsResult<T> = Result<T, AxOrdersWsClientError>;
65
66/// Error type for the Ax orders WebSocket client.
67#[derive(Debug, Clone)]
68pub enum AxOrdersWsClientError {
69    /// Transport/connection error.
70    Transport(String),
71    /// Channel send error.
72    ChannelError(String),
73    /// Authentication error.
74    AuthenticationError(String),
75    /// Client-side validation error.
76    ClientError(String),
77}
78
79impl core::fmt::Display for AxOrdersWsClientError {
80    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
81        match self {
82            Self::Transport(msg) => write!(f, "Transport error: {msg}"),
83            Self::ChannelError(msg) => write!(f, "Channel error: {msg}"),
84            Self::AuthenticationError(msg) => write!(f, "Authentication error: {msg}"),
85            Self::ClientError(msg) => write!(f, "Client error: {msg}"),
86        }
87    }
88}
89
90impl std::error::Error for AxOrdersWsClientError {}
91
92impl From<&'static str> for AxOrdersWsClientError {
93    fn from(msg: &'static str) -> Self {
94        Self::ClientError(msg.to_string())
95    }
96}
97
98/// Orders WebSocket client for Ax.
99///
100/// Provides authenticated order management including placing, canceling,
101/// and monitoring order status via WebSocket.
102#[cfg_attr(
103    feature = "python",
104    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.architect")
105)]
106pub struct AxOrdersWebSocketClient {
107    clock: &'static AtomicTime,
108    url: String,
109    heartbeat: Option<u64>,
110    connection_mode: Arc<ArcSwap<AtomicU8>>,
111    cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
112    out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<AxOrdersWsMessage>>>,
113    signal: Arc<AtomicBool>,
114    task_handle: Option<tokio::task::JoinHandle<()>>,
115    auth_tracker: AuthTracker,
116    instruments_cache: Arc<DashMap<Ustr, InstrumentAny>>,
117    orders_metadata: Arc<DashMap<ClientOrderId, OrderMetadata>>,
118    venue_to_client_id: Arc<DashMap<VenueOrderId, ClientOrderId>>,
119    cid_to_client_order_id: Arc<DashMap<u64, ClientOrderId>>,
120    request_id_counter: Arc<AtomicI64>,
121    account_id: AccountId,
122    trader_id: TraderId,
123}
124
125impl Debug for AxOrdersWebSocketClient {
126    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
127        f.debug_struct(stringify!(AxOrdersWebSocketClient))
128            .field("url", &self.url)
129            .field("heartbeat", &self.heartbeat)
130            .field("account_id", &self.account_id)
131            .finish()
132    }
133}
134
135impl Clone for AxOrdersWebSocketClient {
136    fn clone(&self) -> Self {
137        Self {
138            clock: self.clock,
139            url: self.url.clone(),
140            heartbeat: self.heartbeat,
141            connection_mode: Arc::clone(&self.connection_mode),
142            cmd_tx: Arc::clone(&self.cmd_tx),
143            out_rx: None, // Each clone gets its own receiver
144            signal: Arc::clone(&self.signal),
145            task_handle: None,
146            auth_tracker: self.auth_tracker.clone(),
147            instruments_cache: Arc::clone(&self.instruments_cache),
148            orders_metadata: Arc::clone(&self.orders_metadata),
149            venue_to_client_id: Arc::clone(&self.venue_to_client_id),
150            cid_to_client_order_id: Arc::clone(&self.cid_to_client_order_id),
151            request_id_counter: Arc::clone(&self.request_id_counter),
152            account_id: self.account_id,
153            trader_id: self.trader_id,
154        }
155    }
156}
157
158impl AxOrdersWebSocketClient {
159    /// Creates a new Ax orders WebSocket client.
160    #[must_use]
161    pub fn new(
162        url: String,
163        account_id: AccountId,
164        trader_id: TraderId,
165        heartbeat: Option<u64>,
166    ) -> Self {
167        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
168
169        let initial_mode = AtomicU8::new(ConnectionMode::Closed.as_u8());
170        let connection_mode = Arc::new(ArcSwap::from_pointee(initial_mode));
171
172        Self {
173            clock: get_atomic_clock_realtime(),
174            url,
175            heartbeat: heartbeat.or(Some(DEFAULT_HEARTBEAT_SECS)),
176            connection_mode,
177            cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
178            out_rx: None,
179            signal: Arc::new(AtomicBool::new(false)),
180            task_handle: None,
181            auth_tracker: AuthTracker::default(),
182            instruments_cache: Arc::new(DashMap::new()),
183            orders_metadata: Arc::new(DashMap::new()),
184            venue_to_client_id: Arc::new(DashMap::new()),
185            cid_to_client_order_id: Arc::new(DashMap::new()),
186            request_id_counter: Arc::new(AtomicI64::new(1)),
187            account_id,
188            trader_id,
189        }
190    }
191
192    fn generate_ts_init(&self) -> UnixNanos {
193        self.clock.get_time_ns()
194    }
195
196    /// Returns the WebSocket URL.
197    #[must_use]
198    pub fn url(&self) -> &str {
199        &self.url
200    }
201
202    /// Returns the account ID.
203    #[must_use]
204    pub fn account_id(&self) -> AccountId {
205        self.account_id
206    }
207
208    /// Returns whether the client is currently connected and active.
209    #[must_use]
210    pub fn is_active(&self) -> bool {
211        let connection_mode_arc = self.connection_mode.load();
212        ConnectionMode::from_atomic(&connection_mode_arc).is_active()
213            && !self.signal.load(Ordering::Acquire)
214    }
215
216    /// Returns whether the client is closed.
217    #[must_use]
218    pub fn is_closed(&self) -> bool {
219        let connection_mode_arc = self.connection_mode.load();
220        ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
221            || self.signal.load(Ordering::Acquire)
222    }
223
224    /// Generates a unique request ID.
225    fn next_request_id(&self) -> i64 {
226        self.request_id_counter.fetch_add(1, Ordering::Relaxed)
227    }
228
229    /// Caches an instrument for use during message parsing.
230    pub fn cache_instrument(&self, instrument: InstrumentAny) {
231        let symbol = instrument.symbol().inner();
232        self.instruments_cache.insert(symbol, instrument.clone());
233
234        // If connected, also send to handler
235        if self.is_active() {
236            let cmd = HandlerCommand::UpdateInstrument(Box::new(instrument));
237            let cmd_tx = self.cmd_tx.clone();
238            get_runtime().spawn(async move {
239                let guard = cmd_tx.read().await;
240                let _ = guard.send(cmd);
241            });
242        }
243    }
244
245    /// Returns a cached instrument by symbol.
246    #[must_use]
247    pub fn get_cached_instrument(&self, symbol: &Ustr) -> Option<InstrumentAny> {
248        self.instruments_cache.get(symbol).map(|r| r.clone())
249    }
250
251    /// Returns the orders metadata cache.
252    #[must_use]
253    pub fn orders_metadata(&self) -> &Arc<DashMap<ClientOrderId, OrderMetadata>> {
254        &self.orders_metadata
255    }
256
257    /// Returns the cid to client order ID mapping for order correlation.
258    #[must_use]
259    pub fn cid_to_client_order_id(&self) -> &Arc<DashMap<u64, ClientOrderId>> {
260        &self.cid_to_client_order_id
261    }
262
263    /// Resolves a cid to a ClientOrderId if the mapping exists.
264    #[must_use]
265    pub fn resolve_cid(&self, cid: u64) -> Option<ClientOrderId> {
266        self.cid_to_client_order_id.get(&cid).map(|v| *v)
267    }
268
269    /// Registers an external order with the WebSocket handler for event tracking.
270    ///
271    /// This allows the handler to create proper events (e.g., OrderCanceled, OrderFilled)
272    /// for orders that were reconciled externally and not submitted through this client.
273    ///
274    /// Returns `false` if the instrument is not cached (registration skipped).
275    pub fn register_external_order(
276        &self,
277        client_order_id: ClientOrderId,
278        venue_order_id: VenueOrderId,
279        instrument_id: InstrumentId,
280        strategy_id: StrategyId,
281    ) -> bool {
282        if self.orders_metadata.contains_key(&client_order_id) {
283            return true;
284        }
285
286        // Required for correct precision on fills
287        let symbol = instrument_id.symbol.inner();
288        let Some(instrument) = self.get_cached_instrument(&symbol) else {
289            log::warn!(
290                "Cannot register external order {client_order_id}: \
291                 instrument {instrument_id} not in cache"
292            );
293            return false;
294        };
295
296        let metadata = OrderMetadata {
297            trader_id: self.trader_id,
298            strategy_id,
299            instrument_id,
300            client_order_id,
301            venue_order_id: Some(venue_order_id),
302            ts_init: self.generate_ts_init(),
303            size_precision: instrument.size_precision(),
304            price_precision: instrument.price_precision(),
305            quote_currency: instrument.quote_currency(),
306        };
307
308        self.orders_metadata.insert(client_order_id, metadata);
309        self.venue_to_client_id
310            .insert(venue_order_id, client_order_id);
311
312        log::debug!(
313            "Registered external order {client_order_id} ({venue_order_id}) for {instrument_id} [{strategy_id}]"
314        );
315
316        true
317    }
318
319    /// Establishes the WebSocket connection with authentication.
320    ///
321    /// # Arguments
322    ///
323    /// * `bearer_token` - The bearer token for authentication.
324    ///
325    /// # Errors
326    ///
327    /// Returns an error if the connection cannot be established.
328    pub async fn connect(&mut self, bearer_token: &str) -> AxOrdersWsResult<()> {
329        const MAX_RETRIES: u32 = 5;
330        const CONNECTION_TIMEOUT_SECS: u64 = 10;
331
332        self.signal.store(false, Ordering::Release);
333
334        let (raw_handler, raw_rx) = channel_message_handler();
335
336        // No-op ping handler: handler owns the WebSocketClient and responds to pings directly
337        let ping_handler: PingHandler = Arc::new(move |_payload: Vec<u8>| {
338            // Handler responds to pings internally via select! loop
339        });
340
341        let config = WebSocketConfig {
342            url: self.url.clone(),
343            headers: vec![
344                ("User-Agent".to_string(), NAUTILUS_USER_AGENT.to_string()),
345                (
346                    "Authorization".to_string(),
347                    format!("Bearer {bearer_token}"),
348                ),
349            ],
350            heartbeat: self.heartbeat,
351            heartbeat_msg: None, // Ax server sends heartbeats
352            reconnect_timeout_ms: Some(5_000),
353            reconnect_delay_initial_ms: Some(500),
354            reconnect_delay_max_ms: Some(5_000),
355            reconnect_backoff_factor: Some(1.5),
356            reconnect_jitter_ms: Some(250),
357            reconnect_max_attempts: None,
358        };
359
360        // Retry initial connection with exponential backoff
361        let mut backoff = ExponentialBackoff::new(
362            Duration::from_millis(500),
363            Duration::from_millis(5000),
364            2.0,
365            250,
366            false,
367        )
368        .map_err(|e| AxOrdersWsClientError::Transport(e.to_string()))?;
369
370        let mut last_error: String;
371        let mut attempt = 0;
372
373        let client = loop {
374            attempt += 1;
375
376            match tokio::time::timeout(
377                Duration::from_secs(CONNECTION_TIMEOUT_SECS),
378                WebSocketClient::connect(
379                    config.clone(),
380                    Some(raw_handler.clone()),
381                    Some(ping_handler.clone()),
382                    None,
383                    vec![],
384                    None,
385                ),
386            )
387            .await
388            {
389                Ok(Ok(client)) => {
390                    if attempt > 1 {
391                        log::info!("WebSocket connection established after {attempt} attempts");
392                    }
393                    break client;
394                }
395                Ok(Err(e)) => {
396                    last_error = e.to_string();
397                    log::warn!(
398                        "WebSocket connection attempt failed: attempt={attempt}, max_retries={MAX_RETRIES}, url={}, error={last_error}",
399                        self.url
400                    );
401                }
402                Err(_) => {
403                    last_error = format!("Connection timeout after {CONNECTION_TIMEOUT_SECS}s");
404                    log::warn!(
405                        "WebSocket connection attempt timed out: attempt={attempt}, max_retries={MAX_RETRIES}, url={}",
406                        self.url
407                    );
408                }
409            }
410
411            if attempt >= MAX_RETRIES {
412                return Err(AxOrdersWsClientError::Transport(format!(
413                    "Failed to connect to {} after {MAX_RETRIES} attempts: {}",
414                    self.url,
415                    if last_error.is_empty() {
416                        "unknown error"
417                    } else {
418                        &last_error
419                    }
420                )));
421            }
422
423            let delay = backoff.next_duration();
424            log::debug!(
425                "Retrying in {delay:?} (attempt {}/{MAX_RETRIES})",
426                attempt + 1
427            );
428            tokio::time::sleep(delay).await;
429        };
430
431        self.connection_mode.store(client.connection_mode_atomic());
432
433        let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<AxOrdersWsMessage>();
434        self.out_rx = Some(Arc::new(out_rx));
435
436        let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
437        *self.cmd_tx.write().await = cmd_tx.clone();
438
439        self.send_cmd(HandlerCommand::SetClient(client)).await?;
440
441        if !self.instruments_cache.is_empty() {
442            let cached_instruments: Vec<InstrumentAny> = self
443                .instruments_cache
444                .iter()
445                .map(|entry| entry.value().clone())
446                .collect();
447            self.send_cmd(HandlerCommand::InitializeInstruments(cached_instruments))
448                .await?;
449        }
450
451        // Bearer token is passed in connection headers
452        self.send_cmd(HandlerCommand::Authenticate {
453            token: bearer_token.to_string(),
454        })
455        .await?;
456
457        let signal = Arc::clone(&self.signal);
458        let auth_tracker = self.auth_tracker.clone();
459        let account_id = self.account_id;
460        let orders_metadata = Arc::clone(&self.orders_metadata);
461        let venue_to_client_id = Arc::clone(&self.venue_to_client_id);
462        let cid_to_client_order_id = Arc::clone(&self.cid_to_client_order_id);
463
464        let stream_handle = get_runtime().spawn(async move {
465            let mut handler = FeedHandler::new(
466                signal.clone(),
467                cmd_rx,
468                raw_rx,
469                auth_tracker.clone(),
470                account_id,
471                orders_metadata,
472                venue_to_client_id,
473                cid_to_client_order_id,
474            );
475
476            while let Some(msg) = handler.next().await {
477                if matches!(msg, AxOrdersWsMessage::Reconnected) {
478                    log::info!("WebSocket reconnected, authentication will be restored");
479                }
480
481                if out_tx.send(msg).is_err() {
482                    log::debug!("Output channel closed");
483                    break;
484                }
485            }
486
487            log::debug!("Handler loop exited");
488        });
489
490        self.task_handle = Some(stream_handle);
491
492        Ok(())
493    }
494
495    /// Submits an order using Nautilus domain types.
496    ///
497    /// This method handles conversion from Nautilus domain types to AX-specific
498    /// types and stores order metadata for event correlation.
499    ///
500    /// # Errors
501    ///
502    /// Returns an error if:
503    /// - The order type is not supported (only MARKET (simulated), LIMIT and STOP_LIMIT).
504    /// - The time-in-force is not supported.
505    /// - The instrument is not found in the cache.
506    /// - A limit order is missing a price.
507    /// - A stop-loss order is missing a trigger price.
508    /// - The order command cannot be sent.
509    #[allow(clippy::too_many_arguments)]
510    pub async fn submit_order(
511        &self,
512        trader_id: TraderId,
513        strategy_id: StrategyId,
514        instrument_id: InstrumentId,
515        client_order_id: ClientOrderId,
516        order_side: OrderSide,
517        order_type: OrderType,
518        quantity: Quantity,
519        time_in_force: TimeInForce,
520        price: Option<Price>,
521        trigger_price: Option<Price>,
522        post_only: bool,
523    ) -> AxOrdersWsResult<i64> {
524        if !matches!(
525            order_type,
526            OrderType::Market | OrderType::Limit | OrderType::StopLimit
527        ) {
528            return Err(AxOrdersWsClientError::ClientError(format!(
529                "Unsupported order type: {order_type:?}. AX supports MARKET, LIMIT and STOP_LIMIT."
530            )));
531        }
532
533        // Get instrument from cache for precision
534        let symbol = instrument_id.symbol.inner();
535        let instrument = self.get_cached_instrument(&symbol).ok_or_else(|| {
536            AxOrdersWsClientError::ClientError(format!(
537                "Instrument {instrument_id} not found in cache"
538            ))
539        })?;
540
541        let ax_side = AxOrderSide::try_from(order_side)?;
542
543        let qty_contracts = quantity_to_contracts(quantity)
544            .map_err(|e| AxOrdersWsClientError::ClientError(e.to_string()))?;
545
546        // Market orders are simulated as IOC limit orders with aggressive pricing
547        // because Architect does not support native market orders
548        let request_id = self.next_request_id();
549
550        let (ax_price, ax_tif, ax_post_only, ax_order_type, ax_trigger_price) = match order_type {
551            OrderType::Market => {
552                let market_price = price.ok_or_else(|| {
553                    AxOrdersWsClientError::ClientError(
554                        "Market order requires price (calculated from quote)".to_string(),
555                    )
556                })?;
557                (
558                    market_price.as_decimal(),
559                    AxTimeInForce::Ioc,
560                    false,
561                    None,
562                    None,
563                )
564            }
565            OrderType::Limit => {
566                let ax_tif = AxTimeInForce::try_from(time_in_force)?;
567                let limit_price = price.ok_or_else(|| {
568                    AxOrdersWsClientError::ClientError("Limit order requires price".to_string())
569                })?;
570                (limit_price.as_decimal(), ax_tif, post_only, None, None)
571            }
572            OrderType::StopLimit => {
573                let ax_tif = AxTimeInForce::try_from(time_in_force)?;
574                let limit_price = price.ok_or_else(|| {
575                    AxOrdersWsClientError::ClientError(
576                        "Stop-limit order requires price".to_string(),
577                    )
578                })?;
579                let stop_price = trigger_price.ok_or_else(|| {
580                    AxOrdersWsClientError::ClientError(
581                        "Stop-limit order requires trigger price".to_string(),
582                    )
583                })?;
584                (
585                    limit_price.as_decimal(),
586                    ax_tif,
587                    false,
588                    Some(AxOrderType::StopLossLimit),
589                    Some(stop_price.as_decimal()),
590                )
591            }
592            _ => {
593                return Err(AxOrdersWsClientError::ClientError(format!(
594                    "Unsupported order type: {order_type:?}"
595                )));
596            }
597        };
598
599        // Store order metadata for event correlation (after validation to avoid stale entries)
600        let metadata = OrderMetadata {
601            trader_id,
602            strategy_id,
603            instrument_id,
604            client_order_id,
605            venue_order_id: None,
606            ts_init: self.generate_ts_init(),
607            size_precision: instrument.size_precision(),
608            price_precision: instrument.price_precision(),
609            quote_currency: instrument.quote_currency(),
610        };
611        self.orders_metadata.insert(client_order_id, metadata);
612
613        // Store cid -> client_order_id mapping for correlation
614        let cid = client_order_id_to_cid(&client_order_id);
615        self.cid_to_client_order_id.insert(cid, client_order_id);
616
617        let order = AxWsPlaceOrder {
618            rid: request_id,
619            t: AxOrderRequestType::PlaceOrder,
620            s: symbol,
621            d: ax_side,
622            q: qty_contracts,
623            p: ax_price,
624            tif: ax_tif,
625            po: ax_post_only,
626            tag: Some(AX_NAUTILUS_TAG.to_string()),
627            cid: Some(cid),
628            order_type: ax_order_type,
629            trigger_price: ax_trigger_price,
630        };
631
632        let order_info = WsOrderInfo {
633            client_order_id,
634            symbol,
635        };
636
637        let result = self
638            .send_cmd(HandlerCommand::PlaceOrder {
639                request_id,
640                order,
641                order_info,
642            })
643            .await;
644
645        if result.is_err() {
646            self.orders_metadata.remove(&client_order_id);
647            self.cid_to_client_order_id.remove(&cid);
648        }
649
650        result?;
651        Ok(request_id)
652    }
653
654    /// Cancels an order via WebSocket.
655    ///
656    /// Requires a known `venue_order_id`.
657    ///
658    /// # Errors
659    ///
660    /// Returns an error if the cancel command cannot be sent.
661    pub async fn cancel_order(
662        &self,
663        client_order_id: ClientOrderId,
664        venue_order_id: Option<VenueOrderId>,
665    ) -> AxOrdersWsResult<i64> {
666        let order_id = venue_order_id.map(|v| v.to_string()).ok_or_else(|| {
667            AxOrdersWsClientError::ClientError(format!(
668                "Cannot cancel order {client_order_id}: missing venue_order_id"
669            ))
670        })?;
671
672        let request_id = self.next_request_id();
673
674        self.send_cmd(HandlerCommand::CancelOrder {
675            request_id,
676            order_id,
677        })
678        .await?;
679
680        Ok(request_id)
681    }
682
683    /// Requests open orders via WebSocket.
684    ///
685    /// # Errors
686    ///
687    /// Returns an error if the request command cannot be sent.
688    pub async fn get_open_orders(&self) -> AxOrdersWsResult<i64> {
689        let request_id = self.next_request_id();
690
691        self.send_cmd(HandlerCommand::GetOpenOrders { request_id })
692            .await?;
693
694        Ok(request_id)
695    }
696
697    /// Returns a stream of WebSocket messages.
698    ///
699    /// # Panics
700    ///
701    /// Panics if called before `connect()` or if the stream has already been taken.
702    pub fn stream(&mut self) -> impl futures_util::Stream<Item = AxOrdersWsMessage> + 'static {
703        let rx = self
704            .out_rx
705            .take()
706            .expect("Stream receiver already taken or client not connected - stream() can only be called once");
707        let mut rx = Arc::try_unwrap(rx).expect(
708            "Cannot take ownership of stream - client was cloned and other references exist",
709        );
710        async_stream::stream! {
711            while let Some(msg) = rx.recv().await {
712                yield msg;
713            }
714        }
715    }
716
717    /// Disconnects the WebSocket connection gracefully.
718    pub async fn disconnect(&self) {
719        log::debug!("Disconnecting WebSocket");
720        let _ = self.send_cmd(HandlerCommand::Disconnect).await;
721    }
722
723    /// Closes the WebSocket connection and cleans up resources.
724    pub async fn close(&mut self) {
725        log::debug!("Closing WebSocket client");
726
727        // Send disconnect first to allow graceful cleanup before signal
728        let _ = self.send_cmd(HandlerCommand::Disconnect).await;
729        tokio::time::sleep(Duration::from_millis(50)).await;
730        self.signal.store(true, Ordering::Release);
731
732        if let Some(handle) = self.task_handle.take() {
733            const CLOSE_TIMEOUT: Duration = Duration::from_secs(2);
734            let abort_handle = handle.abort_handle();
735
736            match tokio::time::timeout(CLOSE_TIMEOUT, handle).await {
737                Ok(Ok(())) => log::debug!("Handler task completed gracefully"),
738                Ok(Err(e)) => log::warn!("Handler task panicked: {e}"),
739                Err(_) => {
740                    log::warn!("Handler task did not complete within timeout, aborting");
741                    abort_handle.abort();
742                }
743            }
744        }
745    }
746
747    async fn send_cmd(&self, cmd: HandlerCommand) -> AxOrdersWsResult<()> {
748        let guard = self.cmd_tx.read().await;
749        guard
750            .send(cmd)
751            .map_err(|e| AxOrdersWsClientError::ChannelError(e.to_string()))
752    }
753}
754
755#[cfg(test)]
756mod tests {
757    use std::sync::Arc;
758
759    use super::*;
760
761    #[tokio::test]
762    async fn test_cancel_order_rejects_without_venue_order_id() {
763        let client = AxOrdersWebSocketClient::new(
764            "wss://example.com/orders/ws".to_string(),
765            AccountId::from("AX-001"),
766            TraderId::from("TRADER-001"),
767            Some(30),
768        );
769        let client_order_id = ClientOrderId::from("CID-123");
770
771        let result = client.cancel_order(client_order_id, None).await;
772
773        assert!(matches!(
774            result,
775            Err(AxOrdersWsClientError::ClientError(msg))
776            if msg.contains("missing venue_order_id")
777        ));
778    }
779
780    #[tokio::test]
781    async fn test_cancel_order_sends_known_venue_order_id() {
782        let mut client = AxOrdersWebSocketClient::new(
783            "wss://example.com/orders/ws".to_string(),
784            AccountId::from("AX-001"),
785            TraderId::from("TRADER-001"),
786            Some(30),
787        );
788
789        let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
790        client.cmd_tx = Arc::new(tokio::sync::RwLock::new(cmd_tx));
791
792        let client_order_id = ClientOrderId::from("CID-456");
793        let venue_order_id = VenueOrderId::from("V-ORDER-789");
794
795        let request_id = client
796            .cancel_order(client_order_id, Some(venue_order_id))
797            .await
798            .unwrap();
799
800        assert_eq!(request_id, 1);
801        let cmd = cmd_rx.recv().await.unwrap();
802        match cmd {
803            HandlerCommand::CancelOrder {
804                request_id,
805                order_id,
806            } => {
807                assert_eq!(request_id, 1);
808                assert_eq!(order_id, "V-ORDER-789");
809            }
810            other => panic!("unexpected command: {other:?}"),
811        }
812    }
813}