1use std::{
26 fmt::Debug,
27 num::NonZeroU32,
28 sync::{
29 Arc, LazyLock,
30 atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering},
31 },
32 time::{Duration, SystemTime},
33};
34
35use ahash::AHashSet;
36use arc_swap::ArcSwap;
37use dashmap::DashMap;
38use futures_util::Stream;
39use nautilus_common::live::get_runtime;
40use nautilus_core::{
41 consts::NAUTILUS_USER_AGENT,
42 env::{get_env_var, get_or_env_var},
43};
44use nautilus_model::{
45 data::BarType,
46 enums::{OrderSide, OrderType, PositionSide, TimeInForce, TriggerType},
47 identifiers::{AccountId, ClientOrderId, InstrumentId, StrategyId, TraderId, VenueOrderId},
48 instruments::{Instrument, InstrumentAny},
49 types::{Price, Quantity},
50};
51use nautilus_network::{
52 http::USER_AGENT,
53 mode::ConnectionMode,
54 ratelimiter::quota::Quota,
55 websocket::{
56 AUTHENTICATION_TIMEOUT_SECS, AuthTracker, PingHandler, SubscriptionState, TEXT_PING,
57 WebSocketClient, WebSocketConfig, channel_message_handler,
58 },
59};
60use serde_json::Value;
61use tokio_tungstenite::tungstenite::Error;
62use tokio_util::sync::CancellationToken;
63use ustr::Ustr;
64
65use super::{
66 enums::OKXWsChannel,
67 error::OKXWsError,
68 handler::{HandlerCommand, OKXWsFeedHandler},
69 messages::{
70 NautilusWsMessage, OKXAuthentication, OKXAuthenticationArg, OKXSubscriptionArg,
71 WsAmendOrderParamsBuilder, WsCancelOrderParamsBuilder, WsPostAlgoOrderParamsBuilder,
72 WsPostOrderParamsBuilder,
73 },
74 subscription::topic_from_subscription_arg,
75};
76use crate::common::{
77 consts::{
78 OKX_NAUTILUS_BROKER_ID, OKX_SUPPORTED_ORDER_TYPES, OKX_SUPPORTED_TIME_IN_FORCE,
79 OKX_WS_PUBLIC_URL, OKX_WS_TOPIC_DELIMITER,
80 },
81 credential::Credential,
82 enums::{
83 OKXInstrumentType, OKXOrderType, OKXPositionSide, OKXTargetCurrency, OKXTradeMode,
84 OKXTriggerType, OKXVipLevel, conditional_order_to_algo_type, is_conditional_order,
85 },
86 parse::{bar_spec_as_okx_channel, okx_instrument_type, okx_instrument_type_from_symbol},
87};
88
89pub static OKX_WS_CONNECTION_QUOTA: LazyLock<Quota> =
93 LazyLock::new(|| Quota::per_second(NonZeroU32::new(3).unwrap()));
94
95pub static OKX_WS_SUBSCRIPTION_QUOTA: LazyLock<Quota> =
100 LazyLock::new(|| Quota::per_hour(NonZeroU32::new(480).unwrap()));
101
102pub static OKX_WS_ORDER_QUOTA: LazyLock<Quota> =
107 LazyLock::new(|| Quota::per_second(NonZeroU32::new(250).unwrap()));
108
109pub static OKX_RATE_LIMIT_KEY_SUBSCRIPTION: LazyLock<[Ustr; 1]> =
114 LazyLock::new(|| [Ustr::from("subscription")]);
115
116pub static OKX_RATE_LIMIT_KEY_ORDER: LazyLock<[Ustr; 1]> = LazyLock::new(|| [Ustr::from("order")]);
121
122pub static OKX_RATE_LIMIT_KEY_CANCEL: LazyLock<[Ustr; 1]> =
128 LazyLock::new(|| [Ustr::from("cancel")]);
129
130pub static OKX_RATE_LIMIT_KEY_AMEND: LazyLock<[Ustr; 1]> = LazyLock::new(|| [Ustr::from("amend")]);
134
135#[derive(Clone)]
137#[cfg_attr(
138 feature = "python",
139 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.okx", from_py_object)
140)]
141pub struct OKXWebSocketClient {
142 url: String,
143 account_id: AccountId,
144 vip_level: Arc<AtomicU8>,
145 credential: Option<Credential>,
146 heartbeat: Option<u64>,
147 auth_tracker: AuthTracker,
148 signal: Arc<AtomicBool>,
149 connection_mode: Arc<ArcSwap<AtomicU8>>,
150 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
151 out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>,
152 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
153 subscriptions_inst_type: Arc<DashMap<OKXWsChannel, AHashSet<OKXInstrumentType>>>,
154 subscriptions_inst_family: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
155 subscriptions_inst_id: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
156 subscriptions_bare: Arc<DashMap<OKXWsChannel, bool>>, subscriptions_state: SubscriptionState,
158 request_id_counter: Arc<AtomicU64>,
159 active_client_orders: Arc<DashMap<ClientOrderId, (TraderId, StrategyId, InstrumentId)>>,
160 client_id_aliases: Arc<DashMap<ClientOrderId, ClientOrderId>>,
161 instruments_cache: Arc<DashMap<Ustr, InstrumentAny>>,
162 inst_id_code_cache: Arc<DashMap<Ustr, u64>>,
163 cancellation_token: CancellationToken,
164}
165
166impl Default for OKXWebSocketClient {
167 fn default() -> Self {
168 Self::new(None, None, None, None, None, None).unwrap()
169 }
170}
171
172impl Debug for OKXWebSocketClient {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 f.debug_struct(stringify!(OKXWebSocketClient))
175 .field("url", &self.url)
176 .field(
177 "credential",
178 &self.credential.as_ref().map(|_| "<redacted>"),
179 )
180 .field("heartbeat", &self.heartbeat)
181 .finish_non_exhaustive()
182 }
183}
184
185impl OKXWebSocketClient {
186 pub fn new(
192 url: Option<String>,
193 api_key: Option<String>,
194 api_secret: Option<String>,
195 api_passphrase: Option<String>,
196 account_id: Option<AccountId>,
197 heartbeat: Option<u64>,
198 ) -> anyhow::Result<Self> {
199 let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
200 let account_id = account_id.unwrap_or(AccountId::from("OKX-master"));
201
202 let credential = match (api_key, api_secret, api_passphrase) {
203 (Some(key), Some(secret), Some(passphrase)) => {
204 Some(Credential::new(key, secret, passphrase))
205 }
206 (None, None, None) => None,
207 _ => anyhow::bail!(
208 "`api_key`, `api_secret`, `api_passphrase` credentials must be provided together"
209 ),
210 };
211
212 let signal = Arc::new(AtomicBool::new(false));
213 let subscriptions_inst_type = Arc::new(DashMap::new());
214 let subscriptions_inst_family = Arc::new(DashMap::new());
215 let subscriptions_inst_id = Arc::new(DashMap::new());
216 let subscriptions_bare = Arc::new(DashMap::new());
217 let subscriptions_state = SubscriptionState::new(OKX_WS_TOPIC_DELIMITER);
218
219 Ok(Self {
220 url,
221 account_id,
222 vip_level: Arc::new(AtomicU8::new(0)), credential,
224 heartbeat,
225 auth_tracker: AuthTracker::new(),
226 signal,
227 connection_mode: Arc::new(ArcSwap::from_pointee(AtomicU8::new(
228 ConnectionMode::Closed.as_u8(),
229 ))),
230 cmd_tx: {
231 let (tx, _) = tokio::sync::mpsc::unbounded_channel();
233 Arc::new(tokio::sync::RwLock::new(tx))
234 },
235 out_rx: None,
236 task_handle: None,
237 subscriptions_inst_type,
238 subscriptions_inst_family,
239 subscriptions_inst_id,
240 subscriptions_bare,
241 subscriptions_state,
242 request_id_counter: Arc::new(AtomicU64::new(1)),
243 active_client_orders: Arc::new(DashMap::new()),
244 client_id_aliases: Arc::new(DashMap::new()),
245 instruments_cache: Arc::new(DashMap::new()),
246 inst_id_code_cache: Arc::new(DashMap::new()),
247 cancellation_token: CancellationToken::new(),
248 })
249 }
250
251 pub fn with_credentials(
258 url: Option<String>,
259 api_key: Option<String>,
260 api_secret: Option<String>,
261 api_passphrase: Option<String>,
262 account_id: Option<AccountId>,
263 heartbeat: Option<u64>,
264 ) -> anyhow::Result<Self> {
265 let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
266 let api_key = get_or_env_var(api_key, "OKX_API_KEY")?;
267 let api_secret = get_or_env_var(api_secret, "OKX_API_SECRET")?;
268 let api_passphrase = get_or_env_var(api_passphrase, "OKX_API_PASSPHRASE")?;
269
270 Self::new(
271 Some(url),
272 Some(api_key),
273 Some(api_secret),
274 Some(api_passphrase),
275 account_id,
276 heartbeat,
277 )
278 }
279
280 pub fn from_env() -> anyhow::Result<Self> {
287 let url = get_env_var("OKX_WS_URL")?;
288 let api_key = get_env_var("OKX_API_KEY")?;
289 let api_secret = get_env_var("OKX_API_SECRET")?;
290 let api_passphrase = get_env_var("OKX_API_PASSPHRASE")?;
291
292 Self::new(
293 Some(url),
294 Some(api_key),
295 Some(api_secret),
296 Some(api_passphrase),
297 None,
298 None,
299 )
300 }
301
302 pub fn cancel_all_requests(&self) {
304 self.cancellation_token.cancel();
305 }
306
307 pub fn cancellation_token(&self) -> &CancellationToken {
309 &self.cancellation_token
310 }
311
312 pub fn url(&self) -> &str {
314 self.url.as_str()
315 }
316
317 pub fn api_key(&self) -> Option<&str> {
319 self.credential.clone().map(|c| c.api_key.as_str())
320 }
321
322 #[must_use]
324 pub fn api_key_masked(&self) -> Option<String> {
325 self.credential.clone().map(|c| c.api_key_masked())
326 }
327
328 pub fn is_active(&self) -> bool {
330 let connection_mode_arc = self.connection_mode.load();
331 ConnectionMode::from_atomic(&connection_mode_arc).is_active()
332 && !self.signal.load(Ordering::Acquire)
333 }
334
335 pub fn is_closed(&self) -> bool {
337 let connection_mode_arc = self.connection_mode.load();
338 ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
339 || self.signal.load(Ordering::Acquire)
340 }
341
342 pub fn cache_instruments(&self, instruments: Vec<InstrumentAny>) {
346 for inst in &instruments {
347 self.instruments_cache
348 .insert(inst.symbol().inner(), inst.clone());
349 }
350
351 if !instruments.is_empty()
354 && let Ok(cmd_tx) = self.cmd_tx.try_read()
355 && let Err(e) = cmd_tx.send(HandlerCommand::InitializeInstruments(instruments))
356 {
357 log::debug!("Failed to send bulk instrument update to handler: {e}");
358 }
359 }
360
361 pub fn cache_instrument(&self, instrument: InstrumentAny) {
365 self.instruments_cache
366 .insert(instrument.symbol().inner(), instrument.clone());
367
368 if let Ok(cmd_tx) = self.cmd_tx.try_read()
371 && let Err(e) = cmd_tx.send(HandlerCommand::UpdateInstrument(instrument))
372 {
373 log::debug!("Failed to send instrument update to handler: {e}");
374 }
375 }
376
377 pub fn cache_inst_id_code(&self, inst_id: Ustr, inst_id_code: u64) {
381 self.inst_id_code_cache.insert(inst_id, inst_id_code);
382 }
383
384 pub fn cache_inst_id_codes(&self, mappings: impl IntoIterator<Item = (Ustr, u64)>) {
388 for (inst_id, inst_id_code) in mappings {
389 self.inst_id_code_cache.insert(inst_id, inst_id_code);
390 }
391 }
392
393 #[must_use]
397 pub fn get_inst_id_code(&self, inst_id: &Ustr) -> Option<u64> {
398 self.inst_id_code_cache.get(inst_id).map(|r| *r.value())
399 }
400
401 pub fn set_vip_level(&self, vip_level: OKXVipLevel) {
405 self.vip_level.store(vip_level as u8, Ordering::Relaxed);
406 }
407
408 pub fn vip_level(&self) -> OKXVipLevel {
410 let level = self.vip_level.load(Ordering::Relaxed);
411 OKXVipLevel::from(level)
412 }
413
414 pub async fn connect(&mut self) -> anyhow::Result<()> {
424 let (message_handler, raw_rx) = channel_message_handler();
425
426 let ping_handler: PingHandler = Arc::new(move |_payload: Vec<u8>| {
429 });
431
432 let config = WebSocketConfig {
433 url: self.url.clone(),
434 headers: vec![(USER_AGENT.to_string(), NAUTILUS_USER_AGENT.to_string())],
435 heartbeat: self.heartbeat,
436 heartbeat_msg: Some(TEXT_PING.to_string()),
437 reconnect_timeout_ms: Some(5_000),
438 reconnect_delay_initial_ms: None, reconnect_delay_max_ms: None, reconnect_backoff_factor: None, reconnect_jitter_ms: None, reconnect_max_attempts: None,
443 };
444
445 let keyed_quotas = vec![
447 (
448 OKX_RATE_LIMIT_KEY_SUBSCRIPTION[0].as_str().to_string(),
449 *OKX_WS_SUBSCRIPTION_QUOTA,
450 ),
451 (
452 OKX_RATE_LIMIT_KEY_ORDER[0].as_str().to_string(),
453 *OKX_WS_ORDER_QUOTA,
454 ),
455 (
456 OKX_RATE_LIMIT_KEY_CANCEL[0].as_str().to_string(),
457 *OKX_WS_ORDER_QUOTA,
458 ),
459 (
460 OKX_RATE_LIMIT_KEY_AMEND[0].as_str().to_string(),
461 *OKX_WS_ORDER_QUOTA,
462 ),
463 ];
464
465 let client = WebSocketClient::connect(
466 config,
467 Some(message_handler),
468 Some(ping_handler),
469 None, keyed_quotas,
471 Some(*OKX_WS_CONNECTION_QUOTA), )
473 .await?;
474
475 self.connection_mode.store(client.connection_mode_atomic());
477
478 let account_id = self.account_id;
479 let (msg_tx, rx) = tokio::sync::mpsc::unbounded_channel::<NautilusWsMessage>();
480
481 self.out_rx = Some(Arc::new(rx));
482
483 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
485 *self.cmd_tx.write().await = cmd_tx.clone();
486
487 if !self.instruments_cache.is_empty() {
489 let cached_instruments: Vec<InstrumentAny> = self
490 .instruments_cache
491 .iter()
492 .map(|entry| entry.value().clone())
493 .collect();
494 if let Err(e) = cmd_tx.send(HandlerCommand::InitializeInstruments(cached_instruments)) {
495 log::error!("Failed to replay instruments to handler: {e}");
496 }
497 }
498
499 let signal = self.signal.clone();
500 let active_client_orders = self.active_client_orders.clone();
501 let auth_tracker = self.auth_tracker.clone();
502 let subscriptions_state = self.subscriptions_state.clone();
503 let client_id_aliases = self.client_id_aliases.clone();
504 let inst_id_code_cache = self.inst_id_code_cache.clone();
505
506 let stream_handle = get_runtime().spawn({
507 let auth_tracker = auth_tracker.clone();
508 let signal = signal.clone();
509 let credential = self.credential.clone();
510 let cmd_tx_for_reconnect = cmd_tx.clone();
511 let subscriptions_bare = self.subscriptions_bare.clone();
512 let subscriptions_inst_type = self.subscriptions_inst_type.clone();
513 let subscriptions_inst_family = self.subscriptions_inst_family.clone();
514 let subscriptions_inst_id = self.subscriptions_inst_id.clone();
515 let mut has_reconnected = false;
516
517 async move {
518 let mut handler = OKXWsFeedHandler::new(
519 account_id,
520 signal.clone(),
521 cmd_rx,
522 raw_rx,
523 msg_tx,
524 active_client_orders,
525 client_id_aliases,
526 inst_id_code_cache,
527 auth_tracker.clone(),
528 subscriptions_state.clone(),
529 );
530
531 let resubscribe_all = || {
533 for entry in subscriptions_inst_id.iter() {
534 let (channel, inst_ids) = entry.pair();
535 for inst_id in inst_ids {
536 let arg = OKXSubscriptionArg {
537 channel: channel.clone(),
538 inst_type: None,
539 inst_family: None,
540 inst_id: Some(*inst_id),
541 };
542 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
543 log::error!("Failed to send resubscribe command: error={e}");
544 }
545 }
546 }
547
548 for entry in subscriptions_bare.iter() {
549 let channel = entry.key();
550 let arg = OKXSubscriptionArg {
551 channel: channel.clone(),
552 inst_type: None,
553 inst_family: None,
554 inst_id: None,
555 };
556 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
557 log::error!("Failed to send resubscribe command: error={e}");
558 }
559 }
560
561 for entry in subscriptions_inst_type.iter() {
562 let (channel, inst_types) = entry.pair();
563 for inst_type in inst_types {
564 let arg = OKXSubscriptionArg {
565 channel: channel.clone(),
566 inst_type: Some(*inst_type),
567 inst_family: None,
568 inst_id: None,
569 };
570 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
571 log::error!("Failed to send resubscribe command: error={e}");
572 }
573 }
574 }
575
576 for entry in subscriptions_inst_family.iter() {
577 let (channel, inst_families) = entry.pair();
578 for inst_family in inst_families {
579 let arg = OKXSubscriptionArg {
580 channel: channel.clone(),
581 inst_type: None,
582 inst_family: Some(*inst_family),
583 inst_id: None,
584 };
585 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
586 log::error!("Failed to send resubscribe command: error={e}");
587 }
588 }
589 }
590 };
591
592 loop {
594 match handler.next().await {
595 Some(NautilusWsMessage::Reconnected) => {
596 if signal.load(Ordering::Acquire) {
597 continue;
598 }
599
600 has_reconnected = true;
601
602 let confirmed_topics_vec: Vec<String> = {
604 let confirmed = subscriptions_state.confirmed();
605 let mut topics = Vec::new();
606 for entry in confirmed.iter() {
607 let channel = entry.key();
608 for symbol in entry.value() {
609 if symbol.as_str() == "#" {
610 topics.push(channel.to_string());
611 } else {
612 topics.push(format!("{channel}{OKX_WS_TOPIC_DELIMITER}{symbol}"));
613 }
614 }
615 }
616 topics
617 };
618
619 if !confirmed_topics_vec.is_empty() {
620 log::debug!("Marking confirmed subscriptions as pending for replay: count={}", confirmed_topics_vec.len());
621 for topic in confirmed_topics_vec {
622 subscriptions_state.mark_failure(&topic);
623 }
624 }
625
626 if let Some(cred) = &credential {
627 log::debug!("Re-authenticating after reconnection");
628 let timestamp = std::time::SystemTime::now()
629 .duration_since(std::time::SystemTime::UNIX_EPOCH)
630 .expect("System time should be after UNIX epoch")
631 .as_secs()
632 .to_string();
633 let signature = cred.sign(×tamp, "GET", "/users/self/verify", "");
634
635 let auth_message = super::messages::OKXAuthentication {
636 op: "login",
637 args: vec![super::messages::OKXAuthenticationArg {
638 api_key: cred.api_key.to_string(),
639 passphrase: cred.api_passphrase.clone(),
640 timestamp,
641 sign: signature,
642 }],
643 };
644
645 if let Ok(payload) = serde_json::to_string(&auth_message) {
646 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Authenticate { payload }) {
647 log::error!("Failed to send reconnection auth command: error={e}");
648 }
649 } else {
650 log::error!("Failed to serialize reconnection auth message");
651 }
652 }
653
654 if credential.is_none() {
657 log::debug!("No authentication required, resubscribing immediately");
658 resubscribe_all();
659 }
660
661 continue;
666 }
667 Some(NautilusWsMessage::Authenticated) => {
668 if has_reconnected {
669 resubscribe_all();
670 }
671
672 continue;
677 }
678 Some(msg) => {
679 if handler.send(msg).is_err() {
680 log::error!(
681 "Failed to send message through channel: receiver dropped",
682 );
683 break;
684 }
685 }
686 None => {
687 if handler.is_stopped() {
688 log::debug!(
689 "Stop signal received, ending message processing",
690 );
691 break;
692 }
693 log::debug!("WebSocket stream closed");
694 break;
695 }
696 }
697 }
698
699 log::debug!("Handler task exiting");
700 }
701 });
702
703 self.task_handle = Some(Arc::new(stream_handle));
704
705 self.cmd_tx
706 .read()
707 .await
708 .send(HandlerCommand::SetClient(client))
709 .map_err(|e| {
710 OKXWsError::ClientError(format!("Failed to send WebSocket client to handler: {e}"))
711 })?;
712 log::debug!("Sent WebSocket client to handler");
713
714 if self.credential.is_some()
715 && let Err(e) = self.authenticate().await
716 {
717 anyhow::bail!("Authentication failed: {e}");
718 }
719
720 Ok(())
721 }
722
723 async fn authenticate(&self) -> Result<(), Error> {
725 let credential = self.credential.as_ref().ok_or_else(|| {
726 Error::Io(std::io::Error::other(
727 "API credentials not available to authenticate",
728 ))
729 })?;
730
731 let rx = self.auth_tracker.begin();
732
733 let timestamp = SystemTime::now()
734 .duration_since(SystemTime::UNIX_EPOCH)
735 .expect("System time should be after UNIX epoch")
736 .as_secs()
737 .to_string();
738 let signature = credential.sign(×tamp, "GET", "/users/self/verify", "");
739
740 let auth_message = OKXAuthentication {
741 op: "login",
742 args: vec![OKXAuthenticationArg {
743 api_key: credential.api_key.to_string(),
744 passphrase: credential.api_passphrase.clone(),
745 timestamp,
746 sign: signature,
747 }],
748 };
749
750 let payload = serde_json::to_string(&auth_message).map_err(|e| {
751 Error::Io(std::io::Error::other(format!(
752 "Failed to serialize auth message: {e}"
753 )))
754 })?;
755
756 self.cmd_tx
757 .read()
758 .await
759 .send(HandlerCommand::Authenticate { payload })
760 .map_err(|e| {
761 Error::Io(std::io::Error::other(format!(
762 "Failed to send authenticate command: {e}"
763 )))
764 })?;
765
766 match self
767 .auth_tracker
768 .wait_for_result::<OKXWsError>(Duration::from_secs(AUTHENTICATION_TIMEOUT_SECS), rx)
769 .await
770 {
771 Ok(()) => {
772 log::info!("WebSocket authenticated");
773 Ok(())
774 }
775 Err(e) => {
776 log::error!("WebSocket authentication failed: error={e}");
777 Err(Error::Io(std::io::Error::other(e.to_string())))
778 }
779 }
780 }
781
782 pub fn stream(&mut self) -> impl Stream<Item = NautilusWsMessage> + 'static {
790 let rx = self
791 .out_rx
792 .take()
793 .expect("Data stream receiver already taken or not connected");
794 let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
795 async_stream::stream! {
796 while let Some(data) = rx.recv().await {
797 yield data;
798 }
799 }
800 }
801
802 pub async fn wait_until_active(&self, timeout_secs: f64) -> Result<(), OKXWsError> {
808 let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
809
810 tokio::time::timeout(timeout, async {
811 while !self.is_active() {
812 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
813 }
814 })
815 .await
816 .map_err(|_| {
817 OKXWsError::ClientError(format!(
818 "WebSocket connection timeout after {timeout_secs} seconds"
819 ))
820 })?;
821
822 Ok(())
823 }
824
825 pub async fn close(&mut self) -> Result<(), Error> {
832 log::debug!("Starting close process");
833
834 self.signal.store(true, Ordering::Release);
835
836 if let Err(e) = self.cmd_tx.read().await.send(HandlerCommand::Disconnect) {
837 log::warn!("Failed to send disconnect command to handler: {e}");
838 } else {
839 log::debug!("Sent disconnect command to handler");
840 }
841
842 {
844 if false {
845 log::debug!("No active connection to disconnect");
846 }
847 }
848
849 if let Some(stream_handle) = self.task_handle.take() {
851 match Arc::try_unwrap(stream_handle) {
852 Ok(handle) => {
853 log::debug!("Waiting for stream handle to complete");
854 match tokio::time::timeout(Duration::from_secs(2), handle).await {
855 Ok(Ok(())) => log::debug!("Stream handle completed successfully"),
856 Ok(Err(e)) => log::error!("Stream handle encountered an error: {e:?}"),
857 Err(_) => {
858 log::warn!(
859 "Timeout waiting for stream handle, task may still be running"
860 );
861 }
863 }
864 }
865 Err(arc_handle) => {
866 log::debug!(
867 "Cannot take ownership of stream handle - other references exist, aborting task"
868 );
869 arc_handle.abort();
870 }
871 }
872 } else {
873 log::debug!("No stream handle to await");
874 }
875
876 log::debug!("Close process completed");
877
878 Ok(())
879 }
880
881 pub fn get_subscriptions(&self, instrument_id: InstrumentId) -> Vec<OKXWsChannel> {
883 let symbol = instrument_id.symbol.inner();
884 let mut channels = Vec::new();
885
886 for entry in self.subscriptions_inst_id.iter() {
887 let (channel, instruments) = entry.pair();
888 if instruments.contains(&symbol) {
889 channels.push(channel.clone());
890 }
891 }
892
893 channels
894 }
895
896 fn generate_unique_request_id(&self) -> String {
897 self.request_id_counter
898 .fetch_add(1, Ordering::SeqCst)
899 .to_string()
900 }
901
902 #[allow(
903 clippy::result_large_err,
904 reason = "OKXWsError contains large tungstenite::Error variant"
905 )]
906 async fn subscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
907 for arg in &args {
908 let topic = topic_from_subscription_arg(arg);
909 self.subscriptions_state.mark_subscribe(&topic);
910
911 if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
913 self.subscriptions_bare.insert(arg.channel.clone(), true);
915 } else {
916 if let Some(inst_type) = &arg.inst_type {
918 self.subscriptions_inst_type
919 .entry(arg.channel.clone())
920 .or_default()
921 .insert(*inst_type);
922 }
923
924 if let Some(inst_family) = &arg.inst_family {
926 self.subscriptions_inst_family
927 .entry(arg.channel.clone())
928 .or_default()
929 .insert(*inst_family);
930 }
931
932 if let Some(inst_id) = &arg.inst_id {
934 self.subscriptions_inst_id
935 .entry(arg.channel.clone())
936 .or_default()
937 .insert(*inst_id);
938 }
939 }
940 }
941
942 self.cmd_tx
943 .read()
944 .await
945 .send(HandlerCommand::Subscribe { args })
946 .map_err(|e| OKXWsError::ClientError(format!("Failed to send subscribe command: {e}")))
947 }
948
949 #[allow(clippy::collapsible_if)]
950 async fn unsubscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
951 for arg in &args {
952 let topic = topic_from_subscription_arg(arg);
953 self.subscriptions_state.mark_unsubscribe(&topic);
954
955 if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
957 self.subscriptions_bare.remove(&arg.channel);
959 } else {
960 if let Some(inst_type) = &arg.inst_type {
962 if let Some(mut entry) = self.subscriptions_inst_type.get_mut(&arg.channel) {
963 entry.remove(inst_type);
964 if entry.is_empty() {
965 drop(entry);
966 self.subscriptions_inst_type.remove(&arg.channel);
967 }
968 }
969 }
970
971 if let Some(inst_family) = &arg.inst_family {
973 if let Some(mut entry) = self.subscriptions_inst_family.get_mut(&arg.channel) {
974 entry.remove(inst_family);
975 if entry.is_empty() {
976 drop(entry);
977 self.subscriptions_inst_family.remove(&arg.channel);
978 }
979 }
980 }
981
982 if let Some(inst_id) = &arg.inst_id {
984 if let Some(mut entry) = self.subscriptions_inst_id.get_mut(&arg.channel) {
985 entry.remove(inst_id);
986 if entry.is_empty() {
987 drop(entry);
988 self.subscriptions_inst_id.remove(&arg.channel);
989 }
990 }
991 }
992 }
993 }
994
995 self.cmd_tx
996 .read()
997 .await
998 .send(HandlerCommand::Unsubscribe { args })
999 .map_err(|e| {
1000 OKXWsError::ClientError(format!("Failed to send unsubscribe command: {e}"))
1001 })
1002 }
1003
1004 pub async fn unsubscribe_all(&self) -> Result<(), OKXWsError> {
1013 const BATCH_SIZE: usize = 256;
1014
1015 let mut all_args = Vec::new();
1016
1017 for entry in self.subscriptions_inst_type.iter() {
1018 let (channel, inst_types) = entry.pair();
1019 for inst_type in inst_types {
1020 all_args.push(OKXSubscriptionArg {
1021 channel: channel.clone(),
1022 inst_type: Some(*inst_type),
1023 inst_family: None,
1024 inst_id: None,
1025 });
1026 }
1027 }
1028
1029 for entry in self.subscriptions_inst_family.iter() {
1030 let (channel, inst_families) = entry.pair();
1031 for inst_family in inst_families {
1032 all_args.push(OKXSubscriptionArg {
1033 channel: channel.clone(),
1034 inst_type: None,
1035 inst_family: Some(*inst_family),
1036 inst_id: None,
1037 });
1038 }
1039 }
1040
1041 for entry in self.subscriptions_inst_id.iter() {
1042 let (channel, inst_ids) = entry.pair();
1043 for inst_id in inst_ids {
1044 all_args.push(OKXSubscriptionArg {
1045 channel: channel.clone(),
1046 inst_type: None,
1047 inst_family: None,
1048 inst_id: Some(*inst_id),
1049 });
1050 }
1051 }
1052
1053 for entry in self.subscriptions_bare.iter() {
1054 let channel = entry.key();
1055 all_args.push(OKXSubscriptionArg {
1056 channel: channel.clone(),
1057 inst_type: None,
1058 inst_family: None,
1059 inst_id: None,
1060 });
1061 }
1062
1063 if all_args.is_empty() {
1064 log::debug!("No active subscriptions to unsubscribe from");
1065 return Ok(());
1066 }
1067
1068 log::debug!("Batched unsubscribe from {} channels", all_args.len());
1069
1070 for chunk in all_args.chunks(BATCH_SIZE) {
1071 self.unsubscribe(chunk.to_vec()).await?;
1072 }
1073
1074 Ok(())
1075 }
1076
1077 pub async fn subscribe_instruments(
1089 &self,
1090 instrument_type: OKXInstrumentType,
1091 ) -> Result<(), OKXWsError> {
1092 let arg = OKXSubscriptionArg {
1093 channel: OKXWsChannel::Instruments,
1094 inst_type: Some(instrument_type),
1095 inst_family: None,
1096 inst_id: None,
1097 };
1098 self.subscribe(vec![arg]).await
1099 }
1100
1101 pub async fn subscribe_instrument(
1114 &self,
1115 instrument_id: InstrumentId,
1116 ) -> Result<(), OKXWsError> {
1117 let inst_type = okx_instrument_type_from_symbol(instrument_id.symbol.as_str());
1118
1119 let already_subscribed = self
1120 .subscriptions_inst_type
1121 .get(&OKXWsChannel::Instruments)
1122 .is_some_and(|types| types.contains(&inst_type));
1123
1124 if already_subscribed {
1125 log::debug!("Already subscribed to instrument type {inst_type:?} for {instrument_id}");
1126 return Ok(());
1127 }
1128
1129 log::debug!("Subscribing to instrument type {inst_type:?} for {instrument_id}");
1130 self.subscribe_instruments(inst_type).await
1131 }
1132
1133 pub async fn subscribe_book(&self, instrument_id: InstrumentId) -> anyhow::Result<()> {
1142 self.subscribe_book_with_depth(instrument_id, 0).await
1143 }
1144
1145 pub(crate) async fn subscribe_books_channel(
1147 &self,
1148 instrument_id: InstrumentId,
1149 ) -> Result<(), OKXWsError> {
1150 let arg = OKXSubscriptionArg {
1151 channel: OKXWsChannel::Books,
1152 inst_type: None,
1153 inst_family: None,
1154 inst_id: Some(instrument_id.symbol.inner()),
1155 };
1156 self.subscribe(vec![arg]).await
1157 }
1158
1159 pub async fn subscribe_book_depth5(
1171 &self,
1172 instrument_id: InstrumentId,
1173 ) -> Result<(), OKXWsError> {
1174 let arg = OKXSubscriptionArg {
1175 channel: OKXWsChannel::Books5,
1176 inst_type: None,
1177 inst_family: None,
1178 inst_id: Some(instrument_id.symbol.inner()),
1179 };
1180 self.subscribe(vec![arg]).await
1181 }
1182
1183 pub async fn subscribe_book50_l2_tbt(
1195 &self,
1196 instrument_id: InstrumentId,
1197 ) -> Result<(), OKXWsError> {
1198 let arg = OKXSubscriptionArg {
1199 channel: OKXWsChannel::Books50Tbt,
1200 inst_type: None,
1201 inst_family: None,
1202 inst_id: Some(instrument_id.symbol.inner()),
1203 };
1204 self.subscribe(vec![arg]).await
1205 }
1206
1207 pub async fn subscribe_book_l2_tbt(
1219 &self,
1220 instrument_id: InstrumentId,
1221 ) -> Result<(), OKXWsError> {
1222 let arg = OKXSubscriptionArg {
1223 channel: OKXWsChannel::BooksTbt,
1224 inst_type: None,
1225 inst_family: None,
1226 inst_id: Some(instrument_id.symbol.inner()),
1227 };
1228 self.subscribe(vec![arg]).await
1229 }
1230
1231 pub async fn subscribe_book_with_depth(
1245 &self,
1246 instrument_id: InstrumentId,
1247 depth: u16,
1248 ) -> anyhow::Result<()> {
1249 let vip = self.vip_level();
1250
1251 match depth {
1252 50 => {
1253 if vip < OKXVipLevel::Vip4 {
1254 anyhow::bail!(
1255 "VIP level {vip} insufficient for 50 depth subscription (requires VIP4)"
1256 );
1257 }
1258 self.subscribe_book50_l2_tbt(instrument_id)
1259 .await
1260 .map_err(|e| anyhow::anyhow!(e))
1261 }
1262 0 | 400 => {
1263 if vip >= OKXVipLevel::Vip5 {
1264 self.subscribe_book_l2_tbt(instrument_id)
1265 .await
1266 .map_err(|e| anyhow::anyhow!(e))
1267 } else {
1268 self.subscribe_books_channel(instrument_id)
1269 .await
1270 .map_err(|e| anyhow::anyhow!(e))
1271 }
1272 }
1273 _ => anyhow::bail!("Invalid depth {depth}, must be 0, 50, or 400"),
1274 }
1275 }
1276
1277 pub async fn subscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1290 let arg = OKXSubscriptionArg {
1291 channel: OKXWsChannel::BboTbt,
1292 inst_type: None,
1293 inst_family: None,
1294 inst_id: Some(instrument_id.symbol.inner()),
1295 };
1296 self.subscribe(vec![arg]).await
1297 }
1298
1299 pub async fn subscribe_trades(
1313 &self,
1314 instrument_id: InstrumentId,
1315 aggregated: bool,
1316 ) -> Result<(), OKXWsError> {
1317 let channel = if aggregated {
1318 OKXWsChannel::TradesAll
1319 } else {
1320 OKXWsChannel::Trades
1321 };
1322
1323 let arg = OKXSubscriptionArg {
1324 channel,
1325 inst_type: None,
1326 inst_family: None,
1327 inst_id: Some(instrument_id.symbol.inner()),
1328 };
1329 self.subscribe(vec![arg]).await
1330 }
1331
1332 pub async fn subscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1344 let arg = OKXSubscriptionArg {
1345 channel: OKXWsChannel::Tickers,
1346 inst_type: None,
1347 inst_family: None,
1348 inst_id: Some(instrument_id.symbol.inner()),
1349 };
1350 self.subscribe(vec![arg]).await
1351 }
1352
1353 pub async fn subscribe_mark_prices(
1365 &self,
1366 instrument_id: InstrumentId,
1367 ) -> Result<(), OKXWsError> {
1368 let arg = OKXSubscriptionArg {
1369 channel: OKXWsChannel::MarkPrice,
1370 inst_type: None,
1371 inst_family: None,
1372 inst_id: Some(instrument_id.symbol.inner()),
1373 };
1374 self.subscribe(vec![arg]).await
1375 }
1376
1377 pub async fn subscribe_index_prices(
1389 &self,
1390 instrument_id: InstrumentId,
1391 ) -> Result<(), OKXWsError> {
1392 let arg = OKXSubscriptionArg {
1393 channel: OKXWsChannel::IndexTickers,
1394 inst_type: None,
1395 inst_family: None,
1396 inst_id: Some(instrument_id.symbol.inner()),
1397 };
1398 self.subscribe(vec![arg]).await
1399 }
1400
1401 pub async fn subscribe_funding_rates(
1413 &self,
1414 instrument_id: InstrumentId,
1415 ) -> Result<(), OKXWsError> {
1416 let arg = OKXSubscriptionArg {
1417 channel: OKXWsChannel::FundingRate,
1418 inst_type: None,
1419 inst_family: None,
1420 inst_id: Some(instrument_id.symbol.inner()),
1421 };
1422 self.subscribe(vec![arg]).await
1423 }
1424
1425 pub async fn subscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1437 let channel = bar_spec_as_okx_channel(bar_type.spec())
1439 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1440
1441 let arg = OKXSubscriptionArg {
1442 channel,
1443 inst_type: None,
1444 inst_family: None,
1445 inst_id: Some(bar_type.instrument_id().symbol.inner()),
1446 };
1447 self.subscribe(vec![arg]).await
1448 }
1449
1450 pub async fn unsubscribe_instruments(
1456 &self,
1457 instrument_type: OKXInstrumentType,
1458 ) -> Result<(), OKXWsError> {
1459 let arg = OKXSubscriptionArg {
1460 channel: OKXWsChannel::Instruments,
1461 inst_type: Some(instrument_type),
1462 inst_family: None,
1463 inst_id: None,
1464 };
1465 self.unsubscribe(vec![arg]).await
1466 }
1467
1468 pub async fn unsubscribe_instrument(
1474 &self,
1475 instrument_id: InstrumentId,
1476 ) -> Result<(), OKXWsError> {
1477 let arg = OKXSubscriptionArg {
1478 channel: OKXWsChannel::Instruments,
1479 inst_type: None,
1480 inst_family: None,
1481 inst_id: Some(instrument_id.symbol.inner()),
1482 };
1483 self.unsubscribe(vec![arg]).await
1484 }
1485
1486 pub async fn unsubscribe_book(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1492 let arg = OKXSubscriptionArg {
1493 channel: OKXWsChannel::Books,
1494 inst_type: None,
1495 inst_family: None,
1496 inst_id: Some(instrument_id.symbol.inner()),
1497 };
1498 self.unsubscribe(vec![arg]).await
1499 }
1500
1501 pub async fn unsubscribe_book_depth5(
1507 &self,
1508 instrument_id: InstrumentId,
1509 ) -> Result<(), OKXWsError> {
1510 let arg = OKXSubscriptionArg {
1511 channel: OKXWsChannel::Books5,
1512 inst_type: None,
1513 inst_family: None,
1514 inst_id: Some(instrument_id.symbol.inner()),
1515 };
1516 self.unsubscribe(vec![arg]).await
1517 }
1518
1519 pub async fn unsubscribe_book50_l2_tbt(
1525 &self,
1526 instrument_id: InstrumentId,
1527 ) -> Result<(), OKXWsError> {
1528 let arg = OKXSubscriptionArg {
1529 channel: OKXWsChannel::Books50Tbt,
1530 inst_type: None,
1531 inst_family: None,
1532 inst_id: Some(instrument_id.symbol.inner()),
1533 };
1534 self.unsubscribe(vec![arg]).await
1535 }
1536
1537 pub async fn unsubscribe_book_l2_tbt(
1543 &self,
1544 instrument_id: InstrumentId,
1545 ) -> Result<(), OKXWsError> {
1546 let arg = OKXSubscriptionArg {
1547 channel: OKXWsChannel::BooksTbt,
1548 inst_type: None,
1549 inst_family: None,
1550 inst_id: Some(instrument_id.symbol.inner()),
1551 };
1552 self.unsubscribe(vec![arg]).await
1553 }
1554
1555 pub async fn unsubscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1561 let arg = OKXSubscriptionArg {
1562 channel: OKXWsChannel::BboTbt,
1563 inst_type: None,
1564 inst_family: None,
1565 inst_id: Some(instrument_id.symbol.inner()),
1566 };
1567 self.unsubscribe(vec![arg]).await
1568 }
1569
1570 pub async fn unsubscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1576 let arg = OKXSubscriptionArg {
1577 channel: OKXWsChannel::Tickers,
1578 inst_type: None,
1579 inst_family: None,
1580 inst_id: Some(instrument_id.symbol.inner()),
1581 };
1582 self.unsubscribe(vec![arg]).await
1583 }
1584
1585 pub async fn unsubscribe_mark_prices(
1591 &self,
1592 instrument_id: InstrumentId,
1593 ) -> Result<(), OKXWsError> {
1594 let arg = OKXSubscriptionArg {
1595 channel: OKXWsChannel::MarkPrice,
1596 inst_type: None,
1597 inst_family: None,
1598 inst_id: Some(instrument_id.symbol.inner()),
1599 };
1600 self.unsubscribe(vec![arg]).await
1601 }
1602
1603 pub async fn unsubscribe_index_prices(
1609 &self,
1610 instrument_id: InstrumentId,
1611 ) -> Result<(), OKXWsError> {
1612 let arg = OKXSubscriptionArg {
1613 channel: OKXWsChannel::IndexTickers,
1614 inst_type: None,
1615 inst_family: None,
1616 inst_id: Some(instrument_id.symbol.inner()),
1617 };
1618 self.unsubscribe(vec![arg]).await
1619 }
1620
1621 pub async fn unsubscribe_funding_rates(
1627 &self,
1628 instrument_id: InstrumentId,
1629 ) -> Result<(), OKXWsError> {
1630 let arg = OKXSubscriptionArg {
1631 channel: OKXWsChannel::FundingRate,
1632 inst_type: None,
1633 inst_family: None,
1634 inst_id: Some(instrument_id.symbol.inner()),
1635 };
1636 self.unsubscribe(vec![arg]).await
1637 }
1638
1639 pub async fn unsubscribe_trades(
1645 &self,
1646 instrument_id: InstrumentId,
1647 aggregated: bool,
1648 ) -> Result<(), OKXWsError> {
1649 let channel = if aggregated {
1650 OKXWsChannel::TradesAll
1651 } else {
1652 OKXWsChannel::Trades
1653 };
1654
1655 let arg = OKXSubscriptionArg {
1656 channel,
1657 inst_type: None,
1658 inst_family: None,
1659 inst_id: Some(instrument_id.symbol.inner()),
1660 };
1661 self.unsubscribe(vec![arg]).await
1662 }
1663
1664 pub async fn unsubscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1670 let channel = bar_spec_as_okx_channel(bar_type.spec())
1672 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1673
1674 let arg = OKXSubscriptionArg {
1675 channel,
1676 inst_type: None,
1677 inst_family: None,
1678 inst_id: Some(bar_type.instrument_id().symbol.inner()),
1679 };
1680 self.unsubscribe(vec![arg]).await
1681 }
1682
1683 pub async fn subscribe_orders(
1689 &self,
1690 instrument_type: OKXInstrumentType,
1691 ) -> Result<(), OKXWsError> {
1692 let arg = OKXSubscriptionArg {
1693 channel: OKXWsChannel::Orders,
1694 inst_type: Some(instrument_type),
1695 inst_family: None,
1696 inst_id: None,
1697 };
1698 self.subscribe(vec![arg]).await
1699 }
1700
1701 pub async fn unsubscribe_orders(
1707 &self,
1708 instrument_type: OKXInstrumentType,
1709 ) -> Result<(), OKXWsError> {
1710 let arg = OKXSubscriptionArg {
1711 channel: OKXWsChannel::Orders,
1712 inst_type: Some(instrument_type),
1713 inst_family: None,
1714 inst_id: None,
1715 };
1716 self.unsubscribe(vec![arg]).await
1717 }
1718
1719 pub async fn subscribe_orders_algo(
1725 &self,
1726 instrument_type: OKXInstrumentType,
1727 ) -> Result<(), OKXWsError> {
1728 let arg = OKXSubscriptionArg {
1729 channel: OKXWsChannel::OrdersAlgo,
1730 inst_type: Some(instrument_type),
1731 inst_family: None,
1732 inst_id: None,
1733 };
1734 self.subscribe(vec![arg]).await
1735 }
1736
1737 pub async fn unsubscribe_orders_algo(
1743 &self,
1744 instrument_type: OKXInstrumentType,
1745 ) -> Result<(), OKXWsError> {
1746 let arg = OKXSubscriptionArg {
1747 channel: OKXWsChannel::OrdersAlgo,
1748 inst_type: Some(instrument_type),
1749 inst_family: None,
1750 inst_id: None,
1751 };
1752 self.unsubscribe(vec![arg]).await
1753 }
1754
1755 pub async fn subscribe_fills(
1761 &self,
1762 instrument_type: OKXInstrumentType,
1763 ) -> Result<(), OKXWsError> {
1764 let arg = OKXSubscriptionArg {
1765 channel: OKXWsChannel::Fills,
1766 inst_type: Some(instrument_type),
1767 inst_family: None,
1768 inst_id: None,
1769 };
1770 self.subscribe(vec![arg]).await
1771 }
1772
1773 pub async fn unsubscribe_fills(
1779 &self,
1780 instrument_type: OKXInstrumentType,
1781 ) -> Result<(), OKXWsError> {
1782 let arg = OKXSubscriptionArg {
1783 channel: OKXWsChannel::Fills,
1784 inst_type: Some(instrument_type),
1785 inst_family: None,
1786 inst_id: None,
1787 };
1788 self.unsubscribe(vec![arg]).await
1789 }
1790
1791 pub async fn subscribe_account(&self) -> Result<(), OKXWsError> {
1797 let arg = OKXSubscriptionArg {
1798 channel: OKXWsChannel::Account,
1799 inst_type: None,
1800 inst_family: None,
1801 inst_id: None,
1802 };
1803 self.subscribe(vec![arg]).await
1804 }
1805
1806 pub async fn unsubscribe_account(&self) -> Result<(), OKXWsError> {
1812 let arg = OKXSubscriptionArg {
1813 channel: OKXWsChannel::Account,
1814 inst_type: None,
1815 inst_family: None,
1816 inst_id: None,
1817 };
1818 self.unsubscribe(vec![arg]).await
1819 }
1820
1821 pub async fn subscribe_positions(
1831 &self,
1832 inst_type: OKXInstrumentType,
1833 ) -> Result<(), OKXWsError> {
1834 let arg = OKXSubscriptionArg {
1835 channel: OKXWsChannel::Positions,
1836 inst_type: Some(inst_type),
1837 inst_family: None,
1838 inst_id: None,
1839 };
1840 self.subscribe(vec![arg]).await
1841 }
1842
1843 pub async fn unsubscribe_positions(
1849 &self,
1850 inst_type: OKXInstrumentType,
1851 ) -> Result<(), OKXWsError> {
1852 let arg = OKXSubscriptionArg {
1853 channel: OKXWsChannel::Positions,
1854 inst_type: Some(inst_type),
1855 inst_family: None,
1856 inst_id: None,
1857 };
1858 self.unsubscribe(vec![arg]).await
1859 }
1860
1861 async fn ws_batch_place_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1867 let request_id = self.generate_unique_request_id();
1868 let cmd = HandlerCommand::BatchPlaceOrders { args, request_id };
1869
1870 self.send_cmd(cmd).await
1871 }
1872
1873 async fn ws_batch_cancel_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1879 let request_id = self.generate_unique_request_id();
1880 let cmd = HandlerCommand::BatchCancelOrders { args, request_id };
1881
1882 self.send_cmd(cmd).await
1883 }
1884
1885 async fn ws_batch_amend_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1891 let request_id = self.generate_unique_request_id();
1892 let cmd = HandlerCommand::BatchAmendOrders { args, request_id };
1893
1894 self.send_cmd(cmd).await
1895 }
1896
1897 #[allow(clippy::too_many_arguments)]
1909 pub async fn submit_order(
1910 &self,
1911 trader_id: TraderId,
1912 strategy_id: StrategyId,
1913 instrument_id: InstrumentId,
1914 td_mode: OKXTradeMode,
1915 client_order_id: ClientOrderId,
1916 order_side: OrderSide,
1917 order_type: OrderType,
1918 quantity: Quantity,
1919 time_in_force: Option<TimeInForce>,
1920 price: Option<Price>,
1921 trigger_price: Option<Price>,
1922 post_only: Option<bool>,
1923 reduce_only: Option<bool>,
1924 quote_quantity: Option<bool>,
1925 position_side: Option<PositionSide>,
1926 ) -> Result<(), OKXWsError> {
1927 if !OKX_SUPPORTED_ORDER_TYPES.contains(&order_type) {
1928 return Err(OKXWsError::ClientError(format!(
1929 "Unsupported order type: {order_type:?}",
1930 )));
1931 }
1932
1933 if let Some(tif) = time_in_force
1934 && !OKX_SUPPORTED_TIME_IN_FORCE.contains(&tif)
1935 {
1936 return Err(OKXWsError::ClientError(format!(
1937 "Unsupported time in force: {tif:?}",
1938 )));
1939 }
1940
1941 let mut builder = WsPostOrderParamsBuilder::default();
1942
1943 builder.inst_id(instrument_id.symbol.as_str());
1944
1945 if let Some(inst_id_code) = self.get_inst_id_code(&instrument_id.symbol.inner()) {
1947 builder.inst_id_code(inst_id_code);
1948 }
1949
1950 builder.td_mode(td_mode);
1951 builder.cl_ord_id(client_order_id.as_str());
1952
1953 let instrument = self
1954 .instruments_cache
1955 .get(&instrument_id.symbol.inner())
1956 .ok_or_else(|| {
1957 OKXWsError::ClientError(format!("Unknown instrument {instrument_id}"))
1958 })?;
1959
1960 let instrument_type =
1961 okx_instrument_type(&instrument).map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1962 let quote_currency = instrument.quote_currency();
1963
1964 match instrument_type {
1965 OKXInstrumentType::Spot => {
1966 builder.ccy(quote_currency.to_string());
1968 }
1969 OKXInstrumentType::Margin => {
1970 builder.ccy(quote_currency.to_string());
1971
1972 if let Some(ro) = reduce_only
1973 && ro
1974 {
1975 builder.reduce_only(ro);
1976 }
1977 }
1978 OKXInstrumentType::Swap | OKXInstrumentType::Futures => {
1979 builder.ccy(quote_currency.to_string());
1981
1982 if position_side.is_none() {
1985 builder.pos_side(OKXPositionSide::Net);
1986 }
1987 }
1988 _ => {
1989 builder.ccy(quote_currency.to_string());
1990
1991 if position_side.is_none() {
1993 builder.pos_side(OKXPositionSide::Net);
1994 }
1995
1996 if let Some(ro) = reduce_only
1997 && ro
1998 {
1999 builder.reduce_only(ro);
2000 }
2001 }
2002 };
2003
2004 if instrument_type == OKXInstrumentType::Spot
2011 && order_type == OrderType::Market
2012 && td_mode == OKXTradeMode::Cash
2013 {
2014 match quote_quantity {
2015 Some(true) => {
2016 builder.tgt_ccy(OKXTargetCurrency::QuoteCcy);
2018 }
2019 Some(false) => {
2020 if order_side == OrderSide::Buy {
2021 builder.tgt_ccy(OKXTargetCurrency::BaseCcy);
2023 }
2024 }
2026 None => {
2027 }
2029 }
2030 }
2031
2032 builder.side(order_side);
2033
2034 if let Some(pos_side) = position_side {
2035 builder.pos_side(pos_side);
2036 };
2037
2038 let (okx_ord_type, price) = if post_only.unwrap_or(false) {
2042 (OKXOrderType::PostOnly, price)
2043 } else if let Some(tif) = time_in_force {
2044 match (order_type, tif) {
2045 (OrderType::Market, TimeInForce::Fok) => {
2046 return Err(OKXWsError::ClientError(
2047 "Market orders with FOK time-in-force are not supported by OKX. Use Limit order with FOK instead.".to_string()
2048 ));
2049 }
2050 (OrderType::Market, TimeInForce::Ioc) => {
2051 if instrument_type == OKXInstrumentType::Spot {
2053 (OKXOrderType::Market, price)
2054 } else {
2055 (OKXOrderType::OptimalLimitIoc, price)
2056 }
2057 }
2058 (OrderType::Limit, TimeInForce::Fok) => (OKXOrderType::Fok, price),
2059 (OrderType::Limit, TimeInForce::Ioc) => (OKXOrderType::Ioc, price),
2060 _ => (OKXOrderType::from(order_type), price),
2061 }
2062 } else {
2063 (OKXOrderType::from(order_type), price)
2064 };
2065
2066 log::debug!(
2067 "Order type mapping: order_type={order_type:?}, time_in_force={time_in_force:?}, post_only={post_only:?} -> okx_ord_type={okx_ord_type:?}"
2068 );
2069
2070 builder.ord_type(okx_ord_type);
2071 builder.sz(quantity.to_string());
2072
2073 if let Some(tp) = trigger_price {
2074 builder.px(tp.to_string());
2075 } else if let Some(p) = price {
2076 builder.px(p.to_string());
2077 }
2078
2079 builder.tag(OKX_NAUTILUS_BROKER_ID);
2080
2081 let params = builder
2082 .build()
2083 .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
2084
2085 self.active_client_orders
2086 .insert(client_order_id, (trader_id, strategy_id, instrument_id));
2087
2088 let cmd = HandlerCommand::PlaceOrder {
2089 params,
2090 client_order_id,
2091 trader_id,
2092 strategy_id,
2093 instrument_id,
2094 };
2095
2096 self.send_cmd(cmd).await
2097 }
2098
2099 #[allow(clippy::too_many_arguments)]
2115 pub async fn modify_order(
2116 &self,
2117 trader_id: TraderId,
2118 strategy_id: StrategyId,
2119 instrument_id: InstrumentId,
2120 client_order_id: Option<ClientOrderId>,
2121 price: Option<Price>,
2122 quantity: Option<Quantity>,
2123 venue_order_id: Option<VenueOrderId>,
2124 ) -> Result<(), OKXWsError> {
2125 let mut builder = WsAmendOrderParamsBuilder::default();
2126
2127 builder.inst_id(instrument_id.symbol.as_str());
2128
2129 if let Some(inst_id_code) = self.get_inst_id_code(&instrument_id.symbol.inner()) {
2131 builder.inst_id_code(inst_id_code);
2132 }
2133
2134 if let Some(venue_order_id) = venue_order_id {
2135 builder.ord_id(venue_order_id.as_str());
2136 }
2137
2138 if let Some(client_order_id) = client_order_id {
2139 builder.cl_ord_id(client_order_id.as_str());
2140 }
2141
2142 if let Some(price) = price {
2143 builder.new_px(price.to_string());
2144 }
2145
2146 if let Some(quantity) = quantity {
2147 builder.new_sz(quantity.to_string());
2148 }
2149
2150 let params = builder
2151 .build()
2152 .map_err(|e| OKXWsError::ClientError(format!("Build amend params error: {e}")))?;
2153
2154 if let Some(client_order_id) = client_order_id {
2157 let cmd = HandlerCommand::AmendOrder {
2158 params,
2159 client_order_id,
2160 trader_id,
2161 strategy_id,
2162 instrument_id,
2163 venue_order_id,
2164 };
2165
2166 self.send_cmd(cmd).await
2167 } else {
2168 Err(OKXWsError::ClientError(
2170 "Cannot amend order without client_order_id".to_string(),
2171 ))
2172 }
2173 }
2174
2175 #[allow(clippy::too_many_arguments)]
2186 pub async fn cancel_order(
2187 &self,
2188 trader_id: TraderId,
2189 strategy_id: StrategyId,
2190 instrument_id: InstrumentId,
2191 client_order_id: Option<ClientOrderId>,
2192 venue_order_id: Option<VenueOrderId>,
2193 ) -> Result<(), OKXWsError> {
2194 let cmd = HandlerCommand::CancelOrder {
2195 client_order_id,
2196 venue_order_id,
2197 instrument_id,
2198 trader_id,
2199 strategy_id,
2200 };
2201
2202 self.send_cmd(cmd).await
2203 }
2204
2205 pub async fn mass_cancel_orders(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
2215 let cmd = HandlerCommand::MassCancel { instrument_id };
2216
2217 self.send_cmd(cmd).await
2218 }
2219
2220 #[allow(clippy::type_complexity)]
2227 #[allow(clippy::too_many_arguments)]
2228 pub async fn batch_submit_orders(
2229 &self,
2230 orders: Vec<(
2231 OKXInstrumentType,
2232 InstrumentId,
2233 OKXTradeMode,
2234 ClientOrderId,
2235 OrderSide,
2236 Option<PositionSide>,
2237 OrderType,
2238 Quantity,
2239 Option<Price>,
2240 Option<Price>,
2241 Option<bool>,
2242 Option<bool>,
2243 )>,
2244 ) -> Result<(), OKXWsError> {
2245 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2246 for (
2247 inst_type,
2248 inst_id,
2249 td_mode,
2250 cl_ord_id,
2251 ord_side,
2252 pos_side,
2253 ord_type,
2254 qty,
2255 pr,
2256 tp,
2257 post_only,
2258 reduce_only,
2259 ) in orders
2260 {
2261 let mut builder = WsPostOrderParamsBuilder::default();
2262 builder.inst_type(inst_type);
2263 builder.inst_id(inst_id.symbol.inner());
2264
2265 if let Some(inst_id_code) = self.get_inst_id_code(&inst_id.symbol.inner()) {
2267 builder.inst_id_code(inst_id_code);
2268 }
2269
2270 builder.td_mode(td_mode);
2271 builder.cl_ord_id(cl_ord_id.as_str());
2272 builder.side(ord_side);
2273
2274 if let Some(ps) = pos_side {
2275 builder.pos_side(OKXPositionSide::from(ps));
2276 }
2277
2278 let okx_ord_type = if post_only.unwrap_or(false) {
2279 OKXOrderType::PostOnly
2280 } else {
2281 OKXOrderType::from(ord_type)
2282 };
2283
2284 builder.ord_type(okx_ord_type);
2285 builder.sz(qty.to_string());
2286
2287 if let Some(p) = pr {
2288 builder.px(p.to_string());
2289 } else if let Some(p) = tp {
2290 builder.px(p.to_string());
2291 }
2292
2293 if let Some(ro) = reduce_only {
2294 builder.reduce_only(ro);
2295 }
2296
2297 builder.tag(OKX_NAUTILUS_BROKER_ID);
2298
2299 let params = builder
2300 .build()
2301 .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
2302 let val =
2303 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2304 args.push(val);
2305 }
2306
2307 self.ws_batch_place_orders(args).await
2308 }
2309
2310 #[allow(clippy::type_complexity)]
2317 #[allow(clippy::too_many_arguments)]
2318 pub async fn batch_modify_orders(
2319 &self,
2320 orders: Vec<(
2321 OKXInstrumentType,
2322 InstrumentId,
2323 ClientOrderId,
2324 ClientOrderId,
2325 Option<Price>,
2326 Option<Quantity>,
2327 )>,
2328 ) -> Result<(), OKXWsError> {
2329 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2330 for (_inst_type, inst_id, cl_ord_id, new_cl_ord_id, pr, sz) in orders {
2331 let mut builder = WsAmendOrderParamsBuilder::default();
2332 builder.inst_id(inst_id.symbol.inner());
2334
2335 if let Some(inst_id_code) = self.get_inst_id_code(&inst_id.symbol.inner()) {
2337 builder.inst_id_code(inst_id_code);
2338 }
2339
2340 builder.cl_ord_id(cl_ord_id.as_str());
2341 builder.new_cl_ord_id(new_cl_ord_id.as_str());
2342
2343 if let Some(p) = pr {
2344 builder.new_px(p.to_string());
2345 }
2346
2347 if let Some(q) = sz {
2348 builder.new_sz(q.to_string());
2349 }
2350
2351 let params = builder.build().map_err(|e| {
2352 OKXWsError::ClientError(format!("Build amend batch params error: {e}"))
2353 })?;
2354 let val =
2355 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2356 args.push(val);
2357 }
2358
2359 self.ws_batch_amend_orders(args).await
2360 }
2361
2362 #[allow(clippy::type_complexity)]
2375 pub async fn batch_cancel_orders(
2376 &self,
2377 orders: Vec<(InstrumentId, Option<ClientOrderId>, Option<VenueOrderId>)>,
2378 ) -> Result<(), OKXWsError> {
2379 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2380 for (inst_id, cl_ord_id, ord_id) in orders {
2381 let mut builder = WsCancelOrderParamsBuilder::default();
2382 builder.inst_id(inst_id.symbol.inner());
2384
2385 if let Some(inst_id_code) = self.get_inst_id_code(&inst_id.symbol.inner()) {
2387 builder.inst_id_code(inst_id_code);
2388 }
2389
2390 if let Some(c) = cl_ord_id {
2391 builder.cl_ord_id(c.as_str());
2392 }
2393
2394 if let Some(o) = ord_id {
2395 builder.ord_id(o.as_str());
2396 }
2397
2398 let params = builder.build().map_err(|e| {
2399 OKXWsError::ClientError(format!("Build cancel batch params error: {e}"))
2400 })?;
2401 let val =
2402 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2403 args.push(val);
2404 }
2405
2406 self.ws_batch_cancel_orders(args).await
2407 }
2408
2409 #[allow(clippy::too_many_arguments)]
2420 pub async fn submit_algo_order(
2421 &self,
2422 trader_id: TraderId,
2423 strategy_id: StrategyId,
2424 instrument_id: InstrumentId,
2425 td_mode: OKXTradeMode,
2426 client_order_id: ClientOrderId,
2427 order_side: OrderSide,
2428 order_type: OrderType,
2429 quantity: Quantity,
2430 trigger_price: Price,
2431 trigger_type: Option<TriggerType>,
2432 limit_price: Option<Price>,
2433 reduce_only: Option<bool>,
2434 ) -> Result<(), OKXWsError> {
2435 if !is_conditional_order(order_type) {
2436 return Err(OKXWsError::ClientError(format!(
2437 "Order type {order_type:?} is not a conditional order"
2438 )));
2439 }
2440
2441 let mut builder = WsPostAlgoOrderParamsBuilder::default();
2442 if !matches!(order_side, OrderSide::Buy | OrderSide::Sell) {
2443 return Err(OKXWsError::ClientError(
2444 "Invalid order side for OKX".to_string(),
2445 ));
2446 }
2447
2448 builder.inst_id(instrument_id.symbol.inner());
2449
2450 if let Some(inst_id_code) = self.get_inst_id_code(&instrument_id.symbol.inner()) {
2452 builder.inst_id_code(inst_id_code);
2453 }
2454
2455 builder.td_mode(td_mode);
2456 builder.cl_ord_id(client_order_id.as_str());
2457 builder.side(order_side);
2458 builder.ord_type(
2459 conditional_order_to_algo_type(order_type)
2460 .map_err(|e| OKXWsError::ClientError(e.to_string()))?,
2461 );
2462 builder.sz(quantity.to_string());
2463 builder.trigger_px(trigger_price.to_string());
2464
2465 let okx_trigger_type = trigger_type.map_or(OKXTriggerType::Last, Into::into);
2467 builder.trigger_px_type(okx_trigger_type);
2468
2469 if matches!(order_type, OrderType::StopLimit | OrderType::LimitIfTouched)
2471 && let Some(price) = limit_price
2472 {
2473 builder.order_px(price.to_string());
2474 }
2475
2476 if let Some(reduce) = reduce_only {
2477 builder.reduce_only(reduce);
2478 }
2479
2480 builder.tag(OKX_NAUTILUS_BROKER_ID);
2481
2482 let params = builder
2483 .build()
2484 .map_err(|e| OKXWsError::ClientError(format!("Build algo order params error: {e}")))?;
2485
2486 self.active_client_orders
2487 .insert(client_order_id, (trader_id, strategy_id, instrument_id));
2488
2489 let cmd = HandlerCommand::PlaceAlgoOrder {
2490 params,
2491 client_order_id,
2492 trader_id,
2493 strategy_id,
2494 instrument_id,
2495 };
2496
2497 self.send_cmd(cmd).await
2498 }
2499
2500 pub async fn cancel_algo_order(
2511 &self,
2512 trader_id: TraderId,
2513 strategy_id: StrategyId,
2514 instrument_id: InstrumentId,
2515 client_order_id: Option<ClientOrderId>,
2516 algo_order_id: Option<String>,
2517 ) -> Result<(), OKXWsError> {
2518 let cmd = HandlerCommand::CancelAlgoOrder {
2519 client_order_id,
2520 algo_order_id: algo_order_id.map(|id| VenueOrderId::from(id.as_str())),
2521 instrument_id,
2522 trader_id,
2523 strategy_id,
2524 };
2525
2526 self.send_cmd(cmd).await
2527 }
2528
2529 async fn send_cmd(&self, cmd: HandlerCommand) -> Result<(), OKXWsError> {
2531 self.cmd_tx
2532 .read()
2533 .await
2534 .send(cmd)
2535 .map_err(|e| OKXWsError::ClientError(format!("Handler not available: {e}")))
2536 }
2537}
2538
2539#[cfg(test)]
2540mod tests {
2541 use nautilus_core::time::get_atomic_clock_realtime;
2542 use nautilus_network::RECONNECTED;
2543 use rstest::rstest;
2544 use tokio_tungstenite::tungstenite::Message;
2545
2546 use super::*;
2547 use crate::{
2548 common::{
2549 consts::OKX_POST_ONLY_CANCEL_SOURCE,
2550 enums::{OKXExecType, OKXOrderCategory, OKXOrderStatus, OKXSide},
2551 },
2552 websocket::{
2553 handler::OKXWsFeedHandler,
2554 messages::{OKXOrderMsg, OKXWebSocketError, OKXWsMessage},
2555 },
2556 };
2557
2558 #[rstest]
2559 fn test_timestamp_format_for_websocket_auth() {
2560 let timestamp = SystemTime::now()
2561 .duration_since(SystemTime::UNIX_EPOCH)
2562 .expect("System time should be after UNIX epoch")
2563 .as_secs()
2564 .to_string();
2565
2566 assert!(timestamp.parse::<u64>().is_ok());
2567 assert_eq!(timestamp.len(), 10);
2568 assert!(timestamp.chars().all(|c| c.is_ascii_digit()));
2569 }
2570
2571 #[rstest]
2572 fn test_new_without_credentials() {
2573 let client = OKXWebSocketClient::default();
2574 assert!(client.credential.is_none());
2575 assert_eq!(client.api_key(), None);
2576 }
2577
2578 #[rstest]
2579 fn test_new_with_credentials() {
2580 let client = OKXWebSocketClient::new(
2581 None,
2582 Some("test_key".to_string()),
2583 Some("test_secret".to_string()),
2584 Some("test_passphrase".to_string()),
2585 None,
2586 None,
2587 )
2588 .unwrap();
2589 assert!(client.credential.is_some());
2590 assert_eq!(client.api_key(), Some("test_key"));
2591 }
2592
2593 #[rstest]
2594 fn test_new_partial_credentials_fails() {
2595 let result = OKXWebSocketClient::new(
2596 None,
2597 Some("test_key".to_string()),
2598 None,
2599 Some("test_passphrase".to_string()),
2600 None,
2601 None,
2602 );
2603 assert!(result.is_err());
2604 }
2605
2606 #[rstest]
2607 fn test_request_id_generation() {
2608 let client = OKXWebSocketClient::default();
2609
2610 let initial_counter = client.request_id_counter.load(Ordering::SeqCst);
2611
2612 let id1 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
2613 let id2 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
2614
2615 assert_eq!(id1, initial_counter);
2616 assert_eq!(id2, initial_counter + 1);
2617 assert_eq!(
2618 client.request_id_counter.load(Ordering::SeqCst),
2619 initial_counter + 2
2620 );
2621 }
2622
2623 #[rstest]
2624 fn test_client_state_management() {
2625 let client = OKXWebSocketClient::default();
2626
2627 assert!(client.is_closed());
2628 assert!(!client.is_active());
2629
2630 let client_with_heartbeat =
2631 OKXWebSocketClient::new(None, None, None, None, None, Some(30)).unwrap();
2632
2633 assert!(client_with_heartbeat.heartbeat.is_some());
2634 assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
2635 }
2636
2637 #[rstest]
2642 fn test_websocket_error_handling() {
2643 let clock = get_atomic_clock_realtime();
2644 let ts = clock.get_time_ns().as_u64();
2645
2646 let error = OKXWebSocketError {
2647 code: "60012".to_string(),
2648 message: "Invalid request".to_string(),
2649 conn_id: None,
2650 timestamp: ts,
2651 };
2652
2653 assert_eq!(error.code, "60012");
2654 assert_eq!(error.message, "Invalid request");
2655 assert_eq!(error.timestamp, ts);
2656
2657 let nautilus_msg = NautilusWsMessage::Error(error);
2658 match nautilus_msg {
2659 NautilusWsMessage::Error(e) => {
2660 assert_eq!(e.code, "60012");
2661 assert_eq!(e.message, "Invalid request");
2662 }
2663 _ => panic!("Expected Error variant"),
2664 }
2665 }
2666
2667 #[rstest]
2668 fn test_request_id_generation_sequence() {
2669 let client = OKXWebSocketClient::default();
2670
2671 let initial_counter = client
2672 .request_id_counter
2673 .load(std::sync::atomic::Ordering::SeqCst);
2674 let mut ids = Vec::new();
2675 for _ in 0..10 {
2676 let id = client
2677 .request_id_counter
2678 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2679 ids.push(id);
2680 }
2681
2682 for (i, &id) in ids.iter().enumerate() {
2683 assert_eq!(id, initial_counter + i as u64);
2684 }
2685
2686 assert_eq!(
2687 client
2688 .request_id_counter
2689 .load(std::sync::atomic::Ordering::SeqCst),
2690 initial_counter + 10
2691 );
2692 }
2693
2694 #[rstest]
2695 fn test_client_state_transitions() {
2696 let client = OKXWebSocketClient::default();
2697
2698 assert!(client.is_closed());
2699 assert!(!client.is_active());
2700
2701 let client_with_heartbeat = OKXWebSocketClient::new(
2702 None,
2703 None,
2704 None,
2705 None,
2706 None,
2707 Some(30), )
2709 .unwrap();
2710
2711 assert!(client_with_heartbeat.heartbeat.is_some());
2712 assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
2713
2714 let account_id = AccountId::from("test-account-123");
2715 let client_with_account =
2716 OKXWebSocketClient::new(None, None, None, None, Some(account_id), None).unwrap();
2717
2718 assert_eq!(client_with_account.account_id, account_id);
2719 }
2720
2721 #[rstest]
2722 fn test_websocket_error_scenarios() {
2723 let clock = get_atomic_clock_realtime();
2724 let ts = clock.get_time_ns().as_u64();
2725
2726 let error_scenarios = vec![
2727 ("60012", "Invalid request", None),
2728 ("60009", "Invalid API key", Some("conn-123".to_string())),
2729 ("60014", "Too many requests", None),
2730 ("50001", "Order not found", None),
2731 ];
2732
2733 for (code, message, conn_id) in error_scenarios {
2734 let error = OKXWebSocketError {
2735 code: code.to_string(),
2736 message: message.to_string(),
2737 conn_id: conn_id.clone(),
2738 timestamp: ts,
2739 };
2740
2741 assert_eq!(error.code, code);
2742 assert_eq!(error.message, message);
2743 assert_eq!(error.conn_id, conn_id);
2744 assert_eq!(error.timestamp, ts);
2745
2746 let nautilus_msg = NautilusWsMessage::Error(error);
2747 match nautilus_msg {
2748 NautilusWsMessage::Error(e) => {
2749 assert_eq!(e.code, code);
2750 assert_eq!(e.message, message);
2751 assert_eq!(e.conn_id, conn_id);
2752 }
2753 _ => panic!("Expected Error variant"),
2754 }
2755 }
2756 }
2757
2758 #[rstest]
2759 fn test_feed_handler_reconnection_detection() {
2760 let msg = Message::Text(RECONNECTED.to_string().into());
2761 let result = OKXWsFeedHandler::parse_raw_message(msg);
2762 assert!(matches!(result, Some(OKXWsMessage::Reconnected)));
2763 }
2764
2765 #[rstest]
2766 fn test_feed_handler_normal_message_processing() {
2767 let ping_msg = Message::Text(TEXT_PING.to_string().into());
2769 let result = OKXWsFeedHandler::parse_raw_message(ping_msg);
2770 assert!(matches!(result, Some(OKXWsMessage::Ping)));
2771
2772 let sub_msg = r#"{
2774 "event": "subscribe",
2775 "arg": {
2776 "channel": "tickers",
2777 "instType": "SPOT"
2778 },
2779 "connId": "a4d3ae55"
2780 }"#;
2781
2782 let sub_result =
2783 OKXWsFeedHandler::parse_raw_message(Message::Text(sub_msg.to_string().into()));
2784 assert!(matches!(
2785 sub_result,
2786 Some(OKXWsMessage::Subscription { .. })
2787 ));
2788 }
2789
2790 #[rstest]
2791 fn test_feed_handler_close_message() {
2792 let result = OKXWsFeedHandler::parse_raw_message(Message::Close(None));
2794 assert!(result.is_none());
2795 }
2796
2797 #[rstest]
2798 fn test_reconnection_message_constant() {
2799 assert_eq!(RECONNECTED, "__RECONNECTED__");
2800 }
2801
2802 #[rstest]
2803 fn test_multiple_reconnection_signals() {
2804 for _ in 0..3 {
2806 let msg = Message::Text(RECONNECTED.to_string().into());
2807 let result = OKXWsFeedHandler::parse_raw_message(msg);
2808 assert!(matches!(result, Some(OKXWsMessage::Reconnected)));
2809 }
2810 }
2811
2812 #[tokio::test]
2813 async fn test_wait_until_active_timeout() {
2814 let client = OKXWebSocketClient::new(
2815 None,
2816 Some("test_key".to_string()),
2817 Some("test_secret".to_string()),
2818 Some("test_passphrase".to_string()),
2819 Some(AccountId::from("test-account")),
2820 None,
2821 )
2822 .unwrap();
2823
2824 let result = client.wait_until_active(0.1).await;
2826
2827 assert!(result.is_err());
2828 assert!(!client.is_active());
2829 }
2830
2831 fn sample_canceled_order_msg() -> OKXOrderMsg {
2832 OKXOrderMsg {
2833 acc_fill_sz: Some("0".to_string()),
2834 avg_px: "0".to_string(),
2835 c_time: 0,
2836 cancel_source: None,
2837 cancel_source_reason: None,
2838 category: OKXOrderCategory::Normal,
2839 ccy: Ustr::from("USDT"),
2840 cl_ord_id: "order-1".to_string(),
2841 algo_cl_ord_id: None,
2842 fee: None,
2843 fee_ccy: Ustr::from("USDT"),
2844 fill_px: "0".to_string(),
2845 fill_sz: "0".to_string(),
2846 fill_time: 0,
2847 inst_id: Ustr::from("ETH-USDT-SWAP"),
2848 inst_type: OKXInstrumentType::Swap,
2849 lever: "1".to_string(),
2850 ord_id: Ustr::from("123456"),
2851 ord_type: OKXOrderType::Limit,
2852 pnl: "0".to_string(),
2853 pos_side: OKXPositionSide::Net,
2854 px: "0".to_string(),
2855 reduce_only: "false".to_string(),
2856 side: OKXSide::Buy,
2857 state: OKXOrderStatus::Canceled,
2858 exec_type: OKXExecType::None,
2859 sz: "1".to_string(),
2860 td_mode: OKXTradeMode::Cross,
2861 tgt_ccy: None,
2862 trade_id: String::new(),
2863 u_time: 0,
2864 }
2865 }
2866
2867 #[rstest]
2868 fn test_is_post_only_auto_cancel_detects_cancel_source() {
2869 let mut msg = sample_canceled_order_msg();
2870 msg.cancel_source = Some(OKX_POST_ONLY_CANCEL_SOURCE.to_string());
2871
2872 assert!(OKXWsFeedHandler::is_post_only_auto_cancel(&msg));
2873 }
2874
2875 #[rstest]
2876 fn test_is_post_only_auto_cancel_detects_reason() {
2877 let mut msg = sample_canceled_order_msg();
2878 msg.cancel_source_reason = Some("POST_ONLY would take liquidity".to_string());
2879
2880 assert!(OKXWsFeedHandler::is_post_only_auto_cancel(&msg));
2881 }
2882
2883 #[rstest]
2884 fn test_is_post_only_auto_cancel_false_without_markers() {
2885 let msg = sample_canceled_order_msg();
2886
2887 assert!(!OKXWsFeedHandler::is_post_only_auto_cancel(&msg));
2888 }
2889
2890 #[rstest]
2891 fn test_is_post_only_auto_cancel_false_for_order_type_only() {
2892 let mut msg = sample_canceled_order_msg();
2893 msg.ord_type = OKXOrderType::PostOnly;
2894
2895 assert!(!OKXWsFeedHandler::is_post_only_auto_cancel(&msg));
2896 }
2897
2898 #[tokio::test]
2899 async fn test_batch_cancel_orders_with_multiple_orders() {
2900 use nautilus_model::identifiers::{ClientOrderId, InstrumentId, VenueOrderId};
2901
2902 let client = OKXWebSocketClient::new(
2903 Some("wss://test.okx.com".to_string()),
2904 None,
2905 None,
2906 None,
2907 None,
2908 None,
2909 )
2910 .expect("Failed to create client");
2911
2912 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
2913 let client_order_id1 = ClientOrderId::new("order1");
2914 let client_order_id2 = ClientOrderId::new("order2");
2915 let venue_order_id1 = VenueOrderId::new("venue1");
2916 let venue_order_id2 = VenueOrderId::new("venue2");
2917
2918 let orders = vec![
2919 (instrument_id, Some(client_order_id1), Some(venue_order_id1)),
2920 (instrument_id, Some(client_order_id2), Some(venue_order_id2)),
2921 ];
2922
2923 let result = client.batch_cancel_orders(orders).await;
2925
2926 assert!(result.is_err());
2928 }
2929
2930 #[tokio::test]
2931 async fn test_batch_cancel_orders_with_only_client_order_id() {
2932 use nautilus_model::identifiers::{ClientOrderId, InstrumentId};
2933
2934 let client = OKXWebSocketClient::new(
2935 Some("wss://test.okx.com".to_string()),
2936 None,
2937 None,
2938 None,
2939 None,
2940 None,
2941 )
2942 .expect("Failed to create client");
2943
2944 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
2945 let client_order_id = ClientOrderId::new("order1");
2946
2947 let orders = vec![(instrument_id, Some(client_order_id), None)];
2948
2949 let result = client.batch_cancel_orders(orders).await;
2950
2951 assert!(result.is_err());
2953 }
2954
2955 #[tokio::test]
2956 async fn test_batch_cancel_orders_with_only_venue_order_id() {
2957 use nautilus_model::identifiers::{InstrumentId, VenueOrderId};
2958
2959 let client = OKXWebSocketClient::new(
2960 Some("wss://test.okx.com".to_string()),
2961 None,
2962 None,
2963 None,
2964 None,
2965 None,
2966 )
2967 .expect("Failed to create client");
2968
2969 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
2970 let venue_order_id = VenueOrderId::new("venue1");
2971
2972 let orders = vec![(instrument_id, None, Some(venue_order_id))];
2973
2974 let result = client.batch_cancel_orders(orders).await;
2975
2976 assert!(result.is_err());
2978 }
2979
2980 #[tokio::test]
2981 async fn test_batch_cancel_orders_with_both_ids() {
2982 use nautilus_model::identifiers::{ClientOrderId, InstrumentId, VenueOrderId};
2983
2984 let client = OKXWebSocketClient::new(
2985 Some("wss://test.okx.com".to_string()),
2986 None,
2987 None,
2988 None,
2989 None,
2990 None,
2991 )
2992 .expect("Failed to create client");
2993
2994 let instrument_id = InstrumentId::from("BTC-USDT-SWAP.OKX");
2995 let client_order_id = ClientOrderId::new("order1");
2996 let venue_order_id = VenueOrderId::new("venue1");
2997
2998 let orders = vec![(instrument_id, Some(client_order_id), Some(venue_order_id))];
2999
3000 let result = client.batch_cancel_orders(orders).await;
3001
3002 assert!(result.is_err());
3004 }
3005
3006 #[rstest]
3007 fn test_race_unsubscribe_failure_recovery() {
3008 let client = OKXWebSocketClient::new(
3014 Some("wss://test.okx.com".to_string()),
3015 None,
3016 None,
3017 None,
3018 None,
3019 None,
3020 )
3021 .expect("Failed to create client");
3022
3023 let topic = "trades:BTC-USDT-SWAP";
3024
3025 client.subscriptions_state.mark_subscribe(topic);
3027 client.subscriptions_state.confirm_subscribe(topic);
3028 assert_eq!(client.subscriptions_state.len(), 1);
3029
3030 client.subscriptions_state.mark_unsubscribe(topic);
3032 assert_eq!(client.subscriptions_state.len(), 0);
3033 assert_eq!(
3034 client.subscriptions_state.pending_unsubscribe_topics(),
3035 vec![topic]
3036 );
3037
3038 client.subscriptions_state.confirm_unsubscribe(topic); client.subscriptions_state.mark_subscribe(topic); client.subscriptions_state.confirm_subscribe(topic); assert_eq!(client.subscriptions_state.len(), 1);
3046 assert!(
3047 client
3048 .subscriptions_state
3049 .pending_unsubscribe_topics()
3050 .is_empty()
3051 );
3052 assert!(
3053 client
3054 .subscriptions_state
3055 .pending_subscribe_topics()
3056 .is_empty()
3057 );
3058
3059 let all = client.subscriptions_state.all_topics();
3061 assert_eq!(all.len(), 1);
3062 assert!(all.contains(&topic.to_string()));
3063 }
3064
3065 #[rstest]
3066 fn test_race_resubscribe_before_unsubscribe_ack() {
3067 let client = OKXWebSocketClient::new(
3071 Some("wss://test.okx.com".to_string()),
3072 None,
3073 None,
3074 None,
3075 None,
3076 None,
3077 )
3078 .expect("Failed to create client");
3079
3080 let topic = "books:BTC-USDT";
3081
3082 client.subscriptions_state.mark_subscribe(topic);
3084 client.subscriptions_state.confirm_subscribe(topic);
3085 assert_eq!(client.subscriptions_state.len(), 1);
3086
3087 client.subscriptions_state.mark_unsubscribe(topic);
3089 assert_eq!(client.subscriptions_state.len(), 0);
3090 assert_eq!(
3091 client.subscriptions_state.pending_unsubscribe_topics(),
3092 vec![topic]
3093 );
3094
3095 client.subscriptions_state.mark_subscribe(topic);
3097 assert_eq!(
3098 client.subscriptions_state.pending_subscribe_topics(),
3099 vec![topic]
3100 );
3101
3102 client.subscriptions_state.confirm_unsubscribe(topic);
3104 assert!(
3105 client
3106 .subscriptions_state
3107 .pending_unsubscribe_topics()
3108 .is_empty()
3109 );
3110 assert_eq!(
3111 client.subscriptions_state.pending_subscribe_topics(),
3112 vec![topic]
3113 );
3114
3115 client.subscriptions_state.confirm_subscribe(topic);
3117 assert_eq!(client.subscriptions_state.len(), 1);
3118 assert!(
3119 client
3120 .subscriptions_state
3121 .pending_subscribe_topics()
3122 .is_empty()
3123 );
3124
3125 let all = client.subscriptions_state.all_topics();
3127 assert_eq!(all.len(), 1);
3128 assert!(all.contains(&topic.to_string()));
3129 }
3130
3131 #[rstest]
3132 fn test_race_late_subscribe_confirmation_after_unsubscribe() {
3133 let client = OKXWebSocketClient::new(
3136 Some("wss://test.okx.com".to_string()),
3137 None,
3138 None,
3139 None,
3140 None,
3141 None,
3142 )
3143 .expect("Failed to create client");
3144
3145 let topic = "tickers:ETH-USDT";
3146
3147 client.subscriptions_state.mark_subscribe(topic);
3149 assert_eq!(
3150 client.subscriptions_state.pending_subscribe_topics(),
3151 vec![topic]
3152 );
3153
3154 client.subscriptions_state.mark_unsubscribe(topic);
3156 assert!(
3157 client
3158 .subscriptions_state
3159 .pending_subscribe_topics()
3160 .is_empty()
3161 ); assert_eq!(
3163 client.subscriptions_state.pending_unsubscribe_topics(),
3164 vec![topic]
3165 );
3166
3167 client.subscriptions_state.confirm_subscribe(topic);
3169 assert_eq!(client.subscriptions_state.len(), 0); assert_eq!(
3171 client.subscriptions_state.pending_unsubscribe_topics(),
3172 vec![topic]
3173 );
3174
3175 client.subscriptions_state.confirm_unsubscribe(topic);
3177
3178 assert!(client.subscriptions_state.is_empty());
3180 assert!(client.subscriptions_state.all_topics().is_empty());
3181 }
3182
3183 #[rstest]
3184 fn test_race_reconnection_with_pending_states() {
3185 let client = OKXWebSocketClient::new(
3187 Some("wss://test.okx.com".to_string()),
3188 Some("test_key".to_string()),
3189 Some("test_secret".to_string()),
3190 Some("test_passphrase".to_string()),
3191 Some(AccountId::new("OKX-TEST")),
3192 None,
3193 )
3194 .expect("Failed to create client");
3195
3196 let trade_btc = "trades:BTC-USDT-SWAP";
3199 client.subscriptions_state.mark_subscribe(trade_btc);
3200 client.subscriptions_state.confirm_subscribe(trade_btc);
3201
3202 let trade_eth = "trades:ETH-USDT-SWAP";
3204 client.subscriptions_state.mark_subscribe(trade_eth);
3205
3206 let book_btc = "books:BTC-USDT";
3208 client.subscriptions_state.mark_subscribe(book_btc);
3209 client.subscriptions_state.confirm_subscribe(book_btc);
3210 client.subscriptions_state.mark_unsubscribe(book_btc);
3211
3212 let topics_to_restore = client.subscriptions_state.all_topics();
3214
3215 assert_eq!(topics_to_restore.len(), 2);
3217 assert!(topics_to_restore.contains(&trade_btc.to_string()));
3218 assert!(topics_to_restore.contains(&trade_eth.to_string()));
3219 assert!(!topics_to_restore.contains(&book_btc.to_string())); }
3221
3222 #[rstest]
3223 fn test_race_duplicate_subscribe_messages_idempotent() {
3224 let client = OKXWebSocketClient::new(
3227 Some("wss://test.okx.com".to_string()),
3228 None,
3229 None,
3230 None,
3231 None,
3232 None,
3233 )
3234 .expect("Failed to create client");
3235
3236 let topic = "trades:BTC-USDT-SWAP";
3237
3238 client.subscriptions_state.mark_subscribe(topic);
3240 client.subscriptions_state.confirm_subscribe(topic);
3241 assert_eq!(client.subscriptions_state.len(), 1);
3242
3243 client.subscriptions_state.mark_subscribe(topic);
3245 assert!(
3246 client
3247 .subscriptions_state
3248 .pending_subscribe_topics()
3249 .is_empty()
3250 ); assert_eq!(client.subscriptions_state.len(), 1); client.subscriptions_state.confirm_subscribe(topic);
3255 assert_eq!(client.subscriptions_state.len(), 1);
3256
3257 let all = client.subscriptions_state.all_topics();
3259 assert_eq!(all.len(), 1);
3260 assert_eq!(all[0], topic);
3261 }
3262}