1use std::{
26 fmt::Debug,
27 num::NonZeroU32,
28 sync::{
29 Arc, LazyLock,
30 atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering},
31 },
32 time::{Duration, SystemTime},
33};
34
35use ahash::AHashSet;
36use arc_swap::ArcSwap;
37use dashmap::DashMap;
38use futures_util::Stream;
39use nautilus_common::live::runtime::get_runtime;
40use nautilus_core::{
41 consts::NAUTILUS_USER_AGENT,
42 env::{get_env_var, get_or_env_var},
43};
44use nautilus_model::{
45 data::BarType,
46 enums::{OrderSide, OrderType, PositionSide, TimeInForce, TriggerType},
47 identifiers::{AccountId, ClientOrderId, InstrumentId, StrategyId, TraderId, VenueOrderId},
48 instruments::{Instrument, InstrumentAny},
49 types::{Price, Quantity},
50};
51use nautilus_network::{
52 http::USER_AGENT,
53 mode::ConnectionMode,
54 ratelimiter::quota::Quota,
55 websocket::{
56 AUTHENTICATION_TIMEOUT_SECS, AuthTracker, PingHandler, SubscriptionState, TEXT_PING,
57 WebSocketClient, WebSocketConfig, channel_message_handler,
58 },
59};
60use serde_json::Value;
61use tokio_tungstenite::tungstenite::Error;
62use tokio_util::sync::CancellationToken;
63use ustr::Ustr;
64
65use super::{
66 enums::OKXWsChannel,
67 error::OKXWsError,
68 handler::{HandlerCommand, OKXWsFeedHandler},
69 messages::{
70 NautilusWsMessage, OKXAuthentication, OKXAuthenticationArg, OKXSubscriptionArg,
71 WsAmendOrderParamsBuilder, WsCancelOrderParamsBuilder, WsPostAlgoOrderParamsBuilder,
72 WsPostOrderParamsBuilder,
73 },
74 subscription::topic_from_subscription_arg,
75};
76use crate::common::{
77 consts::{
78 OKX_NAUTILUS_BROKER_ID, OKX_SUPPORTED_ORDER_TYPES, OKX_SUPPORTED_TIME_IN_FORCE,
79 OKX_WS_PUBLIC_URL, OKX_WS_TOPIC_DELIMITER,
80 },
81 credential::Credential,
82 enums::{
83 OKXInstrumentType, OKXOrderType, OKXPositionSide, OKXTargetCurrency, OKXTradeMode,
84 OKXTriggerType, OKXVipLevel, conditional_order_to_algo_type, is_conditional_order,
85 },
86 parse::{bar_spec_as_okx_channel, okx_instrument_type, okx_instrument_type_from_symbol},
87};
88
89pub static OKX_WS_CONNECTION_QUOTA: LazyLock<Quota> =
93 LazyLock::new(|| Quota::per_second(NonZeroU32::new(3).unwrap()));
94
95pub static OKX_WS_SUBSCRIPTION_QUOTA: LazyLock<Quota> =
100 LazyLock::new(|| Quota::per_hour(NonZeroU32::new(480).unwrap()));
101
102pub static OKX_WS_ORDER_QUOTA: LazyLock<Quota> =
107 LazyLock::new(|| Quota::per_second(NonZeroU32::new(250).unwrap()));
108
109pub const OKX_RATE_LIMIT_KEY_SUBSCRIPTION: &str = "subscription";
114
115pub const OKX_RATE_LIMIT_KEY_ORDER: &str = "order";
120
121pub const OKX_RATE_LIMIT_KEY_CANCEL: &str = "cancel";
127
128pub const OKX_RATE_LIMIT_KEY_AMEND: &str = "amend";
132
133#[derive(Clone)]
135#[cfg_attr(
136 feature = "python",
137 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.adapters")
138)]
139pub struct OKXWebSocketClient {
140 url: String,
141 account_id: AccountId,
142 vip_level: Arc<AtomicU8>,
143 credential: Option<Credential>,
144 heartbeat: Option<u64>,
145 auth_tracker: AuthTracker,
146 signal: Arc<AtomicBool>,
147 connection_mode: Arc<ArcSwap<AtomicU8>>,
148 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
149 out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>,
150 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
151 subscriptions_inst_type: Arc<DashMap<OKXWsChannel, AHashSet<OKXInstrumentType>>>,
152 subscriptions_inst_family: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
153 subscriptions_inst_id: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
154 subscriptions_bare: Arc<DashMap<OKXWsChannel, bool>>, subscriptions_state: SubscriptionState,
156 request_id_counter: Arc<AtomicU64>,
157 active_client_orders: Arc<DashMap<ClientOrderId, (TraderId, StrategyId, InstrumentId)>>,
158 client_id_aliases: Arc<DashMap<ClientOrderId, ClientOrderId>>,
159 instruments_cache: Arc<DashMap<Ustr, InstrumentAny>>,
160 cancellation_token: CancellationToken,
161}
162
163impl Default for OKXWebSocketClient {
164 fn default() -> Self {
165 Self::new(None, None, None, None, None, None).unwrap()
166 }
167}
168
169impl Debug for OKXWebSocketClient {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 f.debug_struct(stringify!(OKXWebSocketClient))
172 .field("url", &self.url)
173 .field(
174 "credential",
175 &self.credential.as_ref().map(|_| "<redacted>"),
176 )
177 .field("heartbeat", &self.heartbeat)
178 .finish_non_exhaustive()
179 }
180}
181
182impl OKXWebSocketClient {
183 pub fn new(
189 url: Option<String>,
190 api_key: Option<String>,
191 api_secret: Option<String>,
192 api_passphrase: Option<String>,
193 account_id: Option<AccountId>,
194 heartbeat: Option<u64>,
195 ) -> anyhow::Result<Self> {
196 let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
197 let account_id = account_id.unwrap_or(AccountId::from("OKX-master"));
198
199 let credential = match (api_key, api_secret, api_passphrase) {
200 (Some(key), Some(secret), Some(passphrase)) => {
201 Some(Credential::new(key, secret, passphrase))
202 }
203 (None, None, None) => None,
204 _ => anyhow::bail!(
205 "`api_key`, `api_secret`, `api_passphrase` credentials must be provided together"
206 ),
207 };
208
209 let signal = Arc::new(AtomicBool::new(false));
210 let subscriptions_inst_type = Arc::new(DashMap::new());
211 let subscriptions_inst_family = Arc::new(DashMap::new());
212 let subscriptions_inst_id = Arc::new(DashMap::new());
213 let subscriptions_bare = Arc::new(DashMap::new());
214 let subscriptions_state = SubscriptionState::new(OKX_WS_TOPIC_DELIMITER);
215
216 Ok(Self {
217 url,
218 account_id,
219 vip_level: Arc::new(AtomicU8::new(0)), credential,
221 heartbeat,
222 auth_tracker: AuthTracker::new(),
223 signal,
224 connection_mode: Arc::new(ArcSwap::from_pointee(AtomicU8::new(
225 ConnectionMode::Closed.as_u8(),
226 ))),
227 cmd_tx: {
228 let (tx, _) = tokio::sync::mpsc::unbounded_channel();
230 Arc::new(tokio::sync::RwLock::new(tx))
231 },
232 out_rx: None,
233 task_handle: None,
234 subscriptions_inst_type,
235 subscriptions_inst_family,
236 subscriptions_inst_id,
237 subscriptions_bare,
238 subscriptions_state,
239 request_id_counter: Arc::new(AtomicU64::new(1)),
240 active_client_orders: Arc::new(DashMap::new()),
241 client_id_aliases: Arc::new(DashMap::new()),
242 instruments_cache: Arc::new(DashMap::new()),
243 cancellation_token: CancellationToken::new(),
244 })
245 }
246
247 pub fn with_credentials(
254 url: Option<String>,
255 api_key: Option<String>,
256 api_secret: Option<String>,
257 api_passphrase: Option<String>,
258 account_id: Option<AccountId>,
259 heartbeat: Option<u64>,
260 ) -> anyhow::Result<Self> {
261 let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
262 let api_key = get_or_env_var(api_key, "OKX_API_KEY")?;
263 let api_secret = get_or_env_var(api_secret, "OKX_API_SECRET")?;
264 let api_passphrase = get_or_env_var(api_passphrase, "OKX_API_PASSPHRASE")?;
265
266 Self::new(
267 Some(url),
268 Some(api_key),
269 Some(api_secret),
270 Some(api_passphrase),
271 account_id,
272 heartbeat,
273 )
274 }
275
276 pub fn from_env() -> anyhow::Result<Self> {
283 let url = get_env_var("OKX_WS_URL")?;
284 let api_key = get_env_var("OKX_API_KEY")?;
285 let api_secret = get_env_var("OKX_API_SECRET")?;
286 let api_passphrase = get_env_var("OKX_API_PASSPHRASE")?;
287
288 Self::new(
289 Some(url),
290 Some(api_key),
291 Some(api_secret),
292 Some(api_passphrase),
293 None,
294 None,
295 )
296 }
297
298 pub fn cancel_all_requests(&self) {
300 self.cancellation_token.cancel();
301 }
302
303 pub fn cancellation_token(&self) -> &CancellationToken {
305 &self.cancellation_token
306 }
307
308 pub fn url(&self) -> &str {
310 self.url.as_str()
311 }
312
313 pub fn api_key(&self) -> Option<&str> {
315 self.credential.clone().map(|c| c.api_key.as_str())
316 }
317
318 #[must_use]
320 pub fn api_key_masked(&self) -> Option<String> {
321 self.credential.clone().map(|c| c.api_key_masked())
322 }
323
324 pub fn is_active(&self) -> bool {
326 let connection_mode_arc = self.connection_mode.load();
327 ConnectionMode::from_atomic(&connection_mode_arc).is_active()
328 && !self.signal.load(Ordering::Acquire)
329 }
330
331 pub fn is_closed(&self) -> bool {
333 let connection_mode_arc = self.connection_mode.load();
334 ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
335 || self.signal.load(Ordering::Acquire)
336 }
337
338 pub fn cache_instruments(&self, instruments: Vec<InstrumentAny>) {
342 for inst in &instruments {
343 self.instruments_cache
344 .insert(inst.symbol().inner(), inst.clone());
345 }
346
347 if !instruments.is_empty()
350 && let Ok(cmd_tx) = self.cmd_tx.try_read()
351 && let Err(e) = cmd_tx.send(HandlerCommand::InitializeInstruments(instruments))
352 {
353 log::debug!("Failed to send bulk instrument update to handler: {e}");
354 }
355 }
356
357 pub fn cache_instrument(&self, instrument: InstrumentAny) {
361 self.instruments_cache
362 .insert(instrument.symbol().inner(), instrument.clone());
363
364 if let Ok(cmd_tx) = self.cmd_tx.try_read()
367 && let Err(e) = cmd_tx.send(HandlerCommand::UpdateInstrument(instrument))
368 {
369 log::debug!("Failed to send instrument update to handler: {e}");
370 }
371 }
372
373 pub fn set_vip_level(&self, vip_level: OKXVipLevel) {
377 self.vip_level.store(vip_level as u8, Ordering::Relaxed);
378 }
379
380 pub fn vip_level(&self) -> OKXVipLevel {
382 let level = self.vip_level.load(Ordering::Relaxed);
383 OKXVipLevel::from(level)
384 }
385
386 pub async fn connect(&mut self) -> anyhow::Result<()> {
396 let (message_handler, raw_rx) = channel_message_handler();
397
398 let ping_handler: PingHandler = Arc::new(move |_payload: Vec<u8>| {
401 });
403
404 let config = WebSocketConfig {
405 url: self.url.clone(),
406 headers: vec![(USER_AGENT.to_string(), NAUTILUS_USER_AGENT.to_string())],
407 heartbeat: self.heartbeat,
408 heartbeat_msg: Some(TEXT_PING.to_string()),
409 message_handler: Some(message_handler),
410 ping_handler: Some(ping_handler),
411 reconnect_timeout_ms: Some(5_000),
412 reconnect_delay_initial_ms: None, reconnect_delay_max_ms: None, reconnect_backoff_factor: None, reconnect_jitter_ms: None, reconnect_max_attempts: None,
417 };
418
419 let keyed_quotas = vec![
421 (
422 OKX_RATE_LIMIT_KEY_SUBSCRIPTION.to_string(),
423 *OKX_WS_SUBSCRIPTION_QUOTA,
424 ),
425 (OKX_RATE_LIMIT_KEY_ORDER.to_string(), *OKX_WS_ORDER_QUOTA),
426 (OKX_RATE_LIMIT_KEY_CANCEL.to_string(), *OKX_WS_ORDER_QUOTA),
427 (OKX_RATE_LIMIT_KEY_AMEND.to_string(), *OKX_WS_ORDER_QUOTA),
428 ];
429
430 let client = WebSocketClient::connect(
431 config,
432 None, keyed_quotas,
434 Some(*OKX_WS_CONNECTION_QUOTA), )
436 .await?;
437
438 self.connection_mode.store(client.connection_mode_atomic());
440
441 let account_id = self.account_id;
442 let (msg_tx, rx) = tokio::sync::mpsc::unbounded_channel::<NautilusWsMessage>();
443
444 self.out_rx = Some(Arc::new(rx));
445
446 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
448 *self.cmd_tx.write().await = cmd_tx.clone();
449
450 if !self.instruments_cache.is_empty() {
452 let cached_instruments: Vec<InstrumentAny> = self
453 .instruments_cache
454 .iter()
455 .map(|entry| entry.value().clone())
456 .collect();
457 if let Err(e) = cmd_tx.send(HandlerCommand::InitializeInstruments(cached_instruments)) {
458 tracing::error!("Failed to replay instruments to handler: {e}");
459 }
460 }
461
462 let signal = self.signal.clone();
463 let active_client_orders = self.active_client_orders.clone();
464 let auth_tracker = self.auth_tracker.clone();
465 let subscriptions_state = self.subscriptions_state.clone();
466 let client_id_aliases = self.client_id_aliases.clone();
467
468 let stream_handle = get_runtime().spawn({
469 let auth_tracker = auth_tracker.clone();
470 let signal = signal.clone();
471 let credential = self.credential.clone();
472 let cmd_tx_for_reconnect = cmd_tx.clone();
473 let subscriptions_bare = self.subscriptions_bare.clone();
474 let subscriptions_inst_type = self.subscriptions_inst_type.clone();
475 let subscriptions_inst_family = self.subscriptions_inst_family.clone();
476 let subscriptions_inst_id = self.subscriptions_inst_id.clone();
477 let mut has_reconnected = false;
478
479 async move {
480 let mut handler = OKXWsFeedHandler::new(
481 account_id,
482 signal.clone(),
483 cmd_rx,
484 raw_rx,
485 msg_tx,
486 active_client_orders,
487 client_id_aliases,
488 auth_tracker.clone(),
489 subscriptions_state.clone(),
490 );
491
492 let resubscribe_all = || {
494 for entry in subscriptions_inst_id.iter() {
495 let (channel, inst_ids) = entry.pair();
496 for inst_id in inst_ids {
497 let arg = OKXSubscriptionArg {
498 channel: channel.clone(),
499 inst_type: None,
500 inst_family: None,
501 inst_id: Some(*inst_id),
502 };
503 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
504 tracing::error!(error = %e, "Failed to send resubscribe command");
505 }
506 }
507 }
508
509 for entry in subscriptions_bare.iter() {
510 let channel = entry.key();
511 let arg = OKXSubscriptionArg {
512 channel: channel.clone(),
513 inst_type: None,
514 inst_family: None,
515 inst_id: None,
516 };
517 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
518 tracing::error!(error = %e, "Failed to send resubscribe command");
519 }
520 }
521
522 for entry in subscriptions_inst_type.iter() {
523 let (channel, inst_types) = entry.pair();
524 for inst_type in inst_types {
525 let arg = OKXSubscriptionArg {
526 channel: channel.clone(),
527 inst_type: Some(*inst_type),
528 inst_family: None,
529 inst_id: None,
530 };
531 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
532 tracing::error!(error = %e, "Failed to send resubscribe command");
533 }
534 }
535 }
536
537 for entry in subscriptions_inst_family.iter() {
538 let (channel, inst_families) = entry.pair();
539 for inst_family in inst_families {
540 let arg = OKXSubscriptionArg {
541 channel: channel.clone(),
542 inst_type: None,
543 inst_family: Some(*inst_family),
544 inst_id: None,
545 };
546 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
547 tracing::error!(error = %e, "Failed to send resubscribe command");
548 }
549 }
550 }
551 };
552
553 loop {
555 match handler.next().await {
556 Some(NautilusWsMessage::Reconnected) => {
557 if signal.load(Ordering::Acquire) {
558 continue;
559 }
560
561 has_reconnected = true;
562
563 let confirmed_topics_vec: Vec<String> = {
565 let confirmed = subscriptions_state.confirmed();
566 let mut topics = Vec::new();
567 for entry in confirmed.iter() {
568 let channel = entry.key();
569 for symbol in entry.value() {
570 if symbol.as_str() == "#" {
571 topics.push(channel.to_string());
572 } else {
573 topics.push(format!("{channel}{OKX_WS_TOPIC_DELIMITER}{symbol}"));
574 }
575 }
576 }
577 topics
578 };
579
580 if !confirmed_topics_vec.is_empty() {
581 tracing::debug!(count = confirmed_topics_vec.len(), "Marking confirmed subscriptions as pending for replay");
582 for topic in confirmed_topics_vec {
583 subscriptions_state.mark_failure(&topic);
584 }
585 }
586
587 if let Some(cred) = &credential {
588 tracing::debug!("Re-authenticating after reconnection");
589 let timestamp = std::time::SystemTime::now()
590 .duration_since(std::time::SystemTime::UNIX_EPOCH)
591 .expect("System time should be after UNIX epoch")
592 .as_secs()
593 .to_string();
594 let signature = cred.sign(×tamp, "GET", "/users/self/verify", "");
595
596 let auth_message = super::messages::OKXAuthentication {
597 op: "login",
598 args: vec![super::messages::OKXAuthenticationArg {
599 api_key: cred.api_key.to_string(),
600 passphrase: cred.api_passphrase.clone(),
601 timestamp,
602 sign: signature,
603 }],
604 };
605
606 if let Ok(payload) = serde_json::to_string(&auth_message) {
607 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Authenticate { payload }) {
608 tracing::error!(error = %e, "Failed to send reconnection auth command");
609 }
610 } else {
611 tracing::error!("Failed to serialize reconnection auth message");
612 }
613 }
614
615 if credential.is_none() {
618 tracing::debug!("No authentication required, resubscribing immediately");
619 resubscribe_all();
620 }
621
622 continue;
627 }
628 Some(NautilusWsMessage::Authenticated) => {
629 if has_reconnected {
630 resubscribe_all();
631 }
632
633 continue;
638 }
639 Some(msg) => {
640 if handler.send(msg).is_err() {
641 tracing::error!(
642 "Failed to send message through channel: receiver dropped",
643 );
644 break;
645 }
646 }
647 None => {
648 if handler.is_stopped() {
649 tracing::debug!(
650 "Stop signal received, ending message processing",
651 );
652 break;
653 }
654 tracing::debug!("WebSocket stream closed");
655 break;
656 }
657 }
658 }
659
660 tracing::debug!("Handler task exiting");
661 }
662 });
663
664 self.task_handle = Some(Arc::new(stream_handle));
665
666 self.cmd_tx
667 .read()
668 .await
669 .send(HandlerCommand::SetClient(client))
670 .map_err(|e| {
671 OKXWsError::ClientError(format!("Failed to send WebSocket client to handler: {e}"))
672 })?;
673 tracing::debug!("Sent WebSocket client to handler");
674
675 if self.credential.is_some()
676 && let Err(e) = self.authenticate().await
677 {
678 anyhow::bail!("Authentication failed: {e}");
679 }
680
681 Ok(())
682 }
683
684 async fn authenticate(&self) -> Result<(), Error> {
686 let credential = self.credential.as_ref().ok_or_else(|| {
687 Error::Io(std::io::Error::other(
688 "API credentials not available to authenticate",
689 ))
690 })?;
691
692 let rx = self.auth_tracker.begin();
693
694 let timestamp = SystemTime::now()
695 .duration_since(SystemTime::UNIX_EPOCH)
696 .expect("System time should be after UNIX epoch")
697 .as_secs()
698 .to_string();
699 let signature = credential.sign(×tamp, "GET", "/users/self/verify", "");
700
701 let auth_message = OKXAuthentication {
702 op: "login",
703 args: vec![OKXAuthenticationArg {
704 api_key: credential.api_key.to_string(),
705 passphrase: credential.api_passphrase.clone(),
706 timestamp,
707 sign: signature,
708 }],
709 };
710
711 let payload = serde_json::to_string(&auth_message).map_err(|e| {
712 Error::Io(std::io::Error::other(format!(
713 "Failed to serialize auth message: {e}"
714 )))
715 })?;
716
717 self.cmd_tx
718 .read()
719 .await
720 .send(HandlerCommand::Authenticate { payload })
721 .map_err(|e| {
722 Error::Io(std::io::Error::other(format!(
723 "Failed to send authenticate command: {e}"
724 )))
725 })?;
726
727 match self
728 .auth_tracker
729 .wait_for_result::<OKXWsError>(Duration::from_secs(AUTHENTICATION_TIMEOUT_SECS), rx)
730 .await
731 {
732 Ok(()) => {
733 tracing::info!("WebSocket authenticated");
734 Ok(())
735 }
736 Err(e) => {
737 tracing::error!(error = %e, "WebSocket authentication failed");
738 Err(Error::Io(std::io::Error::other(e.to_string())))
739 }
740 }
741 }
742
743 pub fn stream(&mut self) -> impl Stream<Item = NautilusWsMessage> + 'static {
751 let rx = self
752 .out_rx
753 .take()
754 .expect("Data stream receiver already taken or not connected");
755 let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
756 async_stream::stream! {
757 while let Some(data) = rx.recv().await {
758 yield data;
759 }
760 }
761 }
762
763 pub async fn wait_until_active(&self, timeout_secs: f64) -> Result<(), OKXWsError> {
769 let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
770
771 tokio::time::timeout(timeout, async {
772 while !self.is_active() {
773 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
774 }
775 })
776 .await
777 .map_err(|_| {
778 OKXWsError::ClientError(format!(
779 "WebSocket connection timeout after {timeout_secs} seconds"
780 ))
781 })?;
782
783 Ok(())
784 }
785
786 pub async fn close(&mut self) -> Result<(), Error> {
793 log::debug!("Starting close process");
794
795 self.signal.store(true, Ordering::Release);
796
797 if let Err(e) = self.cmd_tx.read().await.send(HandlerCommand::Disconnect) {
798 log::warn!("Failed to send disconnect command to handler: {e}");
799 } else {
800 log::debug!("Sent disconnect command to handler");
801 }
802
803 {
805 if false {
806 log::debug!("No active connection to disconnect");
807 }
808 }
809
810 if let Some(stream_handle) = self.task_handle.take() {
812 match Arc::try_unwrap(stream_handle) {
813 Ok(handle) => {
814 log::debug!("Waiting for stream handle to complete");
815 match tokio::time::timeout(Duration::from_secs(2), handle).await {
816 Ok(Ok(())) => log::debug!("Stream handle completed successfully"),
817 Ok(Err(e)) => log::error!("Stream handle encountered an error: {e:?}"),
818 Err(_) => {
819 log::warn!(
820 "Timeout waiting for stream handle, task may still be running"
821 );
822 }
824 }
825 }
826 Err(arc_handle) => {
827 log::debug!(
828 "Cannot take ownership of stream handle - other references exist, aborting task"
829 );
830 arc_handle.abort();
831 }
832 }
833 } else {
834 log::debug!("No stream handle to await");
835 }
836
837 log::debug!("Close process completed");
838
839 Ok(())
840 }
841
842 pub fn get_subscriptions(&self, instrument_id: InstrumentId) -> Vec<OKXWsChannel> {
844 let symbol = instrument_id.symbol.inner();
845 let mut channels = Vec::new();
846
847 for entry in self.subscriptions_inst_id.iter() {
848 let (channel, instruments) = entry.pair();
849 if instruments.contains(&symbol) {
850 channels.push(channel.clone());
851 }
852 }
853
854 channels
855 }
856
857 fn generate_unique_request_id(&self) -> String {
858 self.request_id_counter
859 .fetch_add(1, Ordering::SeqCst)
860 .to_string()
861 }
862
863 #[allow(
864 clippy::result_large_err,
865 reason = "OKXWsError contains large tungstenite::Error variant"
866 )]
867 async fn subscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
868 for arg in &args {
869 let topic = topic_from_subscription_arg(arg);
870 self.subscriptions_state.mark_subscribe(&topic);
871
872 if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
874 self.subscriptions_bare.insert(arg.channel.clone(), true);
876 } else {
877 if let Some(inst_type) = &arg.inst_type {
879 self.subscriptions_inst_type
880 .entry(arg.channel.clone())
881 .or_default()
882 .insert(*inst_type);
883 }
884
885 if let Some(inst_family) = &arg.inst_family {
887 self.subscriptions_inst_family
888 .entry(arg.channel.clone())
889 .or_default()
890 .insert(*inst_family);
891 }
892
893 if let Some(inst_id) = &arg.inst_id {
895 self.subscriptions_inst_id
896 .entry(arg.channel.clone())
897 .or_default()
898 .insert(*inst_id);
899 }
900 }
901 }
902
903 self.cmd_tx
904 .read()
905 .await
906 .send(HandlerCommand::Subscribe { args })
907 .map_err(|e| OKXWsError::ClientError(format!("Failed to send subscribe command: {e}")))
908 }
909
910 #[allow(clippy::collapsible_if)]
911 async fn unsubscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
912 for arg in &args {
913 let topic = topic_from_subscription_arg(arg);
914 self.subscriptions_state.mark_unsubscribe(&topic);
915
916 if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
918 self.subscriptions_bare.remove(&arg.channel);
920 } else {
921 if let Some(inst_type) = &arg.inst_type {
923 if let Some(mut entry) = self.subscriptions_inst_type.get_mut(&arg.channel) {
924 entry.remove(inst_type);
925 if entry.is_empty() {
926 drop(entry);
927 self.subscriptions_inst_type.remove(&arg.channel);
928 }
929 }
930 }
931
932 if let Some(inst_family) = &arg.inst_family {
934 if let Some(mut entry) = self.subscriptions_inst_family.get_mut(&arg.channel) {
935 entry.remove(inst_family);
936 if entry.is_empty() {
937 drop(entry);
938 self.subscriptions_inst_family.remove(&arg.channel);
939 }
940 }
941 }
942
943 if let Some(inst_id) = &arg.inst_id {
945 if let Some(mut entry) = self.subscriptions_inst_id.get_mut(&arg.channel) {
946 entry.remove(inst_id);
947 if entry.is_empty() {
948 drop(entry);
949 self.subscriptions_inst_id.remove(&arg.channel);
950 }
951 }
952 }
953 }
954 }
955
956 self.cmd_tx
957 .read()
958 .await
959 .send(HandlerCommand::Unsubscribe { args })
960 .map_err(|e| {
961 OKXWsError::ClientError(format!("Failed to send unsubscribe command: {e}"))
962 })
963 }
964
965 pub async fn unsubscribe_all(&self) -> Result<(), OKXWsError> {
974 let mut all_args = Vec::new();
975
976 for entry in self.subscriptions_inst_type.iter() {
977 let (channel, inst_types) = entry.pair();
978 for inst_type in inst_types {
979 all_args.push(OKXSubscriptionArg {
980 channel: channel.clone(),
981 inst_type: Some(*inst_type),
982 inst_family: None,
983 inst_id: None,
984 });
985 }
986 }
987
988 for entry in self.subscriptions_inst_family.iter() {
989 let (channel, inst_families) = entry.pair();
990 for inst_family in inst_families {
991 all_args.push(OKXSubscriptionArg {
992 channel: channel.clone(),
993 inst_type: None,
994 inst_family: Some(*inst_family),
995 inst_id: None,
996 });
997 }
998 }
999
1000 for entry in self.subscriptions_inst_id.iter() {
1001 let (channel, inst_ids) = entry.pair();
1002 for inst_id in inst_ids {
1003 all_args.push(OKXSubscriptionArg {
1004 channel: channel.clone(),
1005 inst_type: None,
1006 inst_family: None,
1007 inst_id: Some(*inst_id),
1008 });
1009 }
1010 }
1011
1012 for entry in self.subscriptions_bare.iter() {
1013 let channel = entry.key();
1014 all_args.push(OKXSubscriptionArg {
1015 channel: channel.clone(),
1016 inst_type: None,
1017 inst_family: None,
1018 inst_id: None,
1019 });
1020 }
1021
1022 if all_args.is_empty() {
1023 tracing::debug!("No active subscriptions to unsubscribe from");
1024 return Ok(());
1025 }
1026
1027 tracing::debug!("Batched unsubscribe from {} channels", all_args.len());
1028
1029 const BATCH_SIZE: usize = 256;
1030
1031 for chunk in all_args.chunks(BATCH_SIZE) {
1032 self.unsubscribe(chunk.to_vec()).await?;
1033 }
1034
1035 Ok(())
1036 }
1037
1038 pub async fn subscribe_instruments(
1050 &self,
1051 instrument_type: OKXInstrumentType,
1052 ) -> Result<(), OKXWsError> {
1053 let arg = OKXSubscriptionArg {
1054 channel: OKXWsChannel::Instruments,
1055 inst_type: Some(instrument_type),
1056 inst_family: None,
1057 inst_id: None,
1058 };
1059 self.subscribe(vec![arg]).await
1060 }
1061
1062 pub async fn subscribe_instrument(
1075 &self,
1076 instrument_id: InstrumentId,
1077 ) -> Result<(), OKXWsError> {
1078 let inst_type = okx_instrument_type_from_symbol(instrument_id.symbol.as_str());
1079
1080 let already_subscribed = self
1081 .subscriptions_inst_type
1082 .get(&OKXWsChannel::Instruments)
1083 .is_some_and(|types| types.contains(&inst_type));
1084
1085 if already_subscribed {
1086 tracing::debug!(
1087 "Already subscribed to instrument type {inst_type:?} for {instrument_id}"
1088 );
1089 return Ok(());
1090 }
1091
1092 tracing::debug!("Subscribing to instrument type {inst_type:?} for {instrument_id}");
1093 self.subscribe_instruments(inst_type).await
1094 }
1095
1096 pub async fn subscribe_book(&self, instrument_id: InstrumentId) -> anyhow::Result<()> {
1105 self.subscribe_book_with_depth(instrument_id, 0).await
1106 }
1107
1108 pub(crate) async fn subscribe_books_channel(
1110 &self,
1111 instrument_id: InstrumentId,
1112 ) -> Result<(), OKXWsError> {
1113 let arg = OKXSubscriptionArg {
1114 channel: OKXWsChannel::Books,
1115 inst_type: None,
1116 inst_family: None,
1117 inst_id: Some(instrument_id.symbol.inner()),
1118 };
1119 self.subscribe(vec![arg]).await
1120 }
1121
1122 pub async fn subscribe_book_depth5(
1134 &self,
1135 instrument_id: InstrumentId,
1136 ) -> Result<(), OKXWsError> {
1137 let arg = OKXSubscriptionArg {
1138 channel: OKXWsChannel::Books5,
1139 inst_type: None,
1140 inst_family: None,
1141 inst_id: Some(instrument_id.symbol.inner()),
1142 };
1143 self.subscribe(vec![arg]).await
1144 }
1145
1146 pub async fn subscribe_book50_l2_tbt(
1158 &self,
1159 instrument_id: InstrumentId,
1160 ) -> Result<(), OKXWsError> {
1161 let arg = OKXSubscriptionArg {
1162 channel: OKXWsChannel::Books50Tbt,
1163 inst_type: None,
1164 inst_family: None,
1165 inst_id: Some(instrument_id.symbol.inner()),
1166 };
1167 self.subscribe(vec![arg]).await
1168 }
1169
1170 pub async fn subscribe_book_l2_tbt(
1182 &self,
1183 instrument_id: InstrumentId,
1184 ) -> Result<(), OKXWsError> {
1185 let arg = OKXSubscriptionArg {
1186 channel: OKXWsChannel::BooksTbt,
1187 inst_type: None,
1188 inst_family: None,
1189 inst_id: Some(instrument_id.symbol.inner()),
1190 };
1191 self.subscribe(vec![arg]).await
1192 }
1193
1194 pub async fn subscribe_book_with_depth(
1208 &self,
1209 instrument_id: InstrumentId,
1210 depth: u16,
1211 ) -> anyhow::Result<()> {
1212 let vip = self.vip_level();
1213
1214 match depth {
1215 50 => {
1216 if vip < OKXVipLevel::Vip4 {
1217 anyhow::bail!(
1218 "VIP level {vip} insufficient for 50 depth subscription (requires VIP4)"
1219 );
1220 }
1221 self.subscribe_book50_l2_tbt(instrument_id)
1222 .await
1223 .map_err(|e| anyhow::anyhow!(e))
1224 }
1225 0 | 400 => {
1226 if vip >= OKXVipLevel::Vip5 {
1227 self.subscribe_book_l2_tbt(instrument_id)
1228 .await
1229 .map_err(|e| anyhow::anyhow!(e))
1230 } else {
1231 self.subscribe_books_channel(instrument_id)
1232 .await
1233 .map_err(|e| anyhow::anyhow!(e))
1234 }
1235 }
1236 _ => anyhow::bail!("Invalid depth {depth}, must be 0, 50, or 400"),
1237 }
1238 }
1239
1240 pub async fn subscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1253 let arg = OKXSubscriptionArg {
1254 channel: OKXWsChannel::BboTbt,
1255 inst_type: None,
1256 inst_family: None,
1257 inst_id: Some(instrument_id.symbol.inner()),
1258 };
1259 self.subscribe(vec![arg]).await
1260 }
1261
1262 pub async fn subscribe_trades(
1276 &self,
1277 instrument_id: InstrumentId,
1278 aggregated: bool,
1279 ) -> Result<(), OKXWsError> {
1280 let channel = if aggregated {
1281 OKXWsChannel::TradesAll
1282 } else {
1283 OKXWsChannel::Trades
1284 };
1285
1286 let arg = OKXSubscriptionArg {
1287 channel,
1288 inst_type: None,
1289 inst_family: None,
1290 inst_id: Some(instrument_id.symbol.inner()),
1291 };
1292 self.subscribe(vec![arg]).await
1293 }
1294
1295 pub async fn subscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1307 let arg = OKXSubscriptionArg {
1308 channel: OKXWsChannel::Tickers,
1309 inst_type: None,
1310 inst_family: None,
1311 inst_id: Some(instrument_id.symbol.inner()),
1312 };
1313 self.subscribe(vec![arg]).await
1314 }
1315
1316 pub async fn subscribe_mark_prices(
1328 &self,
1329 instrument_id: InstrumentId,
1330 ) -> Result<(), OKXWsError> {
1331 let arg = OKXSubscriptionArg {
1332 channel: OKXWsChannel::MarkPrice,
1333 inst_type: None,
1334 inst_family: None,
1335 inst_id: Some(instrument_id.symbol.inner()),
1336 };
1337 self.subscribe(vec![arg]).await
1338 }
1339
1340 pub async fn subscribe_index_prices(
1352 &self,
1353 instrument_id: InstrumentId,
1354 ) -> Result<(), OKXWsError> {
1355 let arg = OKXSubscriptionArg {
1356 channel: OKXWsChannel::IndexTickers,
1357 inst_type: None,
1358 inst_family: None,
1359 inst_id: Some(instrument_id.symbol.inner()),
1360 };
1361 self.subscribe(vec![arg]).await
1362 }
1363
1364 pub async fn subscribe_funding_rates(
1376 &self,
1377 instrument_id: InstrumentId,
1378 ) -> Result<(), OKXWsError> {
1379 let arg = OKXSubscriptionArg {
1380 channel: OKXWsChannel::FundingRate,
1381 inst_type: None,
1382 inst_family: None,
1383 inst_id: Some(instrument_id.symbol.inner()),
1384 };
1385 self.subscribe(vec![arg]).await
1386 }
1387
1388 pub async fn subscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1400 let channel = bar_spec_as_okx_channel(bar_type.spec())
1402 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1403
1404 let arg = OKXSubscriptionArg {
1405 channel,
1406 inst_type: None,
1407 inst_family: None,
1408 inst_id: Some(bar_type.instrument_id().symbol.inner()),
1409 };
1410 self.subscribe(vec![arg]).await
1411 }
1412
1413 pub async fn unsubscribe_instruments(
1419 &self,
1420 instrument_type: OKXInstrumentType,
1421 ) -> Result<(), OKXWsError> {
1422 let arg = OKXSubscriptionArg {
1423 channel: OKXWsChannel::Instruments,
1424 inst_type: Some(instrument_type),
1425 inst_family: None,
1426 inst_id: None,
1427 };
1428 self.unsubscribe(vec![arg]).await
1429 }
1430
1431 pub async fn unsubscribe_instrument(
1437 &self,
1438 instrument_id: InstrumentId,
1439 ) -> Result<(), OKXWsError> {
1440 let arg = OKXSubscriptionArg {
1441 channel: OKXWsChannel::Instruments,
1442 inst_type: None,
1443 inst_family: None,
1444 inst_id: Some(instrument_id.symbol.inner()),
1445 };
1446 self.unsubscribe(vec![arg]).await
1447 }
1448
1449 pub async fn unsubscribe_book(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1455 let arg = OKXSubscriptionArg {
1456 channel: OKXWsChannel::Books,
1457 inst_type: None,
1458 inst_family: None,
1459 inst_id: Some(instrument_id.symbol.inner()),
1460 };
1461 self.unsubscribe(vec![arg]).await
1462 }
1463
1464 pub async fn unsubscribe_book_depth5(
1470 &self,
1471 instrument_id: InstrumentId,
1472 ) -> Result<(), OKXWsError> {
1473 let arg = OKXSubscriptionArg {
1474 channel: OKXWsChannel::Books5,
1475 inst_type: None,
1476 inst_family: None,
1477 inst_id: Some(instrument_id.symbol.inner()),
1478 };
1479 self.unsubscribe(vec![arg]).await
1480 }
1481
1482 pub async fn unsubscribe_book50_l2_tbt(
1488 &self,
1489 instrument_id: InstrumentId,
1490 ) -> Result<(), OKXWsError> {
1491 let arg = OKXSubscriptionArg {
1492 channel: OKXWsChannel::Books50Tbt,
1493 inst_type: None,
1494 inst_family: None,
1495 inst_id: Some(instrument_id.symbol.inner()),
1496 };
1497 self.unsubscribe(vec![arg]).await
1498 }
1499
1500 pub async fn unsubscribe_book_l2_tbt(
1506 &self,
1507 instrument_id: InstrumentId,
1508 ) -> Result<(), OKXWsError> {
1509 let arg = OKXSubscriptionArg {
1510 channel: OKXWsChannel::BooksTbt,
1511 inst_type: None,
1512 inst_family: None,
1513 inst_id: Some(instrument_id.symbol.inner()),
1514 };
1515 self.unsubscribe(vec![arg]).await
1516 }
1517
1518 pub async fn unsubscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1524 let arg = OKXSubscriptionArg {
1525 channel: OKXWsChannel::BboTbt,
1526 inst_type: None,
1527 inst_family: None,
1528 inst_id: Some(instrument_id.symbol.inner()),
1529 };
1530 self.unsubscribe(vec![arg]).await
1531 }
1532
1533 pub async fn unsubscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1539 let arg = OKXSubscriptionArg {
1540 channel: OKXWsChannel::Tickers,
1541 inst_type: None,
1542 inst_family: None,
1543 inst_id: Some(instrument_id.symbol.inner()),
1544 };
1545 self.unsubscribe(vec![arg]).await
1546 }
1547
1548 pub async fn unsubscribe_mark_prices(
1554 &self,
1555 instrument_id: InstrumentId,
1556 ) -> Result<(), OKXWsError> {
1557 let arg = OKXSubscriptionArg {
1558 channel: OKXWsChannel::MarkPrice,
1559 inst_type: None,
1560 inst_family: None,
1561 inst_id: Some(instrument_id.symbol.inner()),
1562 };
1563 self.unsubscribe(vec![arg]).await
1564 }
1565
1566 pub async fn unsubscribe_index_prices(
1572 &self,
1573 instrument_id: InstrumentId,
1574 ) -> Result<(), OKXWsError> {
1575 let arg = OKXSubscriptionArg {
1576 channel: OKXWsChannel::IndexTickers,
1577 inst_type: None,
1578 inst_family: None,
1579 inst_id: Some(instrument_id.symbol.inner()),
1580 };
1581 self.unsubscribe(vec![arg]).await
1582 }
1583
1584 pub async fn unsubscribe_funding_rates(
1590 &self,
1591 instrument_id: InstrumentId,
1592 ) -> Result<(), OKXWsError> {
1593 let arg = OKXSubscriptionArg {
1594 channel: OKXWsChannel::FundingRate,
1595 inst_type: None,
1596 inst_family: None,
1597 inst_id: Some(instrument_id.symbol.inner()),
1598 };
1599 self.unsubscribe(vec![arg]).await
1600 }
1601
1602 pub async fn unsubscribe_trades(
1608 &self,
1609 instrument_id: InstrumentId,
1610 aggregated: bool,
1611 ) -> Result<(), OKXWsError> {
1612 let channel = if aggregated {
1613 OKXWsChannel::TradesAll
1614 } else {
1615 OKXWsChannel::Trades
1616 };
1617
1618 let arg = OKXSubscriptionArg {
1619 channel,
1620 inst_type: None,
1621 inst_family: None,
1622 inst_id: Some(instrument_id.symbol.inner()),
1623 };
1624 self.unsubscribe(vec![arg]).await
1625 }
1626
1627 pub async fn unsubscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1633 let channel = bar_spec_as_okx_channel(bar_type.spec())
1635 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1636
1637 let arg = OKXSubscriptionArg {
1638 channel,
1639 inst_type: None,
1640 inst_family: None,
1641 inst_id: Some(bar_type.instrument_id().symbol.inner()),
1642 };
1643 self.unsubscribe(vec![arg]).await
1644 }
1645
1646 pub async fn subscribe_orders(
1652 &self,
1653 instrument_type: OKXInstrumentType,
1654 ) -> Result<(), OKXWsError> {
1655 let arg = OKXSubscriptionArg {
1656 channel: OKXWsChannel::Orders,
1657 inst_type: Some(instrument_type),
1658 inst_family: None,
1659 inst_id: None,
1660 };
1661 self.subscribe(vec![arg]).await
1662 }
1663
1664 pub async fn unsubscribe_orders(
1670 &self,
1671 instrument_type: OKXInstrumentType,
1672 ) -> Result<(), OKXWsError> {
1673 let arg = OKXSubscriptionArg {
1674 channel: OKXWsChannel::Orders,
1675 inst_type: Some(instrument_type),
1676 inst_family: None,
1677 inst_id: None,
1678 };
1679 self.unsubscribe(vec![arg]).await
1680 }
1681
1682 pub async fn subscribe_orders_algo(
1688 &self,
1689 instrument_type: OKXInstrumentType,
1690 ) -> Result<(), OKXWsError> {
1691 let arg = OKXSubscriptionArg {
1692 channel: OKXWsChannel::OrdersAlgo,
1693 inst_type: Some(instrument_type),
1694 inst_family: None,
1695 inst_id: None,
1696 };
1697 self.subscribe(vec![arg]).await
1698 }
1699
1700 pub async fn unsubscribe_orders_algo(
1706 &self,
1707 instrument_type: OKXInstrumentType,
1708 ) -> Result<(), OKXWsError> {
1709 let arg = OKXSubscriptionArg {
1710 channel: OKXWsChannel::OrdersAlgo,
1711 inst_type: Some(instrument_type),
1712 inst_family: None,
1713 inst_id: None,
1714 };
1715 self.unsubscribe(vec![arg]).await
1716 }
1717
1718 pub async fn subscribe_fills(
1724 &self,
1725 instrument_type: OKXInstrumentType,
1726 ) -> Result<(), OKXWsError> {
1727 let arg = OKXSubscriptionArg {
1728 channel: OKXWsChannel::Fills,
1729 inst_type: Some(instrument_type),
1730 inst_family: None,
1731 inst_id: None,
1732 };
1733 self.subscribe(vec![arg]).await
1734 }
1735
1736 pub async fn unsubscribe_fills(
1742 &self,
1743 instrument_type: OKXInstrumentType,
1744 ) -> Result<(), OKXWsError> {
1745 let arg = OKXSubscriptionArg {
1746 channel: OKXWsChannel::Fills,
1747 inst_type: Some(instrument_type),
1748 inst_family: None,
1749 inst_id: None,
1750 };
1751 self.unsubscribe(vec![arg]).await
1752 }
1753
1754 pub async fn subscribe_account(&self) -> Result<(), OKXWsError> {
1760 let arg = OKXSubscriptionArg {
1761 channel: OKXWsChannel::Account,
1762 inst_type: None,
1763 inst_family: None,
1764 inst_id: None,
1765 };
1766 self.subscribe(vec![arg]).await
1767 }
1768
1769 pub async fn unsubscribe_account(&self) -> Result<(), OKXWsError> {
1775 let arg = OKXSubscriptionArg {
1776 channel: OKXWsChannel::Account,
1777 inst_type: None,
1778 inst_family: None,
1779 inst_id: None,
1780 };
1781 self.unsubscribe(vec![arg]).await
1782 }
1783
1784 pub async fn subscribe_positions(
1794 &self,
1795 inst_type: OKXInstrumentType,
1796 ) -> Result<(), OKXWsError> {
1797 let arg = OKXSubscriptionArg {
1798 channel: OKXWsChannel::Positions,
1799 inst_type: Some(inst_type),
1800 inst_family: None,
1801 inst_id: None,
1802 };
1803 self.subscribe(vec![arg]).await
1804 }
1805
1806 pub async fn unsubscribe_positions(
1812 &self,
1813 inst_type: OKXInstrumentType,
1814 ) -> Result<(), OKXWsError> {
1815 let arg = OKXSubscriptionArg {
1816 channel: OKXWsChannel::Positions,
1817 inst_type: Some(inst_type),
1818 inst_family: None,
1819 inst_id: None,
1820 };
1821 self.unsubscribe(vec![arg]).await
1822 }
1823
1824 async fn ws_batch_place_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1830 let request_id = self.generate_unique_request_id();
1831 let cmd = HandlerCommand::BatchPlaceOrders { args, request_id };
1832
1833 self.send_cmd(cmd).await
1834 }
1835
1836 async fn ws_batch_cancel_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1842 let request_id = self.generate_unique_request_id();
1843 let cmd = HandlerCommand::BatchCancelOrders { args, request_id };
1844
1845 self.send_cmd(cmd).await
1846 }
1847
1848 async fn ws_batch_amend_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1854 let request_id = self.generate_unique_request_id();
1855 let cmd = HandlerCommand::BatchAmendOrders { args, request_id };
1856
1857 self.send_cmd(cmd).await
1858 }
1859
1860 #[allow(clippy::too_many_arguments)]
1872 pub async fn submit_order(
1873 &self,
1874 trader_id: TraderId,
1875 strategy_id: StrategyId,
1876 instrument_id: InstrumentId,
1877 td_mode: OKXTradeMode,
1878 client_order_id: ClientOrderId,
1879 order_side: OrderSide,
1880 order_type: OrderType,
1881 quantity: Quantity,
1882 time_in_force: Option<TimeInForce>,
1883 price: Option<Price>,
1884 trigger_price: Option<Price>,
1885 post_only: Option<bool>,
1886 reduce_only: Option<bool>,
1887 quote_quantity: Option<bool>,
1888 position_side: Option<PositionSide>,
1889 ) -> Result<(), OKXWsError> {
1890 if !OKX_SUPPORTED_ORDER_TYPES.contains(&order_type) {
1891 return Err(OKXWsError::ClientError(format!(
1892 "Unsupported order type: {order_type:?}",
1893 )));
1894 }
1895
1896 if let Some(tif) = time_in_force
1897 && !OKX_SUPPORTED_TIME_IN_FORCE.contains(&tif)
1898 {
1899 return Err(OKXWsError::ClientError(format!(
1900 "Unsupported time in force: {tif:?}",
1901 )));
1902 }
1903
1904 let mut builder = WsPostOrderParamsBuilder::default();
1905
1906 builder.inst_id(instrument_id.symbol.as_str());
1907 builder.td_mode(td_mode);
1908 builder.cl_ord_id(client_order_id.as_str());
1909
1910 let instrument = self
1911 .instruments_cache
1912 .get(&instrument_id.symbol.inner())
1913 .ok_or_else(|| {
1914 OKXWsError::ClientError(format!("Unknown instrument {instrument_id}"))
1915 })?;
1916
1917 let instrument_type =
1918 okx_instrument_type(&instrument).map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1919 let quote_currency = instrument.quote_currency();
1920
1921 match instrument_type {
1922 OKXInstrumentType::Spot => {
1923 builder.ccy(quote_currency.to_string());
1925 }
1926 OKXInstrumentType::Margin => {
1927 builder.ccy(quote_currency.to_string());
1928
1929 if let Some(ro) = reduce_only
1930 && ro
1931 {
1932 builder.reduce_only(ro);
1933 }
1934 }
1935 OKXInstrumentType::Swap | OKXInstrumentType::Futures => {
1936 builder.ccy(quote_currency.to_string());
1938
1939 if position_side.is_none() {
1942 builder.pos_side(OKXPositionSide::Net);
1943 }
1944 }
1945 _ => {
1946 builder.ccy(quote_currency.to_string());
1947
1948 if position_side.is_none() {
1950 builder.pos_side(OKXPositionSide::Net);
1951 }
1952
1953 if let Some(ro) = reduce_only
1954 && ro
1955 {
1956 builder.reduce_only(ro);
1957 }
1958 }
1959 };
1960
1961 if instrument_type == OKXInstrumentType::Spot
1968 && order_type == OrderType::Market
1969 && td_mode == OKXTradeMode::Cash
1970 {
1971 match quote_quantity {
1972 Some(true) => {
1973 builder.tgt_ccy(OKXTargetCurrency::QuoteCcy);
1975 }
1976 Some(false) => {
1977 if order_side == OrderSide::Buy {
1978 builder.tgt_ccy(OKXTargetCurrency::BaseCcy);
1980 }
1981 }
1983 None => {
1984 }
1986 }
1987 }
1988
1989 builder.side(order_side);
1990
1991 if let Some(pos_side) = position_side {
1992 builder.pos_side(pos_side);
1993 };
1994
1995 let (okx_ord_type, price) = if post_only.unwrap_or(false) {
1998 (OKXOrderType::PostOnly, price)
1999 } else if let Some(tif) = time_in_force {
2000 match (order_type, tif) {
2001 (OrderType::Market, TimeInForce::Fok) => {
2002 return Err(OKXWsError::ClientError(
2003 "Market orders with FOK time-in-force are not supported by OKX. Use Limit order with FOK instead.".to_string()
2004 ));
2005 }
2006 (OrderType::Market, TimeInForce::Ioc) => (OKXOrderType::OptimalLimitIoc, price),
2007 (OrderType::Limit, TimeInForce::Fok) => (OKXOrderType::Fok, price),
2008 (OrderType::Limit, TimeInForce::Ioc) => (OKXOrderType::Ioc, price),
2009 _ => (OKXOrderType::from(order_type), price),
2010 }
2011 } else {
2012 (OKXOrderType::from(order_type), price)
2013 };
2014
2015 log::debug!(
2016 "Order type mapping: order_type={order_type:?}, time_in_force={time_in_force:?}, post_only={post_only:?} -> okx_ord_type={okx_ord_type:?}"
2017 );
2018
2019 builder.ord_type(okx_ord_type);
2020 builder.sz(quantity.to_string());
2021
2022 if let Some(tp) = trigger_price {
2023 builder.px(tp.to_string());
2024 } else if let Some(p) = price {
2025 builder.px(p.to_string());
2026 }
2027
2028 builder.tag(OKX_NAUTILUS_BROKER_ID);
2029
2030 let params = builder
2031 .build()
2032 .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
2033
2034 self.active_client_orders
2035 .insert(client_order_id, (trader_id, strategy_id, instrument_id));
2036
2037 let cmd = HandlerCommand::PlaceOrder {
2038 params,
2039 client_order_id,
2040 trader_id,
2041 strategy_id,
2042 instrument_id,
2043 };
2044
2045 self.send_cmd(cmd).await
2046 }
2047
2048 #[allow(clippy::too_many_arguments)]
2064 pub async fn modify_order(
2065 &self,
2066 trader_id: TraderId,
2067 strategy_id: StrategyId,
2068 instrument_id: InstrumentId,
2069 client_order_id: Option<ClientOrderId>,
2070 price: Option<Price>,
2071 quantity: Option<Quantity>,
2072 venue_order_id: Option<VenueOrderId>,
2073 ) -> Result<(), OKXWsError> {
2074 let mut builder = WsAmendOrderParamsBuilder::default();
2075
2076 builder.inst_id(instrument_id.symbol.as_str());
2077
2078 if let Some(venue_order_id) = venue_order_id {
2079 builder.ord_id(venue_order_id.as_str());
2080 }
2081
2082 if let Some(client_order_id) = client_order_id {
2083 builder.cl_ord_id(client_order_id.as_str());
2084 }
2085
2086 if let Some(price) = price {
2087 builder.new_px(price.to_string());
2088 }
2089
2090 if let Some(quantity) = quantity {
2091 builder.new_sz(quantity.to_string());
2092 }
2093
2094 let params = builder
2095 .build()
2096 .map_err(|e| OKXWsError::ClientError(format!("Build amend params error: {e}")))?;
2097
2098 if let Some(client_order_id) = client_order_id {
2101 let cmd = HandlerCommand::AmendOrder {
2102 params,
2103 client_order_id,
2104 trader_id,
2105 strategy_id,
2106 instrument_id,
2107 venue_order_id,
2108 };
2109
2110 self.send_cmd(cmd).await
2111 } else {
2112 Err(OKXWsError::ClientError(
2114 "Cannot amend order without client_order_id".to_string(),
2115 ))
2116 }
2117 }
2118
2119 #[allow(clippy::too_many_arguments)]
2130 pub async fn cancel_order(
2131 &self,
2132 trader_id: TraderId,
2133 strategy_id: StrategyId,
2134 instrument_id: InstrumentId,
2135 client_order_id: Option<ClientOrderId>,
2136 venue_order_id: Option<VenueOrderId>,
2137 ) -> Result<(), OKXWsError> {
2138 let cmd = HandlerCommand::CancelOrder {
2139 client_order_id,
2140 venue_order_id,
2141 instrument_id,
2142 trader_id,
2143 strategy_id,
2144 };
2145
2146 self.send_cmd(cmd).await
2147 }
2148
2149 pub async fn mass_cancel_orders(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
2159 let cmd = HandlerCommand::MassCancel { instrument_id };
2160
2161 self.send_cmd(cmd).await
2162 }
2163
2164 #[allow(clippy::type_complexity)]
2171 #[allow(clippy::too_many_arguments)]
2172 pub async fn batch_submit_orders(
2173 &self,
2174 orders: Vec<(
2175 OKXInstrumentType,
2176 InstrumentId,
2177 OKXTradeMode,
2178 ClientOrderId,
2179 OrderSide,
2180 Option<PositionSide>,
2181 OrderType,
2182 Quantity,
2183 Option<Price>,
2184 Option<Price>,
2185 Option<bool>,
2186 Option<bool>,
2187 )>,
2188 ) -> Result<(), OKXWsError> {
2189 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2190 for (
2191 inst_type,
2192 inst_id,
2193 td_mode,
2194 cl_ord_id,
2195 ord_side,
2196 pos_side,
2197 ord_type,
2198 qty,
2199 pr,
2200 tp,
2201 post_only,
2202 reduce_only,
2203 ) in orders
2204 {
2205 let mut builder = WsPostOrderParamsBuilder::default();
2206 builder.inst_type(inst_type);
2207 builder.inst_id(inst_id.symbol.inner());
2208 builder.td_mode(td_mode);
2209 builder.cl_ord_id(cl_ord_id.as_str());
2210 builder.side(ord_side);
2211
2212 if let Some(ps) = pos_side {
2213 builder.pos_side(OKXPositionSide::from(ps));
2214 }
2215
2216 let okx_ord_type = if post_only.unwrap_or(false) {
2217 OKXOrderType::PostOnly
2218 } else {
2219 OKXOrderType::from(ord_type)
2220 };
2221
2222 builder.ord_type(okx_ord_type);
2223 builder.sz(qty.to_string());
2224
2225 if let Some(p) = pr {
2226 builder.px(p.to_string());
2227 } else if let Some(p) = tp {
2228 builder.px(p.to_string());
2229 }
2230
2231 if let Some(ro) = reduce_only {
2232 builder.reduce_only(ro);
2233 }
2234
2235 builder.tag(OKX_NAUTILUS_BROKER_ID);
2236
2237 let params = builder
2238 .build()
2239 .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
2240 let val =
2241 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2242 args.push(val);
2243 }
2244
2245 self.ws_batch_place_orders(args).await
2246 }
2247
2248 #[allow(clippy::type_complexity)]
2255 #[allow(clippy::too_many_arguments)]
2256 pub async fn batch_modify_orders(
2257 &self,
2258 orders: Vec<(
2259 OKXInstrumentType,
2260 InstrumentId,
2261 ClientOrderId,
2262 ClientOrderId,
2263 Option<Price>,
2264 Option<Quantity>,
2265 )>,
2266 ) -> Result<(), OKXWsError> {
2267 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2268 for (_inst_type, inst_id, cl_ord_id, new_cl_ord_id, pr, sz) in orders {
2269 let mut builder = WsAmendOrderParamsBuilder::default();
2270 builder.inst_id(inst_id.symbol.inner());
2272 builder.cl_ord_id(cl_ord_id.as_str());
2273 builder.new_cl_ord_id(new_cl_ord_id.as_str());
2274
2275 if let Some(p) = pr {
2276 builder.new_px(p.to_string());
2277 }
2278
2279 if let Some(q) = sz {
2280 builder.new_sz(q.to_string());
2281 }
2282
2283 let params = builder.build().map_err(|e| {
2284 OKXWsError::ClientError(format!("Build amend batch params error: {e}"))
2285 })?;
2286 let val =
2287 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2288 args.push(val);
2289 }
2290
2291 self.ws_batch_amend_orders(args).await
2292 }
2293
2294 #[allow(clippy::type_complexity)]
2307 pub async fn batch_cancel_orders(
2308 &self,
2309 orders: Vec<(InstrumentId, Option<ClientOrderId>, Option<VenueOrderId>)>,
2310 ) -> Result<(), OKXWsError> {
2311 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2312 for (inst_id, cl_ord_id, ord_id) in orders {
2313 let mut builder = WsCancelOrderParamsBuilder::default();
2314 builder.inst_id(inst_id.symbol.inner());
2316
2317 if let Some(c) = cl_ord_id {
2318 builder.cl_ord_id(c.as_str());
2319 }
2320
2321 if let Some(o) = ord_id {
2322 builder.ord_id(o.as_str());
2323 }
2324
2325 let params = builder.build().map_err(|e| {
2326 OKXWsError::ClientError(format!("Build cancel batch params error: {e}"))
2327 })?;
2328 let val =
2329 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2330 args.push(val);
2331 }
2332
2333 self.ws_batch_cancel_orders(args).await
2334 }
2335
2336 #[allow(clippy::too_many_arguments)]
2347 pub async fn submit_algo_order(
2348 &self,
2349 trader_id: TraderId,
2350 strategy_id: StrategyId,
2351 instrument_id: InstrumentId,
2352 td_mode: OKXTradeMode,
2353 client_order_id: ClientOrderId,
2354 order_side: OrderSide,
2355 order_type: OrderType,
2356 quantity: Quantity,
2357 trigger_price: Price,
2358 trigger_type: Option<TriggerType>,
2359 limit_price: Option<Price>,
2360 reduce_only: Option<bool>,
2361 ) -> Result<(), OKXWsError> {
2362 if !is_conditional_order(order_type) {
2363 return Err(OKXWsError::ClientError(format!(
2364 "Order type {order_type:?} is not a conditional order"
2365 )));
2366 }
2367
2368 let mut builder = WsPostAlgoOrderParamsBuilder::default();
2369 if !matches!(order_side, OrderSide::Buy | OrderSide::Sell) {
2370 return Err(OKXWsError::ClientError(
2371 "Invalid order side for OKX".to_string(),
2372 ));
2373 }
2374
2375 builder.inst_id(instrument_id.symbol.inner());
2376 builder.td_mode(td_mode);
2377 builder.cl_ord_id(client_order_id.as_str());
2378 builder.side(order_side);
2379 builder.ord_type(
2380 conditional_order_to_algo_type(order_type)
2381 .map_err(|e| OKXWsError::ClientError(e.to_string()))?,
2382 );
2383 builder.sz(quantity.to_string());
2384 builder.trigger_px(trigger_price.to_string());
2385
2386 let okx_trigger_type = trigger_type.map_or(OKXTriggerType::Last, Into::into);
2388 builder.trigger_px_type(okx_trigger_type);
2389
2390 if matches!(order_type, OrderType::StopLimit | OrderType::LimitIfTouched)
2392 && let Some(price) = limit_price
2393 {
2394 builder.order_px(price.to_string());
2395 }
2396
2397 if let Some(reduce) = reduce_only {
2398 builder.reduce_only(reduce);
2399 }
2400
2401 builder.tag(OKX_NAUTILUS_BROKER_ID);
2402
2403 let params = builder
2404 .build()
2405 .map_err(|e| OKXWsError::ClientError(format!("Build algo order params error: {e}")))?;
2406
2407 self.active_client_orders
2408 .insert(client_order_id, (trader_id, strategy_id, instrument_id));
2409
2410 let cmd = HandlerCommand::PlaceAlgoOrder {
2411 params,
2412 client_order_id,
2413 trader_id,
2414 strategy_id,
2415 instrument_id,
2416 };
2417
2418 self.send_cmd(cmd).await
2419 }
2420
2421 pub async fn cancel_algo_order(
2432 &self,
2433 trader_id: TraderId,
2434 strategy_id: StrategyId,
2435 instrument_id: InstrumentId,
2436 client_order_id: Option<ClientOrderId>,
2437 algo_order_id: Option<String>,
2438 ) -> Result<(), OKXWsError> {
2439 let cmd = HandlerCommand::CancelAlgoOrder {
2440 client_order_id,
2441 algo_order_id: algo_order_id.map(|id| VenueOrderId::from(id.as_str())),
2442 instrument_id,
2443 trader_id,
2444 strategy_id,
2445 };
2446
2447 self.send_cmd(cmd).await
2448 }
2449
2450 async fn send_cmd(&self, cmd: HandlerCommand) -> Result<(), OKXWsError> {
2452 self.cmd_tx
2453 .read()
2454 .await
2455 .send(cmd)
2456 .map_err(|e| OKXWsError::ClientError(format!("Handler not available: {e}")))
2457 }
2458}
2459
2460#[cfg(test)]
2461mod tests {
2462 use nautilus_core::time::get_atomic_clock_realtime;
2463 use nautilus_network::RECONNECTED;
2464 use rstest::rstest;
2465 use tokio_tungstenite::tungstenite::Message;
2466
2467 use super::*;
2468 use crate::{
2469 common::{
2470 consts::OKX_POST_ONLY_CANCEL_SOURCE,
2471 enums::{OKXExecType, OKXOrderCategory, OKXOrderStatus, OKXSide},
2472 },
2473 websocket::{
2474 handler::OKXWsFeedHandler,
2475 messages::{OKXOrderMsg, OKXWebSocketError, OKXWsMessage},
2476 },
2477 };
2478
2479 #[rstest]
2480 fn test_timestamp_format_for_websocket_auth() {
2481 let timestamp = SystemTime::now()
2482 .duration_since(SystemTime::UNIX_EPOCH)
2483 .expect("System time should be after UNIX epoch")
2484 .as_secs()
2485 .to_string();
2486
2487 assert!(timestamp.parse::<u64>().is_ok());
2488 assert_eq!(timestamp.len(), 10);
2489 assert!(timestamp.chars().all(|c| c.is_ascii_digit()));
2490 }
2491
2492 #[rstest]
2493 fn test_new_without_credentials() {
2494 let client = OKXWebSocketClient::default();
2495 assert!(client.credential.is_none());
2496 assert_eq!(client.api_key(), None);
2497 }
2498
2499 #[rstest]
2500 fn test_new_with_credentials() {
2501 let client = OKXWebSocketClient::new(
2502 None,
2503 Some("test_key".to_string()),
2504 Some("test_secret".to_string()),
2505 Some("test_passphrase".to_string()),
2506 None,
2507 None,
2508 )
2509 .unwrap();
2510 assert!(client.credential.is_some());
2511 assert_eq!(client.api_key(), Some("test_key"));
2512 }
2513
2514 #[rstest]
2515 fn test_new_partial_credentials_fails() {
2516 let result = OKXWebSocketClient::new(
2517 None,
2518 Some("test_key".to_string()),
2519 None,
2520 Some("test_passphrase".to_string()),
2521 None,
2522 None,
2523 );
2524 assert!(result.is_err());
2525 }
2526
2527 #[rstest]
2528 fn test_request_id_generation() {
2529 let client = OKXWebSocketClient::default();
2530
2531 let initial_counter = client.request_id_counter.load(Ordering::SeqCst);
2532
2533 let id1 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
2534 let id2 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
2535
2536 assert_eq!(id1, initial_counter);
2537 assert_eq!(id2, initial_counter + 1);
2538 assert_eq!(
2539 client.request_id_counter.load(Ordering::SeqCst),
2540 initial_counter + 2
2541 );
2542 }
2543
2544 #[rstest]
2545 fn test_client_state_management() {
2546 let client = OKXWebSocketClient::default();
2547
2548 assert!(client.is_closed());
2549 assert!(!client.is_active());
2550
2551 let client_with_heartbeat =
2552 OKXWebSocketClient::new(None, None, None, None, None, Some(30)).unwrap();
2553
2554 assert!(client_with_heartbeat.heartbeat.is_some());
2555 assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
2556 }
2557
2558 #[rstest]
2563 fn test_websocket_error_handling() {
2564 let clock = get_atomic_clock_realtime();
2565 let ts = clock.get_time_ns().as_u64();
2566
2567 let error = OKXWebSocketError {
2568 code: "60012".to_string(),
2569 message: "Invalid request".to_string(),
2570 conn_id: None,
2571 timestamp: ts,
2572 };
2573
2574 assert_eq!(error.code, "60012");
2575 assert_eq!(error.message, "Invalid request");
2576 assert_eq!(error.timestamp, ts);
2577
2578 let nautilus_msg = NautilusWsMessage::Error(error);
2579 match nautilus_msg {
2580 NautilusWsMessage::Error(e) => {
2581 assert_eq!(e.code, "60012");
2582 assert_eq!(e.message, "Invalid request");
2583 }
2584 _ => panic!("Expected Error variant"),
2585 }
2586 }
2587
2588 #[rstest]
2589 fn test_request_id_generation_sequence() {
2590 let client = OKXWebSocketClient::default();
2591
2592 let initial_counter = client
2593 .request_id_counter
2594 .load(std::sync::atomic::Ordering::SeqCst);
2595 let mut ids = Vec::new();
2596 for _ in 0..10 {
2597 let id = client
2598 .request_id_counter
2599 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2600 ids.push(id);
2601 }
2602
2603 for (i, &id) in ids.iter().enumerate() {
2604 assert_eq!(id, initial_counter + i as u64);
2605 }
2606
2607 assert_eq!(
2608 client
2609 .request_id_counter
2610 .load(std::sync::atomic::Ordering::SeqCst),
2611 initial_counter + 10
2612 );
2613 }
2614
2615 #[rstest]
2616 fn test_client_state_transitions() {
2617 let client = OKXWebSocketClient::default();
2618
2619 assert!(client.is_closed());
2620 assert!(!client.is_active());
2621
2622 let client_with_heartbeat = OKXWebSocketClient::new(
2623 None,
2624 None,
2625 None,
2626 None,
2627 None,
2628 Some(30), )
2630 .unwrap();
2631
2632 assert!(client_with_heartbeat.heartbeat.is_some());
2633 assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
2634
2635 let account_id = AccountId::from("test-account-123");
2636 let client_with_account =
2637 OKXWebSocketClient::new(None, None, None, None, Some(account_id), None).unwrap();
2638
2639 assert_eq!(client_with_account.account_id, account_id);
2640 }
2641
2642 #[rstest]
2643 fn test_websocket_error_scenarios() {
2644 let clock = get_atomic_clock_realtime();
2645 let ts = clock.get_time_ns().as_u64();
2646
2647 let error_scenarios = vec![
2648 ("60012", "Invalid request", None),
2649 ("60009", "Invalid API key", Some("conn-123".to_string())),
2650 ("60014", "Too many requests", None),
2651 ("50001", "Order not found", None),
2652 ];
2653
2654 for (code, message, conn_id) in error_scenarios {
2655 let error = OKXWebSocketError {
2656 code: code.to_string(),
2657 message: message.to_string(),
2658 conn_id: conn_id.clone(),
2659 timestamp: ts,
2660 };
2661
2662 assert_eq!(error.code, code);
2663 assert_eq!(error.message, message);
2664 assert_eq!(error.conn_id, conn_id);
2665 assert_eq!(error.timestamp, ts);
2666
2667 let nautilus_msg = NautilusWsMessage::Error(error);
2668 match nautilus_msg {
2669 NautilusWsMessage::Error(e) => {
2670 assert_eq!(e.code, code);
2671 assert_eq!(e.message, message);
2672 assert_eq!(e.conn_id, conn_id);
2673 }
2674 _ => panic!("Expected Error variant"),
2675 }
2676 }
2677 }
2678
2679 #[rstest]
2680 fn test_feed_handler_reconnection_detection() {
2681 let msg = Message::Text(RECONNECTED.to_string().into());
2682 let result = OKXWsFeedHandler::parse_raw_message(msg);
2683 assert!(matches!(result, Some(OKXWsMessage::Reconnected)));
2684 }
2685
2686 #[rstest]
2687 fn test_feed_handler_normal_message_processing() {
2688 let ping_msg = Message::Text(TEXT_PING.to_string().into());
2690 let result = OKXWsFeedHandler::parse_raw_message(ping_msg);
2691 assert!(matches!(result, Some(OKXWsMessage::Ping)));
2692
2693 let sub_msg = r#"{
2695 "event": "subscribe",
2696 "arg": {
2697 "channel": "tickers",
2698 "instType": "SPOT"
2699 },
2700 "connId": "a4d3ae55"
2701 }"#;
2702
2703 let sub_result =
2704 OKXWsFeedHandler::parse_raw_message(Message::Text(sub_msg.to_string().into()));
2705 assert!(matches!(
2706 sub_result,
2707 Some(OKXWsMessage::Subscription { .. })
2708 ));
2709 }
2710
2711 #[rstest]
2712 fn test_feed_handler_close_message() {
2713 let result = OKXWsFeedHandler::parse_raw_message(Message::Close(None));
2715 assert!(result.is_none());
2716 }
2717
2718 #[rstest]
2719 fn test_reconnection_message_constant() {
2720 assert_eq!(RECONNECTED, "__RECONNECTED__");
2721 }
2722
2723 #[rstest]
2724 fn test_multiple_reconnection_signals() {
2725 for _ in 0..3 {
2727 let msg = Message::Text(RECONNECTED.to_string().into());
2728 let result = OKXWsFeedHandler::parse_raw_message(msg);
2729 assert!(matches!(result, Some(OKXWsMessage::Reconnected)));
2730 }
2731 }
2732
2733 #[tokio::test]
2734 async fn test_wait_until_active_timeout() {
2735 let client = OKXWebSocketClient::new(
2736 None,
2737 Some("test_key".to_string()),
2738 Some("test_secret".to_string()),
2739 Some("test_passphrase".to_string()),
2740 Some(AccountId::from("test-account")),
2741 None,
2742 )
2743 .unwrap();
2744
2745 let result = client.wait_until_active(0.1).await;
2747
2748 assert!(result.is_err());
2749 assert!(!client.is_active());
2750 }
2751
2752 fn sample_canceled_order_msg() -> OKXOrderMsg {
2753 OKXOrderMsg {
2754 acc_fill_sz: Some("0".to_string()),
2755 avg_px: "0".to_string(),
2756 c_time: 0,
2757 cancel_source: None,
2758 cancel_source_reason: None,
2759 category: OKXOrderCategory::Normal,
2760 ccy: ustr::Ustr::from("USDT"),
2761 cl_ord_id: "order-1".to_string(),
2762 algo_cl_ord_id: None,
2763 fee: None,
2764 fee_ccy: ustr::Ustr::from("USDT"),
2765 fill_px: "0".to_string(),
2766 fill_sz: "0".to_string(),
2767 fill_time: 0,
2768 inst_id: ustr::Ustr::from("ETH-USDT-SWAP"),
2769 inst_type: OKXInstrumentType::Swap,
2770 lever: "1".to_string(),
2771 ord_id: ustr::Ustr::from("123456"),
2772 ord_type: OKXOrderType::Limit,
2773 pnl: "0".to_string(),
2774 pos_side: OKXPositionSide::Net,
2775 px: "0".to_string(),
2776 reduce_only: "false".to_string(),
2777 side: OKXSide::Buy,
2778 state: OKXOrderStatus::Canceled,
2779 exec_type: OKXExecType::None,
2780 sz: "1".to_string(),
2781 td_mode: OKXTradeMode::Cross,
2782 tgt_ccy: None,
2783 trade_id: String::new(),
2784 u_time: 0,
2785 }
2786 }
2787
2788 #[rstest]
2789 fn test_is_post_only_auto_cancel_detects_cancel_source() {
2790 let mut msg = sample_canceled_order_msg();
2791 msg.cancel_source = Some(OKX_POST_ONLY_CANCEL_SOURCE.to_string());
2792
2793 assert!(OKXWsFeedHandler::is_post_only_auto_cancel(&msg));
2794 }
2795
2796 #[rstest]
2797 fn test_is_post_only_auto_cancel_detects_reason() {
2798 let mut msg = sample_canceled_order_msg();
2799 msg.cancel_source_reason = Some("POST_ONLY would take liquidity".to_string());
2800
2801 assert!(OKXWsFeedHandler::is_post_only_auto_cancel(&msg));
2802 }
2803
2804 #[rstest]
2805 fn test_is_post_only_auto_cancel_false_without_markers() {
2806 let msg = sample_canceled_order_msg();
2807
2808 assert!(!OKXWsFeedHandler::is_post_only_auto_cancel(&msg));
2809 }
2810
2811 #[rstest]
2812 fn test_is_post_only_auto_cancel_false_for_order_type_only() {
2813 let mut msg = sample_canceled_order_msg();
2814 msg.ord_type = OKXOrderType::PostOnly;
2815
2816 assert!(!OKXWsFeedHandler::is_post_only_auto_cancel(&msg));
2817 }
2818
2819 #[tokio::test]
2820 async fn test_batch_cancel_orders_with_multiple_orders() {
2821 use nautilus_model::identifiers::{ClientOrderId, InstrumentId, VenueOrderId};
2822
2823 let client = OKXWebSocketClient::new(
2824 Some("wss://test.okx.com".to_string()),
2825 None,
2826 None,
2827 None,
2828 None,
2829 None,
2830 )
2831 .expect("Failed to create client");
2832
2833 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
2834 let client_order_id1 = ClientOrderId::new("order1");
2835 let client_order_id2 = ClientOrderId::new("order2");
2836 let venue_order_id1 = VenueOrderId::new("venue1");
2837 let venue_order_id2 = VenueOrderId::new("venue2");
2838
2839 let orders = vec![
2840 (instrument_id, Some(client_order_id1), Some(venue_order_id1)),
2841 (instrument_id, Some(client_order_id2), Some(venue_order_id2)),
2842 ];
2843
2844 let result = client.batch_cancel_orders(orders).await;
2846
2847 assert!(result.is_err());
2849 }
2850
2851 #[tokio::test]
2852 async fn test_batch_cancel_orders_with_only_client_order_id() {
2853 use nautilus_model::identifiers::{ClientOrderId, InstrumentId};
2854
2855 let client = OKXWebSocketClient::new(
2856 Some("wss://test.okx.com".to_string()),
2857 None,
2858 None,
2859 None,
2860 None,
2861 None,
2862 )
2863 .expect("Failed to create client");
2864
2865 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
2866 let client_order_id = ClientOrderId::new("order1");
2867
2868 let orders = vec![(instrument_id, Some(client_order_id), None)];
2869
2870 let result = client.batch_cancel_orders(orders).await;
2871
2872 assert!(result.is_err());
2874 }
2875
2876 #[tokio::test]
2877 async fn test_batch_cancel_orders_with_only_venue_order_id() {
2878 use nautilus_model::identifiers::{InstrumentId, VenueOrderId};
2879
2880 let client = OKXWebSocketClient::new(
2881 Some("wss://test.okx.com".to_string()),
2882 None,
2883 None,
2884 None,
2885 None,
2886 None,
2887 )
2888 .expect("Failed to create client");
2889
2890 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
2891 let venue_order_id = VenueOrderId::new("venue1");
2892
2893 let orders = vec![(instrument_id, None, Some(venue_order_id))];
2894
2895 let result = client.batch_cancel_orders(orders).await;
2896
2897 assert!(result.is_err());
2899 }
2900
2901 #[tokio::test]
2902 async fn test_batch_cancel_orders_with_both_ids() {
2903 use nautilus_model::identifiers::{ClientOrderId, InstrumentId, VenueOrderId};
2904
2905 let client = OKXWebSocketClient::new(
2906 Some("wss://test.okx.com".to_string()),
2907 None,
2908 None,
2909 None,
2910 None,
2911 None,
2912 )
2913 .expect("Failed to create client");
2914
2915 let instrument_id = InstrumentId::from("BTC-USDT-SWAP.OKX");
2916 let client_order_id = ClientOrderId::new("order1");
2917 let venue_order_id = VenueOrderId::new("venue1");
2918
2919 let orders = vec![(instrument_id, Some(client_order_id), Some(venue_order_id))];
2920
2921 let result = client.batch_cancel_orders(orders).await;
2922
2923 assert!(result.is_err());
2925 }
2926
2927 #[rstest]
2928 fn test_race_unsubscribe_failure_recovery() {
2929 let client = OKXWebSocketClient::new(
2935 Some("wss://test.okx.com".to_string()),
2936 None,
2937 None,
2938 None,
2939 None,
2940 None,
2941 )
2942 .expect("Failed to create client");
2943
2944 let topic = "trades:BTC-USDT-SWAP";
2945
2946 client.subscriptions_state.mark_subscribe(topic);
2948 client.subscriptions_state.confirm_subscribe(topic);
2949 assert_eq!(client.subscriptions_state.len(), 1);
2950
2951 client.subscriptions_state.mark_unsubscribe(topic);
2953 assert_eq!(client.subscriptions_state.len(), 0);
2954 assert_eq!(
2955 client.subscriptions_state.pending_unsubscribe_topics(),
2956 vec![topic]
2957 );
2958
2959 client.subscriptions_state.confirm_unsubscribe(topic); client.subscriptions_state.mark_subscribe(topic); client.subscriptions_state.confirm_subscribe(topic); assert_eq!(client.subscriptions_state.len(), 1);
2967 assert!(
2968 client
2969 .subscriptions_state
2970 .pending_unsubscribe_topics()
2971 .is_empty()
2972 );
2973 assert!(
2974 client
2975 .subscriptions_state
2976 .pending_subscribe_topics()
2977 .is_empty()
2978 );
2979
2980 let all = client.subscriptions_state.all_topics();
2982 assert_eq!(all.len(), 1);
2983 assert!(all.contains(&topic.to_string()));
2984 }
2985
2986 #[rstest]
2987 fn test_race_resubscribe_before_unsubscribe_ack() {
2988 let client = OKXWebSocketClient::new(
2992 Some("wss://test.okx.com".to_string()),
2993 None,
2994 None,
2995 None,
2996 None,
2997 None,
2998 )
2999 .expect("Failed to create client");
3000
3001 let topic = "books:BTC-USDT";
3002
3003 client.subscriptions_state.mark_subscribe(topic);
3005 client.subscriptions_state.confirm_subscribe(topic);
3006 assert_eq!(client.subscriptions_state.len(), 1);
3007
3008 client.subscriptions_state.mark_unsubscribe(topic);
3010 assert_eq!(client.subscriptions_state.len(), 0);
3011 assert_eq!(
3012 client.subscriptions_state.pending_unsubscribe_topics(),
3013 vec![topic]
3014 );
3015
3016 client.subscriptions_state.mark_subscribe(topic);
3018 assert_eq!(
3019 client.subscriptions_state.pending_subscribe_topics(),
3020 vec![topic]
3021 );
3022
3023 client.subscriptions_state.confirm_unsubscribe(topic);
3025 assert!(
3026 client
3027 .subscriptions_state
3028 .pending_unsubscribe_topics()
3029 .is_empty()
3030 );
3031 assert_eq!(
3032 client.subscriptions_state.pending_subscribe_topics(),
3033 vec![topic]
3034 );
3035
3036 client.subscriptions_state.confirm_subscribe(topic);
3038 assert_eq!(client.subscriptions_state.len(), 1);
3039 assert!(
3040 client
3041 .subscriptions_state
3042 .pending_subscribe_topics()
3043 .is_empty()
3044 );
3045
3046 let all = client.subscriptions_state.all_topics();
3048 assert_eq!(all.len(), 1);
3049 assert!(all.contains(&topic.to_string()));
3050 }
3051
3052 #[rstest]
3053 fn test_race_late_subscribe_confirmation_after_unsubscribe() {
3054 let client = OKXWebSocketClient::new(
3057 Some("wss://test.okx.com".to_string()),
3058 None,
3059 None,
3060 None,
3061 None,
3062 None,
3063 )
3064 .expect("Failed to create client");
3065
3066 let topic = "tickers:ETH-USDT";
3067
3068 client.subscriptions_state.mark_subscribe(topic);
3070 assert_eq!(
3071 client.subscriptions_state.pending_subscribe_topics(),
3072 vec![topic]
3073 );
3074
3075 client.subscriptions_state.mark_unsubscribe(topic);
3077 assert!(
3078 client
3079 .subscriptions_state
3080 .pending_subscribe_topics()
3081 .is_empty()
3082 ); assert_eq!(
3084 client.subscriptions_state.pending_unsubscribe_topics(),
3085 vec![topic]
3086 );
3087
3088 client.subscriptions_state.confirm_subscribe(topic);
3090 assert_eq!(client.subscriptions_state.len(), 0); assert_eq!(
3092 client.subscriptions_state.pending_unsubscribe_topics(),
3093 vec![topic]
3094 );
3095
3096 client.subscriptions_state.confirm_unsubscribe(topic);
3098
3099 assert!(client.subscriptions_state.is_empty());
3101 assert!(client.subscriptions_state.all_topics().is_empty());
3102 }
3103
3104 #[rstest]
3105 fn test_race_reconnection_with_pending_states() {
3106 let client = OKXWebSocketClient::new(
3108 Some("wss://test.okx.com".to_string()),
3109 Some("test_key".to_string()),
3110 Some("test_secret".to_string()),
3111 Some("test_passphrase".to_string()),
3112 Some(AccountId::new("OKX-TEST")),
3113 None,
3114 )
3115 .expect("Failed to create client");
3116
3117 let trade_btc = "trades:BTC-USDT-SWAP";
3120 client.subscriptions_state.mark_subscribe(trade_btc);
3121 client.subscriptions_state.confirm_subscribe(trade_btc);
3122
3123 let trade_eth = "trades:ETH-USDT-SWAP";
3125 client.subscriptions_state.mark_subscribe(trade_eth);
3126
3127 let book_btc = "books:BTC-USDT";
3129 client.subscriptions_state.mark_subscribe(book_btc);
3130 client.subscriptions_state.confirm_subscribe(book_btc);
3131 client.subscriptions_state.mark_unsubscribe(book_btc);
3132
3133 let topics_to_restore = client.subscriptions_state.all_topics();
3135
3136 assert_eq!(topics_to_restore.len(), 2);
3138 assert!(topics_to_restore.contains(&trade_btc.to_string()));
3139 assert!(topics_to_restore.contains(&trade_eth.to_string()));
3140 assert!(!topics_to_restore.contains(&book_btc.to_string())); }
3142
3143 #[rstest]
3144 fn test_race_duplicate_subscribe_messages_idempotent() {
3145 let client = OKXWebSocketClient::new(
3148 Some("wss://test.okx.com".to_string()),
3149 None,
3150 None,
3151 None,
3152 None,
3153 None,
3154 )
3155 .expect("Failed to create client");
3156
3157 let topic = "trades:BTC-USDT-SWAP";
3158
3159 client.subscriptions_state.mark_subscribe(topic);
3161 client.subscriptions_state.confirm_subscribe(topic);
3162 assert_eq!(client.subscriptions_state.len(), 1);
3163
3164 client.subscriptions_state.mark_subscribe(topic);
3166 assert!(
3167 client
3168 .subscriptions_state
3169 .pending_subscribe_topics()
3170 .is_empty()
3171 ); assert_eq!(client.subscriptions_state.len(), 1); client.subscriptions_state.confirm_subscribe(topic);
3176 assert_eq!(client.subscriptions_state.len(), 1);
3177
3178 let all = client.subscriptions_state.all_topics();
3180 assert_eq!(all.len(), 1);
3181 assert_eq!(all[0], topic);
3182 }
3183}