1use std::{
26 fmt::Debug,
27 num::NonZeroU32,
28 sync::{
29 Arc, LazyLock,
30 atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering},
31 },
32 time::{Duration, SystemTime},
33};
34
35use ahash::AHashSet;
36use arc_swap::ArcSwap;
37use dashmap::DashMap;
38use futures_util::Stream;
39use nautilus_common::runtime::get_runtime;
40use nautilus_core::{
41 consts::NAUTILUS_USER_AGENT,
42 env::{get_env_var, get_or_env_var},
43};
44use nautilus_model::{
45 data::BarType,
46 enums::{OrderSide, OrderType, PositionSide, TimeInForce, TriggerType},
47 identifiers::{AccountId, ClientOrderId, InstrumentId, StrategyId, TraderId, VenueOrderId},
48 instruments::{Instrument, InstrumentAny},
49 types::{Price, Quantity},
50};
51use nautilus_network::{
52 mode::ConnectionMode,
53 ratelimiter::quota::Quota,
54 websocket::{
55 AUTHENTICATION_TIMEOUT_SECS, AuthTracker, PingHandler, SubscriptionState, TEXT_PING,
56 WebSocketClient, WebSocketConfig, channel_message_handler,
57 },
58};
59use reqwest::header::USER_AGENT;
60use serde_json::Value;
61use tokio_tungstenite::tungstenite::Error;
62use tokio_util::sync::CancellationToken;
63use ustr::Ustr;
64
65use super::{
66 enums::OKXWsChannel,
67 error::OKXWsError,
68 handler::{HandlerCommand, OKXWsFeedHandler},
69 messages::{
70 NautilusWsMessage, OKXAuthentication, OKXAuthenticationArg, OKXSubscriptionArg,
71 WsAmendOrderParamsBuilder, WsCancelOrderParamsBuilder, WsPostAlgoOrderParamsBuilder,
72 WsPostOrderParamsBuilder,
73 },
74 subscription::topic_from_subscription_arg,
75};
76use crate::common::{
77 consts::{
78 OKX_NAUTILUS_BROKER_ID, OKX_SUPPORTED_ORDER_TYPES, OKX_SUPPORTED_TIME_IN_FORCE,
79 OKX_WS_PUBLIC_URL, OKX_WS_TOPIC_DELIMITER,
80 },
81 credential::Credential,
82 enums::{
83 OKXInstrumentType, OKXOrderType, OKXPositionSide, OKXTargetCurrency, OKXTradeMode,
84 OKXTriggerType, OKXVipLevel, conditional_order_to_algo_type, is_conditional_order,
85 },
86 parse::{bar_spec_as_okx_channel, okx_instrument_type, okx_instrument_type_from_symbol},
87};
88
89pub static OKX_WS_CONNECTION_QUOTA: LazyLock<Quota> =
93 LazyLock::new(|| Quota::per_second(NonZeroU32::new(3).unwrap()));
94
95pub static OKX_WS_SUBSCRIPTION_QUOTA: LazyLock<Quota> =
100 LazyLock::new(|| Quota::per_hour(NonZeroU32::new(480).unwrap()));
101
102pub static OKX_WS_ORDER_QUOTA: LazyLock<Quota> =
107 LazyLock::new(|| Quota::per_second(NonZeroU32::new(250).unwrap()));
108
109pub const OKX_RATE_LIMIT_KEY_SUBSCRIPTION: &str = "subscription";
114
115pub const OKX_RATE_LIMIT_KEY_ORDER: &str = "order";
120
121pub const OKX_RATE_LIMIT_KEY_CANCEL: &str = "cancel";
127
128pub const OKX_RATE_LIMIT_KEY_AMEND: &str = "amend";
132
133#[derive(Clone)]
135#[cfg_attr(
136 feature = "python",
137 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.adapters")
138)]
139pub struct OKXWebSocketClient {
140 url: String,
141 account_id: AccountId,
142 vip_level: Arc<AtomicU8>,
143 credential: Option<Credential>,
144 heartbeat: Option<u64>,
145 auth_tracker: AuthTracker,
146 signal: Arc<AtomicBool>,
147 connection_mode: Arc<ArcSwap<AtomicU8>>,
148 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
149 out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>,
150 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
151 subscriptions_inst_type: Arc<DashMap<OKXWsChannel, AHashSet<OKXInstrumentType>>>,
152 subscriptions_inst_family: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
153 subscriptions_inst_id: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
154 subscriptions_bare: Arc<DashMap<OKXWsChannel, bool>>, subscriptions_state: SubscriptionState,
156 request_id_counter: Arc<AtomicU64>,
157 active_client_orders: Arc<DashMap<ClientOrderId, (TraderId, StrategyId, InstrumentId)>>,
158 emitted_order_accepted: Arc<DashMap<VenueOrderId, ()>>, client_id_aliases: Arc<DashMap<ClientOrderId, ClientOrderId>>,
160 instruments_cache: Arc<DashMap<Ustr, InstrumentAny>>,
161 cancellation_token: CancellationToken,
162}
163
164impl Default for OKXWebSocketClient {
165 fn default() -> Self {
166 Self::new(None, None, None, None, None, None).unwrap()
167 }
168}
169
170impl Debug for OKXWebSocketClient {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 f.debug_struct(stringify!(OKXWebSocketClient))
173 .field("url", &self.url)
174 .field(
175 "credential",
176 &self.credential.as_ref().map(|_| "<redacted>"),
177 )
178 .field("heartbeat", &self.heartbeat)
179 .finish_non_exhaustive()
180 }
181}
182
183impl OKXWebSocketClient {
184 pub fn new(
190 url: Option<String>,
191 api_key: Option<String>,
192 api_secret: Option<String>,
193 api_passphrase: Option<String>,
194 account_id: Option<AccountId>,
195 heartbeat: Option<u64>,
196 ) -> anyhow::Result<Self> {
197 let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
198 let account_id = account_id.unwrap_or(AccountId::from("OKX-master"));
199
200 let credential = match (api_key, api_secret, api_passphrase) {
201 (Some(key), Some(secret), Some(passphrase)) => {
202 Some(Credential::new(key, secret, passphrase))
203 }
204 (None, None, None) => None,
205 _ => anyhow::bail!(
206 "`api_key`, `api_secret`, `api_passphrase` credentials must be provided together"
207 ),
208 };
209
210 let signal = Arc::new(AtomicBool::new(false));
211 let subscriptions_inst_type = Arc::new(DashMap::new());
212 let subscriptions_inst_family = Arc::new(DashMap::new());
213 let subscriptions_inst_id = Arc::new(DashMap::new());
214 let subscriptions_bare = Arc::new(DashMap::new());
215 let subscriptions_state = SubscriptionState::new(OKX_WS_TOPIC_DELIMITER);
216
217 Ok(Self {
218 url,
219 account_id,
220 vip_level: Arc::new(AtomicU8::new(0)), credential,
222 heartbeat,
223 auth_tracker: AuthTracker::new(),
224 signal,
225 connection_mode: Arc::new(ArcSwap::from_pointee(AtomicU8::new(
226 ConnectionMode::Closed.as_u8(),
227 ))),
228 cmd_tx: {
229 let (tx, _) = tokio::sync::mpsc::unbounded_channel();
231 Arc::new(tokio::sync::RwLock::new(tx))
232 },
233 out_rx: None,
234 task_handle: None,
235 subscriptions_inst_type,
236 subscriptions_inst_family,
237 subscriptions_inst_id,
238 subscriptions_bare,
239 subscriptions_state,
240 request_id_counter: Arc::new(AtomicU64::new(1)),
241 active_client_orders: Arc::new(DashMap::new()),
242 emitted_order_accepted: Arc::new(DashMap::new()),
243 client_id_aliases: Arc::new(DashMap::new()),
244 instruments_cache: Arc::new(DashMap::new()),
245 cancellation_token: CancellationToken::new(),
246 })
247 }
248
249 pub fn with_credentials(
256 url: Option<String>,
257 api_key: Option<String>,
258 api_secret: Option<String>,
259 api_passphrase: Option<String>,
260 account_id: Option<AccountId>,
261 heartbeat: Option<u64>,
262 ) -> anyhow::Result<Self> {
263 let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
264 let api_key = get_or_env_var(api_key, "OKX_API_KEY")?;
265 let api_secret = get_or_env_var(api_secret, "OKX_API_SECRET")?;
266 let api_passphrase = get_or_env_var(api_passphrase, "OKX_API_PASSPHRASE")?;
267
268 Self::new(
269 Some(url),
270 Some(api_key),
271 Some(api_secret),
272 Some(api_passphrase),
273 account_id,
274 heartbeat,
275 )
276 }
277
278 pub fn from_env() -> anyhow::Result<Self> {
285 let url = get_env_var("OKX_WS_URL")?;
286 let api_key = get_env_var("OKX_API_KEY")?;
287 let api_secret = get_env_var("OKX_API_SECRET")?;
288 let api_passphrase = get_env_var("OKX_API_PASSPHRASE")?;
289
290 Self::new(
291 Some(url),
292 Some(api_key),
293 Some(api_secret),
294 Some(api_passphrase),
295 None,
296 None,
297 )
298 }
299
300 pub fn cancel_all_requests(&self) {
302 self.cancellation_token.cancel();
303 }
304
305 pub fn cancellation_token(&self) -> &CancellationToken {
307 &self.cancellation_token
308 }
309
310 pub fn url(&self) -> &str {
312 self.url.as_str()
313 }
314
315 pub fn api_key(&self) -> Option<&str> {
317 self.credential.clone().map(|c| c.api_key.as_str())
318 }
319
320 #[must_use]
322 pub fn api_key_masked(&self) -> Option<String> {
323 self.credential.clone().map(|c| c.api_key_masked())
324 }
325
326 pub fn is_active(&self) -> bool {
328 let connection_mode_arc = self.connection_mode.load();
329 ConnectionMode::from_atomic(&connection_mode_arc).is_active()
330 && !self.signal.load(Ordering::Relaxed)
331 }
332
333 pub fn is_closed(&self) -> bool {
335 let connection_mode_arc = self.connection_mode.load();
336 ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
337 || self.signal.load(Ordering::Relaxed)
338 }
339
340 pub fn cache_instruments(&self, instruments: Vec<InstrumentAny>) {
344 for inst in &instruments {
345 self.instruments_cache
346 .insert(inst.symbol().inner(), inst.clone());
347 }
348
349 if !instruments.is_empty()
352 && let Ok(cmd_tx) = self.cmd_tx.try_read()
353 && let Err(e) = cmd_tx.send(HandlerCommand::InitializeInstruments(instruments))
354 {
355 log::debug!("Failed to send bulk instrument update to handler: {e}");
356 }
357 }
358
359 pub fn cache_instrument(&self, instrument: InstrumentAny) {
363 self.instruments_cache
364 .insert(instrument.symbol().inner(), instrument.clone());
365
366 if let Ok(cmd_tx) = self.cmd_tx.try_read()
369 && let Err(e) = cmd_tx.send(HandlerCommand::UpdateInstrument(instrument))
370 {
371 log::debug!("Failed to send instrument update to handler: {e}");
372 }
373 }
374
375 pub fn set_vip_level(&self, vip_level: OKXVipLevel) {
379 self.vip_level.store(vip_level as u8, Ordering::Relaxed);
380 }
381
382 pub fn vip_level(&self) -> OKXVipLevel {
384 let level = self.vip_level.load(Ordering::Relaxed);
385 OKXVipLevel::from(level)
386 }
387
388 pub async fn connect(&mut self) -> anyhow::Result<()> {
398 let (message_handler, raw_rx) = channel_message_handler();
399
400 let ping_handler: PingHandler = Arc::new(move |_payload: Vec<u8>| {
403 });
405
406 let config = WebSocketConfig {
407 url: self.url.clone(),
408 headers: vec![(USER_AGENT.to_string(), NAUTILUS_USER_AGENT.to_string())],
409 heartbeat: self.heartbeat,
410 heartbeat_msg: Some(TEXT_PING.to_string()),
411 message_handler: Some(message_handler),
412 ping_handler: Some(ping_handler),
413 reconnect_timeout_ms: Some(5_000),
414 reconnect_delay_initial_ms: None, reconnect_delay_max_ms: None, reconnect_backoff_factor: None, reconnect_jitter_ms: None, reconnect_max_attempts: None,
419 };
420
421 let keyed_quotas = vec![
423 (
424 OKX_RATE_LIMIT_KEY_SUBSCRIPTION.to_string(),
425 *OKX_WS_SUBSCRIPTION_QUOTA,
426 ),
427 (OKX_RATE_LIMIT_KEY_ORDER.to_string(), *OKX_WS_ORDER_QUOTA),
428 (OKX_RATE_LIMIT_KEY_CANCEL.to_string(), *OKX_WS_ORDER_QUOTA),
429 (OKX_RATE_LIMIT_KEY_AMEND.to_string(), *OKX_WS_ORDER_QUOTA),
430 ];
431
432 let client = WebSocketClient::connect(
433 config,
434 None, keyed_quotas,
436 Some(*OKX_WS_CONNECTION_QUOTA), )
438 .await?;
439
440 self.connection_mode.store(client.connection_mode_atomic());
442
443 let account_id = self.account_id;
444 let (msg_tx, rx) = tokio::sync::mpsc::unbounded_channel::<NautilusWsMessage>();
445
446 self.out_rx = Some(Arc::new(rx));
447
448 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
450 *self.cmd_tx.write().await = cmd_tx.clone();
451
452 if !self.instruments_cache.is_empty() {
454 let cached_instruments: Vec<InstrumentAny> = self
455 .instruments_cache
456 .iter()
457 .map(|entry| entry.value().clone())
458 .collect();
459 if let Err(e) = cmd_tx.send(HandlerCommand::InitializeInstruments(cached_instruments)) {
460 tracing::error!("Failed to replay instruments to handler: {e}");
461 }
462 }
463
464 let signal = self.signal.clone();
465 let active_client_orders = self.active_client_orders.clone();
466 let emitted_order_accepted = self.emitted_order_accepted.clone();
467 let auth_tracker = self.auth_tracker.clone();
468 let subscriptions_state = self.subscriptions_state.clone();
469 let client_id_aliases = self.client_id_aliases.clone();
470
471 let stream_handle = get_runtime().spawn({
472 let auth_tracker = auth_tracker.clone();
473 let signal = signal.clone();
474 let credential = self.credential.clone();
475 let cmd_tx_for_reconnect = cmd_tx.clone();
476 let subscriptions_bare = self.subscriptions_bare.clone();
477 let subscriptions_inst_type = self.subscriptions_inst_type.clone();
478 let subscriptions_inst_family = self.subscriptions_inst_family.clone();
479 let subscriptions_inst_id = self.subscriptions_inst_id.clone();
480 let mut has_reconnected = false;
481
482 async move {
483 let mut handler = OKXWsFeedHandler::new(
484 account_id,
485 signal.clone(),
486 cmd_rx,
487 raw_rx,
488 msg_tx,
489 active_client_orders,
490 client_id_aliases,
491 emitted_order_accepted,
492 auth_tracker.clone(),
493 subscriptions_state.clone(),
494 );
495
496 let resubscribe_all = || {
498 for entry in subscriptions_inst_id.iter() {
499 let (channel, inst_ids) = entry.pair();
500 for inst_id in inst_ids.iter() {
501 let arg = OKXSubscriptionArg {
502 channel: channel.clone(),
503 inst_type: None,
504 inst_family: None,
505 inst_id: Some(*inst_id),
506 };
507 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
508 tracing::error!(error = %e, "Failed to send resubscribe command");
509 }
510 }
511 }
512
513 for entry in subscriptions_bare.iter() {
514 let channel = entry.key();
515 let arg = OKXSubscriptionArg {
516 channel: channel.clone(),
517 inst_type: None,
518 inst_family: None,
519 inst_id: None,
520 };
521 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
522 tracing::error!(error = %e, "Failed to send resubscribe command");
523 }
524 }
525
526 for entry in subscriptions_inst_type.iter() {
527 let (channel, inst_types) = entry.pair();
528 for inst_type in inst_types.iter() {
529 let arg = OKXSubscriptionArg {
530 channel: channel.clone(),
531 inst_type: Some(*inst_type),
532 inst_family: None,
533 inst_id: None,
534 };
535 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
536 tracing::error!(error = %e, "Failed to send resubscribe command");
537 }
538 }
539 }
540
541 for entry in subscriptions_inst_family.iter() {
542 let (channel, inst_families) = entry.pair();
543 for inst_family in inst_families.iter() {
544 let arg = OKXSubscriptionArg {
545 channel: channel.clone(),
546 inst_type: None,
547 inst_family: Some(*inst_family),
548 inst_id: None,
549 };
550 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
551 tracing::error!(error = %e, "Failed to send resubscribe command");
552 }
553 }
554 }
555 };
556
557 loop {
559 match handler.next().await {
560 Some(NautilusWsMessage::Reconnected) => {
561 if signal.load(Ordering::Relaxed) {
562 continue;
563 }
564
565 has_reconnected = true;
566
567 let confirmed_topics_vec: Vec<String> = {
569 let confirmed = subscriptions_state.confirmed();
570 let mut topics = Vec::new();
571 for entry in confirmed.iter() {
572 let channel = entry.key();
573 for symbol in entry.value().iter() {
574 if symbol.as_str() == "#" {
575 topics.push(channel.to_string());
576 } else {
577 topics.push(format!("{}{}{}", channel, OKX_WS_TOPIC_DELIMITER, symbol));
578 }
579 }
580 }
581 topics
582 };
583
584 if !confirmed_topics_vec.is_empty() {
585 tracing::debug!(count = confirmed_topics_vec.len(), "Marking confirmed subscriptions as pending for replay");
586 for topic in confirmed_topics_vec {
587 subscriptions_state.mark_failure(&topic);
588 }
589 }
590
591 if let Some(cred) = &credential {
592 tracing::debug!("Re-authenticating after reconnection");
593 let timestamp = std::time::SystemTime::now()
594 .duration_since(std::time::SystemTime::UNIX_EPOCH)
595 .expect("System time should be after UNIX epoch")
596 .as_secs()
597 .to_string();
598 let signature = cred.sign(×tamp, "GET", "/users/self/verify", "");
599
600 let auth_message = super::messages::OKXAuthentication {
601 op: "login",
602 args: vec![super::messages::OKXAuthenticationArg {
603 api_key: cred.api_key.to_string(),
604 passphrase: cred.api_passphrase.clone(),
605 timestamp,
606 sign: signature,
607 }],
608 };
609
610 if let Ok(payload) = serde_json::to_string(&auth_message) {
611 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Authenticate { payload }) {
612 tracing::error!(error = %e, "Failed to send reconnection auth command");
613 }
614 } else {
615 tracing::error!("Failed to serialize reconnection auth message");
616 }
617 }
618
619 if credential.is_none() {
622 tracing::debug!("No authentication required, resubscribing immediately");
623 resubscribe_all();
624 }
625
626 continue;
631 }
632 Some(NautilusWsMessage::Authenticated) => {
633 if has_reconnected {
634 resubscribe_all();
635 }
636
637 continue;
642 }
643 Some(msg) => {
644 if handler.send(msg).is_err() {
645 tracing::error!(
646 "Failed to send message through channel: receiver dropped",
647 );
648 break;
649 }
650 }
651 None => {
652 if handler.is_stopped() {
653 tracing::debug!(
654 "Stop signal received, ending message processing",
655 );
656 break;
657 }
658 tracing::warn!("WebSocket stream ended unexpectedly");
659 break;
660 }
661 }
662 }
663
664 tracing::debug!("Handler task exiting");
665 }
666 });
667
668 self.task_handle = Some(Arc::new(stream_handle));
669
670 self.cmd_tx
671 .read()
672 .await
673 .send(HandlerCommand::SetClient(client))
674 .map_err(|e| {
675 OKXWsError::ClientError(format!("Failed to send WebSocket client to handler: {e}"))
676 })?;
677 tracing::debug!("Sent WebSocket client to handler");
678
679 if self.credential.is_some()
680 && let Err(e) = self.authenticate().await
681 {
682 anyhow::bail!("Authentication failed: {e}");
683 }
684
685 Ok(())
686 }
687
688 async fn authenticate(&self) -> Result<(), Error> {
690 let credential = self.credential.as_ref().ok_or_else(|| {
691 Error::Io(std::io::Error::other(
692 "API credentials not available to authenticate",
693 ))
694 })?;
695
696 let rx = self.auth_tracker.begin();
697
698 let timestamp = SystemTime::now()
699 .duration_since(SystemTime::UNIX_EPOCH)
700 .expect("System time should be after UNIX epoch")
701 .as_secs()
702 .to_string();
703 let signature = credential.sign(×tamp, "GET", "/users/self/verify", "");
704
705 let auth_message = OKXAuthentication {
706 op: "login",
707 args: vec![OKXAuthenticationArg {
708 api_key: credential.api_key.to_string(),
709 passphrase: credential.api_passphrase.clone(),
710 timestamp,
711 sign: signature,
712 }],
713 };
714
715 let payload = serde_json::to_string(&auth_message).map_err(|e| {
716 Error::Io(std::io::Error::other(format!(
717 "Failed to serialize auth message: {e}"
718 )))
719 })?;
720
721 self.cmd_tx
722 .read()
723 .await
724 .send(HandlerCommand::Authenticate { payload })
725 .map_err(|e| {
726 Error::Io(std::io::Error::other(format!(
727 "Failed to send authenticate command: {e}"
728 )))
729 })?;
730
731 match self
732 .auth_tracker
733 .wait_for_result::<OKXWsError>(Duration::from_secs(AUTHENTICATION_TIMEOUT_SECS), rx)
734 .await
735 {
736 Ok(()) => {
737 tracing::info!("WebSocket authenticated");
738 Ok(())
739 }
740 Err(e) => {
741 tracing::error!(error = %e, "WebSocket authentication failed");
742 Err(Error::Io(std::io::Error::other(e.to_string())))
743 }
744 }
745 }
746
747 pub fn stream(&mut self) -> impl Stream<Item = NautilusWsMessage> + 'static {
755 let rx = self
756 .out_rx
757 .take()
758 .expect("Data stream receiver already taken or not connected");
759 let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
760 async_stream::stream! {
761 while let Some(data) = rx.recv().await {
762 yield data;
763 }
764 }
765 }
766
767 pub async fn wait_until_active(&self, timeout_secs: f64) -> Result<(), OKXWsError> {
773 let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
774
775 tokio::time::timeout(timeout, async {
776 while !self.is_active() {
777 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
778 }
779 })
780 .await
781 .map_err(|_| {
782 OKXWsError::ClientError(format!(
783 "WebSocket connection timeout after {timeout_secs} seconds"
784 ))
785 })?;
786
787 Ok(())
788 }
789
790 pub async fn close(&mut self) -> Result<(), Error> {
797 log::debug!("Starting close process");
798
799 self.signal.store(true, Ordering::Relaxed);
800
801 if let Err(e) = self.cmd_tx.read().await.send(HandlerCommand::Disconnect) {
802 log::warn!("Failed to send disconnect command to handler: {e}");
803 } else {
804 log::debug!("Sent disconnect command to handler");
805 }
806
807 {
809 if false {
810 log::debug!("No active connection to disconnect");
811 }
812 }
813
814 if let Some(stream_handle) = self.task_handle.take() {
816 match Arc::try_unwrap(stream_handle) {
817 Ok(handle) => {
818 log::debug!("Waiting for stream handle to complete");
819 match tokio::time::timeout(Duration::from_secs(2), handle).await {
820 Ok(Ok(())) => log::debug!("Stream handle completed successfully"),
821 Ok(Err(e)) => log::error!("Stream handle encountered an error: {e:?}"),
822 Err(_) => {
823 log::warn!(
824 "Timeout waiting for stream handle, task may still be running"
825 );
826 }
828 }
829 }
830 Err(arc_handle) => {
831 log::debug!(
832 "Cannot take ownership of stream handle - other references exist, aborting task"
833 );
834 arc_handle.abort();
835 }
836 }
837 } else {
838 log::debug!("No stream handle to await");
839 }
840
841 log::debug!("Close process completed");
842
843 Ok(())
844 }
845
846 pub fn get_subscriptions(&self, instrument_id: InstrumentId) -> Vec<OKXWsChannel> {
848 let symbol = instrument_id.symbol.inner();
849 let mut channels = Vec::new();
850
851 for entry in self.subscriptions_inst_id.iter() {
852 let (channel, instruments) = entry.pair();
853 if instruments.contains(&symbol) {
854 channels.push(channel.clone());
855 }
856 }
857
858 channels
859 }
860
861 fn generate_unique_request_id(&self) -> String {
862 self.request_id_counter
863 .fetch_add(1, Ordering::SeqCst)
864 .to_string()
865 }
866
867 #[allow(
868 clippy::result_large_err,
869 reason = "OKXWsError contains large tungstenite::Error variant"
870 )]
871 async fn subscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
872 for arg in &args {
873 let topic = topic_from_subscription_arg(arg);
874 self.subscriptions_state.mark_subscribe(&topic);
875
876 if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
878 self.subscriptions_bare.insert(arg.channel.clone(), true);
880 } else {
881 if let Some(inst_type) = &arg.inst_type {
883 self.subscriptions_inst_type
884 .entry(arg.channel.clone())
885 .or_default()
886 .insert(*inst_type);
887 }
888
889 if let Some(inst_family) = &arg.inst_family {
891 self.subscriptions_inst_family
892 .entry(arg.channel.clone())
893 .or_default()
894 .insert(*inst_family);
895 }
896
897 if let Some(inst_id) = &arg.inst_id {
899 self.subscriptions_inst_id
900 .entry(arg.channel.clone())
901 .or_default()
902 .insert(*inst_id);
903 }
904 }
905 }
906
907 self.cmd_tx
908 .read()
909 .await
910 .send(HandlerCommand::Subscribe { args })
911 .map_err(|e| OKXWsError::ClientError(format!("Failed to send subscribe command: {e}")))
912 }
913
914 #[allow(clippy::collapsible_if)]
915 async fn unsubscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
916 for arg in &args {
917 let topic = topic_from_subscription_arg(arg);
918 self.subscriptions_state.mark_unsubscribe(&topic);
919
920 if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
922 self.subscriptions_bare.remove(&arg.channel);
924 } else {
925 if let Some(inst_type) = &arg.inst_type {
927 if let Some(mut entry) = self.subscriptions_inst_type.get_mut(&arg.channel) {
928 entry.remove(inst_type);
929 if entry.is_empty() {
930 drop(entry);
931 self.subscriptions_inst_type.remove(&arg.channel);
932 }
933 }
934 }
935
936 if let Some(inst_family) = &arg.inst_family {
938 if let Some(mut entry) = self.subscriptions_inst_family.get_mut(&arg.channel) {
939 entry.remove(inst_family);
940 if entry.is_empty() {
941 drop(entry);
942 self.subscriptions_inst_family.remove(&arg.channel);
943 }
944 }
945 }
946
947 if let Some(inst_id) = &arg.inst_id {
949 if let Some(mut entry) = self.subscriptions_inst_id.get_mut(&arg.channel) {
950 entry.remove(inst_id);
951 if entry.is_empty() {
952 drop(entry);
953 self.subscriptions_inst_id.remove(&arg.channel);
954 }
955 }
956 }
957 }
958 }
959
960 self.cmd_tx
961 .read()
962 .await
963 .send(HandlerCommand::Unsubscribe { args })
964 .map_err(|e| {
965 OKXWsError::ClientError(format!("Failed to send unsubscribe command: {e}"))
966 })
967 }
968
969 pub async fn unsubscribe_all(&self) -> Result<(), OKXWsError> {
978 let mut all_args = Vec::new();
979
980 for entry in self.subscriptions_inst_type.iter() {
981 let (channel, inst_types) = entry.pair();
982 for inst_type in inst_types.iter() {
983 all_args.push(OKXSubscriptionArg {
984 channel: channel.clone(),
985 inst_type: Some(*inst_type),
986 inst_family: None,
987 inst_id: None,
988 });
989 }
990 }
991
992 for entry in self.subscriptions_inst_family.iter() {
993 let (channel, inst_families) = entry.pair();
994 for inst_family in inst_families.iter() {
995 all_args.push(OKXSubscriptionArg {
996 channel: channel.clone(),
997 inst_type: None,
998 inst_family: Some(*inst_family),
999 inst_id: None,
1000 });
1001 }
1002 }
1003
1004 for entry in self.subscriptions_inst_id.iter() {
1005 let (channel, inst_ids) = entry.pair();
1006 for inst_id in inst_ids.iter() {
1007 all_args.push(OKXSubscriptionArg {
1008 channel: channel.clone(),
1009 inst_type: None,
1010 inst_family: None,
1011 inst_id: Some(*inst_id),
1012 });
1013 }
1014 }
1015
1016 for entry in self.subscriptions_bare.iter() {
1017 let channel = entry.key();
1018 all_args.push(OKXSubscriptionArg {
1019 channel: channel.clone(),
1020 inst_type: None,
1021 inst_family: None,
1022 inst_id: None,
1023 });
1024 }
1025
1026 if all_args.is_empty() {
1027 tracing::debug!("No active subscriptions to unsubscribe from");
1028 return Ok(());
1029 }
1030
1031 tracing::debug!("Batched unsubscribe from {} channels", all_args.len());
1032
1033 const BATCH_SIZE: usize = 256;
1034
1035 for chunk in all_args.chunks(BATCH_SIZE) {
1036 self.unsubscribe(chunk.to_vec()).await?;
1037 }
1038
1039 Ok(())
1040 }
1041
1042 pub async fn subscribe_instruments(
1054 &self,
1055 instrument_type: OKXInstrumentType,
1056 ) -> Result<(), OKXWsError> {
1057 let arg = OKXSubscriptionArg {
1058 channel: OKXWsChannel::Instruments,
1059 inst_type: Some(instrument_type),
1060 inst_family: None,
1061 inst_id: None,
1062 };
1063 self.subscribe(vec![arg]).await
1064 }
1065
1066 pub async fn subscribe_instrument(
1079 &self,
1080 instrument_id: InstrumentId,
1081 ) -> Result<(), OKXWsError> {
1082 let inst_type = okx_instrument_type_from_symbol(instrument_id.symbol.as_str());
1083
1084 let already_subscribed = self
1085 .subscriptions_inst_type
1086 .get(&OKXWsChannel::Instruments)
1087 .is_some_and(|types| types.contains(&inst_type));
1088
1089 if already_subscribed {
1090 tracing::debug!(
1091 "Already subscribed to instrument type {inst_type:?} for {instrument_id}"
1092 );
1093 return Ok(());
1094 }
1095
1096 tracing::info!("Subscribing to instrument type {inst_type:?} for {instrument_id}");
1097 self.subscribe_instruments(inst_type).await
1098 }
1099
1100 pub async fn subscribe_book(&self, instrument_id: InstrumentId) -> anyhow::Result<()> {
1109 self.subscribe_book_with_depth(instrument_id, 0).await
1110 }
1111
1112 pub(crate) async fn subscribe_books_channel(
1114 &self,
1115 instrument_id: InstrumentId,
1116 ) -> Result<(), OKXWsError> {
1117 let arg = OKXSubscriptionArg {
1118 channel: OKXWsChannel::Books,
1119 inst_type: None,
1120 inst_family: None,
1121 inst_id: Some(instrument_id.symbol.inner()),
1122 };
1123 self.subscribe(vec![arg]).await
1124 }
1125
1126 pub async fn subscribe_book_depth5(
1138 &self,
1139 instrument_id: InstrumentId,
1140 ) -> Result<(), OKXWsError> {
1141 let arg = OKXSubscriptionArg {
1142 channel: OKXWsChannel::Books5,
1143 inst_type: None,
1144 inst_family: None,
1145 inst_id: Some(instrument_id.symbol.inner()),
1146 };
1147 self.subscribe(vec![arg]).await
1148 }
1149
1150 pub async fn subscribe_book50_l2_tbt(
1162 &self,
1163 instrument_id: InstrumentId,
1164 ) -> Result<(), OKXWsError> {
1165 let arg = OKXSubscriptionArg {
1166 channel: OKXWsChannel::Books50Tbt,
1167 inst_type: None,
1168 inst_family: None,
1169 inst_id: Some(instrument_id.symbol.inner()),
1170 };
1171 self.subscribe(vec![arg]).await
1172 }
1173
1174 pub async fn subscribe_book_l2_tbt(
1186 &self,
1187 instrument_id: InstrumentId,
1188 ) -> Result<(), OKXWsError> {
1189 let arg = OKXSubscriptionArg {
1190 channel: OKXWsChannel::BooksTbt,
1191 inst_type: None,
1192 inst_family: None,
1193 inst_id: Some(instrument_id.symbol.inner()),
1194 };
1195 self.subscribe(vec![arg]).await
1196 }
1197
1198 pub async fn subscribe_book_with_depth(
1212 &self,
1213 instrument_id: InstrumentId,
1214 depth: u16,
1215 ) -> anyhow::Result<()> {
1216 let vip = self.vip_level();
1217
1218 match depth {
1219 50 => {
1220 if vip < OKXVipLevel::Vip4 {
1221 anyhow::bail!(
1222 "VIP level {vip} insufficient for 50 depth subscription (requires VIP4)"
1223 );
1224 }
1225 self.subscribe_book50_l2_tbt(instrument_id)
1226 .await
1227 .map_err(|e| anyhow::anyhow!(e))
1228 }
1229 0 | 400 => {
1230 if vip >= OKXVipLevel::Vip5 {
1231 self.subscribe_book_l2_tbt(instrument_id)
1232 .await
1233 .map_err(|e| anyhow::anyhow!(e))
1234 } else {
1235 self.subscribe_books_channel(instrument_id)
1236 .await
1237 .map_err(|e| anyhow::anyhow!(e))
1238 }
1239 }
1240 _ => anyhow::bail!("Invalid depth {depth}, must be 0, 50, or 400"),
1241 }
1242 }
1243
1244 pub async fn subscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1257 let arg = OKXSubscriptionArg {
1258 channel: OKXWsChannel::BboTbt,
1259 inst_type: None,
1260 inst_family: None,
1261 inst_id: Some(instrument_id.symbol.inner()),
1262 };
1263 self.subscribe(vec![arg]).await
1264 }
1265
1266 pub async fn subscribe_trades(
1280 &self,
1281 instrument_id: InstrumentId,
1282 aggregated: bool,
1283 ) -> Result<(), OKXWsError> {
1284 let channel = if aggregated {
1285 OKXWsChannel::TradesAll
1286 } else {
1287 OKXWsChannel::Trades
1288 };
1289
1290 let arg = OKXSubscriptionArg {
1291 channel,
1292 inst_type: None,
1293 inst_family: None,
1294 inst_id: Some(instrument_id.symbol.inner()),
1295 };
1296 self.subscribe(vec![arg]).await
1297 }
1298
1299 pub async fn subscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1311 let arg = OKXSubscriptionArg {
1312 channel: OKXWsChannel::Tickers,
1313 inst_type: None,
1314 inst_family: None,
1315 inst_id: Some(instrument_id.symbol.inner()),
1316 };
1317 self.subscribe(vec![arg]).await
1318 }
1319
1320 pub async fn subscribe_mark_prices(
1332 &self,
1333 instrument_id: InstrumentId,
1334 ) -> Result<(), OKXWsError> {
1335 let arg = OKXSubscriptionArg {
1336 channel: OKXWsChannel::MarkPrice,
1337 inst_type: None,
1338 inst_family: None,
1339 inst_id: Some(instrument_id.symbol.inner()),
1340 };
1341 self.subscribe(vec![arg]).await
1342 }
1343
1344 pub async fn subscribe_index_prices(
1356 &self,
1357 instrument_id: InstrumentId,
1358 ) -> Result<(), OKXWsError> {
1359 let arg = OKXSubscriptionArg {
1360 channel: OKXWsChannel::IndexTickers,
1361 inst_type: None,
1362 inst_family: None,
1363 inst_id: Some(instrument_id.symbol.inner()),
1364 };
1365 self.subscribe(vec![arg]).await
1366 }
1367
1368 pub async fn subscribe_funding_rates(
1380 &self,
1381 instrument_id: InstrumentId,
1382 ) -> Result<(), OKXWsError> {
1383 let arg = OKXSubscriptionArg {
1384 channel: OKXWsChannel::FundingRate,
1385 inst_type: None,
1386 inst_family: None,
1387 inst_id: Some(instrument_id.symbol.inner()),
1388 };
1389 self.subscribe(vec![arg]).await
1390 }
1391
1392 pub async fn subscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1404 let channel = bar_spec_as_okx_channel(bar_type.spec())
1406 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1407
1408 let arg = OKXSubscriptionArg {
1409 channel,
1410 inst_type: None,
1411 inst_family: None,
1412 inst_id: Some(bar_type.instrument_id().symbol.inner()),
1413 };
1414 self.subscribe(vec![arg]).await
1415 }
1416
1417 pub async fn unsubscribe_instruments(
1423 &self,
1424 instrument_type: OKXInstrumentType,
1425 ) -> Result<(), OKXWsError> {
1426 let arg = OKXSubscriptionArg {
1427 channel: OKXWsChannel::Instruments,
1428 inst_type: Some(instrument_type),
1429 inst_family: None,
1430 inst_id: None,
1431 };
1432 self.unsubscribe(vec![arg]).await
1433 }
1434
1435 pub async fn unsubscribe_instrument(
1441 &self,
1442 instrument_id: InstrumentId,
1443 ) -> Result<(), OKXWsError> {
1444 let arg = OKXSubscriptionArg {
1445 channel: OKXWsChannel::Instruments,
1446 inst_type: None,
1447 inst_family: None,
1448 inst_id: Some(instrument_id.symbol.inner()),
1449 };
1450 self.unsubscribe(vec![arg]).await
1451 }
1452
1453 pub async fn unsubscribe_book(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1459 let arg = OKXSubscriptionArg {
1460 channel: OKXWsChannel::Books,
1461 inst_type: None,
1462 inst_family: None,
1463 inst_id: Some(instrument_id.symbol.inner()),
1464 };
1465 self.unsubscribe(vec![arg]).await
1466 }
1467
1468 pub async fn unsubscribe_book_depth5(
1474 &self,
1475 instrument_id: InstrumentId,
1476 ) -> Result<(), OKXWsError> {
1477 let arg = OKXSubscriptionArg {
1478 channel: OKXWsChannel::Books5,
1479 inst_type: None,
1480 inst_family: None,
1481 inst_id: Some(instrument_id.symbol.inner()),
1482 };
1483 self.unsubscribe(vec![arg]).await
1484 }
1485
1486 pub async fn unsubscribe_book50_l2_tbt(
1492 &self,
1493 instrument_id: InstrumentId,
1494 ) -> Result<(), OKXWsError> {
1495 let arg = OKXSubscriptionArg {
1496 channel: OKXWsChannel::Books50Tbt,
1497 inst_type: None,
1498 inst_family: None,
1499 inst_id: Some(instrument_id.symbol.inner()),
1500 };
1501 self.unsubscribe(vec![arg]).await
1502 }
1503
1504 pub async fn unsubscribe_book_l2_tbt(
1510 &self,
1511 instrument_id: InstrumentId,
1512 ) -> Result<(), OKXWsError> {
1513 let arg = OKXSubscriptionArg {
1514 channel: OKXWsChannel::BooksTbt,
1515 inst_type: None,
1516 inst_family: None,
1517 inst_id: Some(instrument_id.symbol.inner()),
1518 };
1519 self.unsubscribe(vec![arg]).await
1520 }
1521
1522 pub async fn unsubscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1528 let arg = OKXSubscriptionArg {
1529 channel: OKXWsChannel::BboTbt,
1530 inst_type: None,
1531 inst_family: None,
1532 inst_id: Some(instrument_id.symbol.inner()),
1533 };
1534 self.unsubscribe(vec![arg]).await
1535 }
1536
1537 pub async fn unsubscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1543 let arg = OKXSubscriptionArg {
1544 channel: OKXWsChannel::Tickers,
1545 inst_type: None,
1546 inst_family: None,
1547 inst_id: Some(instrument_id.symbol.inner()),
1548 };
1549 self.unsubscribe(vec![arg]).await
1550 }
1551
1552 pub async fn unsubscribe_mark_prices(
1558 &self,
1559 instrument_id: InstrumentId,
1560 ) -> Result<(), OKXWsError> {
1561 let arg = OKXSubscriptionArg {
1562 channel: OKXWsChannel::MarkPrice,
1563 inst_type: None,
1564 inst_family: None,
1565 inst_id: Some(instrument_id.symbol.inner()),
1566 };
1567 self.unsubscribe(vec![arg]).await
1568 }
1569
1570 pub async fn unsubscribe_index_prices(
1576 &self,
1577 instrument_id: InstrumentId,
1578 ) -> Result<(), OKXWsError> {
1579 let arg = OKXSubscriptionArg {
1580 channel: OKXWsChannel::IndexTickers,
1581 inst_type: None,
1582 inst_family: None,
1583 inst_id: Some(instrument_id.symbol.inner()),
1584 };
1585 self.unsubscribe(vec![arg]).await
1586 }
1587
1588 pub async fn unsubscribe_funding_rates(
1594 &self,
1595 instrument_id: InstrumentId,
1596 ) -> Result<(), OKXWsError> {
1597 let arg = OKXSubscriptionArg {
1598 channel: OKXWsChannel::FundingRate,
1599 inst_type: None,
1600 inst_family: None,
1601 inst_id: Some(instrument_id.symbol.inner()),
1602 };
1603 self.unsubscribe(vec![arg]).await
1604 }
1605
1606 pub async fn unsubscribe_trades(
1612 &self,
1613 instrument_id: InstrumentId,
1614 aggregated: bool,
1615 ) -> Result<(), OKXWsError> {
1616 let channel = if aggregated {
1617 OKXWsChannel::TradesAll
1618 } else {
1619 OKXWsChannel::Trades
1620 };
1621
1622 let arg = OKXSubscriptionArg {
1623 channel,
1624 inst_type: None,
1625 inst_family: None,
1626 inst_id: Some(instrument_id.symbol.inner()),
1627 };
1628 self.unsubscribe(vec![arg]).await
1629 }
1630
1631 pub async fn unsubscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1637 let channel = bar_spec_as_okx_channel(bar_type.spec())
1639 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1640
1641 let arg = OKXSubscriptionArg {
1642 channel,
1643 inst_type: None,
1644 inst_family: None,
1645 inst_id: Some(bar_type.instrument_id().symbol.inner()),
1646 };
1647 self.unsubscribe(vec![arg]).await
1648 }
1649
1650 pub async fn subscribe_orders(
1656 &self,
1657 instrument_type: OKXInstrumentType,
1658 ) -> Result<(), OKXWsError> {
1659 let arg = OKXSubscriptionArg {
1660 channel: OKXWsChannel::Orders,
1661 inst_type: Some(instrument_type),
1662 inst_family: None,
1663 inst_id: None,
1664 };
1665 self.subscribe(vec![arg]).await
1666 }
1667
1668 pub async fn unsubscribe_orders(
1674 &self,
1675 instrument_type: OKXInstrumentType,
1676 ) -> Result<(), OKXWsError> {
1677 let arg = OKXSubscriptionArg {
1678 channel: OKXWsChannel::Orders,
1679 inst_type: Some(instrument_type),
1680 inst_family: None,
1681 inst_id: None,
1682 };
1683 self.unsubscribe(vec![arg]).await
1684 }
1685
1686 pub async fn subscribe_orders_algo(
1692 &self,
1693 instrument_type: OKXInstrumentType,
1694 ) -> Result<(), OKXWsError> {
1695 let arg = OKXSubscriptionArg {
1696 channel: OKXWsChannel::OrdersAlgo,
1697 inst_type: Some(instrument_type),
1698 inst_family: None,
1699 inst_id: None,
1700 };
1701 self.subscribe(vec![arg]).await
1702 }
1703
1704 pub async fn unsubscribe_orders_algo(
1710 &self,
1711 instrument_type: OKXInstrumentType,
1712 ) -> Result<(), OKXWsError> {
1713 let arg = OKXSubscriptionArg {
1714 channel: OKXWsChannel::OrdersAlgo,
1715 inst_type: Some(instrument_type),
1716 inst_family: None,
1717 inst_id: None,
1718 };
1719 self.unsubscribe(vec![arg]).await
1720 }
1721
1722 pub async fn subscribe_fills(
1728 &self,
1729 instrument_type: OKXInstrumentType,
1730 ) -> Result<(), OKXWsError> {
1731 let arg = OKXSubscriptionArg {
1732 channel: OKXWsChannel::Fills,
1733 inst_type: Some(instrument_type),
1734 inst_family: None,
1735 inst_id: None,
1736 };
1737 self.subscribe(vec![arg]).await
1738 }
1739
1740 pub async fn unsubscribe_fills(
1746 &self,
1747 instrument_type: OKXInstrumentType,
1748 ) -> Result<(), OKXWsError> {
1749 let arg = OKXSubscriptionArg {
1750 channel: OKXWsChannel::Fills,
1751 inst_type: Some(instrument_type),
1752 inst_family: None,
1753 inst_id: None,
1754 };
1755 self.unsubscribe(vec![arg]).await
1756 }
1757
1758 pub async fn subscribe_account(&self) -> Result<(), OKXWsError> {
1764 let arg = OKXSubscriptionArg {
1765 channel: OKXWsChannel::Account,
1766 inst_type: None,
1767 inst_family: None,
1768 inst_id: None,
1769 };
1770 self.subscribe(vec![arg]).await
1771 }
1772
1773 pub async fn unsubscribe_account(&self) -> Result<(), OKXWsError> {
1779 let arg = OKXSubscriptionArg {
1780 channel: OKXWsChannel::Account,
1781 inst_type: None,
1782 inst_family: None,
1783 inst_id: None,
1784 };
1785 self.unsubscribe(vec![arg]).await
1786 }
1787
1788 pub async fn subscribe_positions(
1798 &self,
1799 inst_type: OKXInstrumentType,
1800 ) -> Result<(), OKXWsError> {
1801 let arg = OKXSubscriptionArg {
1802 channel: OKXWsChannel::Positions,
1803 inst_type: Some(inst_type),
1804 inst_family: None,
1805 inst_id: None,
1806 };
1807 self.subscribe(vec![arg]).await
1808 }
1809
1810 pub async fn unsubscribe_positions(
1816 &self,
1817 inst_type: OKXInstrumentType,
1818 ) -> Result<(), OKXWsError> {
1819 let arg = OKXSubscriptionArg {
1820 channel: OKXWsChannel::Positions,
1821 inst_type: Some(inst_type),
1822 inst_family: None,
1823 inst_id: None,
1824 };
1825 self.unsubscribe(vec![arg]).await
1826 }
1827
1828 async fn ws_batch_place_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1834 let request_id = self.generate_unique_request_id();
1835 let cmd = HandlerCommand::BatchPlaceOrders { args, request_id };
1836
1837 self.send_cmd(cmd).await
1838 }
1839
1840 async fn ws_batch_cancel_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1846 let request_id = self.generate_unique_request_id();
1847 let cmd = HandlerCommand::BatchCancelOrders { args, request_id };
1848
1849 self.send_cmd(cmd).await
1850 }
1851
1852 async fn ws_batch_amend_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1858 let request_id = self.generate_unique_request_id();
1859 let cmd = HandlerCommand::BatchAmendOrders { args, request_id };
1860
1861 self.send_cmd(cmd).await
1862 }
1863
1864 #[allow(clippy::too_many_arguments)]
1876 pub async fn submit_order(
1877 &self,
1878 trader_id: TraderId,
1879 strategy_id: StrategyId,
1880 instrument_id: InstrumentId,
1881 td_mode: OKXTradeMode,
1882 client_order_id: ClientOrderId,
1883 order_side: OrderSide,
1884 order_type: OrderType,
1885 quantity: Quantity,
1886 time_in_force: Option<TimeInForce>,
1887 price: Option<Price>,
1888 trigger_price: Option<Price>,
1889 post_only: Option<bool>,
1890 reduce_only: Option<bool>,
1891 quote_quantity: Option<bool>,
1892 position_side: Option<PositionSide>,
1893 ) -> Result<(), OKXWsError> {
1894 if !OKX_SUPPORTED_ORDER_TYPES.contains(&order_type) {
1895 return Err(OKXWsError::ClientError(format!(
1896 "Unsupported order type: {order_type:?}",
1897 )));
1898 }
1899
1900 if let Some(tif) = time_in_force
1901 && !OKX_SUPPORTED_TIME_IN_FORCE.contains(&tif)
1902 {
1903 return Err(OKXWsError::ClientError(format!(
1904 "Unsupported time in force: {tif:?}",
1905 )));
1906 }
1907
1908 let mut builder = WsPostOrderParamsBuilder::default();
1909
1910 builder.inst_id(instrument_id.symbol.as_str());
1911 builder.td_mode(td_mode);
1912 builder.cl_ord_id(client_order_id.as_str());
1913
1914 let instrument = self
1915 .instruments_cache
1916 .get(&instrument_id.symbol.inner())
1917 .ok_or_else(|| {
1918 OKXWsError::ClientError(format!("Unknown instrument {instrument_id}"))
1919 })?;
1920
1921 let instrument_type =
1922 okx_instrument_type(&instrument).map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1923 let quote_currency = instrument.quote_currency();
1924
1925 match instrument_type {
1926 OKXInstrumentType::Spot => {
1927 builder.ccy(quote_currency.to_string());
1929 }
1930 OKXInstrumentType::Margin => {
1931 builder.ccy(quote_currency.to_string());
1932
1933 if let Some(ro) = reduce_only
1934 && ro
1935 {
1936 builder.reduce_only(ro);
1937 }
1938 }
1939 OKXInstrumentType::Swap | OKXInstrumentType::Futures => {
1940 builder.ccy(quote_currency.to_string());
1942
1943 if position_side.is_none() {
1946 builder.pos_side(OKXPositionSide::Net);
1947 }
1948 }
1949 _ => {
1950 builder.ccy(quote_currency.to_string());
1951
1952 if position_side.is_none() {
1954 builder.pos_side(OKXPositionSide::Net);
1955 }
1956
1957 if let Some(ro) = reduce_only
1958 && ro
1959 {
1960 builder.reduce_only(ro);
1961 }
1962 }
1963 };
1964
1965 if instrument_type == OKXInstrumentType::Spot
1972 && order_type == OrderType::Market
1973 && td_mode == OKXTradeMode::Cash
1974 {
1975 match quote_quantity {
1976 Some(true) => {
1977 builder.tgt_ccy(OKXTargetCurrency::QuoteCcy);
1979 }
1980 Some(false) => {
1981 if order_side == OrderSide::Buy {
1982 builder.tgt_ccy(OKXTargetCurrency::BaseCcy);
1984 }
1985 }
1987 None => {
1988 }
1990 }
1991 }
1992
1993 builder.side(order_side);
1994
1995 if let Some(pos_side) = position_side {
1996 builder.pos_side(pos_side);
1997 };
1998
1999 let (okx_ord_type, price) = if post_only.unwrap_or(false) {
2002 (OKXOrderType::PostOnly, price)
2003 } else if let Some(tif) = time_in_force {
2004 match (order_type, tif) {
2005 (OrderType::Market, TimeInForce::Fok) => {
2006 return Err(OKXWsError::ClientError(
2007 "Market orders with FOK time-in-force are not supported by OKX. Use Limit order with FOK instead.".to_string()
2008 ));
2009 }
2010 (OrderType::Market, TimeInForce::Ioc) => (OKXOrderType::OptimalLimitIoc, price),
2011 (OrderType::Limit, TimeInForce::Fok) => (OKXOrderType::Fok, price),
2012 (OrderType::Limit, TimeInForce::Ioc) => (OKXOrderType::Ioc, price),
2013 _ => (OKXOrderType::from(order_type), price),
2014 }
2015 } else {
2016 (OKXOrderType::from(order_type), price)
2017 };
2018
2019 log::debug!(
2020 "Order type mapping: order_type={:?}, time_in_force={:?}, post_only={:?} -> okx_ord_type={:?}",
2021 order_type,
2022 time_in_force,
2023 post_only,
2024 okx_ord_type
2025 );
2026
2027 builder.ord_type(okx_ord_type);
2028 builder.sz(quantity.to_string());
2029
2030 if let Some(tp) = trigger_price {
2031 builder.px(tp.to_string());
2032 } else if let Some(p) = price {
2033 builder.px(p.to_string());
2034 }
2035
2036 builder.tag(OKX_NAUTILUS_BROKER_ID);
2037
2038 let params = builder
2039 .build()
2040 .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
2041
2042 self.active_client_orders
2043 .insert(client_order_id, (trader_id, strategy_id, instrument_id));
2044
2045 let cmd = HandlerCommand::PlaceOrder {
2046 params,
2047 client_order_id,
2048 trader_id,
2049 strategy_id,
2050 instrument_id,
2051 };
2052
2053 self.send_cmd(cmd).await
2054 }
2055
2056 #[allow(clippy::too_many_arguments)]
2072 pub async fn modify_order(
2073 &self,
2074 trader_id: TraderId,
2075 strategy_id: StrategyId,
2076 instrument_id: InstrumentId,
2077 client_order_id: Option<ClientOrderId>,
2078 price: Option<Price>,
2079 quantity: Option<Quantity>,
2080 venue_order_id: Option<VenueOrderId>,
2081 ) -> Result<(), OKXWsError> {
2082 let mut builder = WsAmendOrderParamsBuilder::default();
2083
2084 builder.inst_id(instrument_id.symbol.as_str());
2085
2086 if let Some(venue_order_id) = venue_order_id {
2087 builder.ord_id(venue_order_id.as_str());
2088 }
2089
2090 if let Some(client_order_id) = client_order_id {
2091 builder.cl_ord_id(client_order_id.as_str());
2092 }
2093
2094 if let Some(price) = price {
2095 builder.new_px(price.to_string());
2096 }
2097
2098 if let Some(quantity) = quantity {
2099 builder.new_sz(quantity.to_string());
2100 }
2101
2102 let params = builder
2103 .build()
2104 .map_err(|e| OKXWsError::ClientError(format!("Build amend params error: {e}")))?;
2105
2106 if let Some(client_order_id) = client_order_id {
2109 let cmd = HandlerCommand::AmendOrder {
2110 params,
2111 client_order_id,
2112 trader_id,
2113 strategy_id,
2114 instrument_id,
2115 venue_order_id,
2116 };
2117
2118 self.send_cmd(cmd).await
2119 } else {
2120 Err(OKXWsError::ClientError(
2122 "Cannot amend order without client_order_id".to_string(),
2123 ))
2124 }
2125 }
2126
2127 #[allow(clippy::too_many_arguments)]
2138 pub async fn cancel_order(
2139 &self,
2140 trader_id: TraderId,
2141 strategy_id: StrategyId,
2142 instrument_id: InstrumentId,
2143 client_order_id: Option<ClientOrderId>,
2144 venue_order_id: Option<VenueOrderId>,
2145 ) -> Result<(), OKXWsError> {
2146 let cmd = HandlerCommand::CancelOrder {
2147 client_order_id,
2148 venue_order_id,
2149 instrument_id,
2150 trader_id,
2151 strategy_id,
2152 };
2153
2154 self.send_cmd(cmd).await
2155 }
2156
2157 pub async fn mass_cancel_orders(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
2167 let cmd = HandlerCommand::MassCancel { instrument_id };
2168
2169 self.send_cmd(cmd).await
2170 }
2171
2172 #[allow(clippy::type_complexity)]
2179 #[allow(clippy::too_many_arguments)]
2180 pub async fn batch_submit_orders(
2181 &self,
2182 orders: Vec<(
2183 OKXInstrumentType,
2184 InstrumentId,
2185 OKXTradeMode,
2186 ClientOrderId,
2187 OrderSide,
2188 Option<PositionSide>,
2189 OrderType,
2190 Quantity,
2191 Option<Price>,
2192 Option<Price>,
2193 Option<bool>,
2194 Option<bool>,
2195 )>,
2196 ) -> Result<(), OKXWsError> {
2197 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2198 for (
2199 inst_type,
2200 inst_id,
2201 td_mode,
2202 cl_ord_id,
2203 ord_side,
2204 pos_side,
2205 ord_type,
2206 qty,
2207 pr,
2208 tp,
2209 post_only,
2210 reduce_only,
2211 ) in orders
2212 {
2213 let mut builder = WsPostOrderParamsBuilder::default();
2214 builder.inst_type(inst_type);
2215 builder.inst_id(inst_id.symbol.inner());
2216 builder.td_mode(td_mode);
2217 builder.cl_ord_id(cl_ord_id.as_str());
2218 builder.side(ord_side);
2219
2220 if let Some(ps) = pos_side {
2221 builder.pos_side(OKXPositionSide::from(ps));
2222 }
2223
2224 let okx_ord_type = if post_only.unwrap_or(false) {
2225 OKXOrderType::PostOnly
2226 } else {
2227 OKXOrderType::from(ord_type)
2228 };
2229
2230 builder.ord_type(okx_ord_type);
2231 builder.sz(qty.to_string());
2232
2233 if let Some(p) = pr {
2234 builder.px(p.to_string());
2235 } else if let Some(p) = tp {
2236 builder.px(p.to_string());
2237 }
2238
2239 if let Some(ro) = reduce_only {
2240 builder.reduce_only(ro);
2241 }
2242
2243 builder.tag(OKX_NAUTILUS_BROKER_ID);
2244
2245 let params = builder
2246 .build()
2247 .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
2248 let val =
2249 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2250 args.push(val);
2251 }
2252
2253 self.ws_batch_place_orders(args).await
2254 }
2255
2256 #[allow(clippy::type_complexity)]
2263 #[allow(clippy::too_many_arguments)]
2264 pub async fn batch_modify_orders(
2265 &self,
2266 orders: Vec<(
2267 OKXInstrumentType,
2268 InstrumentId,
2269 ClientOrderId,
2270 ClientOrderId,
2271 Option<Price>,
2272 Option<Quantity>,
2273 )>,
2274 ) -> Result<(), OKXWsError> {
2275 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2276 for (_inst_type, inst_id, cl_ord_id, new_cl_ord_id, pr, sz) in orders {
2277 let mut builder = WsAmendOrderParamsBuilder::default();
2278 builder.inst_id(inst_id.symbol.inner());
2280 builder.cl_ord_id(cl_ord_id.as_str());
2281 builder.new_cl_ord_id(new_cl_ord_id.as_str());
2282
2283 if let Some(p) = pr {
2284 builder.new_px(p.to_string());
2285 }
2286
2287 if let Some(q) = sz {
2288 builder.new_sz(q.to_string());
2289 }
2290
2291 let params = builder.build().map_err(|e| {
2292 OKXWsError::ClientError(format!("Build amend batch params error: {e}"))
2293 })?;
2294 let val =
2295 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2296 args.push(val);
2297 }
2298
2299 self.ws_batch_amend_orders(args).await
2300 }
2301
2302 #[allow(clippy::type_complexity)]
2315 pub async fn batch_cancel_orders(
2316 &self,
2317 orders: Vec<(InstrumentId, Option<ClientOrderId>, Option<VenueOrderId>)>,
2318 ) -> Result<(), OKXWsError> {
2319 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2320 for (inst_id, cl_ord_id, ord_id) in orders {
2321 let mut builder = WsCancelOrderParamsBuilder::default();
2322 builder.inst_id(inst_id.symbol.inner());
2324
2325 if let Some(c) = cl_ord_id {
2326 builder.cl_ord_id(c.as_str());
2327 }
2328
2329 if let Some(o) = ord_id {
2330 builder.ord_id(o.as_str());
2331 }
2332
2333 let params = builder.build().map_err(|e| {
2334 OKXWsError::ClientError(format!("Build cancel batch params error: {e}"))
2335 })?;
2336 let val =
2337 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2338 args.push(val);
2339 }
2340
2341 self.ws_batch_cancel_orders(args).await
2342 }
2343
2344 #[allow(clippy::too_many_arguments)]
2355 pub async fn submit_algo_order(
2356 &self,
2357 trader_id: TraderId,
2358 strategy_id: StrategyId,
2359 instrument_id: InstrumentId,
2360 td_mode: OKXTradeMode,
2361 client_order_id: ClientOrderId,
2362 order_side: OrderSide,
2363 order_type: OrderType,
2364 quantity: Quantity,
2365 trigger_price: Price,
2366 trigger_type: Option<TriggerType>,
2367 limit_price: Option<Price>,
2368 reduce_only: Option<bool>,
2369 ) -> Result<(), OKXWsError> {
2370 if !is_conditional_order(order_type) {
2371 return Err(OKXWsError::ClientError(format!(
2372 "Order type {order_type:?} is not a conditional order"
2373 )));
2374 }
2375
2376 let mut builder = WsPostAlgoOrderParamsBuilder::default();
2377 if !matches!(order_side, OrderSide::Buy | OrderSide::Sell) {
2378 return Err(OKXWsError::ClientError(
2379 "Invalid order side for OKX".to_string(),
2380 ));
2381 }
2382
2383 builder.inst_id(instrument_id.symbol.inner());
2384 builder.td_mode(td_mode);
2385 builder.cl_ord_id(client_order_id.as_str());
2386 builder.side(order_side);
2387 builder.ord_type(
2388 conditional_order_to_algo_type(order_type)
2389 .map_err(|e| OKXWsError::ClientError(e.to_string()))?,
2390 );
2391 builder.sz(quantity.to_string());
2392 builder.trigger_px(trigger_price.to_string());
2393
2394 let okx_trigger_type = trigger_type.map_or(OKXTriggerType::Last, Into::into);
2396 builder.trigger_px_type(okx_trigger_type);
2397
2398 if matches!(order_type, OrderType::StopLimit | OrderType::LimitIfTouched)
2400 && let Some(price) = limit_price
2401 {
2402 builder.order_px(price.to_string());
2403 }
2404
2405 if let Some(reduce) = reduce_only {
2406 builder.reduce_only(reduce);
2407 }
2408
2409 builder.tag(OKX_NAUTILUS_BROKER_ID);
2410
2411 let params = builder
2412 .build()
2413 .map_err(|e| OKXWsError::ClientError(format!("Build algo order params error: {e}")))?;
2414
2415 self.active_client_orders
2416 .insert(client_order_id, (trader_id, strategy_id, instrument_id));
2417
2418 let cmd = HandlerCommand::PlaceAlgoOrder {
2419 params,
2420 client_order_id,
2421 trader_id,
2422 strategy_id,
2423 instrument_id,
2424 };
2425
2426 self.send_cmd(cmd).await
2427 }
2428
2429 pub async fn cancel_algo_order(
2440 &self,
2441 trader_id: TraderId,
2442 strategy_id: StrategyId,
2443 instrument_id: InstrumentId,
2444 client_order_id: Option<ClientOrderId>,
2445 algo_order_id: Option<String>,
2446 ) -> Result<(), OKXWsError> {
2447 let cmd = HandlerCommand::CancelAlgoOrder {
2448 client_order_id,
2449 algo_order_id: algo_order_id.map(|id| VenueOrderId::from(id.as_str())),
2450 instrument_id,
2451 trader_id,
2452 strategy_id,
2453 };
2454
2455 self.send_cmd(cmd).await
2456 }
2457
2458 async fn send_cmd(&self, cmd: HandlerCommand) -> Result<(), OKXWsError> {
2460 self.cmd_tx
2461 .read()
2462 .await
2463 .send(cmd)
2464 .map_err(|e| OKXWsError::ClientError(format!("Handler not available: {e}")))
2465 }
2466}
2467
2468#[cfg(test)]
2473mod tests {
2474 use nautilus_core::time::get_atomic_clock_realtime;
2475 use nautilus_network::RECONNECTED;
2476 use rstest::rstest;
2477 use tokio_tungstenite::tungstenite::Message;
2478
2479 use super::*;
2480 use crate::{
2481 common::{
2482 consts::OKX_POST_ONLY_CANCEL_SOURCE,
2483 enums::{OKXExecType, OKXOrderCategory, OKXOrderStatus, OKXSide},
2484 },
2485 websocket::{
2486 handler::OKXWsFeedHandler,
2487 messages::{OKXOrderMsg, OKXWebSocketError, OKXWsMessage},
2488 },
2489 };
2490
2491 #[rstest]
2492 fn test_timestamp_format_for_websocket_auth() {
2493 let timestamp = SystemTime::now()
2494 .duration_since(SystemTime::UNIX_EPOCH)
2495 .expect("System time should be after UNIX epoch")
2496 .as_secs()
2497 .to_string();
2498
2499 assert!(timestamp.parse::<u64>().is_ok());
2500 assert_eq!(timestamp.len(), 10);
2501 assert!(timestamp.chars().all(|c| c.is_ascii_digit()));
2502 }
2503
2504 #[rstest]
2505 fn test_new_without_credentials() {
2506 let client = OKXWebSocketClient::default();
2507 assert!(client.credential.is_none());
2508 assert_eq!(client.api_key(), None);
2509 }
2510
2511 #[rstest]
2512 fn test_new_with_credentials() {
2513 let client = OKXWebSocketClient::new(
2514 None,
2515 Some("test_key".to_string()),
2516 Some("test_secret".to_string()),
2517 Some("test_passphrase".to_string()),
2518 None,
2519 None,
2520 )
2521 .unwrap();
2522 assert!(client.credential.is_some());
2523 assert_eq!(client.api_key(), Some("test_key"));
2524 }
2525
2526 #[rstest]
2527 fn test_new_partial_credentials_fails() {
2528 let result = OKXWebSocketClient::new(
2529 None,
2530 Some("test_key".to_string()),
2531 None,
2532 Some("test_passphrase".to_string()),
2533 None,
2534 None,
2535 );
2536 assert!(result.is_err());
2537 }
2538
2539 #[rstest]
2540 fn test_request_id_generation() {
2541 let client = OKXWebSocketClient::default();
2542
2543 let initial_counter = client.request_id_counter.load(Ordering::SeqCst);
2544
2545 let id1 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
2546 let id2 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
2547
2548 assert_eq!(id1, initial_counter);
2549 assert_eq!(id2, initial_counter + 1);
2550 assert_eq!(
2551 client.request_id_counter.load(Ordering::SeqCst),
2552 initial_counter + 2
2553 );
2554 }
2555
2556 #[rstest]
2557 fn test_client_state_management() {
2558 let client = OKXWebSocketClient::default();
2559
2560 assert!(client.is_closed());
2561 assert!(!client.is_active());
2562
2563 let client_with_heartbeat =
2564 OKXWebSocketClient::new(None, None, None, None, None, Some(30)).unwrap();
2565
2566 assert!(client_with_heartbeat.heartbeat.is_some());
2567 assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
2568 }
2569
2570 #[rstest]
2575 fn test_websocket_error_handling() {
2576 let clock = get_atomic_clock_realtime();
2577 let ts = clock.get_time_ns().as_u64();
2578
2579 let error = OKXWebSocketError {
2580 code: "60012".to_string(),
2581 message: "Invalid request".to_string(),
2582 conn_id: None,
2583 timestamp: ts,
2584 };
2585
2586 assert_eq!(error.code, "60012");
2587 assert_eq!(error.message, "Invalid request");
2588 assert_eq!(error.timestamp, ts);
2589
2590 let nautilus_msg = NautilusWsMessage::Error(error);
2591 match nautilus_msg {
2592 NautilusWsMessage::Error(e) => {
2593 assert_eq!(e.code, "60012");
2594 assert_eq!(e.message, "Invalid request");
2595 }
2596 _ => panic!("Expected Error variant"),
2597 }
2598 }
2599
2600 #[rstest]
2601 fn test_request_id_generation_sequence() {
2602 let client = OKXWebSocketClient::default();
2603
2604 let initial_counter = client
2605 .request_id_counter
2606 .load(std::sync::atomic::Ordering::SeqCst);
2607 let mut ids = Vec::new();
2608 for _ in 0..10 {
2609 let id = client
2610 .request_id_counter
2611 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2612 ids.push(id);
2613 }
2614
2615 for (i, &id) in ids.iter().enumerate() {
2616 assert_eq!(id, initial_counter + i as u64);
2617 }
2618
2619 assert_eq!(
2620 client
2621 .request_id_counter
2622 .load(std::sync::atomic::Ordering::SeqCst),
2623 initial_counter + 10
2624 );
2625 }
2626
2627 #[rstest]
2628 fn test_client_state_transitions() {
2629 let client = OKXWebSocketClient::default();
2630
2631 assert!(client.is_closed());
2632 assert!(!client.is_active());
2633
2634 let client_with_heartbeat = OKXWebSocketClient::new(
2635 None,
2636 None,
2637 None,
2638 None,
2639 None,
2640 Some(30), )
2642 .unwrap();
2643
2644 assert!(client_with_heartbeat.heartbeat.is_some());
2645 assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
2646
2647 let account_id = AccountId::from("test-account-123");
2648 let client_with_account =
2649 OKXWebSocketClient::new(None, None, None, None, Some(account_id), None).unwrap();
2650
2651 assert_eq!(client_with_account.account_id, account_id);
2652 }
2653
2654 #[rstest]
2655 fn test_websocket_error_scenarios() {
2656 let clock = get_atomic_clock_realtime();
2657 let ts = clock.get_time_ns().as_u64();
2658
2659 let error_scenarios = vec![
2660 ("60012", "Invalid request", None),
2661 ("60009", "Invalid API key", Some("conn-123".to_string())),
2662 ("60014", "Too many requests", None),
2663 ("50001", "Order not found", None),
2664 ];
2665
2666 for (code, message, conn_id) in error_scenarios {
2667 let error = OKXWebSocketError {
2668 code: code.to_string(),
2669 message: message.to_string(),
2670 conn_id: conn_id.clone(),
2671 timestamp: ts,
2672 };
2673
2674 assert_eq!(error.code, code);
2675 assert_eq!(error.message, message);
2676 assert_eq!(error.conn_id, conn_id);
2677 assert_eq!(error.timestamp, ts);
2678
2679 let nautilus_msg = NautilusWsMessage::Error(error);
2680 match nautilus_msg {
2681 NautilusWsMessage::Error(e) => {
2682 assert_eq!(e.code, code);
2683 assert_eq!(e.message, message);
2684 assert_eq!(e.conn_id, conn_id);
2685 }
2686 _ => panic!("Expected Error variant"),
2687 }
2688 }
2689 }
2690
2691 #[rstest]
2692 fn test_feed_handler_reconnection_detection() {
2693 let msg = Message::Text(RECONNECTED.to_string().into());
2694 let result = OKXWsFeedHandler::parse_raw_message(msg);
2695 assert!(matches!(result, Some(OKXWsMessage::Reconnected)));
2696 }
2697
2698 #[rstest]
2699 fn test_feed_handler_normal_message_processing() {
2700 let ping_msg = Message::Text(TEXT_PING.to_string().into());
2702 let result = OKXWsFeedHandler::parse_raw_message(ping_msg);
2703 assert!(matches!(result, Some(OKXWsMessage::Ping)));
2704
2705 let sub_msg = r#"{
2707 "event": "subscribe",
2708 "arg": {
2709 "channel": "tickers",
2710 "instType": "SPOT"
2711 },
2712 "connId": "a4d3ae55"
2713 }"#;
2714
2715 let sub_result =
2716 OKXWsFeedHandler::parse_raw_message(Message::Text(sub_msg.to_string().into()));
2717 assert!(matches!(
2718 sub_result,
2719 Some(OKXWsMessage::Subscription { .. })
2720 ));
2721 }
2722
2723 #[rstest]
2724 fn test_feed_handler_close_message() {
2725 let result = OKXWsFeedHandler::parse_raw_message(Message::Close(None));
2727 assert!(result.is_none());
2728 }
2729
2730 #[rstest]
2731 fn test_reconnection_message_constant() {
2732 assert_eq!(RECONNECTED, "__RECONNECTED__");
2733 }
2734
2735 #[rstest]
2736 fn test_multiple_reconnection_signals() {
2737 for _ in 0..3 {
2739 let msg = Message::Text(RECONNECTED.to_string().into());
2740 let result = OKXWsFeedHandler::parse_raw_message(msg);
2741 assert!(matches!(result, Some(OKXWsMessage::Reconnected)));
2742 }
2743 }
2744
2745 #[tokio::test]
2746 async fn test_wait_until_active_timeout() {
2747 let client = OKXWebSocketClient::new(
2748 None,
2749 Some("test_key".to_string()),
2750 Some("test_secret".to_string()),
2751 Some("test_passphrase".to_string()),
2752 Some(AccountId::from("test-account")),
2753 None,
2754 )
2755 .unwrap();
2756
2757 let result = client.wait_until_active(0.1).await;
2759
2760 assert!(result.is_err());
2761 assert!(!client.is_active());
2762 }
2763
2764 fn sample_canceled_order_msg() -> OKXOrderMsg {
2765 OKXOrderMsg {
2766 acc_fill_sz: Some("0".to_string()),
2767 avg_px: "0".to_string(),
2768 c_time: 0,
2769 cancel_source: None,
2770 cancel_source_reason: None,
2771 category: OKXOrderCategory::Normal,
2772 ccy: ustr::Ustr::from("USDT"),
2773 cl_ord_id: "order-1".to_string(),
2774 algo_cl_ord_id: None,
2775 fee: None,
2776 fee_ccy: ustr::Ustr::from("USDT"),
2777 fill_px: "0".to_string(),
2778 fill_sz: "0".to_string(),
2779 fill_time: 0,
2780 inst_id: ustr::Ustr::from("ETH-USDT-SWAP"),
2781 inst_type: OKXInstrumentType::Swap,
2782 lever: "1".to_string(),
2783 ord_id: ustr::Ustr::from("123456"),
2784 ord_type: OKXOrderType::Limit,
2785 pnl: "0".to_string(),
2786 pos_side: OKXPositionSide::Net,
2787 px: "0".to_string(),
2788 reduce_only: "false".to_string(),
2789 side: OKXSide::Buy,
2790 state: OKXOrderStatus::Canceled,
2791 exec_type: OKXExecType::None,
2792 sz: "1".to_string(),
2793 td_mode: OKXTradeMode::Cross,
2794 tgt_ccy: None,
2795 trade_id: String::new(),
2796 u_time: 0,
2797 }
2798 }
2799
2800 #[rstest]
2801 fn test_is_post_only_auto_cancel_detects_cancel_source() {
2802 let mut msg = sample_canceled_order_msg();
2803 msg.cancel_source = Some(OKX_POST_ONLY_CANCEL_SOURCE.to_string());
2804
2805 assert!(OKXWsFeedHandler::is_post_only_auto_cancel(&msg));
2806 }
2807
2808 #[rstest]
2809 fn test_is_post_only_auto_cancel_detects_reason() {
2810 let mut msg = sample_canceled_order_msg();
2811 msg.cancel_source_reason = Some("POST_ONLY would take liquidity".to_string());
2812
2813 assert!(OKXWsFeedHandler::is_post_only_auto_cancel(&msg));
2814 }
2815
2816 #[rstest]
2817 fn test_is_post_only_auto_cancel_false_without_markers() {
2818 let msg = sample_canceled_order_msg();
2819
2820 assert!(!OKXWsFeedHandler::is_post_only_auto_cancel(&msg));
2821 }
2822
2823 #[rstest]
2824 fn test_is_post_only_auto_cancel_false_for_order_type_only() {
2825 let mut msg = sample_canceled_order_msg();
2826 msg.ord_type = OKXOrderType::PostOnly;
2827
2828 assert!(!OKXWsFeedHandler::is_post_only_auto_cancel(&msg));
2829 }
2830
2831 #[tokio::test]
2832 async fn test_batch_cancel_orders_with_multiple_orders() {
2833 use nautilus_model::identifiers::{ClientOrderId, InstrumentId, VenueOrderId};
2834
2835 let client = OKXWebSocketClient::new(
2836 Some("wss://test.okx.com".to_string()),
2837 None,
2838 None,
2839 None,
2840 None,
2841 None,
2842 )
2843 .expect("Failed to create client");
2844
2845 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
2846 let client_order_id1 = ClientOrderId::new("order1");
2847 let client_order_id2 = ClientOrderId::new("order2");
2848 let venue_order_id1 = VenueOrderId::new("venue1");
2849 let venue_order_id2 = VenueOrderId::new("venue2");
2850
2851 let orders = vec![
2852 (instrument_id, Some(client_order_id1), Some(venue_order_id1)),
2853 (instrument_id, Some(client_order_id2), Some(venue_order_id2)),
2854 ];
2855
2856 let result = client.batch_cancel_orders(orders).await;
2858
2859 assert!(result.is_err());
2861 }
2862
2863 #[tokio::test]
2864 async fn test_batch_cancel_orders_with_only_client_order_id() {
2865 use nautilus_model::identifiers::{ClientOrderId, InstrumentId};
2866
2867 let client = OKXWebSocketClient::new(
2868 Some("wss://test.okx.com".to_string()),
2869 None,
2870 None,
2871 None,
2872 None,
2873 None,
2874 )
2875 .expect("Failed to create client");
2876
2877 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
2878 let client_order_id = ClientOrderId::new("order1");
2879
2880 let orders = vec![(instrument_id, Some(client_order_id), None)];
2881
2882 let result = client.batch_cancel_orders(orders).await;
2883
2884 assert!(result.is_err());
2886 }
2887
2888 #[tokio::test]
2889 async fn test_batch_cancel_orders_with_only_venue_order_id() {
2890 use nautilus_model::identifiers::{InstrumentId, VenueOrderId};
2891
2892 let client = OKXWebSocketClient::new(
2893 Some("wss://test.okx.com".to_string()),
2894 None,
2895 None,
2896 None,
2897 None,
2898 None,
2899 )
2900 .expect("Failed to create client");
2901
2902 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
2903 let venue_order_id = VenueOrderId::new("venue1");
2904
2905 let orders = vec![(instrument_id, None, Some(venue_order_id))];
2906
2907 let result = client.batch_cancel_orders(orders).await;
2908
2909 assert!(result.is_err());
2911 }
2912
2913 #[tokio::test]
2914 async fn test_batch_cancel_orders_with_both_ids() {
2915 use nautilus_model::identifiers::{ClientOrderId, InstrumentId, VenueOrderId};
2916
2917 let client = OKXWebSocketClient::new(
2918 Some("wss://test.okx.com".to_string()),
2919 None,
2920 None,
2921 None,
2922 None,
2923 None,
2924 )
2925 .expect("Failed to create client");
2926
2927 let instrument_id = InstrumentId::from("BTC-USDT-SWAP.OKX");
2928 let client_order_id = ClientOrderId::new("order1");
2929 let venue_order_id = VenueOrderId::new("venue1");
2930
2931 let orders = vec![(instrument_id, Some(client_order_id), Some(venue_order_id))];
2932
2933 let result = client.batch_cancel_orders(orders).await;
2934
2935 assert!(result.is_err());
2937 }
2938
2939 #[rstest]
2940 fn test_race_unsubscribe_failure_recovery() {
2941 let client = OKXWebSocketClient::new(
2947 Some("wss://test.okx.com".to_string()),
2948 None,
2949 None,
2950 None,
2951 None,
2952 None,
2953 )
2954 .expect("Failed to create client");
2955
2956 let topic = "trades:BTC-USDT-SWAP";
2957
2958 client.subscriptions_state.mark_subscribe(topic);
2960 client.subscriptions_state.confirm_subscribe(topic);
2961 assert_eq!(client.subscriptions_state.len(), 1);
2962
2963 client.subscriptions_state.mark_unsubscribe(topic);
2965 assert_eq!(client.subscriptions_state.len(), 0);
2966 assert_eq!(
2967 client.subscriptions_state.pending_unsubscribe_topics(),
2968 vec![topic]
2969 );
2970
2971 client.subscriptions_state.confirm_unsubscribe(topic); client.subscriptions_state.mark_subscribe(topic); client.subscriptions_state.confirm_subscribe(topic); assert_eq!(client.subscriptions_state.len(), 1);
2979 assert!(
2980 client
2981 .subscriptions_state
2982 .pending_unsubscribe_topics()
2983 .is_empty()
2984 );
2985 assert!(
2986 client
2987 .subscriptions_state
2988 .pending_subscribe_topics()
2989 .is_empty()
2990 );
2991
2992 let all = client.subscriptions_state.all_topics();
2994 assert_eq!(all.len(), 1);
2995 assert!(all.contains(&topic.to_string()));
2996 }
2997
2998 #[rstest]
2999 fn test_race_resubscribe_before_unsubscribe_ack() {
3000 let client = OKXWebSocketClient::new(
3004 Some("wss://test.okx.com".to_string()),
3005 None,
3006 None,
3007 None,
3008 None,
3009 None,
3010 )
3011 .expect("Failed to create client");
3012
3013 let topic = "books:BTC-USDT";
3014
3015 client.subscriptions_state.mark_subscribe(topic);
3017 client.subscriptions_state.confirm_subscribe(topic);
3018 assert_eq!(client.subscriptions_state.len(), 1);
3019
3020 client.subscriptions_state.mark_unsubscribe(topic);
3022 assert_eq!(client.subscriptions_state.len(), 0);
3023 assert_eq!(
3024 client.subscriptions_state.pending_unsubscribe_topics(),
3025 vec![topic]
3026 );
3027
3028 client.subscriptions_state.mark_subscribe(topic);
3030 assert_eq!(
3031 client.subscriptions_state.pending_subscribe_topics(),
3032 vec![topic]
3033 );
3034
3035 client.subscriptions_state.confirm_unsubscribe(topic);
3037 assert!(
3038 client
3039 .subscriptions_state
3040 .pending_unsubscribe_topics()
3041 .is_empty()
3042 );
3043 assert_eq!(
3044 client.subscriptions_state.pending_subscribe_topics(),
3045 vec![topic]
3046 );
3047
3048 client.subscriptions_state.confirm_subscribe(topic);
3050 assert_eq!(client.subscriptions_state.len(), 1);
3051 assert!(
3052 client
3053 .subscriptions_state
3054 .pending_subscribe_topics()
3055 .is_empty()
3056 );
3057
3058 let all = client.subscriptions_state.all_topics();
3060 assert_eq!(all.len(), 1);
3061 assert!(all.contains(&topic.to_string()));
3062 }
3063
3064 #[rstest]
3065 fn test_race_late_subscribe_confirmation_after_unsubscribe() {
3066 let client = OKXWebSocketClient::new(
3069 Some("wss://test.okx.com".to_string()),
3070 None,
3071 None,
3072 None,
3073 None,
3074 None,
3075 )
3076 .expect("Failed to create client");
3077
3078 let topic = "tickers:ETH-USDT";
3079
3080 client.subscriptions_state.mark_subscribe(topic);
3082 assert_eq!(
3083 client.subscriptions_state.pending_subscribe_topics(),
3084 vec![topic]
3085 );
3086
3087 client.subscriptions_state.mark_unsubscribe(topic);
3089 assert!(
3090 client
3091 .subscriptions_state
3092 .pending_subscribe_topics()
3093 .is_empty()
3094 ); assert_eq!(
3096 client.subscriptions_state.pending_unsubscribe_topics(),
3097 vec![topic]
3098 );
3099
3100 client.subscriptions_state.confirm_subscribe(topic);
3102 assert_eq!(client.subscriptions_state.len(), 0); assert_eq!(
3104 client.subscriptions_state.pending_unsubscribe_topics(),
3105 vec![topic]
3106 );
3107
3108 client.subscriptions_state.confirm_unsubscribe(topic);
3110
3111 assert!(client.subscriptions_state.is_empty());
3113 assert!(client.subscriptions_state.all_topics().is_empty());
3114 }
3115
3116 #[rstest]
3117 fn test_race_reconnection_with_pending_states() {
3118 let client = OKXWebSocketClient::new(
3120 Some("wss://test.okx.com".to_string()),
3121 Some("test_key".to_string()),
3122 Some("test_secret".to_string()),
3123 Some("test_passphrase".to_string()),
3124 Some(AccountId::new("OKX-TEST")),
3125 None,
3126 )
3127 .expect("Failed to create client");
3128
3129 let trade_btc = "trades:BTC-USDT-SWAP";
3132 client.subscriptions_state.mark_subscribe(trade_btc);
3133 client.subscriptions_state.confirm_subscribe(trade_btc);
3134
3135 let trade_eth = "trades:ETH-USDT-SWAP";
3137 client.subscriptions_state.mark_subscribe(trade_eth);
3138
3139 let book_btc = "books:BTC-USDT";
3141 client.subscriptions_state.mark_subscribe(book_btc);
3142 client.subscriptions_state.confirm_subscribe(book_btc);
3143 client.subscriptions_state.mark_unsubscribe(book_btc);
3144
3145 let topics_to_restore = client.subscriptions_state.all_topics();
3147
3148 assert_eq!(topics_to_restore.len(), 2);
3150 assert!(topics_to_restore.contains(&trade_btc.to_string()));
3151 assert!(topics_to_restore.contains(&trade_eth.to_string()));
3152 assert!(!topics_to_restore.contains(&book_btc.to_string())); }
3154
3155 #[rstest]
3156 fn test_race_duplicate_subscribe_messages_idempotent() {
3157 let client = OKXWebSocketClient::new(
3160 Some("wss://test.okx.com".to_string()),
3161 None,
3162 None,
3163 None,
3164 None,
3165 None,
3166 )
3167 .expect("Failed to create client");
3168
3169 let topic = "trades:BTC-USDT-SWAP";
3170
3171 client.subscriptions_state.mark_subscribe(topic);
3173 client.subscriptions_state.confirm_subscribe(topic);
3174 assert_eq!(client.subscriptions_state.len(), 1);
3175
3176 client.subscriptions_state.mark_subscribe(topic);
3178 assert!(
3179 client
3180 .subscriptions_state
3181 .pending_subscribe_topics()
3182 .is_empty()
3183 ); assert_eq!(client.subscriptions_state.len(), 1); client.subscriptions_state.confirm_subscribe(topic);
3188 assert_eq!(client.subscriptions_state.len(), 1);
3189
3190 let all = client.subscriptions_state.all_topics();
3192 assert_eq!(all.len(), 1);
3193 assert_eq!(all[0], topic);
3194 }
3195}