1use std::{
19 fmt::Debug,
20 sync::{
21 Arc,
22 atomic::{AtomicBool, AtomicI64, AtomicU8, Ordering},
23 },
24 time::Duration,
25};
26
27use arc_swap::ArcSwap;
28use dashmap::DashMap;
29use nautilus_common::live::get_runtime;
30use nautilus_core::{
31 consts::NAUTILUS_USER_AGENT,
32 nanos::UnixNanos,
33 time::{AtomicTime, get_atomic_clock_realtime},
34};
35use nautilus_model::{
36 enums::{OrderSide, OrderType, TimeInForce},
37 identifiers::{AccountId, ClientOrderId, InstrumentId, StrategyId, TraderId, VenueOrderId},
38 instruments::{Instrument, InstrumentAny},
39 types::{Price, Quantity},
40};
41use nautilus_network::{
42 backoff::ExponentialBackoff,
43 mode::ConnectionMode,
44 websocket::{
45 AuthTracker, PingHandler, WebSocketClient, WebSocketConfig, channel_message_handler,
46 },
47};
48use ustr::Ustr;
49
50use super::handler::{FeedHandler, HandlerCommand, WsOrderInfo};
51use crate::{
52 common::{
53 consts::AX_NAUTILUS_TAG,
54 enums::{AxOrderRequestType, AxOrderSide, AxOrderType, AxTimeInForce},
55 parse::{client_order_id_to_cid, quantity_to_contracts},
56 },
57 websocket::messages::{AxOrdersWsMessage, AxWsPlaceOrder, OrderMetadata},
58};
59
60const DEFAULT_HEARTBEAT_SECS: u64 = 30;
62
63pub type AxOrdersWsResult<T> = Result<T, AxOrdersWsClientError>;
65
66#[derive(Debug, Clone)]
68pub enum AxOrdersWsClientError {
69 Transport(String),
71 ChannelError(String),
73 AuthenticationError(String),
75 ClientError(String),
77}
78
79impl core::fmt::Display for AxOrdersWsClientError {
80 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
81 match self {
82 Self::Transport(msg) => write!(f, "Transport error: {msg}"),
83 Self::ChannelError(msg) => write!(f, "Channel error: {msg}"),
84 Self::AuthenticationError(msg) => write!(f, "Authentication error: {msg}"),
85 Self::ClientError(msg) => write!(f, "Client error: {msg}"),
86 }
87 }
88}
89
90impl std::error::Error for AxOrdersWsClientError {}
91
92impl From<&'static str> for AxOrdersWsClientError {
93 fn from(msg: &'static str) -> Self {
94 Self::ClientError(msg.to_string())
95 }
96}
97
98#[cfg_attr(
103 feature = "python",
104 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.architect")
105)]
106pub struct AxOrdersWebSocketClient {
107 clock: &'static AtomicTime,
108 url: String,
109 heartbeat: Option<u64>,
110 connection_mode: Arc<ArcSwap<AtomicU8>>,
111 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
112 out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<AxOrdersWsMessage>>>,
113 signal: Arc<AtomicBool>,
114 task_handle: Option<tokio::task::JoinHandle<()>>,
115 auth_tracker: AuthTracker,
116 instruments_cache: Arc<DashMap<Ustr, InstrumentAny>>,
117 orders_metadata: Arc<DashMap<ClientOrderId, OrderMetadata>>,
118 venue_to_client_id: Arc<DashMap<VenueOrderId, ClientOrderId>>,
119 cid_to_client_order_id: Arc<DashMap<u64, ClientOrderId>>,
120 request_id_counter: Arc<AtomicI64>,
121 account_id: AccountId,
122 trader_id: TraderId,
123}
124
125impl Debug for AxOrdersWebSocketClient {
126 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
127 f.debug_struct(stringify!(AxOrdersWebSocketClient))
128 .field("url", &self.url)
129 .field("heartbeat", &self.heartbeat)
130 .field("account_id", &self.account_id)
131 .finish()
132 }
133}
134
135impl Clone for AxOrdersWebSocketClient {
136 fn clone(&self) -> Self {
137 Self {
138 clock: self.clock,
139 url: self.url.clone(),
140 heartbeat: self.heartbeat,
141 connection_mode: Arc::clone(&self.connection_mode),
142 cmd_tx: Arc::clone(&self.cmd_tx),
143 out_rx: None, signal: Arc::clone(&self.signal),
145 task_handle: None,
146 auth_tracker: self.auth_tracker.clone(),
147 instruments_cache: Arc::clone(&self.instruments_cache),
148 orders_metadata: Arc::clone(&self.orders_metadata),
149 venue_to_client_id: Arc::clone(&self.venue_to_client_id),
150 cid_to_client_order_id: Arc::clone(&self.cid_to_client_order_id),
151 request_id_counter: Arc::clone(&self.request_id_counter),
152 account_id: self.account_id,
153 trader_id: self.trader_id,
154 }
155 }
156}
157
158impl AxOrdersWebSocketClient {
159 #[must_use]
161 pub fn new(
162 url: String,
163 account_id: AccountId,
164 trader_id: TraderId,
165 heartbeat: Option<u64>,
166 ) -> Self {
167 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
168
169 let initial_mode = AtomicU8::new(ConnectionMode::Closed.as_u8());
170 let connection_mode = Arc::new(ArcSwap::from_pointee(initial_mode));
171
172 Self {
173 clock: get_atomic_clock_realtime(),
174 url,
175 heartbeat: heartbeat.or(Some(DEFAULT_HEARTBEAT_SECS)),
176 connection_mode,
177 cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
178 out_rx: None,
179 signal: Arc::new(AtomicBool::new(false)),
180 task_handle: None,
181 auth_tracker: AuthTracker::default(),
182 instruments_cache: Arc::new(DashMap::new()),
183 orders_metadata: Arc::new(DashMap::new()),
184 venue_to_client_id: Arc::new(DashMap::new()),
185 cid_to_client_order_id: Arc::new(DashMap::new()),
186 request_id_counter: Arc::new(AtomicI64::new(1)),
187 account_id,
188 trader_id,
189 }
190 }
191
192 fn generate_ts_init(&self) -> UnixNanos {
193 self.clock.get_time_ns()
194 }
195
196 #[must_use]
198 pub fn url(&self) -> &str {
199 &self.url
200 }
201
202 #[must_use]
204 pub fn account_id(&self) -> AccountId {
205 self.account_id
206 }
207
208 #[must_use]
210 pub fn is_active(&self) -> bool {
211 let connection_mode_arc = self.connection_mode.load();
212 ConnectionMode::from_atomic(&connection_mode_arc).is_active()
213 && !self.signal.load(Ordering::Acquire)
214 }
215
216 #[must_use]
218 pub fn is_closed(&self) -> bool {
219 let connection_mode_arc = self.connection_mode.load();
220 ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
221 || self.signal.load(Ordering::Acquire)
222 }
223
224 fn next_request_id(&self) -> i64 {
226 self.request_id_counter.fetch_add(1, Ordering::Relaxed)
227 }
228
229 pub fn cache_instrument(&self, instrument: InstrumentAny) {
231 let symbol = instrument.symbol().inner();
232 self.instruments_cache.insert(symbol, instrument.clone());
233
234 if self.is_active() {
236 let cmd = HandlerCommand::UpdateInstrument(Box::new(instrument));
237 let cmd_tx = self.cmd_tx.clone();
238 get_runtime().spawn(async move {
239 let guard = cmd_tx.read().await;
240 let _ = guard.send(cmd);
241 });
242 }
243 }
244
245 #[must_use]
247 pub fn get_cached_instrument(&self, symbol: &Ustr) -> Option<InstrumentAny> {
248 self.instruments_cache.get(symbol).map(|r| r.clone())
249 }
250
251 #[must_use]
253 pub fn orders_metadata(&self) -> &Arc<DashMap<ClientOrderId, OrderMetadata>> {
254 &self.orders_metadata
255 }
256
257 #[must_use]
259 pub fn cid_to_client_order_id(&self) -> &Arc<DashMap<u64, ClientOrderId>> {
260 &self.cid_to_client_order_id
261 }
262
263 #[must_use]
265 pub fn resolve_cid(&self, cid: u64) -> Option<ClientOrderId> {
266 self.cid_to_client_order_id.get(&cid).map(|v| *v)
267 }
268
269 pub fn register_external_order(
276 &self,
277 client_order_id: ClientOrderId,
278 venue_order_id: VenueOrderId,
279 instrument_id: InstrumentId,
280 strategy_id: StrategyId,
281 ) -> bool {
282 if self.orders_metadata.contains_key(&client_order_id) {
283 return true;
284 }
285
286 let symbol = instrument_id.symbol.inner();
288 let Some(instrument) = self.get_cached_instrument(&symbol) else {
289 log::warn!(
290 "Cannot register external order {client_order_id}: \
291 instrument {instrument_id} not in cache"
292 );
293 return false;
294 };
295
296 let metadata = OrderMetadata {
297 trader_id: self.trader_id,
298 strategy_id,
299 instrument_id,
300 client_order_id,
301 venue_order_id: Some(venue_order_id),
302 ts_init: self.generate_ts_init(),
303 size_precision: instrument.size_precision(),
304 price_precision: instrument.price_precision(),
305 quote_currency: instrument.quote_currency(),
306 };
307
308 self.orders_metadata.insert(client_order_id, metadata);
309 self.venue_to_client_id
310 .insert(venue_order_id, client_order_id);
311
312 log::debug!(
313 "Registered external order {client_order_id} ({venue_order_id}) for {instrument_id} [{strategy_id}]"
314 );
315
316 true
317 }
318
319 pub async fn connect(&mut self, bearer_token: &str) -> AxOrdersWsResult<()> {
329 const MAX_RETRIES: u32 = 5;
330 const CONNECTION_TIMEOUT_SECS: u64 = 10;
331
332 self.signal.store(false, Ordering::Release);
333
334 let (raw_handler, raw_rx) = channel_message_handler();
335
336 let ping_handler: PingHandler = Arc::new(move |_payload: Vec<u8>| {
338 });
340
341 let config = WebSocketConfig {
342 url: self.url.clone(),
343 headers: vec![
344 ("User-Agent".to_string(), NAUTILUS_USER_AGENT.to_string()),
345 (
346 "Authorization".to_string(),
347 format!("Bearer {bearer_token}"),
348 ),
349 ],
350 heartbeat: self.heartbeat,
351 heartbeat_msg: None, reconnect_timeout_ms: Some(5_000),
353 reconnect_delay_initial_ms: Some(500),
354 reconnect_delay_max_ms: Some(5_000),
355 reconnect_backoff_factor: Some(1.5),
356 reconnect_jitter_ms: Some(250),
357 reconnect_max_attempts: None,
358 };
359
360 let mut backoff = ExponentialBackoff::new(
362 Duration::from_millis(500),
363 Duration::from_millis(5000),
364 2.0,
365 250,
366 false,
367 )
368 .map_err(|e| AxOrdersWsClientError::Transport(e.to_string()))?;
369
370 let mut last_error: String;
371 let mut attempt = 0;
372
373 let client = loop {
374 attempt += 1;
375
376 match tokio::time::timeout(
377 Duration::from_secs(CONNECTION_TIMEOUT_SECS),
378 WebSocketClient::connect(
379 config.clone(),
380 Some(raw_handler.clone()),
381 Some(ping_handler.clone()),
382 None,
383 vec![],
384 None,
385 ),
386 )
387 .await
388 {
389 Ok(Ok(client)) => {
390 if attempt > 1 {
391 log::info!("WebSocket connection established after {attempt} attempts");
392 }
393 break client;
394 }
395 Ok(Err(e)) => {
396 last_error = e.to_string();
397 log::warn!(
398 "WebSocket connection attempt failed: attempt={attempt}, max_retries={MAX_RETRIES}, url={}, error={last_error}",
399 self.url
400 );
401 }
402 Err(_) => {
403 last_error = format!("Connection timeout after {CONNECTION_TIMEOUT_SECS}s");
404 log::warn!(
405 "WebSocket connection attempt timed out: attempt={attempt}, max_retries={MAX_RETRIES}, url={}",
406 self.url
407 );
408 }
409 }
410
411 if attempt >= MAX_RETRIES {
412 return Err(AxOrdersWsClientError::Transport(format!(
413 "Failed to connect to {} after {MAX_RETRIES} attempts: {}",
414 self.url,
415 if last_error.is_empty() {
416 "unknown error"
417 } else {
418 &last_error
419 }
420 )));
421 }
422
423 let delay = backoff.next_duration();
424 log::debug!(
425 "Retrying in {delay:?} (attempt {}/{MAX_RETRIES})",
426 attempt + 1
427 );
428 tokio::time::sleep(delay).await;
429 };
430
431 self.connection_mode.store(client.connection_mode_atomic());
432
433 let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<AxOrdersWsMessage>();
434 self.out_rx = Some(Arc::new(out_rx));
435
436 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
437 *self.cmd_tx.write().await = cmd_tx.clone();
438
439 self.send_cmd(HandlerCommand::SetClient(client)).await?;
440
441 if !self.instruments_cache.is_empty() {
442 let cached_instruments: Vec<InstrumentAny> = self
443 .instruments_cache
444 .iter()
445 .map(|entry| entry.value().clone())
446 .collect();
447 self.send_cmd(HandlerCommand::InitializeInstruments(cached_instruments))
448 .await?;
449 }
450
451 self.send_cmd(HandlerCommand::Authenticate {
453 token: bearer_token.to_string(),
454 })
455 .await?;
456
457 let signal = Arc::clone(&self.signal);
458 let auth_tracker = self.auth_tracker.clone();
459 let account_id = self.account_id;
460 let orders_metadata = Arc::clone(&self.orders_metadata);
461 let venue_to_client_id = Arc::clone(&self.venue_to_client_id);
462 let cid_to_client_order_id = Arc::clone(&self.cid_to_client_order_id);
463
464 let stream_handle = get_runtime().spawn(async move {
465 let mut handler = FeedHandler::new(
466 signal.clone(),
467 cmd_rx,
468 raw_rx,
469 auth_tracker.clone(),
470 account_id,
471 orders_metadata,
472 venue_to_client_id,
473 cid_to_client_order_id,
474 );
475
476 while let Some(msg) = handler.next().await {
477 if matches!(msg, AxOrdersWsMessage::Reconnected) {
478 log::info!("WebSocket reconnected, authentication will be restored");
479 }
480
481 if out_tx.send(msg).is_err() {
482 log::debug!("Output channel closed");
483 break;
484 }
485 }
486
487 log::debug!("Handler loop exited");
488 });
489
490 self.task_handle = Some(stream_handle);
491
492 Ok(())
493 }
494
495 #[allow(clippy::too_many_arguments)]
510 pub async fn submit_order(
511 &self,
512 trader_id: TraderId,
513 strategy_id: StrategyId,
514 instrument_id: InstrumentId,
515 client_order_id: ClientOrderId,
516 order_side: OrderSide,
517 order_type: OrderType,
518 quantity: Quantity,
519 time_in_force: TimeInForce,
520 price: Option<Price>,
521 trigger_price: Option<Price>,
522 post_only: bool,
523 ) -> AxOrdersWsResult<i64> {
524 if !matches!(
525 order_type,
526 OrderType::Market | OrderType::Limit | OrderType::StopLimit
527 ) {
528 return Err(AxOrdersWsClientError::ClientError(format!(
529 "Unsupported order type: {order_type:?}. AX supports MARKET, LIMIT and STOP_LIMIT."
530 )));
531 }
532
533 let symbol = instrument_id.symbol.inner();
535 let instrument = self.get_cached_instrument(&symbol).ok_or_else(|| {
536 AxOrdersWsClientError::ClientError(format!(
537 "Instrument {instrument_id} not found in cache"
538 ))
539 })?;
540
541 let ax_side = AxOrderSide::try_from(order_side)?;
542
543 let qty_contracts = quantity_to_contracts(quantity)
544 .map_err(|e| AxOrdersWsClientError::ClientError(e.to_string()))?;
545
546 let request_id = self.next_request_id();
549
550 let (ax_price, ax_tif, ax_post_only, ax_order_type, ax_trigger_price) = match order_type {
551 OrderType::Market => {
552 let market_price = price.ok_or_else(|| {
553 AxOrdersWsClientError::ClientError(
554 "Market order requires price (calculated from quote)".to_string(),
555 )
556 })?;
557 (
558 market_price.as_decimal(),
559 AxTimeInForce::Ioc,
560 false,
561 None,
562 None,
563 )
564 }
565 OrderType::Limit => {
566 let ax_tif = AxTimeInForce::try_from(time_in_force)?;
567 let limit_price = price.ok_or_else(|| {
568 AxOrdersWsClientError::ClientError("Limit order requires price".to_string())
569 })?;
570 (limit_price.as_decimal(), ax_tif, post_only, None, None)
571 }
572 OrderType::StopLimit => {
573 let ax_tif = AxTimeInForce::try_from(time_in_force)?;
574 let limit_price = price.ok_or_else(|| {
575 AxOrdersWsClientError::ClientError(
576 "Stop-limit order requires price".to_string(),
577 )
578 })?;
579 let stop_price = trigger_price.ok_or_else(|| {
580 AxOrdersWsClientError::ClientError(
581 "Stop-limit order requires trigger price".to_string(),
582 )
583 })?;
584 (
585 limit_price.as_decimal(),
586 ax_tif,
587 false,
588 Some(AxOrderType::StopLossLimit),
589 Some(stop_price.as_decimal()),
590 )
591 }
592 _ => {
593 return Err(AxOrdersWsClientError::ClientError(format!(
594 "Unsupported order type: {order_type:?}"
595 )));
596 }
597 };
598
599 let metadata = OrderMetadata {
601 trader_id,
602 strategy_id,
603 instrument_id,
604 client_order_id,
605 venue_order_id: None,
606 ts_init: self.generate_ts_init(),
607 size_precision: instrument.size_precision(),
608 price_precision: instrument.price_precision(),
609 quote_currency: instrument.quote_currency(),
610 };
611 self.orders_metadata.insert(client_order_id, metadata);
612
613 let cid = client_order_id_to_cid(&client_order_id);
615 self.cid_to_client_order_id.insert(cid, client_order_id);
616
617 let order = AxWsPlaceOrder {
618 rid: request_id,
619 t: AxOrderRequestType::PlaceOrder,
620 s: symbol,
621 d: ax_side,
622 q: qty_contracts,
623 p: ax_price,
624 tif: ax_tif,
625 po: ax_post_only,
626 tag: Some(AX_NAUTILUS_TAG.to_string()),
627 cid: Some(cid),
628 order_type: ax_order_type,
629 trigger_price: ax_trigger_price,
630 };
631
632 let order_info = WsOrderInfo {
633 client_order_id,
634 symbol,
635 };
636
637 let result = self
638 .send_cmd(HandlerCommand::PlaceOrder {
639 request_id,
640 order,
641 order_info,
642 })
643 .await;
644
645 if result.is_err() {
646 self.orders_metadata.remove(&client_order_id);
647 self.cid_to_client_order_id.remove(&cid);
648 }
649
650 result?;
651 Ok(request_id)
652 }
653
654 pub async fn cancel_order(
662 &self,
663 client_order_id: ClientOrderId,
664 venue_order_id: Option<VenueOrderId>,
665 ) -> AxOrdersWsResult<i64> {
666 let order_id = venue_order_id.map(|v| v.to_string()).ok_or_else(|| {
667 AxOrdersWsClientError::ClientError(format!(
668 "Cannot cancel order {client_order_id}: missing venue_order_id"
669 ))
670 })?;
671
672 let request_id = self.next_request_id();
673
674 self.send_cmd(HandlerCommand::CancelOrder {
675 request_id,
676 order_id,
677 })
678 .await?;
679
680 Ok(request_id)
681 }
682
683 pub async fn get_open_orders(&self) -> AxOrdersWsResult<i64> {
689 let request_id = self.next_request_id();
690
691 self.send_cmd(HandlerCommand::GetOpenOrders { request_id })
692 .await?;
693
694 Ok(request_id)
695 }
696
697 pub fn stream(&mut self) -> impl futures_util::Stream<Item = AxOrdersWsMessage> + 'static {
703 let rx = self
704 .out_rx
705 .take()
706 .expect("Stream receiver already taken or client not connected - stream() can only be called once");
707 let mut rx = Arc::try_unwrap(rx).expect(
708 "Cannot take ownership of stream - client was cloned and other references exist",
709 );
710 async_stream::stream! {
711 while let Some(msg) = rx.recv().await {
712 yield msg;
713 }
714 }
715 }
716
717 pub async fn disconnect(&self) {
719 log::debug!("Disconnecting WebSocket");
720 let _ = self.send_cmd(HandlerCommand::Disconnect).await;
721 }
722
723 pub async fn close(&mut self) {
725 log::debug!("Closing WebSocket client");
726
727 let _ = self.send_cmd(HandlerCommand::Disconnect).await;
729 tokio::time::sleep(Duration::from_millis(50)).await;
730 self.signal.store(true, Ordering::Release);
731
732 if let Some(handle) = self.task_handle.take() {
733 const CLOSE_TIMEOUT: Duration = Duration::from_secs(2);
734 let abort_handle = handle.abort_handle();
735
736 match tokio::time::timeout(CLOSE_TIMEOUT, handle).await {
737 Ok(Ok(())) => log::debug!("Handler task completed gracefully"),
738 Ok(Err(e)) => log::warn!("Handler task panicked: {e}"),
739 Err(_) => {
740 log::warn!("Handler task did not complete within timeout, aborting");
741 abort_handle.abort();
742 }
743 }
744 }
745 }
746
747 async fn send_cmd(&self, cmd: HandlerCommand) -> AxOrdersWsResult<()> {
748 let guard = self.cmd_tx.read().await;
749 guard
750 .send(cmd)
751 .map_err(|e| AxOrdersWsClientError::ChannelError(e.to_string()))
752 }
753}
754
755#[cfg(test)]
756mod tests {
757 use std::sync::Arc;
758
759 use super::*;
760
761 #[tokio::test]
762 async fn test_cancel_order_rejects_without_venue_order_id() {
763 let client = AxOrdersWebSocketClient::new(
764 "wss://example.com/orders/ws".to_string(),
765 AccountId::from("AX-001"),
766 TraderId::from("TRADER-001"),
767 Some(30),
768 );
769 let client_order_id = ClientOrderId::from("CID-123");
770
771 let result = client.cancel_order(client_order_id, None).await;
772
773 assert!(matches!(
774 result,
775 Err(AxOrdersWsClientError::ClientError(msg))
776 if msg.contains("missing venue_order_id")
777 ));
778 }
779
780 #[tokio::test]
781 async fn test_cancel_order_sends_known_venue_order_id() {
782 let mut client = AxOrdersWebSocketClient::new(
783 "wss://example.com/orders/ws".to_string(),
784 AccountId::from("AX-001"),
785 TraderId::from("TRADER-001"),
786 Some(30),
787 );
788
789 let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
790 client.cmd_tx = Arc::new(tokio::sync::RwLock::new(cmd_tx));
791
792 let client_order_id = ClientOrderId::from("CID-456");
793 let venue_order_id = VenueOrderId::from("V-ORDER-789");
794
795 let request_id = client
796 .cancel_order(client_order_id, Some(venue_order_id))
797 .await
798 .unwrap();
799
800 assert_eq!(request_id, 1);
801 let cmd = cmd_rx.recv().await.unwrap();
802 match cmd {
803 HandlerCommand::CancelOrder {
804 request_id,
805 order_id,
806 } => {
807 assert_eq!(request_id, 1);
808 assert_eq!(order_id, "V-ORDER-789");
809 }
810 other => panic!("unexpected command: {other:?}"),
811 }
812 }
813}