nautilus_architect_ax/websocket/orders/
client.rs1use std::{
19 fmt::Debug,
20 sync::{
21 Arc,
22 atomic::{AtomicBool, AtomicI64, AtomicU8, Ordering},
23 },
24 time::Duration,
25};
26
27use arc_swap::ArcSwap;
28use dashmap::DashMap;
29use nautilus_common::live::get_runtime;
30use nautilus_core::consts::NAUTILUS_USER_AGENT;
31use nautilus_model::{
32 identifiers::{AccountId, ClientOrderId},
33 instruments::{Instrument, InstrumentAny},
34};
35use nautilus_network::{
36 backoff::ExponentialBackoff,
37 mode::ConnectionMode,
38 websocket::{
39 AuthTracker, PingHandler, WebSocketClient, WebSocketConfig, channel_message_handler,
40 },
41};
42use rust_decimal::Decimal;
43use ustr::Ustr;
44
45use super::handler::{FeedHandler, HandlerCommand};
46use crate::{
47 common::enums::{AxOrderSide, AxTimeInForce},
48 websocket::messages::{AxOrdersWsMessage, AxWsPlaceOrder, OrderMetadata},
49};
50
51const DEFAULT_HEARTBEAT_SECS: u64 = 30;
53
54pub type AxOrdersWsResult<T> = Result<T, AxOrdersWsClientError>;
56
57#[derive(Debug, Clone)]
59pub enum AxOrdersWsClientError {
60 Transport(String),
62 ChannelError(String),
64 AuthenticationError(String),
66}
67
68impl core::fmt::Display for AxOrdersWsClientError {
69 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
70 match self {
71 Self::Transport(msg) => write!(f, "Transport error: {msg}"),
72 Self::ChannelError(msg) => write!(f, "Channel error: {msg}"),
73 Self::AuthenticationError(msg) => write!(f, "Authentication error: {msg}"),
74 }
75 }
76}
77
78impl std::error::Error for AxOrdersWsClientError {}
79
80pub struct AxOrdersWebSocketClient {
85 url: String,
86 heartbeat: Option<u64>,
87 connection_mode: Arc<ArcSwap<AtomicU8>>,
88 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
89 out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<AxOrdersWsMessage>>>,
90 signal: Arc<AtomicBool>,
91 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
92 auth_tracker: AuthTracker,
93 instruments_cache: Arc<DashMap<Ustr, InstrumentAny>>,
94 request_id_counter: Arc<AtomicI64>,
95 account_id: AccountId,
96}
97
98impl Debug for AxOrdersWebSocketClient {
99 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
100 f.debug_struct(stringify!(AxOrdersWebSocketClient))
101 .field("url", &self.url)
102 .field("heartbeat", &self.heartbeat)
103 .field("account_id", &self.account_id)
104 .finish()
105 }
106}
107
108impl Clone for AxOrdersWebSocketClient {
109 fn clone(&self) -> Self {
110 Self {
111 url: self.url.clone(),
112 heartbeat: self.heartbeat,
113 connection_mode: Arc::clone(&self.connection_mode),
114 cmd_tx: Arc::clone(&self.cmd_tx),
115 out_rx: None, signal: Arc::clone(&self.signal),
117 task_handle: None, auth_tracker: self.auth_tracker.clone(),
119 instruments_cache: Arc::clone(&self.instruments_cache),
120 request_id_counter: Arc::clone(&self.request_id_counter),
121 account_id: self.account_id,
122 }
123 }
124}
125
126impl AxOrdersWebSocketClient {
127 #[must_use]
129 pub fn new(url: String, account_id: AccountId, heartbeat: Option<u64>) -> Self {
130 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
131
132 let initial_mode = AtomicU8::new(ConnectionMode::Closed.as_u8());
133 let connection_mode = Arc::new(ArcSwap::from_pointee(initial_mode));
134
135 Self {
136 url,
137 heartbeat: heartbeat.or(Some(DEFAULT_HEARTBEAT_SECS)),
138 connection_mode,
139 cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
140 out_rx: None,
141 signal: Arc::new(AtomicBool::new(false)),
142 task_handle: None,
143 auth_tracker: AuthTracker::default(),
144 instruments_cache: Arc::new(DashMap::new()),
145 request_id_counter: Arc::new(AtomicI64::new(1)),
146 account_id,
147 }
148 }
149
150 #[must_use]
152 pub fn url(&self) -> &str {
153 &self.url
154 }
155
156 #[must_use]
158 pub fn account_id(&self) -> AccountId {
159 self.account_id
160 }
161
162 #[must_use]
164 pub fn is_active(&self) -> bool {
165 let connection_mode_arc = self.connection_mode.load();
166 ConnectionMode::from_atomic(&connection_mode_arc).is_active()
167 && !self.signal.load(Ordering::Acquire)
168 }
169
170 #[must_use]
172 pub fn is_closed(&self) -> bool {
173 let connection_mode_arc = self.connection_mode.load();
174 ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
175 || self.signal.load(Ordering::Acquire)
176 }
177
178 fn next_request_id(&self) -> i64 {
180 self.request_id_counter.fetch_add(1, Ordering::Relaxed)
181 }
182
183 pub fn cache_instrument(&self, instrument: InstrumentAny) {
185 let symbol = instrument.symbol().inner();
186 self.instruments_cache.insert(symbol, instrument.clone());
187
188 if self.is_active() {
190 let cmd = HandlerCommand::UpdateInstrument(Box::new(instrument));
191 let cmd_tx = self.cmd_tx.clone();
192 get_runtime().spawn(async move {
193 let guard = cmd_tx.read().await;
194 let _ = guard.send(cmd);
195 });
196 }
197 }
198
199 #[must_use]
201 pub fn get_cached_instrument(&self, symbol: &Ustr) -> Option<InstrumentAny> {
202 self.instruments_cache.get(symbol).map(|r| r.clone())
203 }
204
205 pub async fn connect(&mut self, bearer_token: &str) -> AxOrdersWsResult<()> {
215 const MAX_RETRIES: u32 = 5;
216 const CONNECTION_TIMEOUT_SECS: u64 = 10;
217
218 self.signal.store(false, Ordering::Relaxed);
219
220 let (raw_handler, raw_rx) = channel_message_handler();
221
222 let ping_handler: PingHandler = Arc::new(move |_payload: Vec<u8>| {
224 });
226
227 let config = WebSocketConfig {
228 url: self.url.clone(),
229 headers: vec![
230 ("User-Agent".to_string(), NAUTILUS_USER_AGENT.to_string()),
231 (
232 "Authorization".to_string(),
233 format!("Bearer {bearer_token}"),
234 ),
235 ],
236 heartbeat: self.heartbeat,
237 heartbeat_msg: None, reconnect_timeout_ms: Some(5_000),
239 reconnect_delay_initial_ms: Some(500),
240 reconnect_delay_max_ms: Some(5_000),
241 reconnect_backoff_factor: Some(1.5),
242 reconnect_jitter_ms: Some(250),
243 reconnect_max_attempts: None,
244 };
245
246 let mut backoff = ExponentialBackoff::new(
248 Duration::from_millis(500),
249 Duration::from_millis(5000),
250 2.0,
251 250,
252 false,
253 )
254 .map_err(|e| AxOrdersWsClientError::Transport(e.to_string()))?;
255
256 let mut last_error: String;
257 let mut attempt = 0;
258
259 let client = loop {
260 attempt += 1;
261
262 match tokio::time::timeout(
263 Duration::from_secs(CONNECTION_TIMEOUT_SECS),
264 WebSocketClient::connect(
265 config.clone(),
266 Some(raw_handler.clone()),
267 Some(ping_handler.clone()),
268 None,
269 vec![],
270 None,
271 ),
272 )
273 .await
274 {
275 Ok(Ok(client)) => {
276 if attempt > 1 {
277 log::info!("WebSocket connection established after {attempt} attempts");
278 }
279 break client;
280 }
281 Ok(Err(e)) => {
282 last_error = e.to_string();
283 log::warn!(
284 "WebSocket connection attempt failed: attempt={attempt}, max_retries={MAX_RETRIES}, url={}, error={last_error}",
285 self.url
286 );
287 }
288 Err(_) => {
289 last_error = format!("Connection timeout after {CONNECTION_TIMEOUT_SECS}s");
290 log::warn!(
291 "WebSocket connection attempt timed out: attempt={attempt}, max_retries={MAX_RETRIES}, url={}",
292 self.url
293 );
294 }
295 }
296
297 if attempt >= MAX_RETRIES {
298 return Err(AxOrdersWsClientError::Transport(format!(
299 "Failed to connect to {} after {MAX_RETRIES} attempts: {}",
300 self.url,
301 if last_error.is_empty() {
302 "unknown error"
303 } else {
304 &last_error
305 }
306 )));
307 }
308
309 let delay = backoff.next_duration();
310 log::debug!(
311 "Retrying in {delay:?} (attempt {}/{MAX_RETRIES})",
312 attempt + 1
313 );
314 tokio::time::sleep(delay).await;
315 };
316
317 self.connection_mode.store(client.connection_mode_atomic());
318
319 let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<AxOrdersWsMessage>();
320 self.out_rx = Some(Arc::new(out_rx));
321
322 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
323 *self.cmd_tx.write().await = cmd_tx.clone();
324
325 self.send_cmd(HandlerCommand::SetClient(client)).await?;
326
327 if !self.instruments_cache.is_empty() {
328 let cached_instruments: Vec<InstrumentAny> = self
329 .instruments_cache
330 .iter()
331 .map(|entry| entry.value().clone())
332 .collect();
333 self.send_cmd(HandlerCommand::InitializeInstruments(cached_instruments))
334 .await?;
335 }
336
337 self.send_cmd(HandlerCommand::Authenticate {
339 token: bearer_token.to_string(),
340 })
341 .await?;
342
343 let signal = Arc::clone(&self.signal);
344 let auth_tracker = self.auth_tracker.clone();
345
346 let stream_handle = get_runtime().spawn(async move {
347 let mut handler = FeedHandler::new(
348 signal.clone(),
349 cmd_rx,
350 raw_rx,
351 out_tx.clone(),
352 auth_tracker.clone(),
353 );
354
355 while let Some(msg) = handler.next().await {
356 if matches!(msg, AxOrdersWsMessage::Reconnected) {
357 log::info!("WebSocket reconnected");
358 }
360
361 if out_tx.send(msg).is_err() {
362 log::debug!("Output channel closed");
363 break;
364 }
365 }
366
367 log::debug!("Handler loop exited");
368 });
369
370 self.task_handle = Some(Arc::new(stream_handle));
371
372 Ok(())
373 }
374
375 #[allow(clippy::too_many_arguments)]
381 pub async fn place_order(
382 &self,
383 client_order_id: ClientOrderId,
384 symbol: Ustr,
385 side: AxOrderSide,
386 quantity: i64,
387 price: Decimal,
388 time_in_force: AxTimeInForce,
389 post_only: bool,
390 tag: Option<String>,
391 ) -> AxOrdersWsResult<i64> {
392 let request_id = self.next_request_id();
393
394 let order = AxWsPlaceOrder {
395 rid: request_id,
396 t: "p".to_string(),
397 s: symbol.to_string(),
398 d: side,
399 q: quantity,
400 p: price,
401 tif: time_in_force,
402 po: post_only,
403 tag,
404 };
405
406 let metadata = OrderMetadata {
407 client_order_id,
408 symbol,
409 };
410
411 self.send_cmd(HandlerCommand::PlaceOrder {
412 request_id,
413 order,
414 metadata,
415 })
416 .await?;
417
418 Ok(request_id)
419 }
420
421 pub async fn cancel_order(&self, order_id: &str) -> AxOrdersWsResult<i64> {
427 let request_id = self.next_request_id();
428
429 self.send_cmd(HandlerCommand::CancelOrder {
430 request_id,
431 order_id: order_id.to_string(),
432 })
433 .await?;
434
435 Ok(request_id)
436 }
437
438 pub async fn get_open_orders(&self) -> AxOrdersWsResult<i64> {
444 let request_id = self.next_request_id();
445
446 self.send_cmd(HandlerCommand::GetOpenOrders { request_id })
447 .await?;
448
449 Ok(request_id)
450 }
451
452 pub fn stream(&mut self) -> impl futures_util::Stream<Item = AxOrdersWsMessage> + use<'_> {
458 let rx = self
459 .out_rx
460 .take()
461 .expect("Stream receiver already taken or client not connected");
462 let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
463 async_stream::stream! {
464 while let Some(msg) = rx.recv().await {
465 yield msg;
466 }
467 }
468 }
469
470 pub async fn disconnect(&self) {
472 log::debug!("Disconnecting WebSocket");
473 let _ = self.send_cmd(HandlerCommand::Disconnect).await;
474 }
475
476 pub async fn close(&mut self) {
478 log::debug!("Closing WebSocket client");
479 self.signal.store(true, Ordering::Relaxed);
480
481 let _ = self.send_cmd(HandlerCommand::Disconnect).await;
482
483 if let Some(handle) = self.task_handle.take() {
484 const CLOSE_TIMEOUT: Duration = Duration::from_secs(2);
485
486 match tokio::time::timeout(CLOSE_TIMEOUT, async {
487 loop {
488 if Arc::strong_count(&handle) == 1 {
489 break;
490 }
491 tokio::time::sleep(Duration::from_millis(50)).await;
492 }
493 })
494 .await
495 {
496 Ok(()) => log::debug!("Handler task completed gracefully"),
497 Err(_) => {
498 log::warn!("Handler task did not complete within timeout, aborting");
499 handle.abort();
500 }
501 }
502 }
503 }
504
505 async fn send_cmd(&self, cmd: HandlerCommand) -> AxOrdersWsResult<()> {
506 let guard = self.cmd_tx.read().await;
507 guard
508 .send(cmd)
509 .map_err(|e| AxOrdersWsClientError::ChannelError(e.to_string()))
510 }
511}