1use std::{
26 fmt::Debug,
27 num::NonZeroU32,
28 sync::{
29 Arc, LazyLock,
30 atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering},
31 },
32 time::{Duration, SystemTime},
33};
34
35use ahash::AHashSet;
36use arc_swap::ArcSwap;
37use dashmap::DashMap;
38use futures_util::Stream;
39use nautilus_common::live::get_runtime;
40use nautilus_core::{
41 consts::NAUTILUS_USER_AGENT,
42 env::{get_env_var, get_or_env_var},
43};
44use nautilus_model::{
45 data::BarType,
46 enums::{OrderSide, OrderType, PositionSide, TimeInForce, TriggerType},
47 identifiers::{AccountId, ClientOrderId, InstrumentId, StrategyId, TraderId, VenueOrderId},
48 instruments::{Instrument, InstrumentAny},
49 types::{Price, Quantity},
50};
51use nautilus_network::{
52 http::USER_AGENT,
53 mode::ConnectionMode,
54 ratelimiter::quota::Quota,
55 websocket::{
56 AUTHENTICATION_TIMEOUT_SECS, AuthTracker, PingHandler, SubscriptionState, TEXT_PING,
57 WebSocketClient, WebSocketConfig, channel_message_handler,
58 },
59};
60use serde_json::Value;
61use tokio_tungstenite::tungstenite::Error;
62use tokio_util::sync::CancellationToken;
63use ustr::Ustr;
64
65use super::{
66 enums::OKXWsChannel,
67 error::OKXWsError,
68 handler::{HandlerCommand, OKXWsFeedHandler},
69 messages::{
70 NautilusWsMessage, OKXAuthentication, OKXAuthenticationArg, OKXSubscriptionArg,
71 WsAmendOrderParamsBuilder, WsCancelOrderParamsBuilder, WsPostAlgoOrderParamsBuilder,
72 WsPostOrderParamsBuilder,
73 },
74 subscription::topic_from_subscription_arg,
75};
76use crate::common::{
77 consts::{
78 OKX_NAUTILUS_BROKER_ID, OKX_SUPPORTED_ORDER_TYPES, OKX_SUPPORTED_TIME_IN_FORCE,
79 OKX_WS_PUBLIC_URL, OKX_WS_TOPIC_DELIMITER,
80 },
81 credential::Credential,
82 enums::{
83 OKXInstrumentType, OKXOrderType, OKXPositionSide, OKXTargetCurrency, OKXTradeMode,
84 OKXTriggerType, OKXVipLevel, conditional_order_to_algo_type, is_conditional_order,
85 },
86 parse::{bar_spec_as_okx_channel, okx_instrument_type, okx_instrument_type_from_symbol},
87};
88
89pub static OKX_WS_CONNECTION_QUOTA: LazyLock<Quota> =
93 LazyLock::new(|| Quota::per_second(NonZeroU32::new(3).unwrap()));
94
95pub static OKX_WS_SUBSCRIPTION_QUOTA: LazyLock<Quota> =
100 LazyLock::new(|| Quota::per_hour(NonZeroU32::new(480).unwrap()));
101
102pub static OKX_WS_ORDER_QUOTA: LazyLock<Quota> =
107 LazyLock::new(|| Quota::per_second(NonZeroU32::new(250).unwrap()));
108
109pub const OKX_RATE_LIMIT_KEY_SUBSCRIPTION: &str = "subscription";
114
115pub const OKX_RATE_LIMIT_KEY_ORDER: &str = "order";
120
121pub const OKX_RATE_LIMIT_KEY_CANCEL: &str = "cancel";
127
128pub const OKX_RATE_LIMIT_KEY_AMEND: &str = "amend";
132
133#[derive(Clone)]
135#[cfg_attr(
136 feature = "python",
137 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.okx")
138)]
139pub struct OKXWebSocketClient {
140 url: String,
141 account_id: AccountId,
142 vip_level: Arc<AtomicU8>,
143 credential: Option<Credential>,
144 heartbeat: Option<u64>,
145 auth_tracker: AuthTracker,
146 signal: Arc<AtomicBool>,
147 connection_mode: Arc<ArcSwap<AtomicU8>>,
148 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
149 out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>,
150 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
151 subscriptions_inst_type: Arc<DashMap<OKXWsChannel, AHashSet<OKXInstrumentType>>>,
152 subscriptions_inst_family: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
153 subscriptions_inst_id: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
154 subscriptions_bare: Arc<DashMap<OKXWsChannel, bool>>, subscriptions_state: SubscriptionState,
156 request_id_counter: Arc<AtomicU64>,
157 active_client_orders: Arc<DashMap<ClientOrderId, (TraderId, StrategyId, InstrumentId)>>,
158 client_id_aliases: Arc<DashMap<ClientOrderId, ClientOrderId>>,
159 instruments_cache: Arc<DashMap<Ustr, InstrumentAny>>,
160 cancellation_token: CancellationToken,
161}
162
163impl Default for OKXWebSocketClient {
164 fn default() -> Self {
165 Self::new(None, None, None, None, None, None).unwrap()
166 }
167}
168
169impl Debug for OKXWebSocketClient {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 f.debug_struct(stringify!(OKXWebSocketClient))
172 .field("url", &self.url)
173 .field(
174 "credential",
175 &self.credential.as_ref().map(|_| "<redacted>"),
176 )
177 .field("heartbeat", &self.heartbeat)
178 .finish_non_exhaustive()
179 }
180}
181
182impl OKXWebSocketClient {
183 pub fn new(
189 url: Option<String>,
190 api_key: Option<String>,
191 api_secret: Option<String>,
192 api_passphrase: Option<String>,
193 account_id: Option<AccountId>,
194 heartbeat: Option<u64>,
195 ) -> anyhow::Result<Self> {
196 let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
197 let account_id = account_id.unwrap_or(AccountId::from("OKX-master"));
198
199 let credential = match (api_key, api_secret, api_passphrase) {
200 (Some(key), Some(secret), Some(passphrase)) => {
201 Some(Credential::new(key, secret, passphrase))
202 }
203 (None, None, None) => None,
204 _ => anyhow::bail!(
205 "`api_key`, `api_secret`, `api_passphrase` credentials must be provided together"
206 ),
207 };
208
209 let signal = Arc::new(AtomicBool::new(false));
210 let subscriptions_inst_type = Arc::new(DashMap::new());
211 let subscriptions_inst_family = Arc::new(DashMap::new());
212 let subscriptions_inst_id = Arc::new(DashMap::new());
213 let subscriptions_bare = Arc::new(DashMap::new());
214 let subscriptions_state = SubscriptionState::new(OKX_WS_TOPIC_DELIMITER);
215
216 Ok(Self {
217 url,
218 account_id,
219 vip_level: Arc::new(AtomicU8::new(0)), credential,
221 heartbeat,
222 auth_tracker: AuthTracker::new(),
223 signal,
224 connection_mode: Arc::new(ArcSwap::from_pointee(AtomicU8::new(
225 ConnectionMode::Closed.as_u8(),
226 ))),
227 cmd_tx: {
228 let (tx, _) = tokio::sync::mpsc::unbounded_channel();
230 Arc::new(tokio::sync::RwLock::new(tx))
231 },
232 out_rx: None,
233 task_handle: None,
234 subscriptions_inst_type,
235 subscriptions_inst_family,
236 subscriptions_inst_id,
237 subscriptions_bare,
238 subscriptions_state,
239 request_id_counter: Arc::new(AtomicU64::new(1)),
240 active_client_orders: Arc::new(DashMap::new()),
241 client_id_aliases: Arc::new(DashMap::new()),
242 instruments_cache: Arc::new(DashMap::new()),
243 cancellation_token: CancellationToken::new(),
244 })
245 }
246
247 pub fn with_credentials(
254 url: Option<String>,
255 api_key: Option<String>,
256 api_secret: Option<String>,
257 api_passphrase: Option<String>,
258 account_id: Option<AccountId>,
259 heartbeat: Option<u64>,
260 ) -> anyhow::Result<Self> {
261 let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
262 let api_key = get_or_env_var(api_key, "OKX_API_KEY")?;
263 let api_secret = get_or_env_var(api_secret, "OKX_API_SECRET")?;
264 let api_passphrase = get_or_env_var(api_passphrase, "OKX_API_PASSPHRASE")?;
265
266 Self::new(
267 Some(url),
268 Some(api_key),
269 Some(api_secret),
270 Some(api_passphrase),
271 account_id,
272 heartbeat,
273 )
274 }
275
276 pub fn from_env() -> anyhow::Result<Self> {
283 let url = get_env_var("OKX_WS_URL")?;
284 let api_key = get_env_var("OKX_API_KEY")?;
285 let api_secret = get_env_var("OKX_API_SECRET")?;
286 let api_passphrase = get_env_var("OKX_API_PASSPHRASE")?;
287
288 Self::new(
289 Some(url),
290 Some(api_key),
291 Some(api_secret),
292 Some(api_passphrase),
293 None,
294 None,
295 )
296 }
297
298 pub fn cancel_all_requests(&self) {
300 self.cancellation_token.cancel();
301 }
302
303 pub fn cancellation_token(&self) -> &CancellationToken {
305 &self.cancellation_token
306 }
307
308 pub fn url(&self) -> &str {
310 self.url.as_str()
311 }
312
313 pub fn api_key(&self) -> Option<&str> {
315 self.credential.clone().map(|c| c.api_key.as_str())
316 }
317
318 #[must_use]
320 pub fn api_key_masked(&self) -> Option<String> {
321 self.credential.clone().map(|c| c.api_key_masked())
322 }
323
324 pub fn is_active(&self) -> bool {
326 let connection_mode_arc = self.connection_mode.load();
327 ConnectionMode::from_atomic(&connection_mode_arc).is_active()
328 && !self.signal.load(Ordering::Acquire)
329 }
330
331 pub fn is_closed(&self) -> bool {
333 let connection_mode_arc = self.connection_mode.load();
334 ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
335 || self.signal.load(Ordering::Acquire)
336 }
337
338 pub fn cache_instruments(&self, instruments: Vec<InstrumentAny>) {
342 for inst in &instruments {
343 self.instruments_cache
344 .insert(inst.symbol().inner(), inst.clone());
345 }
346
347 if !instruments.is_empty()
350 && let Ok(cmd_tx) = self.cmd_tx.try_read()
351 && let Err(e) = cmd_tx.send(HandlerCommand::InitializeInstruments(instruments))
352 {
353 log::debug!("Failed to send bulk instrument update to handler: {e}");
354 }
355 }
356
357 pub fn cache_instrument(&self, instrument: InstrumentAny) {
361 self.instruments_cache
362 .insert(instrument.symbol().inner(), instrument.clone());
363
364 if let Ok(cmd_tx) = self.cmd_tx.try_read()
367 && let Err(e) = cmd_tx.send(HandlerCommand::UpdateInstrument(instrument))
368 {
369 log::debug!("Failed to send instrument update to handler: {e}");
370 }
371 }
372
373 pub fn set_vip_level(&self, vip_level: OKXVipLevel) {
377 self.vip_level.store(vip_level as u8, Ordering::Relaxed);
378 }
379
380 pub fn vip_level(&self) -> OKXVipLevel {
382 let level = self.vip_level.load(Ordering::Relaxed);
383 OKXVipLevel::from(level)
384 }
385
386 pub async fn connect(&mut self) -> anyhow::Result<()> {
396 let (message_handler, raw_rx) = channel_message_handler();
397
398 let ping_handler: PingHandler = Arc::new(move |_payload: Vec<u8>| {
401 });
403
404 let config = WebSocketConfig {
405 url: self.url.clone(),
406 headers: vec![(USER_AGENT.to_string(), NAUTILUS_USER_AGENT.to_string())],
407 heartbeat: self.heartbeat,
408 heartbeat_msg: Some(TEXT_PING.to_string()),
409 reconnect_timeout_ms: Some(5_000),
410 reconnect_delay_initial_ms: None, reconnect_delay_max_ms: None, reconnect_backoff_factor: None, reconnect_jitter_ms: None, reconnect_max_attempts: None,
415 };
416
417 let keyed_quotas = vec![
419 (
420 OKX_RATE_LIMIT_KEY_SUBSCRIPTION.to_string(),
421 *OKX_WS_SUBSCRIPTION_QUOTA,
422 ),
423 (OKX_RATE_LIMIT_KEY_ORDER.to_string(), *OKX_WS_ORDER_QUOTA),
424 (OKX_RATE_LIMIT_KEY_CANCEL.to_string(), *OKX_WS_ORDER_QUOTA),
425 (OKX_RATE_LIMIT_KEY_AMEND.to_string(), *OKX_WS_ORDER_QUOTA),
426 ];
427
428 let client = WebSocketClient::connect(
429 config,
430 Some(message_handler),
431 Some(ping_handler),
432 None, keyed_quotas,
434 Some(*OKX_WS_CONNECTION_QUOTA), )
436 .await?;
437
438 self.connection_mode.store(client.connection_mode_atomic());
440
441 let account_id = self.account_id;
442 let (msg_tx, rx) = tokio::sync::mpsc::unbounded_channel::<NautilusWsMessage>();
443
444 self.out_rx = Some(Arc::new(rx));
445
446 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
448 *self.cmd_tx.write().await = cmd_tx.clone();
449
450 if !self.instruments_cache.is_empty() {
452 let cached_instruments: Vec<InstrumentAny> = self
453 .instruments_cache
454 .iter()
455 .map(|entry| entry.value().clone())
456 .collect();
457 if let Err(e) = cmd_tx.send(HandlerCommand::InitializeInstruments(cached_instruments)) {
458 log::error!("Failed to replay instruments to handler: {e}");
459 }
460 }
461
462 let signal = self.signal.clone();
463 let active_client_orders = self.active_client_orders.clone();
464 let auth_tracker = self.auth_tracker.clone();
465 let subscriptions_state = self.subscriptions_state.clone();
466 let client_id_aliases = self.client_id_aliases.clone();
467
468 let stream_handle = get_runtime().spawn({
469 let auth_tracker = auth_tracker.clone();
470 let signal = signal.clone();
471 let credential = self.credential.clone();
472 let cmd_tx_for_reconnect = cmd_tx.clone();
473 let subscriptions_bare = self.subscriptions_bare.clone();
474 let subscriptions_inst_type = self.subscriptions_inst_type.clone();
475 let subscriptions_inst_family = self.subscriptions_inst_family.clone();
476 let subscriptions_inst_id = self.subscriptions_inst_id.clone();
477 let mut has_reconnected = false;
478
479 async move {
480 let mut handler = OKXWsFeedHandler::new(
481 account_id,
482 signal.clone(),
483 cmd_rx,
484 raw_rx,
485 msg_tx,
486 active_client_orders,
487 client_id_aliases,
488 auth_tracker.clone(),
489 subscriptions_state.clone(),
490 );
491
492 let resubscribe_all = || {
494 for entry in subscriptions_inst_id.iter() {
495 let (channel, inst_ids) = entry.pair();
496 for inst_id in inst_ids {
497 let arg = OKXSubscriptionArg {
498 channel: channel.clone(),
499 inst_type: None,
500 inst_family: None,
501 inst_id: Some(*inst_id),
502 };
503 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
504 log::error!("Failed to send resubscribe command: error={e}");
505 }
506 }
507 }
508
509 for entry in subscriptions_bare.iter() {
510 let channel = entry.key();
511 let arg = OKXSubscriptionArg {
512 channel: channel.clone(),
513 inst_type: None,
514 inst_family: None,
515 inst_id: None,
516 };
517 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
518 log::error!("Failed to send resubscribe command: error={e}");
519 }
520 }
521
522 for entry in subscriptions_inst_type.iter() {
523 let (channel, inst_types) = entry.pair();
524 for inst_type in inst_types {
525 let arg = OKXSubscriptionArg {
526 channel: channel.clone(),
527 inst_type: Some(*inst_type),
528 inst_family: None,
529 inst_id: None,
530 };
531 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
532 log::error!("Failed to send resubscribe command: error={e}");
533 }
534 }
535 }
536
537 for entry in subscriptions_inst_family.iter() {
538 let (channel, inst_families) = entry.pair();
539 for inst_family in inst_families {
540 let arg = OKXSubscriptionArg {
541 channel: channel.clone(),
542 inst_type: None,
543 inst_family: Some(*inst_family),
544 inst_id: None,
545 };
546 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
547 log::error!("Failed to send resubscribe command: error={e}");
548 }
549 }
550 }
551 };
552
553 loop {
555 match handler.next().await {
556 Some(NautilusWsMessage::Reconnected) => {
557 if signal.load(Ordering::Acquire) {
558 continue;
559 }
560
561 has_reconnected = true;
562
563 let confirmed_topics_vec: Vec<String> = {
565 let confirmed = subscriptions_state.confirmed();
566 let mut topics = Vec::new();
567 for entry in confirmed.iter() {
568 let channel = entry.key();
569 for symbol in entry.value() {
570 if symbol.as_str() == "#" {
571 topics.push(channel.to_string());
572 } else {
573 topics.push(format!("{channel}{OKX_WS_TOPIC_DELIMITER}{symbol}"));
574 }
575 }
576 }
577 topics
578 };
579
580 if !confirmed_topics_vec.is_empty() {
581 log::debug!("Marking confirmed subscriptions as pending for replay: count={}", confirmed_topics_vec.len());
582 for topic in confirmed_topics_vec {
583 subscriptions_state.mark_failure(&topic);
584 }
585 }
586
587 if let Some(cred) = &credential {
588 log::debug!("Re-authenticating after reconnection");
589 let timestamp = std::time::SystemTime::now()
590 .duration_since(std::time::SystemTime::UNIX_EPOCH)
591 .expect("System time should be after UNIX epoch")
592 .as_secs()
593 .to_string();
594 let signature = cred.sign(×tamp, "GET", "/users/self/verify", "");
595
596 let auth_message = super::messages::OKXAuthentication {
597 op: "login",
598 args: vec![super::messages::OKXAuthenticationArg {
599 api_key: cred.api_key.to_string(),
600 passphrase: cred.api_passphrase.clone(),
601 timestamp,
602 sign: signature,
603 }],
604 };
605
606 if let Ok(payload) = serde_json::to_string(&auth_message) {
607 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Authenticate { payload }) {
608 log::error!("Failed to send reconnection auth command: error={e}");
609 }
610 } else {
611 log::error!("Failed to serialize reconnection auth message");
612 }
613 }
614
615 if credential.is_none() {
618 log::debug!("No authentication required, resubscribing immediately");
619 resubscribe_all();
620 }
621
622 continue;
627 }
628 Some(NautilusWsMessage::Authenticated) => {
629 if has_reconnected {
630 resubscribe_all();
631 }
632
633 continue;
638 }
639 Some(msg) => {
640 if handler.send(msg).is_err() {
641 log::error!(
642 "Failed to send message through channel: receiver dropped",
643 );
644 break;
645 }
646 }
647 None => {
648 if handler.is_stopped() {
649 log::debug!(
650 "Stop signal received, ending message processing",
651 );
652 break;
653 }
654 log::debug!("WebSocket stream closed");
655 break;
656 }
657 }
658 }
659
660 log::debug!("Handler task exiting");
661 }
662 });
663
664 self.task_handle = Some(Arc::new(stream_handle));
665
666 self.cmd_tx
667 .read()
668 .await
669 .send(HandlerCommand::SetClient(client))
670 .map_err(|e| {
671 OKXWsError::ClientError(format!("Failed to send WebSocket client to handler: {e}"))
672 })?;
673 log::debug!("Sent WebSocket client to handler");
674
675 if self.credential.is_some()
676 && let Err(e) = self.authenticate().await
677 {
678 anyhow::bail!("Authentication failed: {e}");
679 }
680
681 Ok(())
682 }
683
684 async fn authenticate(&self) -> Result<(), Error> {
686 let credential = self.credential.as_ref().ok_or_else(|| {
687 Error::Io(std::io::Error::other(
688 "API credentials not available to authenticate",
689 ))
690 })?;
691
692 let rx = self.auth_tracker.begin();
693
694 let timestamp = SystemTime::now()
695 .duration_since(SystemTime::UNIX_EPOCH)
696 .expect("System time should be after UNIX epoch")
697 .as_secs()
698 .to_string();
699 let signature = credential.sign(×tamp, "GET", "/users/self/verify", "");
700
701 let auth_message = OKXAuthentication {
702 op: "login",
703 args: vec![OKXAuthenticationArg {
704 api_key: credential.api_key.to_string(),
705 passphrase: credential.api_passphrase.clone(),
706 timestamp,
707 sign: signature,
708 }],
709 };
710
711 let payload = serde_json::to_string(&auth_message).map_err(|e| {
712 Error::Io(std::io::Error::other(format!(
713 "Failed to serialize auth message: {e}"
714 )))
715 })?;
716
717 self.cmd_tx
718 .read()
719 .await
720 .send(HandlerCommand::Authenticate { payload })
721 .map_err(|e| {
722 Error::Io(std::io::Error::other(format!(
723 "Failed to send authenticate command: {e}"
724 )))
725 })?;
726
727 match self
728 .auth_tracker
729 .wait_for_result::<OKXWsError>(Duration::from_secs(AUTHENTICATION_TIMEOUT_SECS), rx)
730 .await
731 {
732 Ok(()) => {
733 log::info!("WebSocket authenticated");
734 Ok(())
735 }
736 Err(e) => {
737 log::error!("WebSocket authentication failed: error={e}");
738 Err(Error::Io(std::io::Error::other(e.to_string())))
739 }
740 }
741 }
742
743 pub fn stream(&mut self) -> impl Stream<Item = NautilusWsMessage> + 'static {
751 let rx = self
752 .out_rx
753 .take()
754 .expect("Data stream receiver already taken or not connected");
755 let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
756 async_stream::stream! {
757 while let Some(data) = rx.recv().await {
758 yield data;
759 }
760 }
761 }
762
763 pub async fn wait_until_active(&self, timeout_secs: f64) -> Result<(), OKXWsError> {
769 let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
770
771 tokio::time::timeout(timeout, async {
772 while !self.is_active() {
773 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
774 }
775 })
776 .await
777 .map_err(|_| {
778 OKXWsError::ClientError(format!(
779 "WebSocket connection timeout after {timeout_secs} seconds"
780 ))
781 })?;
782
783 Ok(())
784 }
785
786 pub async fn close(&mut self) -> Result<(), Error> {
793 log::debug!("Starting close process");
794
795 self.signal.store(true, Ordering::Release);
796
797 if let Err(e) = self.cmd_tx.read().await.send(HandlerCommand::Disconnect) {
798 log::warn!("Failed to send disconnect command to handler: {e}");
799 } else {
800 log::debug!("Sent disconnect command to handler");
801 }
802
803 {
805 if false {
806 log::debug!("No active connection to disconnect");
807 }
808 }
809
810 if let Some(stream_handle) = self.task_handle.take() {
812 match Arc::try_unwrap(stream_handle) {
813 Ok(handle) => {
814 log::debug!("Waiting for stream handle to complete");
815 match tokio::time::timeout(Duration::from_secs(2), handle).await {
816 Ok(Ok(())) => log::debug!("Stream handle completed successfully"),
817 Ok(Err(e)) => log::error!("Stream handle encountered an error: {e:?}"),
818 Err(_) => {
819 log::warn!(
820 "Timeout waiting for stream handle, task may still be running"
821 );
822 }
824 }
825 }
826 Err(arc_handle) => {
827 log::debug!(
828 "Cannot take ownership of stream handle - other references exist, aborting task"
829 );
830 arc_handle.abort();
831 }
832 }
833 } else {
834 log::debug!("No stream handle to await");
835 }
836
837 log::debug!("Close process completed");
838
839 Ok(())
840 }
841
842 pub fn get_subscriptions(&self, instrument_id: InstrumentId) -> Vec<OKXWsChannel> {
844 let symbol = instrument_id.symbol.inner();
845 let mut channels = Vec::new();
846
847 for entry in self.subscriptions_inst_id.iter() {
848 let (channel, instruments) = entry.pair();
849 if instruments.contains(&symbol) {
850 channels.push(channel.clone());
851 }
852 }
853
854 channels
855 }
856
857 fn generate_unique_request_id(&self) -> String {
858 self.request_id_counter
859 .fetch_add(1, Ordering::SeqCst)
860 .to_string()
861 }
862
863 #[allow(
864 clippy::result_large_err,
865 reason = "OKXWsError contains large tungstenite::Error variant"
866 )]
867 async fn subscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
868 for arg in &args {
869 let topic = topic_from_subscription_arg(arg);
870 self.subscriptions_state.mark_subscribe(&topic);
871
872 if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
874 self.subscriptions_bare.insert(arg.channel.clone(), true);
876 } else {
877 if let Some(inst_type) = &arg.inst_type {
879 self.subscriptions_inst_type
880 .entry(arg.channel.clone())
881 .or_default()
882 .insert(*inst_type);
883 }
884
885 if let Some(inst_family) = &arg.inst_family {
887 self.subscriptions_inst_family
888 .entry(arg.channel.clone())
889 .or_default()
890 .insert(*inst_family);
891 }
892
893 if let Some(inst_id) = &arg.inst_id {
895 self.subscriptions_inst_id
896 .entry(arg.channel.clone())
897 .or_default()
898 .insert(*inst_id);
899 }
900 }
901 }
902
903 self.cmd_tx
904 .read()
905 .await
906 .send(HandlerCommand::Subscribe { args })
907 .map_err(|e| OKXWsError::ClientError(format!("Failed to send subscribe command: {e}")))
908 }
909
910 #[allow(clippy::collapsible_if)]
911 async fn unsubscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
912 for arg in &args {
913 let topic = topic_from_subscription_arg(arg);
914 self.subscriptions_state.mark_unsubscribe(&topic);
915
916 if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
918 self.subscriptions_bare.remove(&arg.channel);
920 } else {
921 if let Some(inst_type) = &arg.inst_type {
923 if let Some(mut entry) = self.subscriptions_inst_type.get_mut(&arg.channel) {
924 entry.remove(inst_type);
925 if entry.is_empty() {
926 drop(entry);
927 self.subscriptions_inst_type.remove(&arg.channel);
928 }
929 }
930 }
931
932 if let Some(inst_family) = &arg.inst_family {
934 if let Some(mut entry) = self.subscriptions_inst_family.get_mut(&arg.channel) {
935 entry.remove(inst_family);
936 if entry.is_empty() {
937 drop(entry);
938 self.subscriptions_inst_family.remove(&arg.channel);
939 }
940 }
941 }
942
943 if let Some(inst_id) = &arg.inst_id {
945 if let Some(mut entry) = self.subscriptions_inst_id.get_mut(&arg.channel) {
946 entry.remove(inst_id);
947 if entry.is_empty() {
948 drop(entry);
949 self.subscriptions_inst_id.remove(&arg.channel);
950 }
951 }
952 }
953 }
954 }
955
956 self.cmd_tx
957 .read()
958 .await
959 .send(HandlerCommand::Unsubscribe { args })
960 .map_err(|e| {
961 OKXWsError::ClientError(format!("Failed to send unsubscribe command: {e}"))
962 })
963 }
964
965 pub async fn unsubscribe_all(&self) -> Result<(), OKXWsError> {
974 const BATCH_SIZE: usize = 256;
975
976 let mut all_args = Vec::new();
977
978 for entry in self.subscriptions_inst_type.iter() {
979 let (channel, inst_types) = entry.pair();
980 for inst_type in inst_types {
981 all_args.push(OKXSubscriptionArg {
982 channel: channel.clone(),
983 inst_type: Some(*inst_type),
984 inst_family: None,
985 inst_id: None,
986 });
987 }
988 }
989
990 for entry in self.subscriptions_inst_family.iter() {
991 let (channel, inst_families) = entry.pair();
992 for inst_family in inst_families {
993 all_args.push(OKXSubscriptionArg {
994 channel: channel.clone(),
995 inst_type: None,
996 inst_family: Some(*inst_family),
997 inst_id: None,
998 });
999 }
1000 }
1001
1002 for entry in self.subscriptions_inst_id.iter() {
1003 let (channel, inst_ids) = entry.pair();
1004 for inst_id in inst_ids {
1005 all_args.push(OKXSubscriptionArg {
1006 channel: channel.clone(),
1007 inst_type: None,
1008 inst_family: None,
1009 inst_id: Some(*inst_id),
1010 });
1011 }
1012 }
1013
1014 for entry in self.subscriptions_bare.iter() {
1015 let channel = entry.key();
1016 all_args.push(OKXSubscriptionArg {
1017 channel: channel.clone(),
1018 inst_type: None,
1019 inst_family: None,
1020 inst_id: None,
1021 });
1022 }
1023
1024 if all_args.is_empty() {
1025 log::debug!("No active subscriptions to unsubscribe from");
1026 return Ok(());
1027 }
1028
1029 log::debug!("Batched unsubscribe from {} channels", all_args.len());
1030
1031 for chunk in all_args.chunks(BATCH_SIZE) {
1032 self.unsubscribe(chunk.to_vec()).await?;
1033 }
1034
1035 Ok(())
1036 }
1037
1038 pub async fn subscribe_instruments(
1050 &self,
1051 instrument_type: OKXInstrumentType,
1052 ) -> Result<(), OKXWsError> {
1053 let arg = OKXSubscriptionArg {
1054 channel: OKXWsChannel::Instruments,
1055 inst_type: Some(instrument_type),
1056 inst_family: None,
1057 inst_id: None,
1058 };
1059 self.subscribe(vec![arg]).await
1060 }
1061
1062 pub async fn subscribe_instrument(
1075 &self,
1076 instrument_id: InstrumentId,
1077 ) -> Result<(), OKXWsError> {
1078 let inst_type = okx_instrument_type_from_symbol(instrument_id.symbol.as_str());
1079
1080 let already_subscribed = self
1081 .subscriptions_inst_type
1082 .get(&OKXWsChannel::Instruments)
1083 .is_some_and(|types| types.contains(&inst_type));
1084
1085 if already_subscribed {
1086 log::debug!("Already subscribed to instrument type {inst_type:?} for {instrument_id}");
1087 return Ok(());
1088 }
1089
1090 log::debug!("Subscribing to instrument type {inst_type:?} for {instrument_id}");
1091 self.subscribe_instruments(inst_type).await
1092 }
1093
1094 pub async fn subscribe_book(&self, instrument_id: InstrumentId) -> anyhow::Result<()> {
1103 self.subscribe_book_with_depth(instrument_id, 0).await
1104 }
1105
1106 pub(crate) async fn subscribe_books_channel(
1108 &self,
1109 instrument_id: InstrumentId,
1110 ) -> Result<(), OKXWsError> {
1111 let arg = OKXSubscriptionArg {
1112 channel: OKXWsChannel::Books,
1113 inst_type: None,
1114 inst_family: None,
1115 inst_id: Some(instrument_id.symbol.inner()),
1116 };
1117 self.subscribe(vec![arg]).await
1118 }
1119
1120 pub async fn subscribe_book_depth5(
1132 &self,
1133 instrument_id: InstrumentId,
1134 ) -> Result<(), OKXWsError> {
1135 let arg = OKXSubscriptionArg {
1136 channel: OKXWsChannel::Books5,
1137 inst_type: None,
1138 inst_family: None,
1139 inst_id: Some(instrument_id.symbol.inner()),
1140 };
1141 self.subscribe(vec![arg]).await
1142 }
1143
1144 pub async fn subscribe_book50_l2_tbt(
1156 &self,
1157 instrument_id: InstrumentId,
1158 ) -> Result<(), OKXWsError> {
1159 let arg = OKXSubscriptionArg {
1160 channel: OKXWsChannel::Books50Tbt,
1161 inst_type: None,
1162 inst_family: None,
1163 inst_id: Some(instrument_id.symbol.inner()),
1164 };
1165 self.subscribe(vec![arg]).await
1166 }
1167
1168 pub async fn subscribe_book_l2_tbt(
1180 &self,
1181 instrument_id: InstrumentId,
1182 ) -> Result<(), OKXWsError> {
1183 let arg = OKXSubscriptionArg {
1184 channel: OKXWsChannel::BooksTbt,
1185 inst_type: None,
1186 inst_family: None,
1187 inst_id: Some(instrument_id.symbol.inner()),
1188 };
1189 self.subscribe(vec![arg]).await
1190 }
1191
1192 pub async fn subscribe_book_with_depth(
1206 &self,
1207 instrument_id: InstrumentId,
1208 depth: u16,
1209 ) -> anyhow::Result<()> {
1210 let vip = self.vip_level();
1211
1212 match depth {
1213 50 => {
1214 if vip < OKXVipLevel::Vip4 {
1215 anyhow::bail!(
1216 "VIP level {vip} insufficient for 50 depth subscription (requires VIP4)"
1217 );
1218 }
1219 self.subscribe_book50_l2_tbt(instrument_id)
1220 .await
1221 .map_err(|e| anyhow::anyhow!(e))
1222 }
1223 0 | 400 => {
1224 if vip >= OKXVipLevel::Vip5 {
1225 self.subscribe_book_l2_tbt(instrument_id)
1226 .await
1227 .map_err(|e| anyhow::anyhow!(e))
1228 } else {
1229 self.subscribe_books_channel(instrument_id)
1230 .await
1231 .map_err(|e| anyhow::anyhow!(e))
1232 }
1233 }
1234 _ => anyhow::bail!("Invalid depth {depth}, must be 0, 50, or 400"),
1235 }
1236 }
1237
1238 pub async fn subscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1251 let arg = OKXSubscriptionArg {
1252 channel: OKXWsChannel::BboTbt,
1253 inst_type: None,
1254 inst_family: None,
1255 inst_id: Some(instrument_id.symbol.inner()),
1256 };
1257 self.subscribe(vec![arg]).await
1258 }
1259
1260 pub async fn subscribe_trades(
1274 &self,
1275 instrument_id: InstrumentId,
1276 aggregated: bool,
1277 ) -> Result<(), OKXWsError> {
1278 let channel = if aggregated {
1279 OKXWsChannel::TradesAll
1280 } else {
1281 OKXWsChannel::Trades
1282 };
1283
1284 let arg = OKXSubscriptionArg {
1285 channel,
1286 inst_type: None,
1287 inst_family: None,
1288 inst_id: Some(instrument_id.symbol.inner()),
1289 };
1290 self.subscribe(vec![arg]).await
1291 }
1292
1293 pub async fn subscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1305 let arg = OKXSubscriptionArg {
1306 channel: OKXWsChannel::Tickers,
1307 inst_type: None,
1308 inst_family: None,
1309 inst_id: Some(instrument_id.symbol.inner()),
1310 };
1311 self.subscribe(vec![arg]).await
1312 }
1313
1314 pub async fn subscribe_mark_prices(
1326 &self,
1327 instrument_id: InstrumentId,
1328 ) -> Result<(), OKXWsError> {
1329 let arg = OKXSubscriptionArg {
1330 channel: OKXWsChannel::MarkPrice,
1331 inst_type: None,
1332 inst_family: None,
1333 inst_id: Some(instrument_id.symbol.inner()),
1334 };
1335 self.subscribe(vec![arg]).await
1336 }
1337
1338 pub async fn subscribe_index_prices(
1350 &self,
1351 instrument_id: InstrumentId,
1352 ) -> Result<(), OKXWsError> {
1353 let arg = OKXSubscriptionArg {
1354 channel: OKXWsChannel::IndexTickers,
1355 inst_type: None,
1356 inst_family: None,
1357 inst_id: Some(instrument_id.symbol.inner()),
1358 };
1359 self.subscribe(vec![arg]).await
1360 }
1361
1362 pub async fn subscribe_funding_rates(
1374 &self,
1375 instrument_id: InstrumentId,
1376 ) -> Result<(), OKXWsError> {
1377 let arg = OKXSubscriptionArg {
1378 channel: OKXWsChannel::FundingRate,
1379 inst_type: None,
1380 inst_family: None,
1381 inst_id: Some(instrument_id.symbol.inner()),
1382 };
1383 self.subscribe(vec![arg]).await
1384 }
1385
1386 pub async fn subscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1398 let channel = bar_spec_as_okx_channel(bar_type.spec())
1400 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1401
1402 let arg = OKXSubscriptionArg {
1403 channel,
1404 inst_type: None,
1405 inst_family: None,
1406 inst_id: Some(bar_type.instrument_id().symbol.inner()),
1407 };
1408 self.subscribe(vec![arg]).await
1409 }
1410
1411 pub async fn unsubscribe_instruments(
1417 &self,
1418 instrument_type: OKXInstrumentType,
1419 ) -> Result<(), OKXWsError> {
1420 let arg = OKXSubscriptionArg {
1421 channel: OKXWsChannel::Instruments,
1422 inst_type: Some(instrument_type),
1423 inst_family: None,
1424 inst_id: None,
1425 };
1426 self.unsubscribe(vec![arg]).await
1427 }
1428
1429 pub async fn unsubscribe_instrument(
1435 &self,
1436 instrument_id: InstrumentId,
1437 ) -> Result<(), OKXWsError> {
1438 let arg = OKXSubscriptionArg {
1439 channel: OKXWsChannel::Instruments,
1440 inst_type: None,
1441 inst_family: None,
1442 inst_id: Some(instrument_id.symbol.inner()),
1443 };
1444 self.unsubscribe(vec![arg]).await
1445 }
1446
1447 pub async fn unsubscribe_book(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1453 let arg = OKXSubscriptionArg {
1454 channel: OKXWsChannel::Books,
1455 inst_type: None,
1456 inst_family: None,
1457 inst_id: Some(instrument_id.symbol.inner()),
1458 };
1459 self.unsubscribe(vec![arg]).await
1460 }
1461
1462 pub async fn unsubscribe_book_depth5(
1468 &self,
1469 instrument_id: InstrumentId,
1470 ) -> Result<(), OKXWsError> {
1471 let arg = OKXSubscriptionArg {
1472 channel: OKXWsChannel::Books5,
1473 inst_type: None,
1474 inst_family: None,
1475 inst_id: Some(instrument_id.symbol.inner()),
1476 };
1477 self.unsubscribe(vec![arg]).await
1478 }
1479
1480 pub async fn unsubscribe_book50_l2_tbt(
1486 &self,
1487 instrument_id: InstrumentId,
1488 ) -> Result<(), OKXWsError> {
1489 let arg = OKXSubscriptionArg {
1490 channel: OKXWsChannel::Books50Tbt,
1491 inst_type: None,
1492 inst_family: None,
1493 inst_id: Some(instrument_id.symbol.inner()),
1494 };
1495 self.unsubscribe(vec![arg]).await
1496 }
1497
1498 pub async fn unsubscribe_book_l2_tbt(
1504 &self,
1505 instrument_id: InstrumentId,
1506 ) -> Result<(), OKXWsError> {
1507 let arg = OKXSubscriptionArg {
1508 channel: OKXWsChannel::BooksTbt,
1509 inst_type: None,
1510 inst_family: None,
1511 inst_id: Some(instrument_id.symbol.inner()),
1512 };
1513 self.unsubscribe(vec![arg]).await
1514 }
1515
1516 pub async fn unsubscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1522 let arg = OKXSubscriptionArg {
1523 channel: OKXWsChannel::BboTbt,
1524 inst_type: None,
1525 inst_family: None,
1526 inst_id: Some(instrument_id.symbol.inner()),
1527 };
1528 self.unsubscribe(vec![arg]).await
1529 }
1530
1531 pub async fn unsubscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1537 let arg = OKXSubscriptionArg {
1538 channel: OKXWsChannel::Tickers,
1539 inst_type: None,
1540 inst_family: None,
1541 inst_id: Some(instrument_id.symbol.inner()),
1542 };
1543 self.unsubscribe(vec![arg]).await
1544 }
1545
1546 pub async fn unsubscribe_mark_prices(
1552 &self,
1553 instrument_id: InstrumentId,
1554 ) -> Result<(), OKXWsError> {
1555 let arg = OKXSubscriptionArg {
1556 channel: OKXWsChannel::MarkPrice,
1557 inst_type: None,
1558 inst_family: None,
1559 inst_id: Some(instrument_id.symbol.inner()),
1560 };
1561 self.unsubscribe(vec![arg]).await
1562 }
1563
1564 pub async fn unsubscribe_index_prices(
1570 &self,
1571 instrument_id: InstrumentId,
1572 ) -> Result<(), OKXWsError> {
1573 let arg = OKXSubscriptionArg {
1574 channel: OKXWsChannel::IndexTickers,
1575 inst_type: None,
1576 inst_family: None,
1577 inst_id: Some(instrument_id.symbol.inner()),
1578 };
1579 self.unsubscribe(vec![arg]).await
1580 }
1581
1582 pub async fn unsubscribe_funding_rates(
1588 &self,
1589 instrument_id: InstrumentId,
1590 ) -> Result<(), OKXWsError> {
1591 let arg = OKXSubscriptionArg {
1592 channel: OKXWsChannel::FundingRate,
1593 inst_type: None,
1594 inst_family: None,
1595 inst_id: Some(instrument_id.symbol.inner()),
1596 };
1597 self.unsubscribe(vec![arg]).await
1598 }
1599
1600 pub async fn unsubscribe_trades(
1606 &self,
1607 instrument_id: InstrumentId,
1608 aggregated: bool,
1609 ) -> Result<(), OKXWsError> {
1610 let channel = if aggregated {
1611 OKXWsChannel::TradesAll
1612 } else {
1613 OKXWsChannel::Trades
1614 };
1615
1616 let arg = OKXSubscriptionArg {
1617 channel,
1618 inst_type: None,
1619 inst_family: None,
1620 inst_id: Some(instrument_id.symbol.inner()),
1621 };
1622 self.unsubscribe(vec![arg]).await
1623 }
1624
1625 pub async fn unsubscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1631 let channel = bar_spec_as_okx_channel(bar_type.spec())
1633 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1634
1635 let arg = OKXSubscriptionArg {
1636 channel,
1637 inst_type: None,
1638 inst_family: None,
1639 inst_id: Some(bar_type.instrument_id().symbol.inner()),
1640 };
1641 self.unsubscribe(vec![arg]).await
1642 }
1643
1644 pub async fn subscribe_orders(
1650 &self,
1651 instrument_type: OKXInstrumentType,
1652 ) -> Result<(), OKXWsError> {
1653 let arg = OKXSubscriptionArg {
1654 channel: OKXWsChannel::Orders,
1655 inst_type: Some(instrument_type),
1656 inst_family: None,
1657 inst_id: None,
1658 };
1659 self.subscribe(vec![arg]).await
1660 }
1661
1662 pub async fn unsubscribe_orders(
1668 &self,
1669 instrument_type: OKXInstrumentType,
1670 ) -> Result<(), OKXWsError> {
1671 let arg = OKXSubscriptionArg {
1672 channel: OKXWsChannel::Orders,
1673 inst_type: Some(instrument_type),
1674 inst_family: None,
1675 inst_id: None,
1676 };
1677 self.unsubscribe(vec![arg]).await
1678 }
1679
1680 pub async fn subscribe_orders_algo(
1686 &self,
1687 instrument_type: OKXInstrumentType,
1688 ) -> Result<(), OKXWsError> {
1689 let arg = OKXSubscriptionArg {
1690 channel: OKXWsChannel::OrdersAlgo,
1691 inst_type: Some(instrument_type),
1692 inst_family: None,
1693 inst_id: None,
1694 };
1695 self.subscribe(vec![arg]).await
1696 }
1697
1698 pub async fn unsubscribe_orders_algo(
1704 &self,
1705 instrument_type: OKXInstrumentType,
1706 ) -> Result<(), OKXWsError> {
1707 let arg = OKXSubscriptionArg {
1708 channel: OKXWsChannel::OrdersAlgo,
1709 inst_type: Some(instrument_type),
1710 inst_family: None,
1711 inst_id: None,
1712 };
1713 self.unsubscribe(vec![arg]).await
1714 }
1715
1716 pub async fn subscribe_fills(
1722 &self,
1723 instrument_type: OKXInstrumentType,
1724 ) -> Result<(), OKXWsError> {
1725 let arg = OKXSubscriptionArg {
1726 channel: OKXWsChannel::Fills,
1727 inst_type: Some(instrument_type),
1728 inst_family: None,
1729 inst_id: None,
1730 };
1731 self.subscribe(vec![arg]).await
1732 }
1733
1734 pub async fn unsubscribe_fills(
1740 &self,
1741 instrument_type: OKXInstrumentType,
1742 ) -> Result<(), OKXWsError> {
1743 let arg = OKXSubscriptionArg {
1744 channel: OKXWsChannel::Fills,
1745 inst_type: Some(instrument_type),
1746 inst_family: None,
1747 inst_id: None,
1748 };
1749 self.unsubscribe(vec![arg]).await
1750 }
1751
1752 pub async fn subscribe_account(&self) -> Result<(), OKXWsError> {
1758 let arg = OKXSubscriptionArg {
1759 channel: OKXWsChannel::Account,
1760 inst_type: None,
1761 inst_family: None,
1762 inst_id: None,
1763 };
1764 self.subscribe(vec![arg]).await
1765 }
1766
1767 pub async fn unsubscribe_account(&self) -> Result<(), OKXWsError> {
1773 let arg = OKXSubscriptionArg {
1774 channel: OKXWsChannel::Account,
1775 inst_type: None,
1776 inst_family: None,
1777 inst_id: None,
1778 };
1779 self.unsubscribe(vec![arg]).await
1780 }
1781
1782 pub async fn subscribe_positions(
1792 &self,
1793 inst_type: OKXInstrumentType,
1794 ) -> Result<(), OKXWsError> {
1795 let arg = OKXSubscriptionArg {
1796 channel: OKXWsChannel::Positions,
1797 inst_type: Some(inst_type),
1798 inst_family: None,
1799 inst_id: None,
1800 };
1801 self.subscribe(vec![arg]).await
1802 }
1803
1804 pub async fn unsubscribe_positions(
1810 &self,
1811 inst_type: OKXInstrumentType,
1812 ) -> Result<(), OKXWsError> {
1813 let arg = OKXSubscriptionArg {
1814 channel: OKXWsChannel::Positions,
1815 inst_type: Some(inst_type),
1816 inst_family: None,
1817 inst_id: None,
1818 };
1819 self.unsubscribe(vec![arg]).await
1820 }
1821
1822 async fn ws_batch_place_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1828 let request_id = self.generate_unique_request_id();
1829 let cmd = HandlerCommand::BatchPlaceOrders { args, request_id };
1830
1831 self.send_cmd(cmd).await
1832 }
1833
1834 async fn ws_batch_cancel_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1840 let request_id = self.generate_unique_request_id();
1841 let cmd = HandlerCommand::BatchCancelOrders { args, request_id };
1842
1843 self.send_cmd(cmd).await
1844 }
1845
1846 async fn ws_batch_amend_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1852 let request_id = self.generate_unique_request_id();
1853 let cmd = HandlerCommand::BatchAmendOrders { args, request_id };
1854
1855 self.send_cmd(cmd).await
1856 }
1857
1858 #[allow(clippy::too_many_arguments)]
1870 pub async fn submit_order(
1871 &self,
1872 trader_id: TraderId,
1873 strategy_id: StrategyId,
1874 instrument_id: InstrumentId,
1875 td_mode: OKXTradeMode,
1876 client_order_id: ClientOrderId,
1877 order_side: OrderSide,
1878 order_type: OrderType,
1879 quantity: Quantity,
1880 time_in_force: Option<TimeInForce>,
1881 price: Option<Price>,
1882 trigger_price: Option<Price>,
1883 post_only: Option<bool>,
1884 reduce_only: Option<bool>,
1885 quote_quantity: Option<bool>,
1886 position_side: Option<PositionSide>,
1887 ) -> Result<(), OKXWsError> {
1888 if !OKX_SUPPORTED_ORDER_TYPES.contains(&order_type) {
1889 return Err(OKXWsError::ClientError(format!(
1890 "Unsupported order type: {order_type:?}",
1891 )));
1892 }
1893
1894 if let Some(tif) = time_in_force
1895 && !OKX_SUPPORTED_TIME_IN_FORCE.contains(&tif)
1896 {
1897 return Err(OKXWsError::ClientError(format!(
1898 "Unsupported time in force: {tif:?}",
1899 )));
1900 }
1901
1902 let mut builder = WsPostOrderParamsBuilder::default();
1903
1904 builder.inst_id(instrument_id.symbol.as_str());
1905 builder.td_mode(td_mode);
1906 builder.cl_ord_id(client_order_id.as_str());
1907
1908 let instrument = self
1909 .instruments_cache
1910 .get(&instrument_id.symbol.inner())
1911 .ok_or_else(|| {
1912 OKXWsError::ClientError(format!("Unknown instrument {instrument_id}"))
1913 })?;
1914
1915 let instrument_type =
1916 okx_instrument_type(&instrument).map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1917 let quote_currency = instrument.quote_currency();
1918
1919 match instrument_type {
1920 OKXInstrumentType::Spot => {
1921 builder.ccy(quote_currency.to_string());
1923 }
1924 OKXInstrumentType::Margin => {
1925 builder.ccy(quote_currency.to_string());
1926
1927 if let Some(ro) = reduce_only
1928 && ro
1929 {
1930 builder.reduce_only(ro);
1931 }
1932 }
1933 OKXInstrumentType::Swap | OKXInstrumentType::Futures => {
1934 builder.ccy(quote_currency.to_string());
1936
1937 if position_side.is_none() {
1940 builder.pos_side(OKXPositionSide::Net);
1941 }
1942 }
1943 _ => {
1944 builder.ccy(quote_currency.to_string());
1945
1946 if position_side.is_none() {
1948 builder.pos_side(OKXPositionSide::Net);
1949 }
1950
1951 if let Some(ro) = reduce_only
1952 && ro
1953 {
1954 builder.reduce_only(ro);
1955 }
1956 }
1957 };
1958
1959 if instrument_type == OKXInstrumentType::Spot
1966 && order_type == OrderType::Market
1967 && td_mode == OKXTradeMode::Cash
1968 {
1969 match quote_quantity {
1970 Some(true) => {
1971 builder.tgt_ccy(OKXTargetCurrency::QuoteCcy);
1973 }
1974 Some(false) => {
1975 if order_side == OrderSide::Buy {
1976 builder.tgt_ccy(OKXTargetCurrency::BaseCcy);
1978 }
1979 }
1981 None => {
1982 }
1984 }
1985 }
1986
1987 builder.side(order_side);
1988
1989 if let Some(pos_side) = position_side {
1990 builder.pos_side(pos_side);
1991 };
1992
1993 let (okx_ord_type, price) = if post_only.unwrap_or(false) {
1996 (OKXOrderType::PostOnly, price)
1997 } else if let Some(tif) = time_in_force {
1998 match (order_type, tif) {
1999 (OrderType::Market, TimeInForce::Fok) => {
2000 return Err(OKXWsError::ClientError(
2001 "Market orders with FOK time-in-force are not supported by OKX. Use Limit order with FOK instead.".to_string()
2002 ));
2003 }
2004 (OrderType::Market, TimeInForce::Ioc) => (OKXOrderType::OptimalLimitIoc, price),
2005 (OrderType::Limit, TimeInForce::Fok) => (OKXOrderType::Fok, price),
2006 (OrderType::Limit, TimeInForce::Ioc) => (OKXOrderType::Ioc, price),
2007 _ => (OKXOrderType::from(order_type), price),
2008 }
2009 } else {
2010 (OKXOrderType::from(order_type), price)
2011 };
2012
2013 log::debug!(
2014 "Order type mapping: order_type={order_type:?}, time_in_force={time_in_force:?}, post_only={post_only:?} -> okx_ord_type={okx_ord_type:?}"
2015 );
2016
2017 builder.ord_type(okx_ord_type);
2018 builder.sz(quantity.to_string());
2019
2020 if let Some(tp) = trigger_price {
2021 builder.px(tp.to_string());
2022 } else if let Some(p) = price {
2023 builder.px(p.to_string());
2024 }
2025
2026 builder.tag(OKX_NAUTILUS_BROKER_ID);
2027
2028 let params = builder
2029 .build()
2030 .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
2031
2032 self.active_client_orders
2033 .insert(client_order_id, (trader_id, strategy_id, instrument_id));
2034
2035 let cmd = HandlerCommand::PlaceOrder {
2036 params,
2037 client_order_id,
2038 trader_id,
2039 strategy_id,
2040 instrument_id,
2041 };
2042
2043 self.send_cmd(cmd).await
2044 }
2045
2046 #[allow(clippy::too_many_arguments)]
2062 pub async fn modify_order(
2063 &self,
2064 trader_id: TraderId,
2065 strategy_id: StrategyId,
2066 instrument_id: InstrumentId,
2067 client_order_id: Option<ClientOrderId>,
2068 price: Option<Price>,
2069 quantity: Option<Quantity>,
2070 venue_order_id: Option<VenueOrderId>,
2071 ) -> Result<(), OKXWsError> {
2072 let mut builder = WsAmendOrderParamsBuilder::default();
2073
2074 builder.inst_id(instrument_id.symbol.as_str());
2075
2076 if let Some(venue_order_id) = venue_order_id {
2077 builder.ord_id(venue_order_id.as_str());
2078 }
2079
2080 if let Some(client_order_id) = client_order_id {
2081 builder.cl_ord_id(client_order_id.as_str());
2082 }
2083
2084 if let Some(price) = price {
2085 builder.new_px(price.to_string());
2086 }
2087
2088 if let Some(quantity) = quantity {
2089 builder.new_sz(quantity.to_string());
2090 }
2091
2092 let params = builder
2093 .build()
2094 .map_err(|e| OKXWsError::ClientError(format!("Build amend params error: {e}")))?;
2095
2096 if let Some(client_order_id) = client_order_id {
2099 let cmd = HandlerCommand::AmendOrder {
2100 params,
2101 client_order_id,
2102 trader_id,
2103 strategy_id,
2104 instrument_id,
2105 venue_order_id,
2106 };
2107
2108 self.send_cmd(cmd).await
2109 } else {
2110 Err(OKXWsError::ClientError(
2112 "Cannot amend order without client_order_id".to_string(),
2113 ))
2114 }
2115 }
2116
2117 #[allow(clippy::too_many_arguments)]
2128 pub async fn cancel_order(
2129 &self,
2130 trader_id: TraderId,
2131 strategy_id: StrategyId,
2132 instrument_id: InstrumentId,
2133 client_order_id: Option<ClientOrderId>,
2134 venue_order_id: Option<VenueOrderId>,
2135 ) -> Result<(), OKXWsError> {
2136 let cmd = HandlerCommand::CancelOrder {
2137 client_order_id,
2138 venue_order_id,
2139 instrument_id,
2140 trader_id,
2141 strategy_id,
2142 };
2143
2144 self.send_cmd(cmd).await
2145 }
2146
2147 pub async fn mass_cancel_orders(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
2157 let cmd = HandlerCommand::MassCancel { instrument_id };
2158
2159 self.send_cmd(cmd).await
2160 }
2161
2162 #[allow(clippy::type_complexity)]
2169 #[allow(clippy::too_many_arguments)]
2170 pub async fn batch_submit_orders(
2171 &self,
2172 orders: Vec<(
2173 OKXInstrumentType,
2174 InstrumentId,
2175 OKXTradeMode,
2176 ClientOrderId,
2177 OrderSide,
2178 Option<PositionSide>,
2179 OrderType,
2180 Quantity,
2181 Option<Price>,
2182 Option<Price>,
2183 Option<bool>,
2184 Option<bool>,
2185 )>,
2186 ) -> Result<(), OKXWsError> {
2187 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2188 for (
2189 inst_type,
2190 inst_id,
2191 td_mode,
2192 cl_ord_id,
2193 ord_side,
2194 pos_side,
2195 ord_type,
2196 qty,
2197 pr,
2198 tp,
2199 post_only,
2200 reduce_only,
2201 ) in orders
2202 {
2203 let mut builder = WsPostOrderParamsBuilder::default();
2204 builder.inst_type(inst_type);
2205 builder.inst_id(inst_id.symbol.inner());
2206 builder.td_mode(td_mode);
2207 builder.cl_ord_id(cl_ord_id.as_str());
2208 builder.side(ord_side);
2209
2210 if let Some(ps) = pos_side {
2211 builder.pos_side(OKXPositionSide::from(ps));
2212 }
2213
2214 let okx_ord_type = if post_only.unwrap_or(false) {
2215 OKXOrderType::PostOnly
2216 } else {
2217 OKXOrderType::from(ord_type)
2218 };
2219
2220 builder.ord_type(okx_ord_type);
2221 builder.sz(qty.to_string());
2222
2223 if let Some(p) = pr {
2224 builder.px(p.to_string());
2225 } else if let Some(p) = tp {
2226 builder.px(p.to_string());
2227 }
2228
2229 if let Some(ro) = reduce_only {
2230 builder.reduce_only(ro);
2231 }
2232
2233 builder.tag(OKX_NAUTILUS_BROKER_ID);
2234
2235 let params = builder
2236 .build()
2237 .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
2238 let val =
2239 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2240 args.push(val);
2241 }
2242
2243 self.ws_batch_place_orders(args).await
2244 }
2245
2246 #[allow(clippy::type_complexity)]
2253 #[allow(clippy::too_many_arguments)]
2254 pub async fn batch_modify_orders(
2255 &self,
2256 orders: Vec<(
2257 OKXInstrumentType,
2258 InstrumentId,
2259 ClientOrderId,
2260 ClientOrderId,
2261 Option<Price>,
2262 Option<Quantity>,
2263 )>,
2264 ) -> Result<(), OKXWsError> {
2265 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2266 for (_inst_type, inst_id, cl_ord_id, new_cl_ord_id, pr, sz) in orders {
2267 let mut builder = WsAmendOrderParamsBuilder::default();
2268 builder.inst_id(inst_id.symbol.inner());
2270 builder.cl_ord_id(cl_ord_id.as_str());
2271 builder.new_cl_ord_id(new_cl_ord_id.as_str());
2272
2273 if let Some(p) = pr {
2274 builder.new_px(p.to_string());
2275 }
2276
2277 if let Some(q) = sz {
2278 builder.new_sz(q.to_string());
2279 }
2280
2281 let params = builder.build().map_err(|e| {
2282 OKXWsError::ClientError(format!("Build amend batch params error: {e}"))
2283 })?;
2284 let val =
2285 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2286 args.push(val);
2287 }
2288
2289 self.ws_batch_amend_orders(args).await
2290 }
2291
2292 #[allow(clippy::type_complexity)]
2305 pub async fn batch_cancel_orders(
2306 &self,
2307 orders: Vec<(InstrumentId, Option<ClientOrderId>, Option<VenueOrderId>)>,
2308 ) -> Result<(), OKXWsError> {
2309 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2310 for (inst_id, cl_ord_id, ord_id) in orders {
2311 let mut builder = WsCancelOrderParamsBuilder::default();
2312 builder.inst_id(inst_id.symbol.inner());
2314
2315 if let Some(c) = cl_ord_id {
2316 builder.cl_ord_id(c.as_str());
2317 }
2318
2319 if let Some(o) = ord_id {
2320 builder.ord_id(o.as_str());
2321 }
2322
2323 let params = builder.build().map_err(|e| {
2324 OKXWsError::ClientError(format!("Build cancel batch params error: {e}"))
2325 })?;
2326 let val =
2327 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2328 args.push(val);
2329 }
2330
2331 self.ws_batch_cancel_orders(args).await
2332 }
2333
2334 #[allow(clippy::too_many_arguments)]
2345 pub async fn submit_algo_order(
2346 &self,
2347 trader_id: TraderId,
2348 strategy_id: StrategyId,
2349 instrument_id: InstrumentId,
2350 td_mode: OKXTradeMode,
2351 client_order_id: ClientOrderId,
2352 order_side: OrderSide,
2353 order_type: OrderType,
2354 quantity: Quantity,
2355 trigger_price: Price,
2356 trigger_type: Option<TriggerType>,
2357 limit_price: Option<Price>,
2358 reduce_only: Option<bool>,
2359 ) -> Result<(), OKXWsError> {
2360 if !is_conditional_order(order_type) {
2361 return Err(OKXWsError::ClientError(format!(
2362 "Order type {order_type:?} is not a conditional order"
2363 )));
2364 }
2365
2366 let mut builder = WsPostAlgoOrderParamsBuilder::default();
2367 if !matches!(order_side, OrderSide::Buy | OrderSide::Sell) {
2368 return Err(OKXWsError::ClientError(
2369 "Invalid order side for OKX".to_string(),
2370 ));
2371 }
2372
2373 builder.inst_id(instrument_id.symbol.inner());
2374 builder.td_mode(td_mode);
2375 builder.cl_ord_id(client_order_id.as_str());
2376 builder.side(order_side);
2377 builder.ord_type(
2378 conditional_order_to_algo_type(order_type)
2379 .map_err(|e| OKXWsError::ClientError(e.to_string()))?,
2380 );
2381 builder.sz(quantity.to_string());
2382 builder.trigger_px(trigger_price.to_string());
2383
2384 let okx_trigger_type = trigger_type.map_or(OKXTriggerType::Last, Into::into);
2386 builder.trigger_px_type(okx_trigger_type);
2387
2388 if matches!(order_type, OrderType::StopLimit | OrderType::LimitIfTouched)
2390 && let Some(price) = limit_price
2391 {
2392 builder.order_px(price.to_string());
2393 }
2394
2395 if let Some(reduce) = reduce_only {
2396 builder.reduce_only(reduce);
2397 }
2398
2399 builder.tag(OKX_NAUTILUS_BROKER_ID);
2400
2401 let params = builder
2402 .build()
2403 .map_err(|e| OKXWsError::ClientError(format!("Build algo order params error: {e}")))?;
2404
2405 self.active_client_orders
2406 .insert(client_order_id, (trader_id, strategy_id, instrument_id));
2407
2408 let cmd = HandlerCommand::PlaceAlgoOrder {
2409 params,
2410 client_order_id,
2411 trader_id,
2412 strategy_id,
2413 instrument_id,
2414 };
2415
2416 self.send_cmd(cmd).await
2417 }
2418
2419 pub async fn cancel_algo_order(
2430 &self,
2431 trader_id: TraderId,
2432 strategy_id: StrategyId,
2433 instrument_id: InstrumentId,
2434 client_order_id: Option<ClientOrderId>,
2435 algo_order_id: Option<String>,
2436 ) -> Result<(), OKXWsError> {
2437 let cmd = HandlerCommand::CancelAlgoOrder {
2438 client_order_id,
2439 algo_order_id: algo_order_id.map(|id| VenueOrderId::from(id.as_str())),
2440 instrument_id,
2441 trader_id,
2442 strategy_id,
2443 };
2444
2445 self.send_cmd(cmd).await
2446 }
2447
2448 async fn send_cmd(&self, cmd: HandlerCommand) -> Result<(), OKXWsError> {
2450 self.cmd_tx
2451 .read()
2452 .await
2453 .send(cmd)
2454 .map_err(|e| OKXWsError::ClientError(format!("Handler not available: {e}")))
2455 }
2456}
2457
2458#[cfg(test)]
2459mod tests {
2460 use nautilus_core::time::get_atomic_clock_realtime;
2461 use nautilus_network::RECONNECTED;
2462 use rstest::rstest;
2463 use tokio_tungstenite::tungstenite::Message;
2464
2465 use super::*;
2466 use crate::{
2467 common::{
2468 consts::OKX_POST_ONLY_CANCEL_SOURCE,
2469 enums::{OKXExecType, OKXOrderCategory, OKXOrderStatus, OKXSide},
2470 },
2471 websocket::{
2472 handler::OKXWsFeedHandler,
2473 messages::{OKXOrderMsg, OKXWebSocketError, OKXWsMessage},
2474 },
2475 };
2476
2477 #[rstest]
2478 fn test_timestamp_format_for_websocket_auth() {
2479 let timestamp = SystemTime::now()
2480 .duration_since(SystemTime::UNIX_EPOCH)
2481 .expect("System time should be after UNIX epoch")
2482 .as_secs()
2483 .to_string();
2484
2485 assert!(timestamp.parse::<u64>().is_ok());
2486 assert_eq!(timestamp.len(), 10);
2487 assert!(timestamp.chars().all(|c| c.is_ascii_digit()));
2488 }
2489
2490 #[rstest]
2491 fn test_new_without_credentials() {
2492 let client = OKXWebSocketClient::default();
2493 assert!(client.credential.is_none());
2494 assert_eq!(client.api_key(), None);
2495 }
2496
2497 #[rstest]
2498 fn test_new_with_credentials() {
2499 let client = OKXWebSocketClient::new(
2500 None,
2501 Some("test_key".to_string()),
2502 Some("test_secret".to_string()),
2503 Some("test_passphrase".to_string()),
2504 None,
2505 None,
2506 )
2507 .unwrap();
2508 assert!(client.credential.is_some());
2509 assert_eq!(client.api_key(), Some("test_key"));
2510 }
2511
2512 #[rstest]
2513 fn test_new_partial_credentials_fails() {
2514 let result = OKXWebSocketClient::new(
2515 None,
2516 Some("test_key".to_string()),
2517 None,
2518 Some("test_passphrase".to_string()),
2519 None,
2520 None,
2521 );
2522 assert!(result.is_err());
2523 }
2524
2525 #[rstest]
2526 fn test_request_id_generation() {
2527 let client = OKXWebSocketClient::default();
2528
2529 let initial_counter = client.request_id_counter.load(Ordering::SeqCst);
2530
2531 let id1 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
2532 let id2 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
2533
2534 assert_eq!(id1, initial_counter);
2535 assert_eq!(id2, initial_counter + 1);
2536 assert_eq!(
2537 client.request_id_counter.load(Ordering::SeqCst),
2538 initial_counter + 2
2539 );
2540 }
2541
2542 #[rstest]
2543 fn test_client_state_management() {
2544 let client = OKXWebSocketClient::default();
2545
2546 assert!(client.is_closed());
2547 assert!(!client.is_active());
2548
2549 let client_with_heartbeat =
2550 OKXWebSocketClient::new(None, None, None, None, None, Some(30)).unwrap();
2551
2552 assert!(client_with_heartbeat.heartbeat.is_some());
2553 assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
2554 }
2555
2556 #[rstest]
2561 fn test_websocket_error_handling() {
2562 let clock = get_atomic_clock_realtime();
2563 let ts = clock.get_time_ns().as_u64();
2564
2565 let error = OKXWebSocketError {
2566 code: "60012".to_string(),
2567 message: "Invalid request".to_string(),
2568 conn_id: None,
2569 timestamp: ts,
2570 };
2571
2572 assert_eq!(error.code, "60012");
2573 assert_eq!(error.message, "Invalid request");
2574 assert_eq!(error.timestamp, ts);
2575
2576 let nautilus_msg = NautilusWsMessage::Error(error);
2577 match nautilus_msg {
2578 NautilusWsMessage::Error(e) => {
2579 assert_eq!(e.code, "60012");
2580 assert_eq!(e.message, "Invalid request");
2581 }
2582 _ => panic!("Expected Error variant"),
2583 }
2584 }
2585
2586 #[rstest]
2587 fn test_request_id_generation_sequence() {
2588 let client = OKXWebSocketClient::default();
2589
2590 let initial_counter = client
2591 .request_id_counter
2592 .load(std::sync::atomic::Ordering::SeqCst);
2593 let mut ids = Vec::new();
2594 for _ in 0..10 {
2595 let id = client
2596 .request_id_counter
2597 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2598 ids.push(id);
2599 }
2600
2601 for (i, &id) in ids.iter().enumerate() {
2602 assert_eq!(id, initial_counter + i as u64);
2603 }
2604
2605 assert_eq!(
2606 client
2607 .request_id_counter
2608 .load(std::sync::atomic::Ordering::SeqCst),
2609 initial_counter + 10
2610 );
2611 }
2612
2613 #[rstest]
2614 fn test_client_state_transitions() {
2615 let client = OKXWebSocketClient::default();
2616
2617 assert!(client.is_closed());
2618 assert!(!client.is_active());
2619
2620 let client_with_heartbeat = OKXWebSocketClient::new(
2621 None,
2622 None,
2623 None,
2624 None,
2625 None,
2626 Some(30), )
2628 .unwrap();
2629
2630 assert!(client_with_heartbeat.heartbeat.is_some());
2631 assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
2632
2633 let account_id = AccountId::from("test-account-123");
2634 let client_with_account =
2635 OKXWebSocketClient::new(None, None, None, None, Some(account_id), None).unwrap();
2636
2637 assert_eq!(client_with_account.account_id, account_id);
2638 }
2639
2640 #[rstest]
2641 fn test_websocket_error_scenarios() {
2642 let clock = get_atomic_clock_realtime();
2643 let ts = clock.get_time_ns().as_u64();
2644
2645 let error_scenarios = vec![
2646 ("60012", "Invalid request", None),
2647 ("60009", "Invalid API key", Some("conn-123".to_string())),
2648 ("60014", "Too many requests", None),
2649 ("50001", "Order not found", None),
2650 ];
2651
2652 for (code, message, conn_id) in error_scenarios {
2653 let error = OKXWebSocketError {
2654 code: code.to_string(),
2655 message: message.to_string(),
2656 conn_id: conn_id.clone(),
2657 timestamp: ts,
2658 };
2659
2660 assert_eq!(error.code, code);
2661 assert_eq!(error.message, message);
2662 assert_eq!(error.conn_id, conn_id);
2663 assert_eq!(error.timestamp, ts);
2664
2665 let nautilus_msg = NautilusWsMessage::Error(error);
2666 match nautilus_msg {
2667 NautilusWsMessage::Error(e) => {
2668 assert_eq!(e.code, code);
2669 assert_eq!(e.message, message);
2670 assert_eq!(e.conn_id, conn_id);
2671 }
2672 _ => panic!("Expected Error variant"),
2673 }
2674 }
2675 }
2676
2677 #[rstest]
2678 fn test_feed_handler_reconnection_detection() {
2679 let msg = Message::Text(RECONNECTED.to_string().into());
2680 let result = OKXWsFeedHandler::parse_raw_message(msg);
2681 assert!(matches!(result, Some(OKXWsMessage::Reconnected)));
2682 }
2683
2684 #[rstest]
2685 fn test_feed_handler_normal_message_processing() {
2686 let ping_msg = Message::Text(TEXT_PING.to_string().into());
2688 let result = OKXWsFeedHandler::parse_raw_message(ping_msg);
2689 assert!(matches!(result, Some(OKXWsMessage::Ping)));
2690
2691 let sub_msg = r#"{
2693 "event": "subscribe",
2694 "arg": {
2695 "channel": "tickers",
2696 "instType": "SPOT"
2697 },
2698 "connId": "a4d3ae55"
2699 }"#;
2700
2701 let sub_result =
2702 OKXWsFeedHandler::parse_raw_message(Message::Text(sub_msg.to_string().into()));
2703 assert!(matches!(
2704 sub_result,
2705 Some(OKXWsMessage::Subscription { .. })
2706 ));
2707 }
2708
2709 #[rstest]
2710 fn test_feed_handler_close_message() {
2711 let result = OKXWsFeedHandler::parse_raw_message(Message::Close(None));
2713 assert!(result.is_none());
2714 }
2715
2716 #[rstest]
2717 fn test_reconnection_message_constant() {
2718 assert_eq!(RECONNECTED, "__RECONNECTED__");
2719 }
2720
2721 #[rstest]
2722 fn test_multiple_reconnection_signals() {
2723 for _ in 0..3 {
2725 let msg = Message::Text(RECONNECTED.to_string().into());
2726 let result = OKXWsFeedHandler::parse_raw_message(msg);
2727 assert!(matches!(result, Some(OKXWsMessage::Reconnected)));
2728 }
2729 }
2730
2731 #[tokio::test]
2732 async fn test_wait_until_active_timeout() {
2733 let client = OKXWebSocketClient::new(
2734 None,
2735 Some("test_key".to_string()),
2736 Some("test_secret".to_string()),
2737 Some("test_passphrase".to_string()),
2738 Some(AccountId::from("test-account")),
2739 None,
2740 )
2741 .unwrap();
2742
2743 let result = client.wait_until_active(0.1).await;
2745
2746 assert!(result.is_err());
2747 assert!(!client.is_active());
2748 }
2749
2750 fn sample_canceled_order_msg() -> OKXOrderMsg {
2751 OKXOrderMsg {
2752 acc_fill_sz: Some("0".to_string()),
2753 avg_px: "0".to_string(),
2754 c_time: 0,
2755 cancel_source: None,
2756 cancel_source_reason: None,
2757 category: OKXOrderCategory::Normal,
2758 ccy: Ustr::from("USDT"),
2759 cl_ord_id: "order-1".to_string(),
2760 algo_cl_ord_id: None,
2761 fee: None,
2762 fee_ccy: Ustr::from("USDT"),
2763 fill_px: "0".to_string(),
2764 fill_sz: "0".to_string(),
2765 fill_time: 0,
2766 inst_id: Ustr::from("ETH-USDT-SWAP"),
2767 inst_type: OKXInstrumentType::Swap,
2768 lever: "1".to_string(),
2769 ord_id: Ustr::from("123456"),
2770 ord_type: OKXOrderType::Limit,
2771 pnl: "0".to_string(),
2772 pos_side: OKXPositionSide::Net,
2773 px: "0".to_string(),
2774 reduce_only: "false".to_string(),
2775 side: OKXSide::Buy,
2776 state: OKXOrderStatus::Canceled,
2777 exec_type: OKXExecType::None,
2778 sz: "1".to_string(),
2779 td_mode: OKXTradeMode::Cross,
2780 tgt_ccy: None,
2781 trade_id: String::new(),
2782 u_time: 0,
2783 }
2784 }
2785
2786 #[rstest]
2787 fn test_is_post_only_auto_cancel_detects_cancel_source() {
2788 let mut msg = sample_canceled_order_msg();
2789 msg.cancel_source = Some(OKX_POST_ONLY_CANCEL_SOURCE.to_string());
2790
2791 assert!(OKXWsFeedHandler::is_post_only_auto_cancel(&msg));
2792 }
2793
2794 #[rstest]
2795 fn test_is_post_only_auto_cancel_detects_reason() {
2796 let mut msg = sample_canceled_order_msg();
2797 msg.cancel_source_reason = Some("POST_ONLY would take liquidity".to_string());
2798
2799 assert!(OKXWsFeedHandler::is_post_only_auto_cancel(&msg));
2800 }
2801
2802 #[rstest]
2803 fn test_is_post_only_auto_cancel_false_without_markers() {
2804 let msg = sample_canceled_order_msg();
2805
2806 assert!(!OKXWsFeedHandler::is_post_only_auto_cancel(&msg));
2807 }
2808
2809 #[rstest]
2810 fn test_is_post_only_auto_cancel_false_for_order_type_only() {
2811 let mut msg = sample_canceled_order_msg();
2812 msg.ord_type = OKXOrderType::PostOnly;
2813
2814 assert!(!OKXWsFeedHandler::is_post_only_auto_cancel(&msg));
2815 }
2816
2817 #[tokio::test]
2818 async fn test_batch_cancel_orders_with_multiple_orders() {
2819 use nautilus_model::identifiers::{ClientOrderId, InstrumentId, VenueOrderId};
2820
2821 let client = OKXWebSocketClient::new(
2822 Some("wss://test.okx.com".to_string()),
2823 None,
2824 None,
2825 None,
2826 None,
2827 None,
2828 )
2829 .expect("Failed to create client");
2830
2831 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
2832 let client_order_id1 = ClientOrderId::new("order1");
2833 let client_order_id2 = ClientOrderId::new("order2");
2834 let venue_order_id1 = VenueOrderId::new("venue1");
2835 let venue_order_id2 = VenueOrderId::new("venue2");
2836
2837 let orders = vec![
2838 (instrument_id, Some(client_order_id1), Some(venue_order_id1)),
2839 (instrument_id, Some(client_order_id2), Some(venue_order_id2)),
2840 ];
2841
2842 let result = client.batch_cancel_orders(orders).await;
2844
2845 assert!(result.is_err());
2847 }
2848
2849 #[tokio::test]
2850 async fn test_batch_cancel_orders_with_only_client_order_id() {
2851 use nautilus_model::identifiers::{ClientOrderId, InstrumentId};
2852
2853 let client = OKXWebSocketClient::new(
2854 Some("wss://test.okx.com".to_string()),
2855 None,
2856 None,
2857 None,
2858 None,
2859 None,
2860 )
2861 .expect("Failed to create client");
2862
2863 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
2864 let client_order_id = ClientOrderId::new("order1");
2865
2866 let orders = vec![(instrument_id, Some(client_order_id), None)];
2867
2868 let result = client.batch_cancel_orders(orders).await;
2869
2870 assert!(result.is_err());
2872 }
2873
2874 #[tokio::test]
2875 async fn test_batch_cancel_orders_with_only_venue_order_id() {
2876 use nautilus_model::identifiers::{InstrumentId, VenueOrderId};
2877
2878 let client = OKXWebSocketClient::new(
2879 Some("wss://test.okx.com".to_string()),
2880 None,
2881 None,
2882 None,
2883 None,
2884 None,
2885 )
2886 .expect("Failed to create client");
2887
2888 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
2889 let venue_order_id = VenueOrderId::new("venue1");
2890
2891 let orders = vec![(instrument_id, None, Some(venue_order_id))];
2892
2893 let result = client.batch_cancel_orders(orders).await;
2894
2895 assert!(result.is_err());
2897 }
2898
2899 #[tokio::test]
2900 async fn test_batch_cancel_orders_with_both_ids() {
2901 use nautilus_model::identifiers::{ClientOrderId, InstrumentId, VenueOrderId};
2902
2903 let client = OKXWebSocketClient::new(
2904 Some("wss://test.okx.com".to_string()),
2905 None,
2906 None,
2907 None,
2908 None,
2909 None,
2910 )
2911 .expect("Failed to create client");
2912
2913 let instrument_id = InstrumentId::from("BTC-USDT-SWAP.OKX");
2914 let client_order_id = ClientOrderId::new("order1");
2915 let venue_order_id = VenueOrderId::new("venue1");
2916
2917 let orders = vec![(instrument_id, Some(client_order_id), Some(venue_order_id))];
2918
2919 let result = client.batch_cancel_orders(orders).await;
2920
2921 assert!(result.is_err());
2923 }
2924
2925 #[rstest]
2926 fn test_race_unsubscribe_failure_recovery() {
2927 let client = OKXWebSocketClient::new(
2933 Some("wss://test.okx.com".to_string()),
2934 None,
2935 None,
2936 None,
2937 None,
2938 None,
2939 )
2940 .expect("Failed to create client");
2941
2942 let topic = "trades:BTC-USDT-SWAP";
2943
2944 client.subscriptions_state.mark_subscribe(topic);
2946 client.subscriptions_state.confirm_subscribe(topic);
2947 assert_eq!(client.subscriptions_state.len(), 1);
2948
2949 client.subscriptions_state.mark_unsubscribe(topic);
2951 assert_eq!(client.subscriptions_state.len(), 0);
2952 assert_eq!(
2953 client.subscriptions_state.pending_unsubscribe_topics(),
2954 vec![topic]
2955 );
2956
2957 client.subscriptions_state.confirm_unsubscribe(topic); client.subscriptions_state.mark_subscribe(topic); client.subscriptions_state.confirm_subscribe(topic); assert_eq!(client.subscriptions_state.len(), 1);
2965 assert!(
2966 client
2967 .subscriptions_state
2968 .pending_unsubscribe_topics()
2969 .is_empty()
2970 );
2971 assert!(
2972 client
2973 .subscriptions_state
2974 .pending_subscribe_topics()
2975 .is_empty()
2976 );
2977
2978 let all = client.subscriptions_state.all_topics();
2980 assert_eq!(all.len(), 1);
2981 assert!(all.contains(&topic.to_string()));
2982 }
2983
2984 #[rstest]
2985 fn test_race_resubscribe_before_unsubscribe_ack() {
2986 let client = OKXWebSocketClient::new(
2990 Some("wss://test.okx.com".to_string()),
2991 None,
2992 None,
2993 None,
2994 None,
2995 None,
2996 )
2997 .expect("Failed to create client");
2998
2999 let topic = "books:BTC-USDT";
3000
3001 client.subscriptions_state.mark_subscribe(topic);
3003 client.subscriptions_state.confirm_subscribe(topic);
3004 assert_eq!(client.subscriptions_state.len(), 1);
3005
3006 client.subscriptions_state.mark_unsubscribe(topic);
3008 assert_eq!(client.subscriptions_state.len(), 0);
3009 assert_eq!(
3010 client.subscriptions_state.pending_unsubscribe_topics(),
3011 vec![topic]
3012 );
3013
3014 client.subscriptions_state.mark_subscribe(topic);
3016 assert_eq!(
3017 client.subscriptions_state.pending_subscribe_topics(),
3018 vec![topic]
3019 );
3020
3021 client.subscriptions_state.confirm_unsubscribe(topic);
3023 assert!(
3024 client
3025 .subscriptions_state
3026 .pending_unsubscribe_topics()
3027 .is_empty()
3028 );
3029 assert_eq!(
3030 client.subscriptions_state.pending_subscribe_topics(),
3031 vec![topic]
3032 );
3033
3034 client.subscriptions_state.confirm_subscribe(topic);
3036 assert_eq!(client.subscriptions_state.len(), 1);
3037 assert!(
3038 client
3039 .subscriptions_state
3040 .pending_subscribe_topics()
3041 .is_empty()
3042 );
3043
3044 let all = client.subscriptions_state.all_topics();
3046 assert_eq!(all.len(), 1);
3047 assert!(all.contains(&topic.to_string()));
3048 }
3049
3050 #[rstest]
3051 fn test_race_late_subscribe_confirmation_after_unsubscribe() {
3052 let client = OKXWebSocketClient::new(
3055 Some("wss://test.okx.com".to_string()),
3056 None,
3057 None,
3058 None,
3059 None,
3060 None,
3061 )
3062 .expect("Failed to create client");
3063
3064 let topic = "tickers:ETH-USDT";
3065
3066 client.subscriptions_state.mark_subscribe(topic);
3068 assert_eq!(
3069 client.subscriptions_state.pending_subscribe_topics(),
3070 vec![topic]
3071 );
3072
3073 client.subscriptions_state.mark_unsubscribe(topic);
3075 assert!(
3076 client
3077 .subscriptions_state
3078 .pending_subscribe_topics()
3079 .is_empty()
3080 ); assert_eq!(
3082 client.subscriptions_state.pending_unsubscribe_topics(),
3083 vec![topic]
3084 );
3085
3086 client.subscriptions_state.confirm_subscribe(topic);
3088 assert_eq!(client.subscriptions_state.len(), 0); assert_eq!(
3090 client.subscriptions_state.pending_unsubscribe_topics(),
3091 vec![topic]
3092 );
3093
3094 client.subscriptions_state.confirm_unsubscribe(topic);
3096
3097 assert!(client.subscriptions_state.is_empty());
3099 assert!(client.subscriptions_state.all_topics().is_empty());
3100 }
3101
3102 #[rstest]
3103 fn test_race_reconnection_with_pending_states() {
3104 let client = OKXWebSocketClient::new(
3106 Some("wss://test.okx.com".to_string()),
3107 Some("test_key".to_string()),
3108 Some("test_secret".to_string()),
3109 Some("test_passphrase".to_string()),
3110 Some(AccountId::new("OKX-TEST")),
3111 None,
3112 )
3113 .expect("Failed to create client");
3114
3115 let trade_btc = "trades:BTC-USDT-SWAP";
3118 client.subscriptions_state.mark_subscribe(trade_btc);
3119 client.subscriptions_state.confirm_subscribe(trade_btc);
3120
3121 let trade_eth = "trades:ETH-USDT-SWAP";
3123 client.subscriptions_state.mark_subscribe(trade_eth);
3124
3125 let book_btc = "books:BTC-USDT";
3127 client.subscriptions_state.mark_subscribe(book_btc);
3128 client.subscriptions_state.confirm_subscribe(book_btc);
3129 client.subscriptions_state.mark_unsubscribe(book_btc);
3130
3131 let topics_to_restore = client.subscriptions_state.all_topics();
3133
3134 assert_eq!(topics_to_restore.len(), 2);
3136 assert!(topics_to_restore.contains(&trade_btc.to_string()));
3137 assert!(topics_to_restore.contains(&trade_eth.to_string()));
3138 assert!(!topics_to_restore.contains(&book_btc.to_string())); }
3140
3141 #[rstest]
3142 fn test_race_duplicate_subscribe_messages_idempotent() {
3143 let client = OKXWebSocketClient::new(
3146 Some("wss://test.okx.com".to_string()),
3147 None,
3148 None,
3149 None,
3150 None,
3151 None,
3152 )
3153 .expect("Failed to create client");
3154
3155 let topic = "trades:BTC-USDT-SWAP";
3156
3157 client.subscriptions_state.mark_subscribe(topic);
3159 client.subscriptions_state.confirm_subscribe(topic);
3160 assert_eq!(client.subscriptions_state.len(), 1);
3161
3162 client.subscriptions_state.mark_subscribe(topic);
3164 assert!(
3165 client
3166 .subscriptions_state
3167 .pending_subscribe_topics()
3168 .is_empty()
3169 ); assert_eq!(client.subscriptions_state.len(), 1); client.subscriptions_state.confirm_subscribe(topic);
3174 assert_eq!(client.subscriptions_state.len(), 1);
3175
3176 let all = client.subscriptions_state.all_topics();
3178 assert_eq!(all.len(), 1);
3179 assert_eq!(all[0], topic);
3180 }
3181}