1use std::{
26 collections::VecDeque,
27 fmt::Debug,
28 num::NonZeroU32,
29 sync::{
30 Arc, LazyLock,
31 atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering},
32 },
33 time::{Duration, SystemTime},
34};
35
36use ahash::{AHashMap, AHashSet};
37use dashmap::DashMap;
38use futures_util::Stream;
39use nautilus_common::runtime::get_runtime;
40use nautilus_core::{
41 UUID4,
42 consts::NAUTILUS_USER_AGENT,
43 env::{get_env_var, get_or_env_var},
44 nanos::UnixNanos,
45 time::get_atomic_clock_realtime,
46};
47use nautilus_model::{
48 data::BarType,
49 enums::{OrderSide, OrderStatus, OrderType, PositionSide, TimeInForce, TriggerType},
50 events::{AccountState, OrderCancelRejected, OrderModifyRejected, OrderRejected},
51 identifiers::{AccountId, ClientOrderId, InstrumentId, StrategyId, TraderId, VenueOrderId},
52 instruments::{Instrument, InstrumentAny},
53 types::{Money, Price, Quantity},
54};
55use nautilus_network::{
56 RECONNECTED,
57 ratelimiter::quota::Quota,
58 retry::{RetryManager, create_websocket_retry_manager},
59 websocket::{
60 PingHandler, TEXT_PING, TEXT_PONG, WebSocketClient, WebSocketConfig,
61 channel_message_handler,
62 },
63};
64use reqwest::header::USER_AGENT;
65use serde_json::Value;
66use tokio::sync::mpsc::UnboundedReceiver;
67use tokio_tungstenite::tungstenite::{Error, Message};
68use tokio_util::sync::CancellationToken;
69use ustr::Ustr;
70
71use super::{
72 auth::{AUTHENTICATION_TIMEOUT_SECS, AuthTracker},
73 enums::{OKXSubscriptionEvent, OKXWsChannel, OKXWsOperation},
74 error::OKXWsError,
75 messages::{
76 ExecutionReport, NautilusWsMessage, OKXAuthentication, OKXAuthenticationArg,
77 OKXSubscription, OKXSubscriptionArg, OKXWebSocketArg, OKXWebSocketError, OKXWebSocketEvent,
78 OKXWsRequest, WsAmendOrderParams, WsAmendOrderParamsBuilder, WsCancelAlgoOrderParams,
79 WsCancelAlgoOrderParamsBuilder, WsCancelOrderParams, WsCancelOrderParamsBuilder,
80 WsMassCancelParams, WsPostAlgoOrderParams, WsPostAlgoOrderParamsBuilder, WsPostOrderParams,
81 WsPostOrderParamsBuilder,
82 },
83 parse::{parse_book_msg_vec, parse_ws_message_data},
84 subscription::{SubscriptionState, topic_from_subscription_arg, topic_from_websocket_arg},
85};
86use crate::{
87 common::{
88 consts::{
89 OKX_NAUTILUS_BROKER_ID, OKX_POST_ONLY_CANCEL_REASON, OKX_POST_ONLY_CANCEL_SOURCE,
90 OKX_POST_ONLY_ERROR_CODE, OKX_SUPPORTED_ORDER_TYPES, OKX_SUPPORTED_TIME_IN_FORCE,
91 OKX_TARGET_CCY_BASE, OKX_TARGET_CCY_QUOTE, OKX_WS_PUBLIC_URL, should_retry_error_code,
92 },
93 credential::Credential,
94 enums::{
95 OKXInstrumentType, OKXOrderStatus, OKXOrderType, OKXPositionSide, OKXTradeMode,
96 OKXTriggerType, OKXVipLevel, conditional_order_to_algo_type, is_conditional_order,
97 },
98 parse::{
99 bar_spec_as_okx_channel, okx_instrument_type, parse_account_state,
100 parse_client_order_id, parse_millisecond_timestamp,
101 },
102 },
103 http::models::OKXAccount,
104 websocket::{
105 messages::{OKXAlgoOrderMsg, OKXOrderMsg},
106 parse::{parse_algo_order_msg, parse_order_msg},
107 },
108};
109
110type PlaceRequestData = (ClientOrderId, TraderId, StrategyId, InstrumentId);
111type CancelRequestData = (
112 ClientOrderId,
113 TraderId,
114 StrategyId,
115 InstrumentId,
116 Option<VenueOrderId>,
117);
118type AmendRequestData = (
119 ClientOrderId,
120 TraderId,
121 StrategyId,
122 InstrumentId,
123 Option<VenueOrderId>,
124);
125type MassCancelRequestData = InstrumentId;
126
127pub static OKX_WS_QUOTA: LazyLock<Quota> =
135 LazyLock::new(|| Quota::per_second(NonZeroU32::new(3).unwrap()));
136
137pub static OKX_WS_ORDER_QUOTA: LazyLock<Quota> =
142 LazyLock::new(|| Quota::per_second(NonZeroU32::new(250).unwrap()));
143
144fn should_retry_okx_error(error: &OKXWsError) -> bool {
146 match error {
147 OKXWsError::OkxError { error_code, .. } => should_retry_error_code(error_code),
148 OKXWsError::TungsteniteError(_) => true, OKXWsError::ClientError(msg) => {
150 let msg_lower = msg.to_lowercase();
152 msg_lower.contains("timeout")
153 || msg_lower.contains("timed out")
154 || msg_lower.contains("connection")
155 || msg_lower.contains("network")
156 }
157 OKXWsError::AuthenticationError(_)
158 | OKXWsError::JsonError(_)
159 | OKXWsError::ParsingError(_) => {
160 false
162 }
163 }
164}
165
166fn create_okx_timeout_error(msg: String) -> OKXWsError {
168 OKXWsError::ClientError(msg)
169}
170
171fn channel_requires_auth(channel: &OKXWsChannel) -> bool {
172 matches!(
173 channel,
174 OKXWsChannel::Account
175 | OKXWsChannel::Orders
176 | OKXWsChannel::Fills
177 | OKXWsChannel::OrdersAlgo
178 )
179}
180
181#[derive(Clone)]
183#[cfg_attr(
184 feature = "python",
185 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.adapters")
186)]
187pub struct OKXWebSocketClient {
188 url: String,
189 account_id: AccountId,
190 vip_level: Arc<AtomicU8>,
191 credential: Option<Credential>,
192 heartbeat: Option<u64>,
193 inner: Arc<tokio::sync::RwLock<Option<WebSocketClient>>>,
194 auth_tracker: AuthTracker,
195 rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>,
196 signal: Arc<AtomicBool>,
197 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
198 subscriptions_inst_type: Arc<DashMap<OKXWsChannel, AHashSet<OKXInstrumentType>>>,
199 subscriptions_inst_family: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
200 subscriptions_inst_id: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
201 subscriptions_bare: Arc<DashMap<OKXWsChannel, bool>>, subscriptions_state: SubscriptionState,
203 request_id_counter: Arc<AtomicU64>,
204 pending_place_requests: Arc<DashMap<String, PlaceRequestData>>,
205 pending_cancel_requests: Arc<DashMap<String, CancelRequestData>>,
206 pending_amend_requests: Arc<DashMap<String, AmendRequestData>>,
207 pending_mass_cancel_requests: Arc<DashMap<String, MassCancelRequestData>>,
208 active_client_orders: Arc<DashMap<ClientOrderId, (TraderId, StrategyId, InstrumentId)>>,
209 client_id_aliases: Arc<DashMap<ClientOrderId, ClientOrderId>>,
210 instruments_cache: Arc<AHashMap<Ustr, InstrumentAny>>,
211 retry_manager: Arc<RetryManager<OKXWsError>>,
212 cancellation_token: CancellationToken,
213}
214
215impl Default for OKXWebSocketClient {
216 fn default() -> Self {
217 Self::new(None, None, None, None, None, None).unwrap()
218 }
219}
220
221impl Debug for OKXWebSocketClient {
222 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223 f.debug_struct(stringify!(OKXWebSocketClient))
224 .field("url", &self.url)
225 .field(
226 "credential",
227 &self.credential.as_ref().map(|_| "<redacted>"),
228 )
229 .field("heartbeat", &self.heartbeat)
230 .finish_non_exhaustive()
231 }
232}
233
234impl OKXWebSocketClient {
235 pub fn new(
241 url: Option<String>,
242 api_key: Option<String>,
243 api_secret: Option<String>,
244 api_passphrase: Option<String>,
245 account_id: Option<AccountId>,
246 heartbeat: Option<u64>,
247 ) -> anyhow::Result<Self> {
248 let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
249 let account_id = account_id.unwrap_or(AccountId::from("OKX-master"));
250
251 let credential = match (api_key, api_secret, api_passphrase) {
252 (Some(key), Some(secret), Some(passphrase)) => {
253 Some(Credential::new(key, secret, passphrase))
254 }
255 (None, None, None) => None,
256 _ => anyhow::bail!(
257 "`api_key`, `api_secret`, `api_passphrase` credentials must be provided together"
258 ),
259 };
260
261 let signal = Arc::new(AtomicBool::new(false));
262 let subscriptions_inst_type = Arc::new(DashMap::new());
263 let subscriptions_inst_family = Arc::new(DashMap::new());
264 let subscriptions_inst_id = Arc::new(DashMap::new());
265 let subscriptions_bare = Arc::new(DashMap::new());
266 let subscriptions_state = SubscriptionState::new();
267
268 Ok(Self {
269 url,
270 account_id,
271 vip_level: Arc::new(AtomicU8::new(0)), credential,
273 heartbeat,
274 inner: Arc::new(tokio::sync::RwLock::new(None)),
275 auth_tracker: AuthTracker::new(),
276 rx: None,
277 signal,
278 task_handle: None,
279 subscriptions_inst_type,
280 subscriptions_inst_family,
281 subscriptions_inst_id,
282 subscriptions_bare,
283 subscriptions_state,
284 request_id_counter: Arc::new(AtomicU64::new(1)),
285 pending_place_requests: Arc::new(DashMap::new()),
286 pending_cancel_requests: Arc::new(DashMap::new()),
287 pending_amend_requests: Arc::new(DashMap::new()),
288 pending_mass_cancel_requests: Arc::new(DashMap::new()),
289 active_client_orders: Arc::new(DashMap::new()),
290 client_id_aliases: Arc::new(DashMap::new()),
291 instruments_cache: Arc::new(AHashMap::new()),
292 retry_manager: Arc::new(create_websocket_retry_manager()?),
293 cancellation_token: CancellationToken::new(),
294 })
295 }
296
297 pub fn with_credentials(
304 url: Option<String>,
305 api_key: Option<String>,
306 api_secret: Option<String>,
307 api_passphrase: Option<String>,
308 account_id: Option<AccountId>,
309 heartbeat: Option<u64>,
310 ) -> anyhow::Result<Self> {
311 let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
312 let api_key = get_or_env_var(api_key, "OKX_API_KEY")?;
313 let api_secret = get_or_env_var(api_secret, "OKX_API_SECRET")?;
314 let api_passphrase = get_or_env_var(api_passphrase, "OKX_API_PASSPHRASE")?;
315
316 Self::new(
317 Some(url),
318 Some(api_key),
319 Some(api_secret),
320 Some(api_passphrase),
321 account_id,
322 heartbeat,
323 )
324 }
325
326 pub fn from_env() -> anyhow::Result<Self> {
333 let url = get_env_var("OKX_WS_URL")?;
334 let api_key = get_env_var("OKX_API_KEY")?;
335 let api_secret = get_env_var("OKX_API_SECRET")?;
336 let api_passphrase = get_env_var("OKX_API_PASSPHRASE")?;
337
338 Self::new(
339 Some(url),
340 Some(api_key),
341 Some(api_secret),
342 Some(api_passphrase),
343 None,
344 None,
345 )
346 }
347
348 pub fn cancel_all_requests(&self) {
350 self.cancellation_token.cancel();
351 }
352
353 pub fn cancellation_token(&self) -> &CancellationToken {
355 &self.cancellation_token
356 }
357
358 pub fn url(&self) -> &str {
360 self.url.as_str()
361 }
362
363 pub fn api_key(&self) -> Option<&str> {
365 self.credential.clone().map(|c| c.api_key.as_str())
366 }
367
368 pub fn is_active(&self) -> bool {
371 match self.inner.try_read() {
373 Ok(guard) => match &*guard {
374 Some(inner) => inner.is_active(),
375 None => false,
376 },
377 Err(_) => false, }
379 }
380
381 pub fn is_closed(&self) -> bool {
383 match self.inner.try_read() {
385 Ok(guard) => match &*guard {
386 Some(inner) => inner.is_closed(),
387 None => true,
388 },
389 Err(_) => true, }
391 }
392
393 pub fn initialize_instruments_cache(&mut self, instruments: Vec<InstrumentAny>) {
395 let mut instruments_cache: AHashMap<Ustr, InstrumentAny> = AHashMap::new();
396 for inst in instruments {
397 instruments_cache.insert(inst.symbol().inner(), inst.clone());
398 }
399
400 self.instruments_cache = Arc::new(instruments_cache)
401 }
402
403 pub fn set_vip_level(&self, vip_level: OKXVipLevel) {
407 self.vip_level.store(vip_level as u8, Ordering::Relaxed);
408 }
409
410 pub fn vip_level(&self) -> OKXVipLevel {
412 let level = self.vip_level.load(Ordering::Relaxed);
413 OKXVipLevel::from(level)
414 }
415
416 pub async fn connect(&mut self) -> anyhow::Result<()> {
426 let (message_handler, reader) = channel_message_handler();
427
428 let inner_for_ping = self.inner.clone();
429 let ping_handler: PingHandler = Arc::new(move |payload: Vec<u8>| {
430 let inner = inner_for_ping.clone();
431
432 get_runtime().spawn(async move {
433 let len = payload.len();
434 let guard = inner.read().await;
435
436 if let Some(client) = guard.as_ref() {
437 if let Err(err) = client.send_pong(payload).await {
438 tracing::warn!(error = %err, "Failed to send pong frame");
439 } else {
440 tracing::trace!("Sent pong frame ({len} bytes)");
441 }
442 } else {
443 tracing::debug!("Ping received with no active websocket client");
444 }
445 });
446 });
447
448 let config = WebSocketConfig {
449 url: self.url.clone(),
450 headers: vec![(USER_AGENT.to_string(), NAUTILUS_USER_AGENT.to_string())],
451 heartbeat: self.heartbeat,
452 heartbeat_msg: Some(TEXT_PING.to_string()),
453 message_handler: Some(message_handler),
454 ping_handler: Some(ping_handler),
455 reconnect_timeout_ms: Some(5_000),
456 reconnect_delay_initial_ms: None, reconnect_delay_max_ms: None, reconnect_backoff_factor: None, reconnect_jitter_ms: None, };
461
462 let keyed_quotas = vec![
464 ("subscription".to_string(), *OKX_WS_QUOTA),
465 ("order".to_string(), *OKX_WS_ORDER_QUOTA),
466 ("cancel".to_string(), *OKX_WS_ORDER_QUOTA),
467 ("amend".to_string(), *OKX_WS_ORDER_QUOTA),
468 ];
469
470 let client = WebSocketClient::connect(
471 config,
472 None, keyed_quotas,
474 Some(*OKX_WS_QUOTA), )
476 .await?;
477
478 {
480 let mut inner_guard = self.inner.write().await;
481 *inner_guard = Some(client);
482 }
483
484 let account_id = self.account_id;
485 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<NautilusWsMessage>();
486
487 self.rx = Some(Arc::new(rx));
488 let signal = self.signal.clone();
489 let pending_place_requests = self.pending_place_requests.clone();
490 let pending_cancel_requests = self.pending_cancel_requests.clone();
491 let pending_amend_requests = self.pending_amend_requests.clone();
492 let pending_mass_cancel_requests = self.pending_mass_cancel_requests.clone();
493 let active_client_orders = self.active_client_orders.clone();
494 let auth_tracker = self.auth_tracker.clone();
495
496 let instruments_cache = self.instruments_cache.clone();
497 let inner_client = self.inner.clone();
498 let credential_clone = self.credential.clone();
499 let subscriptions_inst_type = self.subscriptions_inst_type.clone();
500 let subscriptions_inst_family = self.subscriptions_inst_family.clone();
501 let subscriptions_inst_id = self.subscriptions_inst_id.clone();
502 let subscriptions_bare = self.subscriptions_bare.clone();
503 let subscriptions_state = self.subscriptions_state.clone();
504 let client_id_aliases = self.client_id_aliases.clone();
505
506 let stream_handle = get_runtime().spawn({
507 let auth_tracker = auth_tracker.clone();
508 async move {
509 let mut handler = OKXWsMessageHandler::new(
510 account_id,
511 instruments_cache,
512 reader,
513 signal,
514 inner_client.clone(),
515 tx,
516 pending_place_requests,
517 pending_cancel_requests,
518 pending_amend_requests,
519 pending_mass_cancel_requests,
520 active_client_orders,
521 client_id_aliases,
522 auth_tracker.clone(),
523 subscriptions_state.clone(),
524 );
525
526 loop {
528 match handler.next().await {
529 Some(NautilusWsMessage::Reconnected) => {
530 tracing::info!("Handling WebSocket reconnection");
531
532 let auth_tracker_for_task = auth_tracker.clone();
533 let inner_client_for_task = inner_client.clone();
534 let subscriptions_inst_type_for_task = subscriptions_inst_type.clone();
535 let subscriptions_inst_family_for_task = subscriptions_inst_family.clone();
536 let subscriptions_inst_id_for_task = subscriptions_inst_id.clone();
537 let subscriptions_bare_for_task = subscriptions_bare.clone();
538 let subscriptions_state_for_task = subscriptions_state.clone();
539
540 let auth_wait = if let Some(cred) = &credential_clone {
541 let rx = auth_tracker.begin();
542 let inner_guard = inner_client.read().await;
543
544 if let Some(client) = &*inner_guard {
545 let timestamp = SystemTime::now()
546 .duration_since(SystemTime::UNIX_EPOCH)
547 .expect("System time should be after UNIX epoch")
548 .as_secs()
549 .to_string();
550 let signature =
551 cred.sign(×tamp, "GET", "/users/self/verify", "");
552
553 let auth_message = OKXAuthentication {
554 op: "login",
555 args: vec![OKXAuthenticationArg {
556 api_key: cred.api_key.to_string(),
557 passphrase: cred.api_passphrase.clone(),
558 timestamp,
559 sign: signature,
560 }],
561 };
562
563 if let Err(e) = client
564 .send_text(serde_json::to_string(&auth_message).unwrap(), None)
565 .await
566 {
567 tracing::error!(
568 "Failed to send re-authentication request: {e}",
569 );
570 auth_tracker.fail(e.to_string());
571 } else {
572 tracing::info!(
573 "Sent re-authentication request, waiting for response before resubscribing",
574 );
575 }
576 } else {
577 auth_tracker
578 .fail("Cannot authenticate: not connected".to_string());
579 }
580
581 drop(inner_guard);
582
583 Some(rx)
584 } else {
585 None
586 };
587
588 get_runtime().spawn(async move {
589 let auth_succeeded = match auth_wait {
590 Some(rx) => match auth_tracker_for_task
591 .wait_for_result(
592 Duration::from_secs(AUTHENTICATION_TIMEOUT_SECS),
593 rx,
594 )
595 .await
596 {
597 Ok(()) => {
598 tracing::info!(
599 "Authentication successful after reconnect, proceeding with resubscription",
600 );
601 true
602 }
603 Err(e) => {
604 tracing::error!(
605 "Authentication after reconnect failed: {e}",
606 );
607 false
608 }
609 },
610 None => true,
611 };
612
613 let confirmed_topic_count = subscriptions_state_for_task.len();
614 if confirmed_topic_count == 0 {
615 tracing::debug!(
616 "No confirmed subscriptions recorded before reconnect; resubscribe will rely on pending topics"
617 );
618 } else {
619 tracing::debug!(confirmed_topic_count, "Confirmed subscriptions recorded before reconnect");
620 }
621 let confirmed_topics = subscriptions_state_for_task.confirmed();
622 if confirmed_topic_count <= 10 {
623 let topics: Vec<_> = confirmed_topics
624 .iter()
625 .map(|entry| entry.key().clone())
626 .collect();
627 if !topics.is_empty() {
628 tracing::trace!(topics = ?topics, "Confirmed topics before reconnect");
629 }
630 }
631 drop(confirmed_topics);
632
633 let pending_topics = subscriptions_state_for_task.pending();
634 let pending_topic_count = pending_topics.len();
635 if pending_topic_count > 0 {
636 tracing::debug!(pending_topic_count, "Pending subscriptions awaiting replay after reconnect");
637 }
638 drop(pending_topics);
639
640 let inner_guard = inner_client_for_task.read().await;
641 if let Some(client) = &*inner_guard {
642 let should_resubscribe = |channel: &OKXWsChannel| {
643 if channel_requires_auth(channel) && !auth_succeeded {
644 tracing::warn!(
645 ?channel,
646 "Skipping private channel resubscription due to missing authentication",
647 );
648 return false;
649 }
650 true
651 };
652
653 let mut inst_type_args = Vec::new();
654 for entry in subscriptions_inst_type_for_task.iter() {
655 let (channel, inst_types) = entry.pair();
656 if !should_resubscribe(channel) {
657 continue;
658 }
659 for inst_type in inst_types.iter() {
660 let arg = OKXSubscriptionArg {
661 channel: channel.clone(),
662 inst_type: Some(*inst_type),
663 inst_family: None,
664 inst_id: None,
665 };
666 let topic = topic_from_subscription_arg(&arg);
667 subscriptions_state_for_task.mark_subscribe(&topic);
668 inst_type_args.push(arg);
669 }
670 }
671 if !inst_type_args.is_empty() {
672 let sub_request = OKXSubscription {
673 op: OKXWsOperation::Subscribe,
674 args: inst_type_args,
675 };
676 if let Err(e) = client
677 .send_text(
678 serde_json::to_string(&sub_request).unwrap(),
679 None,
680 )
681 .await
682 {
683 tracing::error!(
684 "Failed to re-subscribe inst_type channels: {e}",
685 );
686 }
687 }
688
689 let mut inst_family_args = Vec::new();
690 for entry in subscriptions_inst_family_for_task.iter() {
691 let (channel, inst_families) = entry.pair();
692 if !should_resubscribe(channel) {
693 continue;
694 }
695 for inst_family in inst_families.iter() {
696 let arg = OKXSubscriptionArg {
697 channel: channel.clone(),
698 inst_type: None,
699 inst_family: Some(*inst_family),
700 inst_id: None,
701 };
702 let topic = topic_from_subscription_arg(&arg);
703 subscriptions_state_for_task.mark_subscribe(&topic);
704 inst_family_args.push(arg);
705 }
706 }
707 if !inst_family_args.is_empty() {
708 let sub_request = OKXSubscription {
709 op: OKXWsOperation::Subscribe,
710 args: inst_family_args,
711 };
712 if let Err(e) = client
713 .send_text(
714 serde_json::to_string(&sub_request).unwrap(),
715 None,
716 )
717 .await
718 {
719 tracing::error!(
720 "Failed to re-subscribe inst_family channels: {e}",
721 );
722 }
723 }
724
725 let mut inst_id_args = Vec::new();
726 for entry in subscriptions_inst_id_for_task.iter() {
727 let (channel, inst_ids) = entry.pair();
728 if !should_resubscribe(channel) {
729 continue;
730 }
731 for inst_id in inst_ids.iter() {
732 let arg = OKXSubscriptionArg {
733 channel: channel.clone(),
734 inst_type: None,
735 inst_family: None,
736 inst_id: Some(*inst_id),
737 };
738 let topic = topic_from_subscription_arg(&arg);
739 subscriptions_state_for_task.mark_subscribe(&topic);
740 inst_id_args.push(arg);
741 }
742 }
743 if !inst_id_args.is_empty() {
744 let sub_request = OKXSubscription {
745 op: OKXWsOperation::Subscribe,
746 args: inst_id_args,
747 };
748 if let Err(e) = client
749 .send_text(
750 serde_json::to_string(&sub_request).unwrap(),
751 None,
752 )
753 .await
754 {
755 tracing::error!(
756 "Failed to re-subscribe inst_id channels: {e}",
757 );
758 }
759 }
760
761 let mut bare_args = Vec::new();
762 for entry in subscriptions_bare_for_task.iter() {
763 let channel = entry.key();
764 if !should_resubscribe(channel) {
765 continue;
766 }
767 let arg = OKXSubscriptionArg {
768 channel: channel.clone(),
769 inst_type: None,
770 inst_family: None,
771 inst_id: None,
772 };
773 let topic = topic_from_subscription_arg(&arg);
774 subscriptions_state_for_task.mark_subscribe(&topic);
775 bare_args.push(arg);
776 }
777 if !bare_args.is_empty() {
778 let sub_request = OKXSubscription {
779 op: OKXWsOperation::Subscribe,
780 args: bare_args,
781 };
782 if let Err(e) = client
783 .send_text(
784 serde_json::to_string(&sub_request).unwrap(),
785 None,
786 )
787 .await
788 {
789 tracing::error!(
790 "Failed to re-subscribe bare channels: {e}",
791 );
792 }
793 }
794
795 tracing::info!("Completed re-subscription after reconnect");
796 } else {
797 tracing::warn!(
798 "Skipping resubscription after reconnect: websocket client unavailable",
799 );
800 }
801 });
802
803 continue;
804 }
805 Some(msg) => {
806 if handler.tx.send(msg).is_err() {
807 tracing::error!(
808 "Failed to send message through channel: receiver dropped",
809 );
810 break;
811 }
812 }
813 None => {
814 if handler.is_stopped() {
815 tracing::debug!(
816 "Stop signal received, ending message processing",
817 );
818 break;
819 }
820 tracing::warn!("WebSocket stream ended unexpectedly");
821 break;
822 }
823 }
824 }
825 }
826 });
827
828 self.task_handle = Some(Arc::new(stream_handle));
829
830 if self.credential.is_some() {
831 self.authenticate().await?;
832 }
833
834 Ok(())
835 }
836
837 async fn authenticate(&self) -> Result<(), Error> {
839 let credential = self.credential.as_ref().ok_or_else(|| {
840 Error::Io(std::io::Error::other(
841 "API credentials not available to authenticate",
842 ))
843 })?;
844
845 let rx = self.auth_tracker.begin();
846
847 let timestamp = SystemTime::now()
848 .duration_since(SystemTime::UNIX_EPOCH)
849 .expect("System time should be after UNIX epoch")
850 .as_secs()
851 .to_string();
852 let signature = credential.sign(×tamp, "GET", "/users/self/verify", "");
853
854 let auth_message = OKXAuthentication {
855 op: "login",
856 args: vec![OKXAuthenticationArg {
857 api_key: credential.api_key.to_string(),
858 passphrase: credential.api_passphrase.clone(),
859 timestamp,
860 sign: signature,
861 }],
862 };
863
864 {
865 let inner_guard = self.inner.read().await;
866 if let Some(inner) = &*inner_guard {
867 if let Err(e) = inner
868 .send_text(serde_json::to_string(&auth_message).unwrap(), None)
869 .await
870 {
871 tracing::error!("Error sending auth message: {e:?}");
872 self.auth_tracker.fail(e.to_string());
873 return Err(Error::Io(std::io::Error::other(e.to_string())));
874 }
875 } else {
876 log::error!("Cannot authenticate: not connected");
877 self.auth_tracker
878 .fail("Cannot authenticate: not connected".to_string());
879 return Err(Error::ConnectionClosed);
880 }
881 }
882
883 match self
884 .auth_tracker
885 .wait_for_result(Duration::from_secs(AUTHENTICATION_TIMEOUT_SECS), rx)
886 .await
887 {
888 Ok(()) => {
889 tracing::info!("Authentication confirmed by client");
890 Ok(())
891 }
892 Err(e) => {
893 tracing::error!("Authentication failed: {e}");
894 Err(Error::Io(std::io::Error::other(e.to_string())))
895 }
896 }
897 }
898
899 pub fn stream(&mut self) -> impl Stream<Item = NautilusWsMessage> + 'static {
907 let rx = self
908 .rx
909 .take()
910 .expect("Data stream receiver already taken or not connected");
911 let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
912 async_stream::stream! {
913 while let Some(data) = rx.recv().await {
914 yield data;
915 }
916 }
917 }
918
919 pub async fn wait_until_active(&self, timeout_secs: f64) -> Result<(), OKXWsError> {
925 let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
926
927 tokio::time::timeout(timeout, async {
928 while !self.is_active() {
929 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
930 }
931 })
932 .await
933 .map_err(|_| {
934 OKXWsError::ClientError(format!(
935 "WebSocket connection timeout after {timeout_secs} seconds"
936 ))
937 })?;
938
939 Ok(())
940 }
941
942 pub async fn close(&mut self) -> Result<(), Error> {
949 log::debug!("Starting close process");
950
951 self.signal.store(true, Ordering::Relaxed);
952
953 {
954 let inner_guard = self.inner.read().await;
955 if let Some(inner) = &*inner_guard {
956 log::debug!("Disconnecting websocket");
957
958 match tokio::time::timeout(Duration::from_secs(3), inner.disconnect()).await {
959 Ok(()) => log::debug!("Websocket disconnected successfully"),
960 Err(_) => {
961 log::warn!(
962 "Timeout waiting for websocket disconnect, continuing with cleanup"
963 )
964 }
965 }
966 } else {
967 log::debug!("No active connection to disconnect");
968 }
969 }
970
971 if let Some(stream_handle) = self.task_handle.take() {
973 match Arc::try_unwrap(stream_handle) {
974 Ok(handle) => {
975 log::debug!("Waiting for stream handle to complete");
976 match tokio::time::timeout(Duration::from_secs(2), handle).await {
977 Ok(Ok(())) => log::debug!("Stream handle completed successfully"),
978 Ok(Err(e)) => log::error!("Stream handle encountered an error: {e:?}"),
979 Err(_) => {
980 log::warn!(
981 "Timeout waiting for stream handle, task may still be running"
982 );
983 }
985 }
986 }
987 Err(arc_handle) => {
988 log::debug!(
989 "Cannot take ownership of stream handle - other references exist, aborting task"
990 );
991 arc_handle.abort();
992 }
993 }
994 } else {
995 log::debug!("No stream handle to await");
996 }
997
998 log::debug!("Close process completed");
999
1000 Ok(())
1001 }
1002
1003 pub fn get_subscriptions(&self, instrument_id: InstrumentId) -> Vec<OKXWsChannel> {
1005 let symbol = instrument_id.symbol.inner();
1006 let mut channels = Vec::new();
1007
1008 for entry in self.subscriptions_inst_id.iter() {
1009 let (channel, instruments) = entry.pair();
1010 if instruments.contains(&symbol) {
1011 channels.push(channel.clone());
1012 }
1013 }
1014
1015 channels
1016 }
1017
1018 fn generate_unique_request_id(&self) -> String {
1019 self.request_id_counter
1020 .fetch_add(1, Ordering::SeqCst)
1021 .to_string()
1022 }
1023
1024 #[allow(
1025 clippy::result_large_err,
1026 reason = "OKXWsError contains large tungstenite::Error variant"
1027 )]
1028 fn get_instrument_type_and_family(
1029 &self,
1030 symbol: Ustr,
1031 ) -> Result<(OKXInstrumentType, String), OKXWsError> {
1032 let instrument = self.instruments_cache.get(&symbol).ok_or_else(|| {
1034 OKXWsError::ClientError(format!("Instrument not found in cache: {symbol}"))
1035 })?;
1036
1037 let inst_type =
1038 okx_instrument_type(instrument).map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1039
1040 let inst_family = match instrument {
1042 InstrumentAny::CurrencyPair(_) => symbol.as_str().to_string(),
1043 InstrumentAny::CryptoPerpetual(_) => {
1044 symbol
1046 .as_str()
1047 .strip_suffix("-SWAP")
1048 .unwrap_or(symbol.as_str())
1049 .to_string()
1050 }
1051 InstrumentAny::CryptoFuture(_) => {
1052 let parts: Vec<&str> = symbol.as_str().split('-').collect();
1054 if parts.len() >= 2 {
1055 format!("{}-{}", parts[0], parts[1])
1056 } else {
1057 return Err(OKXWsError::ClientError(format!(
1058 "Unable to parse futures instrument family from symbol: {symbol}",
1059 )));
1060 }
1061 }
1062 InstrumentAny::CryptoOption(_) => {
1063 let parts: Vec<&str> = symbol.as_str().split('-').collect();
1065 if parts.len() >= 2 {
1066 format!("{}-{}", parts[0], parts[1])
1067 } else {
1068 return Err(OKXWsError::ClientError(format!(
1069 "Unable to parse option instrument family from symbol: {symbol}",
1070 )));
1071 }
1072 }
1073 _ => {
1074 return Err(OKXWsError::ClientError(format!(
1075 "Unsupported instrument type: {instrument:?}",
1076 )));
1077 }
1078 };
1079
1080 Ok((inst_type, inst_family))
1081 }
1082
1083 async fn subscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
1084 for arg in &args {
1085 let topic = topic_from_subscription_arg(arg);
1086 self.subscriptions_state.mark_subscribe(&topic);
1087
1088 if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
1090 self.subscriptions_bare.insert(arg.channel.clone(), true);
1092 } else {
1093 if let Some(inst_type) = &arg.inst_type {
1095 self.subscriptions_inst_type
1096 .entry(arg.channel.clone())
1097 .or_default()
1098 .insert(*inst_type);
1099 }
1100
1101 if let Some(inst_family) = &arg.inst_family {
1103 self.subscriptions_inst_family
1104 .entry(arg.channel.clone())
1105 .or_default()
1106 .insert(*inst_family);
1107 }
1108
1109 if let Some(inst_id) = &arg.inst_id {
1111 self.subscriptions_inst_id
1112 .entry(arg.channel.clone())
1113 .or_default()
1114 .insert(*inst_id);
1115 }
1116 }
1117 }
1118
1119 let message = OKXSubscription {
1120 op: OKXWsOperation::Subscribe,
1121 args,
1122 };
1123
1124 let json_txt =
1125 serde_json::to_string(&message).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1126
1127 {
1128 let inner_guard = self.inner.read().await;
1129 if let Some(inner) = &*inner_guard {
1130 if let Err(e) = inner
1131 .send_text(json_txt, Some(vec!["subscription".to_string()]))
1132 .await
1133 {
1134 tracing::error!("Error sending message: {e:?}")
1135 }
1136 } else {
1137 return Err(OKXWsError::ClientError(
1138 "Cannot send message: not connected".to_string(),
1139 ));
1140 }
1141 }
1142
1143 Ok(())
1144 }
1145
1146 #[allow(clippy::collapsible_if, reason = "Clearer uncollapsed")]
1147 async fn unsubscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
1148 for arg in &args {
1149 let topic = topic_from_subscription_arg(arg);
1150 self.subscriptions_state.mark_unsubscribe(&topic);
1151
1152 if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
1154 self.subscriptions_bare.remove(&arg.channel);
1156 } else {
1157 if let Some(inst_type) = &arg.inst_type {
1159 if let Some(mut entry) = self.subscriptions_inst_type.get_mut(&arg.channel) {
1160 entry.remove(inst_type);
1161 if entry.is_empty() {
1162 drop(entry);
1163 self.subscriptions_inst_type.remove(&arg.channel);
1164 }
1165 }
1166 }
1167
1168 if let Some(inst_family) = &arg.inst_family {
1170 if let Some(mut entry) = self.subscriptions_inst_family.get_mut(&arg.channel) {
1171 entry.remove(inst_family);
1172 if entry.is_empty() {
1173 drop(entry);
1174 self.subscriptions_inst_family.remove(&arg.channel);
1175 }
1176 }
1177 }
1178
1179 if let Some(inst_id) = &arg.inst_id {
1181 if let Some(mut entry) = self.subscriptions_inst_id.get_mut(&arg.channel) {
1182 entry.remove(inst_id);
1183 if entry.is_empty() {
1184 drop(entry);
1185 self.subscriptions_inst_id.remove(&arg.channel);
1186 }
1187 }
1188 }
1189 }
1190 }
1191
1192 let message = OKXSubscription {
1193 op: OKXWsOperation::Unsubscribe,
1194 args,
1195 };
1196
1197 let json_txt = serde_json::to_string(&message).expect("Must be valid JSON");
1198
1199 {
1200 let inner_guard = self.inner.read().await;
1201 if let Some(inner) = &*inner_guard {
1202 if let Err(e) = inner
1203 .send_text(json_txt, Some(vec!["subscription".to_string()]))
1204 .await
1205 {
1206 tracing::error!("Error sending message: {e:?}")
1207 }
1208 } else {
1209 log::error!("Cannot send message: not connected");
1210 }
1211 }
1212
1213 Ok(())
1214 }
1215
1216 #[allow(dead_code)]
1217 async fn resubscribe_all(&self) {
1218 let mut subs_bare = Vec::new();
1220 for entry in self.subscriptions_bare.iter() {
1221 let channel = entry.key();
1222 subs_bare.push(channel.clone());
1223 }
1224
1225 let mut subs_inst_type = Vec::new();
1226 for entry in self.subscriptions_inst_type.iter() {
1227 let (channel, inst_types) = entry.pair();
1228 if !inst_types.is_empty() {
1229 subs_inst_type.push((channel.clone(), inst_types.clone()));
1230 }
1231 }
1232
1233 let mut subs_inst_family = Vec::new();
1234 for entry in self.subscriptions_inst_family.iter() {
1235 let (channel, inst_families) = entry.pair();
1236 if !inst_families.is_empty() {
1237 subs_inst_family.push((channel.clone(), inst_families.clone()));
1238 }
1239 }
1240
1241 let mut subs_inst_id = Vec::new();
1242 for entry in self.subscriptions_inst_id.iter() {
1243 let (channel, inst_ids) = entry.pair();
1244 if !inst_ids.is_empty() {
1245 subs_inst_id.push((channel.clone(), inst_ids.clone()));
1246 }
1247 }
1248
1249 for (channel, inst_types) in subs_inst_type {
1251 if inst_types.is_empty() {
1252 continue;
1253 }
1254
1255 tracing::debug!("Resubscribing: channel={channel}, instrument_types={inst_types:?}");
1256
1257 for inst_type in inst_types {
1258 let arg = OKXSubscriptionArg {
1259 channel: channel.clone(),
1260 inst_type: Some(inst_type),
1261 inst_family: None,
1262 inst_id: None,
1263 };
1264
1265 if let Err(e) = self.subscribe(vec![arg]).await {
1266 tracing::error!(
1267 "Failed to resubscribe to channel {channel} with instrument type: {e}"
1268 );
1269 }
1270 }
1271 }
1272
1273 for (channel, inst_families) in subs_inst_family {
1275 if inst_families.is_empty() {
1276 continue;
1277 }
1278
1279 tracing::debug!(
1280 "Resubscribing: channel={channel}, instrument_families={inst_families:?}"
1281 );
1282
1283 for inst_family in inst_families {
1284 let arg = OKXSubscriptionArg {
1285 channel: channel.clone(),
1286 inst_type: None,
1287 inst_family: Some(inst_family),
1288 inst_id: None,
1289 };
1290
1291 if let Err(e) = self.subscribe(vec![arg]).await {
1292 tracing::error!(
1293 "Failed to resubscribe to channel {channel} with instrument family: {e}"
1294 );
1295 }
1296 }
1297 }
1298
1299 for (channel, inst_ids) in subs_inst_id {
1301 if inst_ids.is_empty() {
1302 continue;
1303 }
1304
1305 tracing::debug!("Resubscribing: channel={channel}, instrument_ids={inst_ids:?}");
1306
1307 for inst_id in inst_ids {
1308 let arg = OKXSubscriptionArg {
1309 channel: channel.clone(),
1310 inst_type: None,
1311 inst_family: None,
1312 inst_id: Some(inst_id),
1313 };
1314
1315 if let Err(e) = self.subscribe(vec![arg]).await {
1316 tracing::error!(
1317 "Failed to resubscribe to channel {channel} with instrument ID: {e}"
1318 );
1319 }
1320 }
1321 }
1322
1323 for channel in subs_bare {
1325 tracing::debug!("Resubscribing to bare channel: {channel}");
1326
1327 let arg = OKXSubscriptionArg {
1328 channel,
1329 inst_type: None,
1330 inst_family: None,
1331 inst_id: None,
1332 };
1333
1334 if let Err(e) = self.subscribe(vec![arg]).await {
1335 tracing::error!("Failed to resubscribe to bare channel: {e}");
1336 }
1337 }
1338 }
1339
1340 pub async fn subscribe_instruments(
1352 &self,
1353 instrument_type: OKXInstrumentType,
1354 ) -> Result<(), OKXWsError> {
1355 let arg = OKXSubscriptionArg {
1356 channel: OKXWsChannel::Instruments,
1357 inst_type: Some(instrument_type),
1358 inst_family: None,
1359 inst_id: None,
1360 };
1361 self.subscribe(vec![arg]).await
1362 }
1363
1364 pub async fn subscribe_instrument(
1376 &self,
1377 instrument_id: InstrumentId,
1378 ) -> Result<(), OKXWsError> {
1379 let arg = OKXSubscriptionArg {
1380 channel: OKXWsChannel::Instruments,
1381 inst_type: None,
1382 inst_family: None,
1383 inst_id: Some(instrument_id.symbol.inner()),
1384 };
1385 self.subscribe(vec![arg]).await
1386 }
1387
1388 pub async fn subscribe_book(&self, instrument_id: InstrumentId) -> anyhow::Result<()> {
1397 self.subscribe_book_with_depth(instrument_id, 0).await
1398 }
1399
1400 pub(crate) async fn subscribe_books_channel(
1402 &self,
1403 instrument_id: InstrumentId,
1404 ) -> Result<(), OKXWsError> {
1405 let arg = OKXSubscriptionArg {
1406 channel: OKXWsChannel::Books,
1407 inst_type: None,
1408 inst_family: None,
1409 inst_id: Some(instrument_id.symbol.inner()),
1410 };
1411 self.subscribe(vec![arg]).await
1412 }
1413
1414 pub async fn subscribe_book_depth5(
1426 &self,
1427 instrument_id: InstrumentId,
1428 ) -> Result<(), OKXWsError> {
1429 let arg = OKXSubscriptionArg {
1430 channel: OKXWsChannel::Books5,
1431 inst_type: None,
1432 inst_family: None,
1433 inst_id: Some(instrument_id.symbol.inner()),
1434 };
1435 self.subscribe(vec![arg]).await
1436 }
1437
1438 pub async fn subscribe_book50_l2_tbt(
1450 &self,
1451 instrument_id: InstrumentId,
1452 ) -> Result<(), OKXWsError> {
1453 let arg = OKXSubscriptionArg {
1454 channel: OKXWsChannel::Books50Tbt,
1455 inst_type: None,
1456 inst_family: None,
1457 inst_id: Some(instrument_id.symbol.inner()),
1458 };
1459 self.subscribe(vec![arg]).await
1460 }
1461
1462 pub async fn subscribe_book_l2_tbt(
1474 &self,
1475 instrument_id: InstrumentId,
1476 ) -> Result<(), OKXWsError> {
1477 let arg = OKXSubscriptionArg {
1478 channel: OKXWsChannel::BooksTbt,
1479 inst_type: None,
1480 inst_family: None,
1481 inst_id: Some(instrument_id.symbol.inner()),
1482 };
1483 self.subscribe(vec![arg]).await
1484 }
1485
1486 pub async fn subscribe_book_with_depth(
1500 &self,
1501 instrument_id: InstrumentId,
1502 depth: u16,
1503 ) -> anyhow::Result<()> {
1504 let vip = self.vip_level();
1505
1506 match depth {
1507 50 => {
1508 if vip < OKXVipLevel::Vip4 {
1509 anyhow::bail!(
1510 "VIP level {vip} insufficient for 50 depth subscription (requires VIP4)"
1511 );
1512 }
1513 self.subscribe_book50_l2_tbt(instrument_id)
1514 .await
1515 .map_err(|e| anyhow::anyhow!(e))
1516 }
1517 0 | 400 => {
1518 if vip >= OKXVipLevel::Vip5 {
1519 self.subscribe_book_l2_tbt(instrument_id)
1520 .await
1521 .map_err(|e| anyhow::anyhow!(e))
1522 } else {
1523 self.subscribe_books_channel(instrument_id)
1524 .await
1525 .map_err(|e| anyhow::anyhow!(e))
1526 }
1527 }
1528 _ => anyhow::bail!("Invalid depth {depth}, must be 0, 50, or 400"),
1529 }
1530 }
1531
1532 pub async fn subscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1547 let (inst_type, _) = self.get_instrument_type_and_family(instrument_id.symbol.inner())?;
1548
1549 let channel = if inst_type == OKXInstrumentType::Spot {
1552 OKXWsChannel::Tickers
1553 } else {
1554 OKXWsChannel::BboTbt
1555 };
1556
1557 let arg = OKXSubscriptionArg {
1558 channel,
1559 inst_type: None,
1560 inst_family: None,
1561 inst_id: Some(instrument_id.symbol.inner()),
1562 };
1563 self.subscribe(vec![arg]).await
1564 }
1565
1566 pub async fn subscribe_trades(
1576 &self,
1577 instrument_id: InstrumentId,
1578 _aggregated: bool, ) -> Result<(), OKXWsError> {
1580 let channel = OKXWsChannel::Trades;
1585
1586 let arg = OKXSubscriptionArg {
1587 channel,
1588 inst_type: None,
1589 inst_family: None,
1590 inst_id: Some(instrument_id.symbol.inner()),
1591 };
1592 self.subscribe(vec![arg]).await
1593 }
1594
1595 pub async fn subscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1607 let arg = OKXSubscriptionArg {
1608 channel: OKXWsChannel::Tickers,
1609 inst_type: None,
1610 inst_family: None,
1611 inst_id: Some(instrument_id.symbol.inner()),
1612 };
1613 self.subscribe(vec![arg]).await
1614 }
1615
1616 pub async fn subscribe_mark_prices(
1628 &self,
1629 instrument_id: InstrumentId,
1630 ) -> Result<(), OKXWsError> {
1631 let arg = OKXSubscriptionArg {
1632 channel: OKXWsChannel::MarkPrice,
1633 inst_type: None,
1634 inst_family: None,
1635 inst_id: Some(instrument_id.symbol.inner()),
1636 };
1637 self.subscribe(vec![arg]).await
1638 }
1639
1640 pub async fn subscribe_index_prices(
1652 &self,
1653 instrument_id: InstrumentId,
1654 ) -> Result<(), OKXWsError> {
1655 let arg = OKXSubscriptionArg {
1656 channel: OKXWsChannel::IndexTickers,
1657 inst_type: None,
1658 inst_family: None,
1659 inst_id: Some(instrument_id.symbol.inner()),
1660 };
1661 self.subscribe(vec![arg]).await
1662 }
1663
1664 pub async fn subscribe_funding_rates(
1676 &self,
1677 instrument_id: InstrumentId,
1678 ) -> Result<(), OKXWsError> {
1679 let arg = OKXSubscriptionArg {
1680 channel: OKXWsChannel::FundingRate,
1681 inst_type: None,
1682 inst_family: None,
1683 inst_id: Some(instrument_id.symbol.inner()),
1684 };
1685 self.subscribe(vec![arg]).await
1686 }
1687
1688 pub async fn subscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1700 let channel = bar_spec_as_okx_channel(bar_type.spec())
1702 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1703
1704 let arg = OKXSubscriptionArg {
1705 channel,
1706 inst_type: None,
1707 inst_family: None,
1708 inst_id: Some(bar_type.instrument_id().symbol.inner()),
1709 };
1710 self.subscribe(vec![arg]).await
1711 }
1712
1713 pub async fn unsubscribe_instruments(
1719 &self,
1720 instrument_type: OKXInstrumentType,
1721 ) -> Result<(), OKXWsError> {
1722 let arg = OKXSubscriptionArg {
1723 channel: OKXWsChannel::Instruments,
1724 inst_type: Some(instrument_type),
1725 inst_family: None,
1726 inst_id: None,
1727 };
1728 self.unsubscribe(vec![arg]).await
1729 }
1730
1731 pub async fn unsubscribe_instrument(
1737 &self,
1738 instrument_id: InstrumentId,
1739 ) -> Result<(), OKXWsError> {
1740 let arg = OKXSubscriptionArg {
1741 channel: OKXWsChannel::Instruments,
1742 inst_type: None,
1743 inst_family: None,
1744 inst_id: Some(instrument_id.symbol.inner()),
1745 };
1746 self.unsubscribe(vec![arg]).await
1747 }
1748
1749 pub async fn unsubscribe_book(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1755 let arg = OKXSubscriptionArg {
1756 channel: OKXWsChannel::Books,
1757 inst_type: None,
1758 inst_family: None,
1759 inst_id: Some(instrument_id.symbol.inner()),
1760 };
1761 self.unsubscribe(vec![arg]).await
1762 }
1763
1764 pub async fn unsubscribe_book_depth5(
1770 &self,
1771 instrument_id: InstrumentId,
1772 ) -> Result<(), OKXWsError> {
1773 let arg = OKXSubscriptionArg {
1774 channel: OKXWsChannel::Books5,
1775 inst_type: None,
1776 inst_family: None,
1777 inst_id: Some(instrument_id.symbol.inner()),
1778 };
1779 self.unsubscribe(vec![arg]).await
1780 }
1781
1782 pub async fn unsubscribe_book50_l2_tbt(
1788 &self,
1789 instrument_id: InstrumentId,
1790 ) -> Result<(), OKXWsError> {
1791 let arg = OKXSubscriptionArg {
1792 channel: OKXWsChannel::Books50Tbt,
1793 inst_type: None,
1794 inst_family: None,
1795 inst_id: Some(instrument_id.symbol.inner()),
1796 };
1797 self.unsubscribe(vec![arg]).await
1798 }
1799
1800 pub async fn unsubscribe_book_l2_tbt(
1806 &self,
1807 instrument_id: InstrumentId,
1808 ) -> Result<(), OKXWsError> {
1809 let arg = OKXSubscriptionArg {
1810 channel: OKXWsChannel::BooksTbt,
1811 inst_type: None,
1812 inst_family: None,
1813 inst_id: Some(instrument_id.symbol.inner()),
1814 };
1815 self.unsubscribe(vec![arg]).await
1816 }
1817
1818 pub async fn unsubscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1824 let (inst_type, _) = self.get_instrument_type_and_family(instrument_id.symbol.inner())?;
1825
1826 let channel = if inst_type == OKXInstrumentType::Spot {
1829 OKXWsChannel::Tickers
1830 } else {
1831 OKXWsChannel::BboTbt
1832 };
1833
1834 let arg = OKXSubscriptionArg {
1835 channel,
1836 inst_type: None,
1837 inst_family: None,
1838 inst_id: Some(instrument_id.symbol.inner()),
1839 };
1840 self.unsubscribe(vec![arg]).await
1841 }
1842
1843 pub async fn unsubscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1849 let arg = OKXSubscriptionArg {
1850 channel: OKXWsChannel::Tickers,
1851 inst_type: None,
1852 inst_family: None,
1853 inst_id: Some(instrument_id.symbol.inner()),
1854 };
1855 self.unsubscribe(vec![arg]).await
1856 }
1857
1858 pub async fn unsubscribe_mark_prices(
1864 &self,
1865 instrument_id: InstrumentId,
1866 ) -> Result<(), OKXWsError> {
1867 let arg = OKXSubscriptionArg {
1868 channel: OKXWsChannel::MarkPrice,
1869 inst_type: None,
1870 inst_family: None,
1871 inst_id: Some(instrument_id.symbol.inner()),
1872 };
1873 self.unsubscribe(vec![arg]).await
1874 }
1875
1876 pub async fn unsubscribe_index_prices(
1882 &self,
1883 instrument_id: InstrumentId,
1884 ) -> Result<(), OKXWsError> {
1885 let arg = OKXSubscriptionArg {
1886 channel: OKXWsChannel::IndexTickers,
1887 inst_type: None,
1888 inst_family: None,
1889 inst_id: Some(instrument_id.symbol.inner()),
1890 };
1891 self.unsubscribe(vec![arg]).await
1892 }
1893
1894 pub async fn unsubscribe_funding_rates(
1900 &self,
1901 instrument_id: InstrumentId,
1902 ) -> Result<(), OKXWsError> {
1903 let arg = OKXSubscriptionArg {
1904 channel: OKXWsChannel::FundingRate,
1905 inst_type: None,
1906 inst_family: None,
1907 inst_id: Some(instrument_id.symbol.inner()),
1908 };
1909 self.unsubscribe(vec![arg]).await
1910 }
1911
1912 pub async fn unsubscribe_trades(
1918 &self,
1919 instrument_id: InstrumentId,
1920 _aggregated: bool,
1921 ) -> Result<(), OKXWsError> {
1922 let channel = OKXWsChannel::Trades;
1924
1925 let arg = OKXSubscriptionArg {
1926 channel,
1927 inst_type: None,
1928 inst_family: None,
1929 inst_id: Some(instrument_id.symbol.inner()),
1930 };
1931 self.unsubscribe(vec![arg]).await
1932 }
1933
1934 pub async fn unsubscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1940 let channel = bar_spec_as_okx_channel(bar_type.spec())
1942 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1943
1944 let arg = OKXSubscriptionArg {
1945 channel,
1946 inst_type: None,
1947 inst_family: None,
1948 inst_id: Some(bar_type.instrument_id().symbol.inner()),
1949 };
1950 self.unsubscribe(vec![arg]).await
1951 }
1952
1953 pub async fn subscribe_orders(
1959 &self,
1960 instrument_type: OKXInstrumentType,
1961 ) -> Result<(), OKXWsError> {
1962 let arg = OKXSubscriptionArg {
1963 channel: OKXWsChannel::Orders,
1964 inst_type: Some(instrument_type),
1965 inst_family: None,
1966 inst_id: None,
1967 };
1968 self.subscribe(vec![arg]).await
1969 }
1970
1971 pub async fn unsubscribe_orders(
1977 &self,
1978 instrument_type: OKXInstrumentType,
1979 ) -> Result<(), OKXWsError> {
1980 let arg = OKXSubscriptionArg {
1981 channel: OKXWsChannel::Orders,
1982 inst_type: Some(instrument_type),
1983 inst_family: None,
1984 inst_id: None,
1985 };
1986 self.unsubscribe(vec![arg]).await
1987 }
1988
1989 pub async fn subscribe_orders_algo(
1995 &self,
1996 instrument_type: OKXInstrumentType,
1997 ) -> Result<(), OKXWsError> {
1998 let arg = OKXSubscriptionArg {
1999 channel: OKXWsChannel::OrdersAlgo,
2000 inst_type: Some(instrument_type),
2001 inst_family: None,
2002 inst_id: None,
2003 };
2004 self.subscribe(vec![arg]).await
2005 }
2006
2007 pub async fn unsubscribe_orders_algo(
2013 &self,
2014 instrument_type: OKXInstrumentType,
2015 ) -> Result<(), OKXWsError> {
2016 let arg = OKXSubscriptionArg {
2017 channel: OKXWsChannel::OrdersAlgo,
2018 inst_type: Some(instrument_type),
2019 inst_family: None,
2020 inst_id: None,
2021 };
2022 self.unsubscribe(vec![arg]).await
2023 }
2024
2025 pub async fn subscribe_fills(
2031 &self,
2032 instrument_type: OKXInstrumentType,
2033 ) -> Result<(), OKXWsError> {
2034 let arg = OKXSubscriptionArg {
2035 channel: OKXWsChannel::Fills,
2036 inst_type: Some(instrument_type),
2037 inst_family: None,
2038 inst_id: None,
2039 };
2040 self.subscribe(vec![arg]).await
2041 }
2042
2043 pub async fn unsubscribe_fills(
2049 &self,
2050 instrument_type: OKXInstrumentType,
2051 ) -> Result<(), OKXWsError> {
2052 let arg = OKXSubscriptionArg {
2053 channel: OKXWsChannel::Fills,
2054 inst_type: Some(instrument_type),
2055 inst_family: None,
2056 inst_id: None,
2057 };
2058 self.unsubscribe(vec![arg]).await
2059 }
2060
2061 pub async fn subscribe_account(&self) -> Result<(), OKXWsError> {
2067 let arg = OKXSubscriptionArg {
2068 channel: OKXWsChannel::Account,
2069 inst_type: None,
2070 inst_family: None,
2071 inst_id: None,
2072 };
2073 self.subscribe(vec![arg]).await
2074 }
2075
2076 pub async fn unsubscribe_account(&self) -> Result<(), OKXWsError> {
2082 let arg = OKXSubscriptionArg {
2083 channel: OKXWsChannel::Account,
2084 inst_type: None,
2085 inst_family: None,
2086 inst_id: None,
2087 };
2088 self.unsubscribe(vec![arg]).await
2089 }
2090
2091 async fn ws_cancel_order(
2097 &self,
2098 params: WsCancelOrderParams,
2099 request_id: Option<String>,
2100 ) -> Result<(), OKXWsError> {
2101 let request_id = request_id.unwrap_or(self.generate_unique_request_id());
2102
2103 let req = OKXWsRequest {
2104 id: Some(request_id),
2105 op: OKXWsOperation::CancelOrder,
2106 args: vec![params],
2107 exp_time: None,
2108 };
2109
2110 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2111
2112 {
2113 let inner_guard = self.inner.read().await;
2114 if let Some(inner) = &*inner_guard {
2115 if let Err(e) = inner.send_text(txt, Some(vec!["cancel".to_string()])).await {
2116 tracing::error!("Error sending message: {e:?}");
2117 }
2118 Ok(())
2119 } else {
2120 Err(OKXWsError::ClientError("Not connected".to_string()))
2121 }
2122 }
2123 }
2124
2125 async fn ws_mass_cancel_with_id(
2131 &self,
2132 args: Vec<Value>,
2133 request_id: String,
2134 ) -> Result<(), OKXWsError> {
2135 let req = OKXWsRequest {
2136 id: Some(request_id),
2137 op: OKXWsOperation::MassCancel,
2138 args,
2139 exp_time: None,
2140 };
2141
2142 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2143
2144 {
2145 let inner_guard = self.inner.read().await;
2146 if let Some(inner) = &*inner_guard {
2147 if let Err(e) = inner.send_text(txt, Some(vec!["cancel".to_string()])).await {
2148 tracing::error!("Error sending message: {e:?}");
2149 }
2150 Ok(())
2151 } else {
2152 Err(OKXWsError::ClientError("Not connected".to_string()))
2153 }
2154 }
2155 }
2156
2157 async fn ws_amend_order(
2163 &self,
2164 params: WsAmendOrderParams,
2165 request_id: Option<String>,
2166 ) -> Result<(), OKXWsError> {
2167 let request_id = request_id.unwrap_or(self.generate_unique_request_id());
2168
2169 let req = OKXWsRequest {
2170 id: Some(request_id),
2171 op: OKXWsOperation::AmendOrder,
2172 args: vec![params],
2173 exp_time: None,
2174 };
2175
2176 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2177
2178 {
2179 let inner_guard = self.inner.read().await;
2180 if let Some(inner) = &*inner_guard {
2181 if let Err(e) = inner.send_text(txt, Some(vec!["amend".to_string()])).await {
2182 tracing::error!("Error sending message: {e:?}");
2183 }
2184 Ok(())
2185 } else {
2186 Err(OKXWsError::ClientError("Not connected".to_string()))
2187 }
2188 }
2189 }
2190
2191 async fn ws_batch_place_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
2197 let request_id = self.generate_unique_request_id();
2198
2199 let req = OKXWsRequest {
2200 id: Some(request_id),
2201 op: OKXWsOperation::BatchOrders,
2202 args,
2203 exp_time: None,
2204 };
2205
2206 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2207
2208 {
2209 let inner_guard = self.inner.read().await;
2210 if let Some(inner) = &*inner_guard {
2211 if let Err(e) = inner.send_text(txt, Some(vec!["order".to_string()])).await {
2212 tracing::error!("Error sending message: {e:?}");
2213 }
2214 Ok(())
2215 } else {
2216 Err(OKXWsError::ClientError("Not connected".to_string()))
2217 }
2218 }
2219 }
2220
2221 async fn ws_batch_cancel_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
2227 let request_id = self.generate_unique_request_id();
2228
2229 let req = OKXWsRequest {
2230 id: Some(request_id),
2231 op: OKXWsOperation::BatchCancelOrders,
2232 args,
2233 exp_time: None,
2234 };
2235
2236 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2237
2238 {
2239 let inner_guard = self.inner.read().await;
2240 if let Some(inner) = &*inner_guard {
2241 if let Err(e) = inner.send_text(txt, Some(vec!["cancel".to_string()])).await {
2242 tracing::error!("Error sending message: {e:?}");
2243 }
2244 Ok(())
2245 } else {
2246 Err(OKXWsError::ClientError("Not connected".to_string()))
2247 }
2248 }
2249 }
2250
2251 async fn ws_batch_amend_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
2257 let request_id = self.generate_unique_request_id();
2258
2259 let req = OKXWsRequest {
2260 id: Some(request_id),
2261 op: OKXWsOperation::BatchAmendOrders,
2262 args,
2263 exp_time: None,
2264 };
2265
2266 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2267
2268 {
2269 let inner_guard = self.inner.read().await;
2270 if let Some(inner) = &*inner_guard {
2271 if let Err(e) = inner.send_text(txt, Some(vec!["amend".to_string()])).await {
2272 tracing::error!("Error sending message: {e:?}");
2273 }
2274 Ok(())
2275 } else {
2276 Err(OKXWsError::ClientError("Not connected".to_string()))
2277 }
2278 }
2279 }
2280
2281 #[allow(clippy::too_many_arguments)]
2293 pub async fn submit_order(
2294 &self,
2295 trader_id: TraderId,
2296 strategy_id: StrategyId,
2297 instrument_id: InstrumentId,
2298 td_mode: OKXTradeMode,
2299 client_order_id: ClientOrderId,
2300 order_side: OrderSide,
2301 order_type: OrderType,
2302 quantity: Quantity,
2303 time_in_force: Option<TimeInForce>,
2304 price: Option<Price>,
2305 trigger_price: Option<Price>,
2306 post_only: Option<bool>,
2307 reduce_only: Option<bool>,
2308 quote_quantity: Option<bool>,
2309 position_side: Option<PositionSide>,
2310 ) -> Result<(), OKXWsError> {
2311 if !OKX_SUPPORTED_ORDER_TYPES.contains(&order_type) {
2312 return Err(OKXWsError::ClientError(format!(
2313 "Unsupported order type: {order_type:?}",
2314 )));
2315 }
2316
2317 if let Some(tif) = time_in_force
2318 && !OKX_SUPPORTED_TIME_IN_FORCE.contains(&tif)
2319 {
2320 return Err(OKXWsError::ClientError(format!(
2321 "Unsupported time in force: {tif:?}",
2322 )));
2323 }
2324
2325 let mut builder = WsPostOrderParamsBuilder::default();
2326
2327 builder.inst_id(instrument_id.symbol.as_str());
2328 builder.td_mode(td_mode);
2329 builder.cl_ord_id(client_order_id.as_str());
2330
2331 let instrument = self
2332 .instruments_cache
2333 .get(&instrument_id.symbol.inner())
2334 .ok_or_else(|| {
2335 OKXWsError::ClientError(format!("Unknown instrument {instrument_id}"))
2336 })?;
2337
2338 let instrument_type =
2339 okx_instrument_type(instrument).map_err(|e| OKXWsError::ClientError(e.to_string()))?;
2340 let quote_currency = instrument.quote_currency();
2341
2342 match instrument_type {
2343 OKXInstrumentType::Spot => {
2344 builder.ccy(quote_currency.to_string());
2346 }
2347 OKXInstrumentType::Margin => {
2348 builder.ccy(quote_currency.to_string());
2350
2351 if let Some(ro) = reduce_only
2353 && ro
2354 {
2355 builder.reduce_only(ro);
2356 }
2357 }
2358 OKXInstrumentType::Swap | OKXInstrumentType::Futures => {
2359 builder.ccy(quote_currency.to_string());
2361 }
2362 _ => {
2363 builder.ccy(quote_currency.to_string());
2365
2366 if let Some(ro) = reduce_only
2368 && ro
2369 {
2370 builder.reduce_only(ro);
2371 }
2372 }
2373 };
2374
2375 if instrument_type == OKXInstrumentType::Spot && order_type == OrderType::Market {
2376 match quote_quantity {
2381 Some(true) => {
2382 builder.tgt_ccy(OKX_TARGET_CCY_QUOTE.to_string());
2383 }
2384 Some(false) => {
2385 if order_side == OrderSide::Buy {
2386 builder.tgt_ccy(OKX_TARGET_CCY_BASE.to_string());
2388 }
2389 }
2391 None => {
2392 }
2394 }
2395 }
2396
2397 builder.side(order_side);
2398
2399 if let Some(pos_side) = position_side {
2400 builder.pos_side(pos_side);
2401 };
2402
2403 let okx_ord_type = if post_only.unwrap_or(false) {
2405 OKXOrderType::PostOnly
2406 } else {
2407 OKXOrderType::from(order_type)
2408 };
2409
2410 log::debug!(
2411 "Order type mapping: order_type={:?}, time_in_force={:?}, post_only={:?} -> okx_ord_type={:?}",
2412 order_type,
2413 time_in_force,
2414 post_only,
2415 okx_ord_type
2416 );
2417
2418 builder.ord_type(okx_ord_type);
2419 builder.sz(quantity.to_string());
2420
2421 if let Some(tp) = trigger_price {
2422 builder.px(tp.to_string());
2423 } else if let Some(p) = price {
2424 builder.px(p.to_string());
2425 }
2426
2427 builder.tag(OKX_NAUTILUS_BROKER_ID);
2428
2429 let params = builder
2430 .build()
2431 .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
2432
2433 log::debug!("Sending order params to OKX: {:?}", params);
2435
2436 let request_id = self.generate_unique_request_id();
2437
2438 self.pending_place_requests.insert(
2439 request_id.clone(),
2440 (client_order_id, trader_id, strategy_id, instrument_id),
2441 );
2442
2443 self.active_client_orders
2444 .insert(client_order_id, (trader_id, strategy_id, instrument_id));
2445
2446 self.retry_manager
2447 .execute_with_retry_with_cancel(
2448 "submit_order",
2449 || {
2450 let params = params.clone();
2451 let request_id = request_id.clone();
2452 async move { self.ws_place_order(params, Some(request_id)).await }
2453 },
2454 should_retry_okx_error,
2455 create_okx_timeout_error,
2456 &self.cancellation_token,
2457 )
2458 .await
2459 }
2460
2461 #[allow(clippy::too_many_arguments)]
2472 pub async fn cancel_order(
2473 &self,
2474 trader_id: TraderId,
2475 strategy_id: StrategyId,
2476 instrument_id: InstrumentId,
2477 client_order_id: Option<ClientOrderId>,
2478 venue_order_id: Option<VenueOrderId>,
2479 ) -> Result<(), OKXWsError> {
2480 let mut builder = WsCancelOrderParamsBuilder::default();
2481 builder.inst_id(instrument_id.symbol.as_str());
2484
2485 if let Some(venue_order_id) = venue_order_id {
2486 builder.ord_id(venue_order_id.as_str());
2487 }
2488
2489 if let Some(client_order_id) = client_order_id {
2491 builder.cl_ord_id(client_order_id.as_str());
2492 }
2493
2494 let params = builder
2495 .build()
2496 .map_err(|e| OKXWsError::ClientError(format!("Build cancel params error: {e}")))?;
2497
2498 let request_id = self.generate_unique_request_id();
2499
2500 if let Some(client_order_id) = client_order_id {
2503 self.pending_cancel_requests.insert(
2504 request_id.clone(),
2505 (
2506 client_order_id,
2507 trader_id,
2508 strategy_id,
2509 instrument_id,
2510 venue_order_id,
2511 ),
2512 );
2513 }
2514
2515 self.retry_manager
2516 .execute_with_retry_with_cancel(
2517 "cancel_order",
2518 || {
2519 let params = params.clone();
2520 let request_id = request_id.clone();
2521 async move { self.ws_cancel_order(params, Some(request_id)).await }
2522 },
2523 should_retry_okx_error,
2524 create_okx_timeout_error,
2525 &self.cancellation_token,
2526 )
2527 .await
2528 }
2529
2530 async fn ws_place_order(
2536 &self,
2537 params: WsPostOrderParams,
2538 request_id: Option<String>,
2539 ) -> Result<(), OKXWsError> {
2540 let request_id = request_id.unwrap_or(self.generate_unique_request_id());
2541
2542 let req = OKXWsRequest {
2543 id: Some(request_id),
2544 op: OKXWsOperation::Order,
2545 exp_time: None,
2546 args: vec![params],
2547 };
2548
2549 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2550
2551 {
2552 let inner_guard = self.inner.read().await;
2553 if let Some(inner) = &*inner_guard {
2554 if let Err(e) = inner.send_text(txt, Some(vec!["order".to_string()])).await {
2555 tracing::error!("Error sending message: {e:?}");
2556 }
2557 Ok(())
2558 } else {
2559 Err(OKXWsError::ClientError("Not connected".to_string()))
2560 }
2561 }
2562 }
2563
2564 #[allow(clippy::too_many_arguments)]
2575 pub async fn modify_order(
2576 &self,
2577 trader_id: TraderId,
2578 strategy_id: StrategyId,
2579 instrument_id: InstrumentId,
2580 client_order_id: Option<ClientOrderId>,
2581 price: Option<Price>,
2582 quantity: Option<Quantity>,
2583 venue_order_id: Option<VenueOrderId>,
2584 ) -> Result<(), OKXWsError> {
2585 let mut builder = WsAmendOrderParamsBuilder::default();
2586
2587 builder.inst_id(instrument_id.symbol.as_str());
2588
2589 if let Some(venue_order_id) = venue_order_id {
2590 builder.ord_id(venue_order_id.as_str());
2591 }
2592
2593 if let Some(client_order_id) = client_order_id {
2594 builder.cl_ord_id(client_order_id.as_str());
2595 }
2596
2597 if let Some(price) = price {
2598 builder.new_px(price.to_string());
2599 }
2600
2601 if let Some(quantity) = quantity {
2602 builder.new_sz(quantity.to_string());
2603 }
2604
2605 let params = builder
2606 .build()
2607 .map_err(|e| OKXWsError::ClientError(format!("Build amend params error: {e}")))?;
2608
2609 let request_id = self
2611 .request_id_counter
2612 .fetch_add(1, Ordering::SeqCst)
2613 .to_string();
2614
2615 if let Some(client_order_id) = client_order_id {
2618 self.pending_amend_requests.insert(
2619 request_id.clone(),
2620 (
2621 client_order_id,
2622 trader_id,
2623 strategy_id,
2624 instrument_id,
2625 venue_order_id,
2626 ),
2627 );
2628 }
2629
2630 self.retry_manager
2631 .execute_with_retry_with_cancel(
2632 "modify_order",
2633 || {
2634 let params = params.clone();
2635 let request_id = request_id.clone();
2636 async move { self.ws_amend_order(params, Some(request_id)).await }
2637 },
2638 should_retry_okx_error,
2639 create_okx_timeout_error,
2640 &self.cancellation_token,
2641 )
2642 .await
2643 }
2644
2645 #[allow(clippy::type_complexity)]
2652 #[allow(clippy::too_many_arguments)]
2653 pub async fn batch_submit_orders(
2654 &self,
2655 orders: Vec<(
2656 OKXInstrumentType,
2657 InstrumentId,
2658 OKXTradeMode,
2659 ClientOrderId,
2660 OrderSide,
2661 Option<PositionSide>,
2662 OrderType,
2663 Quantity,
2664 Option<Price>,
2665 Option<Price>,
2666 Option<bool>,
2667 Option<bool>,
2668 )>,
2669 ) -> Result<(), OKXWsError> {
2670 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2671 for (
2672 inst_type,
2673 inst_id,
2674 td_mode,
2675 cl_ord_id,
2676 ord_side,
2677 pos_side,
2678 ord_type,
2679 qty,
2680 pr,
2681 tp,
2682 post_only,
2683 reduce_only,
2684 ) in orders
2685 {
2686 let mut builder = WsPostOrderParamsBuilder::default();
2687 builder.inst_type(inst_type);
2688 builder.inst_id(inst_id.symbol.inner());
2689 builder.td_mode(td_mode);
2690 builder.cl_ord_id(cl_ord_id.as_str());
2691 builder.side(ord_side);
2692
2693 if let Some(ps) = pos_side {
2694 builder.pos_side(OKXPositionSide::from(ps));
2695 }
2696
2697 let okx_ord_type = if post_only.unwrap_or(false) {
2698 OKXOrderType::PostOnly
2699 } else {
2700 OKXOrderType::from(ord_type)
2701 };
2702
2703 builder.ord_type(okx_ord_type);
2704 builder.sz(qty.to_string());
2705
2706 if let Some(p) = pr {
2707 builder.px(p.to_string());
2708 } else if let Some(p) = tp {
2709 builder.px(p.to_string());
2710 }
2711
2712 if let Some(ro) = reduce_only {
2713 builder.reduce_only(ro);
2714 }
2715
2716 builder.tag(OKX_NAUTILUS_BROKER_ID);
2717
2718 let params = builder
2719 .build()
2720 .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
2721 let val =
2722 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2723 args.push(val);
2724 }
2725
2726 self.ws_batch_place_orders(args).await
2727 }
2728
2729 #[allow(clippy::type_complexity)]
2736 pub async fn batch_cancel_orders(
2737 &self,
2738 orders: Vec<(
2739 OKXInstrumentType,
2740 InstrumentId,
2741 Option<ClientOrderId>,
2742 Option<String>,
2743 )>,
2744 ) -> Result<(), OKXWsError> {
2745 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2746 for (_inst_type, inst_id, cl_ord_id, ord_id) in orders {
2747 let mut builder = WsCancelOrderParamsBuilder::default();
2748 builder.inst_id(inst_id.symbol.inner());
2750
2751 if let Some(c) = cl_ord_id {
2752 builder.cl_ord_id(c.as_str());
2753 }
2754
2755 if let Some(o) = ord_id {
2756 builder.ord_id(o);
2757 }
2758
2759 let params = builder.build().map_err(|e| {
2760 OKXWsError::ClientError(format!("Build cancel batch params error: {e}"))
2761 })?;
2762 let val =
2763 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2764 args.push(val);
2765 }
2766
2767 self.ws_batch_cancel_orders(args).await
2768 }
2769
2770 pub async fn mass_cancel_orders(&self, inst_id: InstrumentId) -> Result<(), OKXWsError> {
2784 let (inst_type, inst_family) =
2785 self.get_instrument_type_and_family(inst_id.symbol.inner())?;
2786
2787 let params = WsMassCancelParams {
2788 inst_type,
2789 inst_family: Ustr::from(&inst_family),
2790 };
2791
2792 let args =
2793 vec![serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?];
2794
2795 let request_id = self.generate_unique_request_id();
2796
2797 self.pending_mass_cancel_requests
2798 .insert(request_id.clone(), inst_id);
2799
2800 self.retry_manager
2801 .execute_with_retry_with_cancel(
2802 "mass_cancel_orders",
2803 || {
2804 let args = args.clone();
2805 let request_id = request_id.clone();
2806 async move { self.ws_mass_cancel_with_id(args, request_id).await }
2807 },
2808 should_retry_okx_error,
2809 create_okx_timeout_error,
2810 &self.cancellation_token,
2811 )
2812 .await
2813 }
2814
2815 #[allow(clippy::type_complexity)]
2822 #[allow(clippy::too_many_arguments)]
2823 pub async fn batch_modify_orders(
2824 &self,
2825 orders: Vec<(
2826 OKXInstrumentType,
2827 InstrumentId,
2828 ClientOrderId,
2829 ClientOrderId,
2830 Option<Price>,
2831 Option<Quantity>,
2832 )>,
2833 ) -> Result<(), OKXWsError> {
2834 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2835 for (_inst_type, inst_id, cl_ord_id, new_cl_ord_id, pr, sz) in orders {
2836 let mut builder = WsAmendOrderParamsBuilder::default();
2837 builder.inst_id(inst_id.symbol.inner());
2839 builder.cl_ord_id(cl_ord_id.as_str());
2840 builder.new_cl_ord_id(new_cl_ord_id.as_str());
2841
2842 if let Some(p) = pr {
2843 builder.new_px(p.to_string());
2844 }
2845
2846 if let Some(q) = sz {
2847 builder.new_sz(q.to_string());
2848 }
2849
2850 let params = builder.build().map_err(|e| {
2851 OKXWsError::ClientError(format!("Build amend batch params error: {e}"))
2852 })?;
2853 let val =
2854 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2855 args.push(val);
2856 }
2857
2858 self.ws_batch_amend_orders(args).await
2859 }
2860
2861 #[allow(clippy::too_many_arguments)]
2872 pub async fn submit_algo_order(
2873 &self,
2874 trader_id: TraderId,
2875 strategy_id: StrategyId,
2876 instrument_id: InstrumentId,
2877 td_mode: OKXTradeMode,
2878 client_order_id: ClientOrderId,
2879 order_side: OrderSide,
2880 order_type: OrderType,
2881 quantity: Quantity,
2882 trigger_price: Price,
2883 trigger_type: Option<TriggerType>,
2884 limit_price: Option<Price>,
2885 reduce_only: Option<bool>,
2886 ) -> Result<(), OKXWsError> {
2887 if !is_conditional_order(order_type) {
2888 return Err(OKXWsError::ClientError(format!(
2889 "Order type {order_type:?} is not a conditional order"
2890 )));
2891 }
2892
2893 let mut builder = WsPostAlgoOrderParamsBuilder::default();
2894 if !matches!(order_side, OrderSide::Buy | OrderSide::Sell) {
2895 return Err(OKXWsError::ClientError(
2896 "Invalid order side for OKX".to_string(),
2897 ));
2898 }
2899
2900 builder.inst_id(instrument_id.symbol.inner());
2901 builder.td_mode(td_mode);
2902 builder.cl_ord_id(client_order_id.as_str());
2903 builder.side(order_side);
2904 builder.ord_type(
2905 conditional_order_to_algo_type(order_type)
2906 .map_err(|e| OKXWsError::ClientError(e.to_string()))?,
2907 );
2908 builder.sz(quantity.to_string());
2909 builder.trigger_px(trigger_price.to_string());
2910
2911 let okx_trigger_type = trigger_type.map(Into::into).unwrap_or(OKXTriggerType::Last);
2913 builder.trigger_px_type(okx_trigger_type);
2914
2915 if matches!(order_type, OrderType::StopLimit | OrderType::LimitIfTouched)
2917 && let Some(price) = limit_price
2918 {
2919 builder.order_px(price.to_string());
2920 }
2921
2922 if let Some(reduce) = reduce_only {
2923 builder.reduce_only(reduce);
2924 }
2925
2926 builder.tag(OKX_NAUTILUS_BROKER_ID);
2927
2928 let params = builder
2929 .build()
2930 .map_err(|e| OKXWsError::ClientError(format!("Build algo order params error: {e}")))?;
2931
2932 let request_id = self.generate_unique_request_id();
2933
2934 self.pending_place_requests.insert(
2935 request_id.clone(),
2936 (client_order_id, trader_id, strategy_id, instrument_id),
2937 );
2938
2939 self.retry_manager
2940 .execute_with_retry_with_cancel(
2941 "submit_algo_order",
2942 || {
2943 let params = params.clone();
2944 let request_id = request_id.clone();
2945 async move { self.ws_place_algo_order(params, Some(request_id)).await }
2946 },
2947 should_retry_okx_error,
2948 create_okx_timeout_error,
2949 &self.cancellation_token,
2950 )
2951 .await
2952 }
2953
2954 pub async fn cancel_algo_order(
2965 &self,
2966 trader_id: TraderId,
2967 strategy_id: StrategyId,
2968 instrument_id: InstrumentId,
2969 client_order_id: Option<ClientOrderId>,
2970 algo_order_id: Option<String>,
2971 ) -> Result<(), OKXWsError> {
2972 let mut builder = WsCancelAlgoOrderParamsBuilder::default();
2973 builder.inst_id(instrument_id.symbol.inner());
2974
2975 if let Some(client_order_id) = client_order_id {
2976 builder.algo_cl_ord_id(client_order_id.as_str());
2977 }
2978
2979 if let Some(algo_id) = algo_order_id {
2980 builder.algo_id(algo_id);
2981 }
2982
2983 let params = builder
2984 .build()
2985 .map_err(|e| OKXWsError::ClientError(format!("Build cancel algo params error: {e}")))?;
2986
2987 let request_id = self.generate_unique_request_id();
2988
2989 if let Some(client_order_id) = client_order_id {
2991 self.pending_cancel_requests.insert(
2992 request_id.clone(),
2993 (client_order_id, trader_id, strategy_id, instrument_id, None),
2994 );
2995 }
2996
2997 self.retry_manager
2998 .execute_with_retry_with_cancel(
2999 "cancel_algo_order",
3000 || {
3001 let params = params.clone();
3002 let request_id = request_id.clone();
3003 async move { self.ws_cancel_algo_order(params, Some(request_id)).await }
3004 },
3005 should_retry_okx_error,
3006 create_okx_timeout_error,
3007 &self.cancellation_token,
3008 )
3009 .await
3010 }
3011
3012 async fn ws_place_algo_order(
3014 &self,
3015 params: WsPostAlgoOrderParams,
3016 request_id: Option<String>,
3017 ) -> Result<(), OKXWsError> {
3018 let request_id = request_id.unwrap_or(self.generate_unique_request_id());
3019
3020 let req = OKXWsRequest {
3021 id: Some(request_id),
3022 op: OKXWsOperation::OrderAlgo,
3023 exp_time: None,
3024 args: vec![params],
3025 };
3026
3027 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
3028
3029 {
3030 let inner_guard = self.inner.read().await;
3031 if let Some(inner) = &*inner_guard {
3032 if let Err(e) = inner
3033 .send_text(txt, Some(vec!["orders-algo".to_string()]))
3034 .await
3035 {
3036 tracing::error!("Error sending algo order message: {e:?}");
3037 }
3038 Ok(())
3039 } else {
3040 Err(OKXWsError::ClientError("Not connected".to_string()))
3041 }
3042 }
3043 }
3044
3045 async fn ws_cancel_algo_order(
3047 &self,
3048 params: WsCancelAlgoOrderParams,
3049 request_id: Option<String>,
3050 ) -> Result<(), OKXWsError> {
3051 let request_id = request_id.unwrap_or(self.generate_unique_request_id());
3052
3053 let req = OKXWsRequest {
3054 id: Some(request_id),
3055 op: OKXWsOperation::CancelAlgos,
3056 exp_time: None,
3057 args: vec![params],
3058 };
3059
3060 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
3061
3062 {
3063 let inner_guard = self.inner.read().await;
3064 if let Some(inner) = &*inner_guard {
3065 if let Err(e) = inner
3066 .send_text(txt, Some(vec!["cancel-algos".to_string()]))
3067 .await
3068 {
3069 tracing::error!("Error sending cancel algo message: {e:?}");
3070 }
3071 Ok(())
3072 } else {
3073 Err(OKXWsError::ClientError("Not connected".to_string()))
3074 }
3075 }
3076 }
3077}
3078
3079struct OKXFeedHandler {
3080 receiver: UnboundedReceiver<Message>,
3081 signal: Arc<AtomicBool>,
3082}
3083
3084impl OKXFeedHandler {
3085 pub fn new(receiver: UnboundedReceiver<Message>, signal: Arc<AtomicBool>) -> Self {
3087 Self { receiver, signal }
3088 }
3089
3090 async fn next(&mut self) -> Option<OKXWebSocketEvent> {
3092 loop {
3093 tokio::select! {
3094 msg = self.receiver.recv() => match msg {
3095 Some(msg) => match msg {
3096 Message::Text(text) => {
3097 if text == TEXT_PONG {
3099 tracing::trace!("Received pong from OKX");
3100 continue;
3101 }
3102 if text == TEXT_PING {
3103 tracing::trace!("Received ping from OKX (text)");
3104 return Some(OKXWebSocketEvent::Ping);
3105 }
3106
3107 if text == RECONNECTED {
3109 tracing::info!("Received WebSocket reconnection signal");
3110 return Some(OKXWebSocketEvent::Reconnected);
3111 }
3112 tracing::trace!("Received WebSocket message: {text}");
3113
3114 match serde_json::from_str(&text) {
3115 Ok(ws_event) => match &ws_event {
3116 OKXWebSocketEvent::Error { code, msg } => {
3117 tracing::error!("WebSocket error: {code} - {msg}");
3118 return Some(ws_event);
3119 }
3120 OKXWebSocketEvent::Login {
3121 event,
3122 code,
3123 msg,
3124 conn_id,
3125 } => {
3126 if code == "0" {
3127 tracing::info!(
3128 "Successfully authenticated with OKX WebSocket, conn_id={conn_id}"
3129 );
3130 } else {
3131 tracing::error!(
3132 "Authentication failed: {event} {code} - {msg}"
3133 );
3134 }
3135 return Some(ws_event);
3136 }
3137 OKXWebSocketEvent::Subscription {
3138 event,
3139 arg,
3140 conn_id, .. } => {
3141 let channel_str = serde_json::to_string(&arg.channel)
3142 .expect("Invalid OKX websocket channel")
3143 .trim_matches('"')
3144 .to_string();
3145 tracing::debug!(
3146 "{event}d: channel={channel_str}, conn_id={conn_id}"
3147 );
3148 continue;
3149 }
3150 OKXWebSocketEvent::ChannelConnCount {
3151 event: _,
3152 channel,
3153 conn_count,
3154 conn_id,
3155 } => {
3156 let channel_str = serde_json::to_string(&channel)
3157 .expect("Invalid OKX websocket channel")
3158 .trim_matches('"')
3159 .to_string();
3160 tracing::debug!(
3161 "Channel connection status: channel={channel_str}, connections={conn_count}, conn_id={conn_id}",
3162 );
3163 continue;
3164 }
3165 OKXWebSocketEvent::Ping => {
3166 tracing::trace!("Ignoring ping event parsed from text payload");
3167 continue;
3168 }
3169 OKXWebSocketEvent::Data { .. } => return Some(ws_event),
3170 OKXWebSocketEvent::BookData { .. } => return Some(ws_event),
3171 OKXWebSocketEvent::OrderResponse {
3172 id,
3173 op,
3174 code,
3175 msg,
3176 data,
3177 } => {
3178 if code == "0" {
3179 tracing::debug!(
3180 "Order operation successful: id={:?}, op={op}, code={code}",
3181 id
3182 );
3183
3184 if let Some(order_data) = data.first() {
3186 let success_msg = order_data
3187 .get("sMsg")
3188 .and_then(|s| s.as_str())
3189 .unwrap_or("Order operation successful");
3190 tracing::debug!("Order success details: {success_msg}");
3191 }
3192 } else {
3193 let error_msg = data
3195 .first()
3196 .and_then(|d| d.get("sMsg"))
3197 .and_then(|s| s.as_str())
3198 .unwrap_or(msg.as_str());
3199 tracing::error!(
3200 "Order operation failed: id={id:?}, op={op}, code={code}, error={error_msg}",
3201 );
3202 }
3203 return Some(ws_event);
3204 }
3205 OKXWebSocketEvent::Reconnected => {
3206 tracing::warn!("Unexpected Reconnected event from deserialization");
3208 continue;
3209 }
3210 },
3211 Err(e) => {
3212 tracing::error!("Failed to parse message: {e}: {text}");
3213 return None;
3214 }
3215 }
3216 }
3217 Message::Ping(payload) => {
3218 tracing::trace!("Received ping frame from OKX ({} bytes)", payload.len());
3219 continue;
3220 }
3221 Message::Pong(payload) => {
3222 tracing::trace!("Received pong frame from OKX ({} bytes)", payload.len());
3223 continue;
3224 }
3225 Message::Binary(msg) => {
3226 tracing::debug!("Raw binary: {msg:?}");
3227 }
3228 Message::Close(_) => {
3229 tracing::debug!("Received close message");
3230 return None;
3231 }
3232 msg => {
3233 tracing::warn!("Unexpected message: {msg}");
3234 }
3235 }
3236 None => {
3237 tracing::info!("WebSocket stream closed");
3238 return None;
3239 }
3240 },
3241 _ = tokio::time::sleep(Duration::from_millis(1)) => {
3242 if self.signal.load(std::sync::atomic::Ordering::Relaxed) {
3243 tracing::debug!("Stop signal received");
3244 return None;
3245 }
3246 }
3247 }
3248 }
3249 }
3250}
3251
3252struct OKXWsMessageHandler {
3253 account_id: AccountId,
3254 inner: Arc<tokio::sync::RwLock<Option<WebSocketClient>>>,
3255 handler: OKXFeedHandler,
3256 #[allow(dead_code)]
3257 tx: tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>,
3258 pending_place_requests: Arc<DashMap<String, PlaceRequestData>>,
3259 pending_cancel_requests: Arc<DashMap<String, CancelRequestData>>,
3260 pending_amend_requests: Arc<DashMap<String, AmendRequestData>>,
3261 pending_mass_cancel_requests: Arc<DashMap<String, MassCancelRequestData>>,
3262 active_client_orders: Arc<DashMap<ClientOrderId, (TraderId, StrategyId, InstrumentId)>>,
3263 client_id_aliases: Arc<DashMap<ClientOrderId, ClientOrderId>>,
3264 instruments_cache: Arc<AHashMap<Ustr, InstrumentAny>>,
3265 last_account_state: Option<AccountState>,
3266 fee_cache: AHashMap<Ustr, Money>, funding_rate_cache: AHashMap<Ustr, (Ustr, u64)>, auth_tracker: AuthTracker,
3269 pending_messages: VecDeque<NautilusWsMessage>,
3270 subscriptions_state: SubscriptionState,
3271}
3272
3273impl OKXWsMessageHandler {
3274 fn schedule_text_pong(&self) {
3275 let inner = self.inner.clone();
3276 get_runtime().spawn(async move {
3277 let guard = inner.read().await;
3278
3279 if let Some(client) = guard.as_ref() {
3280 if let Err(err) = client.send_text(TEXT_PONG.to_string(), None).await {
3281 tracing::warn!(error = %err, "Failed to send pong response to OKX text ping");
3282 } else {
3283 tracing::trace!("Sent pong response to OKX text ping");
3284 }
3285 } else {
3286 tracing::debug!("Received text ping with no active websocket client");
3287 }
3288 });
3289 }
3290
3291 fn try_handle_post_only_auto_cancel(
3292 &mut self,
3293 msg: &OKXOrderMsg,
3294 ts_init: UnixNanos,
3295 exec_reports: &mut Vec<ExecutionReport>,
3296 ) -> bool {
3297 if !Self::is_post_only_auto_cancel(msg) {
3298 return false;
3299 }
3300
3301 let Some(client_order_id) = parse_client_order_id(&msg.cl_ord_id) else {
3302 return false;
3303 };
3304
3305 let Some((_, (trader_id, strategy_id, instrument_id))) =
3306 self.active_client_orders.remove(&client_order_id)
3307 else {
3308 return false;
3309 };
3310
3311 self.client_id_aliases.remove(&client_order_id);
3312
3313 if !exec_reports.is_empty() {
3314 let reports = std::mem::take(exec_reports);
3315 self.pending_messages
3316 .push_back(NautilusWsMessage::ExecutionReports(reports));
3317 }
3318
3319 let reason = msg
3320 .cancel_source_reason
3321 .as_ref()
3322 .filter(|reason| !reason.is_empty())
3323 .map(|reason| Ustr::from(reason.as_str()))
3324 .unwrap_or_else(|| Ustr::from(OKX_POST_ONLY_CANCEL_REASON));
3325
3326 let ts_event = parse_millisecond_timestamp(msg.u_time);
3327 let rejected = OrderRejected::new(
3328 trader_id,
3329 strategy_id,
3330 instrument_id,
3331 client_order_id,
3332 self.account_id,
3333 reason,
3334 UUID4::new(),
3335 ts_event,
3336 ts_init,
3337 false,
3338 true,
3339 );
3340
3341 self.pending_messages
3342 .push_back(NautilusWsMessage::OrderRejected(rejected));
3343
3344 true
3345 }
3346
3347 fn is_post_only_auto_cancel(msg: &OKXOrderMsg) -> bool {
3348 if msg.state != OKXOrderStatus::Canceled {
3349 return false;
3350 }
3351
3352 let cancel_source_matches = matches!(
3353 msg.cancel_source.as_deref(),
3354 Some(source) if source == OKX_POST_ONLY_CANCEL_SOURCE
3355 );
3356
3357 let reason_matches = matches!(
3358 msg.cancel_source_reason.as_deref(),
3359 Some(reason) if reason.contains("POST_ONLY")
3360 );
3361
3362 if !(cancel_source_matches || reason_matches) {
3363 return false;
3364 }
3365
3366 msg.acc_fill_sz
3367 .as_ref()
3368 .map(|filled| filled == "0" || filled.is_empty())
3369 .unwrap_or(true)
3370 }
3371
3372 fn register_client_order_aliases(
3373 &self,
3374 raw_child: &Option<ClientOrderId>,
3375 parent_from_msg: &Option<ClientOrderId>,
3376 ) -> Option<ClientOrderId> {
3377 if let Some(parent) = parent_from_msg {
3378 self.client_id_aliases.insert(*parent, *parent);
3379 if let Some(child) = raw_child.as_ref().filter(|child| **child != *parent) {
3380 self.client_id_aliases.insert(*child, *parent);
3381 }
3382 Some(*parent)
3383 } else if let Some(child) = raw_child.as_ref() {
3384 if let Some(mapped) = self.client_id_aliases.get(child) {
3385 Some(*mapped.value())
3386 } else {
3387 self.client_id_aliases.insert(*child, *child);
3388 Some(*child)
3389 }
3390 } else {
3391 None
3392 }
3393 }
3394
3395 fn adjust_execution_report(
3396 &self,
3397 report: ExecutionReport,
3398 effective_client_id: &Option<ClientOrderId>,
3399 raw_child: &Option<ClientOrderId>,
3400 ) -> ExecutionReport {
3401 match report {
3402 ExecutionReport::Order(status_report) => {
3403 let mut adjusted = status_report;
3404 let mut final_id = *effective_client_id;
3405
3406 if final_id.is_none() {
3407 final_id = adjusted.client_order_id;
3408 }
3409
3410 if final_id.is_none()
3411 && let Some(child) = raw_child.as_ref()
3412 && let Some(mapped) = self.client_id_aliases.get(child)
3413 {
3414 final_id = Some(*mapped.value());
3415 }
3416
3417 if let Some(final_id_value) = final_id {
3418 if adjusted.client_order_id != Some(final_id_value) {
3419 adjusted = adjusted.with_client_order_id(final_id_value);
3420 }
3421 self.client_id_aliases
3422 .insert(final_id_value, final_id_value);
3423
3424 if let Some(child) =
3425 raw_child.as_ref().filter(|child| **child != final_id_value)
3426 {
3427 adjusted = adjusted.with_linked_order_ids(vec![*child]);
3428 }
3429 }
3430
3431 ExecutionReport::Order(adjusted)
3432 }
3433 ExecutionReport::Fill(mut fill_report) => {
3434 let mut final_id = *effective_client_id;
3435 if final_id.is_none() {
3436 final_id = fill_report.client_order_id;
3437 }
3438 if final_id.is_none()
3439 && let Some(child) = raw_child.as_ref()
3440 && let Some(mapped) = self.client_id_aliases.get(child)
3441 {
3442 final_id = Some(*mapped.value());
3443 }
3444
3445 if let Some(final_id_value) = final_id {
3446 fill_report.client_order_id = Some(final_id_value);
3447 self.client_id_aliases
3448 .insert(final_id_value, final_id_value);
3449 }
3450
3451 ExecutionReport::Fill(fill_report)
3452 }
3453 }
3454 }
3455
3456 fn update_caches_with_report(&mut self, report: &ExecutionReport) {
3457 match report {
3458 ExecutionReport::Fill(fill_report) => {
3459 let order_id = fill_report.venue_order_id.inner();
3460 let current_fee = self
3461 .fee_cache
3462 .get(&order_id)
3463 .copied()
3464 .unwrap_or_else(|| Money::new(0.0, fill_report.commission.currency));
3465 let total_fee = current_fee + fill_report.commission;
3466 self.fee_cache.insert(order_id, total_fee);
3467 }
3468 ExecutionReport::Order(status_report) => {
3469 if matches!(status_report.order_status, OrderStatus::Filled) {
3470 self.fee_cache.remove(&status_report.venue_order_id.inner());
3471 }
3472
3473 if matches!(
3474 status_report.order_status,
3475 OrderStatus::Canceled
3476 | OrderStatus::Expired
3477 | OrderStatus::Filled
3478 | OrderStatus::Rejected,
3479 ) {
3480 if let Some(client_order_id) = status_report.client_order_id {
3481 self.active_client_orders.remove(&client_order_id);
3482 self.client_id_aliases.remove(&client_order_id);
3483 }
3484 if let Some(linked) = &status_report.linked_order_ids {
3485 for child in linked {
3486 self.client_id_aliases.remove(child);
3487 }
3488 }
3489 }
3490 }
3491 }
3492 }
3493
3494 #[allow(clippy::too_many_arguments)]
3496 pub fn new(
3497 account_id: AccountId,
3498 instruments_cache: Arc<AHashMap<Ustr, InstrumentAny>>,
3499 reader: UnboundedReceiver<Message>,
3500 signal: Arc<AtomicBool>,
3501 inner: Arc<tokio::sync::RwLock<Option<WebSocketClient>>>,
3502 tx: tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>,
3503 pending_place_requests: Arc<DashMap<String, PlaceRequestData>>,
3504 pending_cancel_requests: Arc<DashMap<String, CancelRequestData>>,
3505 pending_amend_requests: Arc<DashMap<String, AmendRequestData>>,
3506 pending_mass_cancel_requests: Arc<DashMap<String, MassCancelRequestData>>,
3507 active_client_orders: Arc<DashMap<ClientOrderId, (TraderId, StrategyId, InstrumentId)>>,
3508 client_id_aliases: Arc<DashMap<ClientOrderId, ClientOrderId>>,
3509 auth_tracker: AuthTracker,
3510 subscriptions_state: SubscriptionState,
3511 ) -> Self {
3512 Self {
3513 account_id,
3514 inner,
3515 handler: OKXFeedHandler::new(reader, signal),
3516 tx,
3517 pending_place_requests,
3518 pending_cancel_requests,
3519 pending_amend_requests,
3520 pending_mass_cancel_requests,
3521 active_client_orders,
3522 client_id_aliases,
3523 instruments_cache,
3524 last_account_state: None,
3525 fee_cache: AHashMap::new(),
3526 funding_rate_cache: AHashMap::new(),
3527 auth_tracker,
3528 pending_messages: VecDeque::new(),
3529 subscriptions_state,
3530 }
3531 }
3532
3533 fn is_stopped(&self) -> bool {
3534 self.handler
3535 .signal
3536 .load(std::sync::atomic::Ordering::Relaxed)
3537 }
3538
3539 #[allow(dead_code)]
3540 async fn run(&mut self) {
3541 while let Some(data) = self.next().await {
3542 if let Err(e) = self.tx.send(data) {
3543 tracing::error!("Error sending data: {e}");
3544 break; }
3546 }
3547 }
3548
3549 async fn next(&mut self) -> Option<NautilusWsMessage> {
3550 if let Some(message) = self.pending_messages.pop_front() {
3551 return Some(message);
3552 }
3553
3554 let clock = get_atomic_clock_realtime();
3555
3556 while let Some(event) = self.handler.next().await {
3557 let ts_init = clock.get_time_ns();
3558
3559 match event {
3560 OKXWebSocketEvent::Ping => {
3561 self.schedule_text_pong();
3562 continue;
3563 }
3564 OKXWebSocketEvent::Login {
3565 code, msg, conn_id, ..
3566 } => {
3567 if code == "0" {
3568 self.auth_tracker.succeed();
3569 continue;
3570 }
3571
3572 tracing::error!("Authentication failed: {msg}");
3573 self.auth_tracker.fail(msg.clone());
3574
3575 let error = OKXWebSocketError {
3576 code,
3577 message: msg,
3578 conn_id: Some(conn_id),
3579 timestamp: clock.get_time_ns().as_u64(),
3580 };
3581 self.pending_messages
3582 .push_back(NautilusWsMessage::Error(error));
3583 continue;
3584 }
3585 OKXWebSocketEvent::BookData { arg, action, data } => {
3586 let Some(inst_id) = arg.inst_id else {
3587 tracing::error!("Instrument ID missing for book data event");
3588 continue;
3589 };
3590
3591 let Some(inst) = self.instruments_cache.get(&inst_id) else {
3592 continue;
3593 };
3594
3595 let instrument_id = inst.id();
3596 let price_precision = inst.price_precision();
3597 let size_precision = inst.size_precision();
3598
3599 match parse_book_msg_vec(
3600 data,
3601 &instrument_id,
3602 price_precision,
3603 size_precision,
3604 action,
3605 ts_init,
3606 ) {
3607 Ok(payloads) => return Some(NautilusWsMessage::Data(payloads)),
3608 Err(e) => {
3609 tracing::error!("Failed to parse book message: {e}");
3610 continue;
3611 }
3612 }
3613 }
3614 OKXWebSocketEvent::OrderResponse {
3615 id,
3616 op,
3617 code,
3618 msg,
3619 data,
3620 } => {
3621 if code == "0" {
3622 tracing::debug!(
3623 "Order operation successful: id={id:?} op={op} code={code}"
3624 );
3625
3626 if op == OKXWsOperation::MassCancel
3627 && let Some(request_id) = &id
3628 && let Some((_, instrument_id)) =
3629 self.pending_mass_cancel_requests.remove(request_id)
3630 {
3631 tracing::info!(
3632 "Mass cancel operation successful for instrument: {}",
3633 instrument_id
3634 );
3635 }
3636
3637 if let Some(first) = data.first()
3638 && let Some(success_msg) =
3639 first.get("sMsg").and_then(|value| value.as_str())
3640 {
3641 tracing::debug!("Order details: {success_msg}");
3642 }
3643
3644 continue;
3645 }
3646
3647 let error_msg = data
3648 .first()
3649 .and_then(|d| d.get("sMsg"))
3650 .and_then(|s| s.as_str())
3651 .unwrap_or(&msg)
3652 .to_string();
3653
3654 if let Some(first) = data.first() {
3655 tracing::debug!(
3656 "Error data fields: {}",
3657 serde_json::to_string_pretty(first)
3658 .unwrap_or_else(|_| "unable to serialize".to_string())
3659 );
3660 }
3661
3662 tracing::error!(
3663 "Order operation failed: id={id:?} op={op} code={code} msg={msg}"
3664 );
3665
3666 if let Some(request_id) = &id {
3667 match op {
3668 OKXWsOperation::Order => {
3669 if let Some((
3670 _,
3671 (client_order_id, trader_id, strategy_id, instrument_id),
3672 )) = self.pending_place_requests.remove(request_id)
3673 {
3674 let ts_event = clock.get_time_ns();
3675 let due_post_only =
3676 is_post_only_rejection(code.as_str(), &data);
3677 let rejected = OrderRejected::new(
3678 trader_id,
3679 strategy_id,
3680 instrument_id,
3681 client_order_id,
3682 self.account_id,
3683 Ustr::from(error_msg.as_str()),
3684 UUID4::new(),
3685 ts_event,
3686 ts_init,
3687 false, due_post_only,
3689 );
3690
3691 return Some(NautilusWsMessage::OrderRejected(rejected));
3692 }
3693 }
3694 OKXWsOperation::CancelOrder => {
3695 if let Some((
3696 _,
3697 (
3698 client_order_id,
3699 trader_id,
3700 strategy_id,
3701 instrument_id,
3702 venue_order_id,
3703 ),
3704 )) = self.pending_cancel_requests.remove(request_id)
3705 {
3706 let ts_event = clock.get_time_ns();
3707 let rejected = OrderCancelRejected::new(
3708 trader_id,
3709 strategy_id,
3710 instrument_id,
3711 client_order_id,
3712 Ustr::from(error_msg.as_str()),
3713 UUID4::new(),
3714 ts_event,
3715 ts_init,
3716 false, venue_order_id,
3718 Some(self.account_id),
3719 );
3720
3721 return Some(NautilusWsMessage::OrderCancelRejected(rejected));
3722 }
3723 }
3724 OKXWsOperation::AmendOrder => {
3725 if let Some((
3726 _,
3727 (
3728 client_order_id,
3729 trader_id,
3730 strategy_id,
3731 instrument_id,
3732 venue_order_id,
3733 ),
3734 )) = self.pending_amend_requests.remove(request_id)
3735 {
3736 let ts_event = clock.get_time_ns();
3737 let rejected = OrderModifyRejected::new(
3738 trader_id,
3739 strategy_id,
3740 instrument_id,
3741 client_order_id,
3742 Ustr::from(error_msg.as_str()),
3743 UUID4::new(),
3744 ts_event,
3745 ts_init,
3746 false, venue_order_id,
3748 Some(self.account_id),
3749 );
3750
3751 return Some(NautilusWsMessage::OrderModifyRejected(rejected));
3752 }
3753 }
3754 OKXWsOperation::MassCancel => {
3755 if let Some((_, instrument_id)) =
3756 self.pending_mass_cancel_requests.remove(request_id)
3757 {
3758 tracing::error!(
3759 "Mass cancel operation failed for {}: code={code} msg={error_msg}",
3760 instrument_id
3761 );
3762 let error = OKXWebSocketError {
3763 code,
3764 message: format!(
3765 "Mass cancel failed for {}: {}",
3766 instrument_id, error_msg
3767 ),
3768 conn_id: None,
3769 timestamp: clock.get_time_ns().as_u64(),
3770 };
3771 return Some(NautilusWsMessage::Error(error));
3772 } else {
3773 tracing::error!(
3774 "Mass cancel operation failed: code={code} msg={error_msg}"
3775 );
3776 }
3777 }
3778 _ => tracing::warn!("Unhandled operation type for rejection: {op}"),
3779 }
3780 }
3781
3782 let error = OKXWebSocketError {
3783 code,
3784 message: error_msg,
3785 conn_id: None,
3786 timestamp: clock.get_time_ns().as_u64(),
3787 };
3788 return Some(NautilusWsMessage::Error(error));
3789 }
3790 OKXWebSocketEvent::Data { arg, data } => {
3791 let OKXWebSocketArg {
3792 channel, inst_id, ..
3793 } = arg;
3794
3795 match channel {
3796 OKXWsChannel::Account => {
3797 match serde_json::from_value::<Vec<OKXAccount>>(data) {
3798 Ok(accounts) => {
3799 if let Some(account) = accounts.first() {
3800 match parse_account_state(account, self.account_id, ts_init)
3801 {
3802 Ok(account_state) => {
3803 if let Some(last_account_state) =
3804 &self.last_account_state
3805 && account_state.has_same_balances_and_margins(
3806 last_account_state,
3807 )
3808 {
3809 continue;
3810 }
3811 self.last_account_state =
3812 Some(account_state.clone());
3813 return Some(NautilusWsMessage::AccountUpdate(
3814 account_state,
3815 ));
3816 }
3817 Err(e) => tracing::error!(
3818 "Failed to parse account state: {e}"
3819 ),
3820 }
3821 }
3822 }
3823 Err(e) => tracing::error!("Failed to parse account data: {e}"),
3824 }
3825 continue;
3826 }
3827 OKXWsChannel::Orders => {
3828 let orders: Vec<OKXOrderMsg> = match serde_json::from_value(data) {
3829 Ok(orders) => orders,
3830 Err(e) => {
3831 tracing::error!(
3832 "Failed to deserialize orders channel payload: {e}"
3833 );
3834 continue;
3835 }
3836 };
3837
3838 tracing::debug!(
3839 "Received {} order message(s) from orders channel",
3840 orders.len()
3841 );
3842
3843 let mut exec_reports: Vec<ExecutionReport> =
3844 Vec::with_capacity(orders.len());
3845
3846 for msg in orders {
3847 tracing::debug!(
3848 "Processing order message: inst_id={}, cl_ord_id={}, state={:?}, exec_type={:?}",
3849 msg.inst_id,
3850 msg.cl_ord_id,
3851 msg.state,
3852 msg.exec_type
3853 );
3854
3855 if self.try_handle_post_only_auto_cancel(
3856 &msg,
3857 ts_init,
3858 &mut exec_reports,
3859 ) {
3860 continue;
3861 }
3862
3863 let raw_child = parse_client_order_id(&msg.cl_ord_id);
3864 let parent_from_msg = msg
3865 .algo_cl_ord_id
3866 .as_ref()
3867 .filter(|value| !value.is_empty())
3868 .map(ClientOrderId::new);
3869 let effective_client_id = self
3870 .register_client_order_aliases(&raw_child, &parent_from_msg);
3871
3872 match parse_order_msg(
3873 &msg,
3874 self.account_id,
3875 &self.instruments_cache,
3876 &self.fee_cache,
3877 ts_init,
3878 ) {
3879 Ok(report) => {
3880 tracing::debug!(
3881 "Successfully parsed execution report: {:?}",
3882 report
3883 );
3884 let adjusted = self.adjust_execution_report(
3885 report,
3886 &effective_client_id,
3887 &raw_child,
3888 );
3889 self.update_caches_with_report(&adjusted);
3890 exec_reports.push(adjusted);
3891 }
3892 Err(e) => tracing::error!("Failed to parse order message: {e}"),
3893 }
3894 }
3895
3896 if !exec_reports.is_empty() {
3897 tracing::debug!(
3898 "Pushing {} execution report(s) to message queue",
3899 exec_reports.len()
3900 );
3901 self.pending_messages
3902 .push_back(NautilusWsMessage::ExecutionReports(exec_reports));
3903 } else {
3904 tracing::debug!(
3905 "No execution reports generated from order messages"
3906 );
3907 }
3908
3909 if let Some(message) = self.pending_messages.pop_front() {
3910 return Some(message);
3911 }
3912
3913 continue;
3914 }
3915 OKXWsChannel::OrdersAlgo => {
3916 let orders: Vec<OKXAlgoOrderMsg> = match serde_json::from_value(data) {
3917 Ok(orders) => orders,
3918 Err(e) => {
3919 tracing::error!(
3920 "Failed to deserialize algo orders payload: {e}"
3921 );
3922 continue;
3923 }
3924 };
3925
3926 let mut exec_reports: Vec<ExecutionReport> =
3927 Vec::with_capacity(orders.len());
3928
3929 for msg in orders {
3930 let raw_child = parse_client_order_id(&msg.cl_ord_id);
3931 let parent_from_msg = parse_client_order_id(&msg.algo_cl_ord_id);
3932 let effective_client_id = self
3933 .register_client_order_aliases(&raw_child, &parent_from_msg);
3934
3935 match parse_algo_order_msg(
3936 msg,
3937 self.account_id,
3938 &self.instruments_cache,
3939 ts_init,
3940 ) {
3941 Ok(report) => {
3942 let adjusted = self.adjust_execution_report(
3943 report,
3944 &effective_client_id,
3945 &raw_child,
3946 );
3947 self.update_caches_with_report(&adjusted);
3948 exec_reports.push(adjusted);
3949 }
3950 Err(e) => {
3951 tracing::error!("Failed to parse algo order message: {e}")
3952 }
3953 }
3954 }
3955
3956 if !exec_reports.is_empty() {
3957 return Some(NautilusWsMessage::ExecutionReports(exec_reports));
3958 }
3959
3960 continue;
3961 }
3962 _ => {
3963 let Some(inst_id) = inst_id else {
3964 tracing::error!("No instrument for channel {:?}", channel);
3965 continue;
3966 };
3967
3968 let Some(instrument) = self.instruments_cache.get(&inst_id) else {
3969 tracing::error!(
3970 "No instrument for channel {:?}, inst_id {:?}",
3971 channel,
3972 inst_id
3973 );
3974 continue;
3975 };
3976
3977 let instrument_id = instrument.id();
3978 let price_precision = instrument.price_precision();
3979 let size_precision = instrument.size_precision();
3980
3981 match parse_ws_message_data(
3982 &channel,
3983 data,
3984 &instrument_id,
3985 price_precision,
3986 size_precision,
3987 ts_init,
3988 &mut self.funding_rate_cache,
3989 ) {
3990 Ok(Some(msg)) => return Some(msg),
3991 Ok(None) => continue,
3992 Err(e) => {
3993 tracing::error!(
3994 "Error parsing message for channel {:?}: {e}",
3995 channel
3996 );
3997 continue;
3998 }
3999 }
4000 }
4001 }
4002 }
4003 OKXWebSocketEvent::Error { code, msg } => {
4004 let error = OKXWebSocketError {
4005 code,
4006 message: msg,
4007 conn_id: None,
4008 timestamp: clock.get_time_ns().as_u64(),
4009 };
4010 return Some(NautilusWsMessage::Error(error));
4011 }
4012 OKXWebSocketEvent::Reconnected => {
4013 return Some(NautilusWsMessage::Reconnected);
4014 }
4015 OKXWebSocketEvent::Subscription {
4016 event,
4017 arg,
4018 code,
4019 msg,
4020 ..
4021 } => {
4022 let topic = topic_from_websocket_arg(&arg);
4023 let success = code.as_deref().map(|c| c == "0").unwrap_or(true);
4024
4025 match event {
4026 OKXSubscriptionEvent::Subscribe => {
4027 if success {
4028 self.subscriptions_state.confirm(&topic);
4029 } else {
4030 tracing::warn!(?topic, error = ?msg, code = ?code, "Subscription failed");
4031 self.subscriptions_state.mark_failure(&topic);
4032 }
4033 }
4034 OKXSubscriptionEvent::Unsubscribe => {
4035 if success {
4036 self.subscriptions_state.clear_pending(&topic);
4037 } else {
4038 tracing::warn!(?topic, error = ?msg, code = ?code, "Unsubscription failed");
4039 self.subscriptions_state.mark_failure(&topic);
4040 }
4041 }
4042 }
4043
4044 continue;
4045 }
4046 OKXWebSocketEvent::ChannelConnCount { .. } => continue,
4047 }
4048 }
4049
4050 None
4051 }
4052}
4053
4054pub fn is_post_only_rejection(code: &str, data: &[Value]) -> bool {
4056 if code == OKX_POST_ONLY_ERROR_CODE {
4057 return true;
4058 }
4059
4060 for entry in data {
4061 if let Some(s_code) = entry.get("sCode").and_then(|value| value.as_str())
4062 && s_code == OKX_POST_ONLY_ERROR_CODE
4063 {
4064 return true;
4065 }
4066
4067 if let Some(inner_code) = entry.get("code").and_then(|value| value.as_str())
4068 && inner_code == OKX_POST_ONLY_ERROR_CODE
4069 {
4070 return true;
4071 }
4072 }
4073
4074 false
4075}
4076
4077#[cfg(test)]
4082mod tests {
4083 use futures_util;
4084 use rstest::rstest;
4085
4086 use super::*;
4087 use crate::common::enums::{OKXExecType, OKXSide};
4088
4089 #[rstest]
4090 fn test_timestamp_format_for_websocket_auth() {
4091 let timestamp = SystemTime::now()
4092 .duration_since(SystemTime::UNIX_EPOCH)
4093 .expect("System time should be after UNIX epoch")
4094 .as_secs()
4095 .to_string();
4096
4097 assert!(timestamp.parse::<u64>().is_ok());
4098 assert_eq!(timestamp.len(), 10);
4099 assert!(timestamp.chars().all(|c| c.is_ascii_digit()));
4100 }
4101
4102 #[rstest]
4103 fn test_new_without_credentials() {
4104 let client = OKXWebSocketClient::default();
4105 assert!(client.credential.is_none());
4106 assert_eq!(client.api_key(), None);
4107 }
4108
4109 #[rstest]
4110 fn test_new_with_credentials() {
4111 let client = OKXWebSocketClient::new(
4112 None,
4113 Some("test_key".to_string()),
4114 Some("test_secret".to_string()),
4115 Some("test_passphrase".to_string()),
4116 None,
4117 None,
4118 )
4119 .unwrap();
4120 assert!(client.credential.is_some());
4121 assert_eq!(client.api_key(), Some("test_key"));
4122 }
4123
4124 #[rstest]
4125 fn test_new_partial_credentials_fails() {
4126 let result = OKXWebSocketClient::new(
4127 None,
4128 Some("test_key".to_string()),
4129 None,
4130 Some("test_passphrase".to_string()),
4131 None,
4132 None,
4133 );
4134 assert!(result.is_err());
4135 }
4136
4137 #[rstest]
4138 fn test_request_id_generation() {
4139 let client = OKXWebSocketClient::default();
4140
4141 let initial_counter = client.request_id_counter.load(Ordering::SeqCst);
4142
4143 let id1 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
4144 let id2 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
4145
4146 assert_eq!(id1, initial_counter);
4147 assert_eq!(id2, initial_counter + 1);
4148 assert_eq!(
4149 client.request_id_counter.load(Ordering::SeqCst),
4150 initial_counter + 2
4151 );
4152 }
4153
4154 #[rstest]
4155 fn test_client_state_management() {
4156 let client = OKXWebSocketClient::default();
4157
4158 assert!(client.is_closed());
4159 assert!(!client.is_active());
4160
4161 let client_with_heartbeat =
4162 OKXWebSocketClient::new(None, None, None, None, None, Some(30)).unwrap();
4163
4164 assert!(client_with_heartbeat.heartbeat.is_some());
4165 assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
4166 }
4167
4168 #[rstest]
4169 fn test_request_cache_operations() {
4170 let client = OKXWebSocketClient::default();
4171
4172 assert_eq!(client.pending_place_requests.len(), 0);
4173 assert_eq!(client.pending_cancel_requests.len(), 0);
4174 assert_eq!(client.pending_amend_requests.len(), 0);
4175
4176 let client_order_id = ClientOrderId::from("test-order-123");
4177 let trader_id = TraderId::from("test-trader-001");
4178 let strategy_id = StrategyId::from("test-strategy-001");
4179 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
4180
4181 client.pending_place_requests.insert(
4182 "place-123".to_string(),
4183 (client_order_id, trader_id, strategy_id, instrument_id),
4184 );
4185
4186 assert_eq!(client.pending_place_requests.len(), 1);
4187 assert!(client.pending_place_requests.contains_key("place-123"));
4188
4189 let removed = client.pending_place_requests.remove("place-123");
4190 assert!(removed.is_some());
4191 assert_eq!(client.pending_place_requests.len(), 0);
4192 }
4193
4194 #[rstest]
4195 fn test_websocket_error_handling() {
4196 let clock = get_atomic_clock_realtime();
4197 let ts = clock.get_time_ns().as_u64();
4198
4199 let error = OKXWebSocketError {
4200 code: "60012".to_string(),
4201 message: "Invalid request".to_string(),
4202 conn_id: None,
4203 timestamp: ts,
4204 };
4205
4206 assert_eq!(error.code, "60012");
4207 assert_eq!(error.message, "Invalid request");
4208 assert_eq!(error.timestamp, ts);
4209
4210 let nautilus_msg = NautilusWsMessage::Error(error);
4211 match nautilus_msg {
4212 NautilusWsMessage::Error(err) => {
4213 assert_eq!(err.code, "60012");
4214 assert_eq!(err.message, "Invalid request");
4215 }
4216 _ => panic!("Expected Error variant"),
4217 }
4218 }
4219
4220 #[rstest]
4221 fn test_request_id_generation_sequence() {
4222 let client = OKXWebSocketClient::default();
4223
4224 let initial_counter = client
4225 .request_id_counter
4226 .load(std::sync::atomic::Ordering::SeqCst);
4227 let mut ids = Vec::new();
4228 for _ in 0..10 {
4229 let id = client
4230 .request_id_counter
4231 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
4232 ids.push(id);
4233 }
4234
4235 for (i, &id) in ids.iter().enumerate() {
4236 assert_eq!(id, initial_counter + i as u64);
4237 }
4238
4239 assert_eq!(
4240 client
4241 .request_id_counter
4242 .load(std::sync::atomic::Ordering::SeqCst),
4243 initial_counter + 10
4244 );
4245 }
4246
4247 #[rstest]
4248 fn test_client_state_transitions() {
4249 let client = OKXWebSocketClient::default();
4250
4251 assert!(client.is_closed());
4252 assert!(!client.is_active());
4253
4254 let client_with_heartbeat = OKXWebSocketClient::new(
4255 None,
4256 None,
4257 None,
4258 None,
4259 None,
4260 Some(30), )
4262 .unwrap();
4263
4264 assert!(client_with_heartbeat.heartbeat.is_some());
4265 assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
4266
4267 let account_id = AccountId::from("test-account-123");
4268 let client_with_account =
4269 OKXWebSocketClient::new(None, None, None, None, Some(account_id), None).unwrap();
4270
4271 assert_eq!(client_with_account.account_id, account_id);
4272 }
4273
4274 #[tokio::test]
4275 async fn test_concurrent_request_handling() {
4276 let client = Arc::new(OKXWebSocketClient::default());
4277
4278 let initial_counter = client
4279 .request_id_counter
4280 .load(std::sync::atomic::Ordering::SeqCst);
4281 let mut handles = Vec::new();
4282
4283 for i in 0..10 {
4284 let client_clone = Arc::clone(&client);
4285 let handle = tokio::spawn(async move {
4286 let client_order_id = ClientOrderId::from(format!("order-{i}").as_str());
4287 let trader_id = TraderId::from("trader-001");
4288 let strategy_id = StrategyId::from("strategy-001");
4289 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
4290
4291 let request_id = client_clone
4292 .request_id_counter
4293 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
4294 let request_id_str = request_id.to_string();
4295
4296 client_clone.pending_place_requests.insert(
4297 request_id_str.clone(),
4298 (client_order_id, trader_id, strategy_id, instrument_id),
4299 );
4300
4301 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
4303
4304 let removed = client_clone.pending_place_requests.remove(&request_id_str);
4306 assert!(removed.is_some());
4307
4308 request_id
4309 });
4310 handles.push(handle);
4311 }
4312
4313 let results: Vec<_> = futures_util::future::join_all(handles).await;
4315
4316 assert_eq!(results.len(), 10);
4317 for result in results {
4318 assert!(result.is_ok());
4319 }
4320
4321 assert_eq!(client.pending_place_requests.len(), 0);
4322
4323 let final_counter = client
4324 .request_id_counter
4325 .load(std::sync::atomic::Ordering::SeqCst);
4326 assert_eq!(final_counter, initial_counter + 10);
4327 }
4328
4329 #[rstest]
4330 fn test_websocket_error_scenarios() {
4331 let clock = get_atomic_clock_realtime();
4332 let ts = clock.get_time_ns().as_u64();
4333
4334 let error_scenarios = vec![
4335 ("60012", "Invalid request", None),
4336 ("60009", "Invalid API key", Some("conn-123".to_string())),
4337 ("60014", "Too many requests", None),
4338 ("50001", "Order not found", None),
4339 ];
4340
4341 for (code, message, conn_id) in error_scenarios {
4342 let error = OKXWebSocketError {
4343 code: code.to_string(),
4344 message: message.to_string(),
4345 conn_id: conn_id.clone(),
4346 timestamp: ts,
4347 };
4348
4349 assert_eq!(error.code, code);
4350 assert_eq!(error.message, message);
4351 assert_eq!(error.conn_id, conn_id);
4352 assert_eq!(error.timestamp, ts);
4353
4354 let nautilus_msg = NautilusWsMessage::Error(error);
4355 match nautilus_msg {
4356 NautilusWsMessage::Error(err) => {
4357 assert_eq!(err.code, code);
4358 assert_eq!(err.message, message);
4359 assert_eq!(err.conn_id, conn_id);
4360 }
4361 _ => panic!("Expected Error variant"),
4362 }
4363 }
4364 }
4365
4366 #[tokio::test]
4367 async fn test_feed_handler_reconnection_detection() {
4368 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
4369 let signal = Arc::new(AtomicBool::new(false));
4370 let mut handler = OKXFeedHandler::new(rx, signal.clone());
4371
4372 tx.send(Message::Text(RECONNECTED.to_string().into()))
4373 .unwrap();
4374
4375 let result = handler.next().await;
4376 assert!(matches!(result, Some(OKXWebSocketEvent::Reconnected)));
4377 }
4378
4379 #[tokio::test]
4380 async fn test_feed_handler_normal_message_processing() {
4381 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
4382 let signal = Arc::new(AtomicBool::new(false));
4383 let mut handler = OKXFeedHandler::new(rx, signal.clone());
4384
4385 let ping_msg = TEXT_PING;
4387 tx.send(Message::Text(ping_msg.to_string().into())).unwrap();
4388
4389 let sub_msg = r#"{
4391 "event": "subscribe",
4392 "arg": {
4393 "channel": "tickers",
4394 "instType": "SPOT"
4395 },
4396 "connId": "a4d3ae55"
4397 }"#;
4398
4399 tx.send(Message::Text(sub_msg.to_string().into())).unwrap();
4400
4401 let first = handler.next().await;
4402 assert!(matches!(first, Some(OKXWebSocketEvent::Ping)));
4403
4404 signal.store(true, std::sync::atomic::Ordering::Relaxed);
4406
4407 let result = handler.next().await;
4408 assert!(result.is_none());
4409 }
4410
4411 #[tokio::test]
4412 async fn test_feed_handler_stop_signal() {
4413 let (_tx, rx) = tokio::sync::mpsc::unbounded_channel();
4414 let signal = Arc::new(AtomicBool::new(true)); let mut handler = OKXFeedHandler::new(rx, signal.clone());
4416
4417 let result = handler.next().await;
4418 assert!(result.is_none());
4419 }
4420
4421 #[tokio::test]
4422 async fn test_feed_handler_close_message() {
4423 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
4424 let signal = Arc::new(AtomicBool::new(false));
4425 let mut handler = OKXFeedHandler::new(rx, signal.clone());
4426
4427 tx.send(Message::Close(None)).unwrap();
4429
4430 let result = handler.next().await;
4431 assert!(result.is_none());
4432 }
4433
4434 #[tokio::test]
4435 async fn test_reconnection_message_constant() {
4436 assert_eq!(RECONNECTED, "__RECONNECTED__");
4437 }
4438
4439 #[tokio::test]
4440 async fn test_multiple_reconnection_signals() {
4441 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
4442 let signal = Arc::new(AtomicBool::new(false));
4443 let mut handler = OKXFeedHandler::new(rx, signal.clone());
4444
4445 for _ in 0..3 {
4447 tx.send(Message::Text(RECONNECTED.to_string().into()))
4448 .unwrap();
4449
4450 let result = handler.next().await;
4451 assert!(matches!(result, Some(OKXWebSocketEvent::Reconnected)));
4452 }
4453 }
4454
4455 #[tokio::test]
4456 async fn test_wait_until_active_timeout() {
4457 let client = OKXWebSocketClient::new(
4458 None,
4459 Some("test_key".to_string()),
4460 Some("test_secret".to_string()),
4461 Some("test_passphrase".to_string()),
4462 Some(AccountId::from("test-account")),
4463 None,
4464 )
4465 .unwrap();
4466
4467 let result = client.wait_until_active(0.1).await;
4469
4470 assert!(result.is_err());
4471 assert!(!client.is_active());
4472 }
4473
4474 fn sample_canceled_order_msg() -> OKXOrderMsg {
4475 OKXOrderMsg {
4476 acc_fill_sz: Some("0".to_string()),
4477 avg_px: "0".to_string(),
4478 c_time: 0,
4479 cancel_source: None,
4480 cancel_source_reason: None,
4481 category: ustr::Ustr::from("normal"),
4482 ccy: ustr::Ustr::from("USDT"),
4483 cl_ord_id: "order-1".to_string(),
4484 algo_cl_ord_id: None,
4485 fee: None,
4486 fee_ccy: ustr::Ustr::from("USDT"),
4487 fill_px: "0".to_string(),
4488 fill_sz: "0".to_string(),
4489 fill_time: 0,
4490 inst_id: ustr::Ustr::from("ETH-USDT-SWAP"),
4491 inst_type: OKXInstrumentType::Swap,
4492 lever: "1".to_string(),
4493 ord_id: ustr::Ustr::from("123456"),
4494 ord_type: OKXOrderType::Limit,
4495 pnl: "0".to_string(),
4496 pos_side: OKXPositionSide::Net,
4497 px: "0".to_string(),
4498 reduce_only: "false".to_string(),
4499 side: OKXSide::Buy,
4500 state: OKXOrderStatus::Canceled,
4501 exec_type: OKXExecType::None,
4502 sz: "1".to_string(),
4503 td_mode: OKXTradeMode::Cross,
4504 trade_id: String::new(),
4505 u_time: 0,
4506 }
4507 }
4508
4509 #[rstest]
4510 fn test_is_post_only_auto_cancel_detects_cancel_source() {
4511 let mut msg = sample_canceled_order_msg();
4512 msg.cancel_source = Some(super::OKX_POST_ONLY_CANCEL_SOURCE.to_string());
4513
4514 assert!(OKXWsMessageHandler::is_post_only_auto_cancel(&msg));
4515 }
4516
4517 #[rstest]
4518 fn test_is_post_only_auto_cancel_detects_reason() {
4519 let mut msg = sample_canceled_order_msg();
4520 msg.cancel_source_reason = Some("POST_ONLY would take liquidity".to_string());
4521
4522 assert!(OKXWsMessageHandler::is_post_only_auto_cancel(&msg));
4523 }
4524
4525 #[rstest]
4526 fn test_is_post_only_auto_cancel_false_without_markers() {
4527 let msg = sample_canceled_order_msg();
4528
4529 assert!(!OKXWsMessageHandler::is_post_only_auto_cancel(&msg));
4530 }
4531
4532 #[rstest]
4533 fn test_is_post_only_auto_cancel_false_for_order_type_only() {
4534 let mut msg = sample_canceled_order_msg();
4535 msg.ord_type = OKXOrderType::PostOnly;
4536
4537 assert!(!OKXWsMessageHandler::is_post_only_auto_cancel(&msg));
4538 }
4539
4540 #[rstest]
4541 fn test_is_post_only_rejection_detects_by_code() {
4542 assert!(super::is_post_only_rejection("51019", &[]));
4543 }
4544
4545 #[rstest]
4546 fn test_is_post_only_rejection_detects_by_inner_code() {
4547 let data = vec![serde_json::json!({
4548 "sCode": "51019"
4549 })];
4550 assert!(super::is_post_only_rejection("50000", &data));
4551 }
4552
4553 #[rstest]
4554 fn test_is_post_only_rejection_false_for_unrelated_error() {
4555 let data = vec![serde_json::json!({
4556 "sMsg": "Insufficient balance"
4557 })];
4558 assert!(!super::is_post_only_rejection("50000", &data));
4559 }
4560}