1use std::{
24 fmt::Debug,
25 num::NonZeroU32,
26 sync::{
27 Arc, LazyLock,
28 atomic::{AtomicBool, AtomicU64, Ordering},
29 },
30 time::{Duration, SystemTime},
31};
32
33use ahash::{AHashMap, AHashSet};
34use dashmap::DashMap;
35use futures_util::Stream;
36use nautilus_common::runtime::get_runtime;
37use nautilus_core::{
38 UUID4, consts::NAUTILUS_USER_AGENT, env::get_env_var, time::get_atomic_clock_realtime,
39};
40use nautilus_model::{
41 data::BarType,
42 enums::{OrderSide, OrderStatus, OrderType, PositionSide, TimeInForce},
43 events::{AccountState, OrderCancelRejected, OrderModifyRejected, OrderRejected},
44 identifiers::{AccountId, ClientOrderId, InstrumentId, StrategyId, TraderId, VenueOrderId},
45 instruments::{Instrument, InstrumentAny},
46 types::{Money, Price, Quantity},
47};
48use nautilus_network::{
49 RECONNECTED,
50 ratelimiter::quota::Quota,
51 retry::{RetryManager, create_websocket_retry_manager},
52 websocket::{WebSocketClient, WebSocketConfig, channel_message_handler},
53};
54use reqwest::header::USER_AGENT;
55use serde_json::Value;
56use tokio::sync::mpsc::UnboundedReceiver;
57use tokio_tungstenite::tungstenite::{Error, Message};
58use tokio_util::sync::CancellationToken;
59use ustr::Ustr;
60
61use super::{
62 enums::{OKXWsChannel, OKXWsOperation},
63 error::OKXWsError,
64 messages::{
65 ExecutionReport, NautilusWsMessage, OKXAuthentication, OKXAuthenticationArg,
66 OKXSubscription, OKXSubscriptionArg, OKXWebSocketError, OKXWebSocketEvent, OKXWsRequest,
67 WsAmendOrderParams, WsAmendOrderParamsBuilder, WsCancelOrderParams,
68 WsCancelOrderParamsBuilder, WsMassCancelParams, WsPostOrderParams,
69 WsPostOrderParamsBuilder,
70 },
71 parse::{parse_book_msg_vec, parse_ws_message_data},
72};
73use crate::common::consts::should_retry_error_code;
74
75fn should_retry_okx_error(error: &OKXWsError) -> bool {
77 match error {
78 OKXWsError::OkxError { error_code, .. } => should_retry_error_code(error_code),
79 OKXWsError::TungsteniteError(_) => true, OKXWsError::ClientError(msg) => {
81 let msg_lower = msg.to_lowercase();
83 msg_lower.contains("timeout")
84 || msg_lower.contains("timed out")
85 || msg_lower.contains("connection")
86 || msg_lower.contains("network")
87 }
88 OKXWsError::JsonError(_) | OKXWsError::ParsingError(_) => false, }
90}
91
92fn create_okx_timeout_error(msg: String) -> OKXWsError {
94 OKXWsError::ClientError(msg)
95}
96use crate::{
97 common::{
98 consts::{
99 OKX_NAUTILUS_BROKER_ID, OKX_SUPPORTED_ORDER_TYPES, OKX_SUPPORTED_TIME_IN_FORCE,
100 OKX_WS_PUBLIC_URL,
101 },
102 credential::Credential,
103 enums::{OKXInstrumentType, OKXOrderType, OKXPositionSide, OKXSide, OKXTradeMode},
104 parse::{bar_spec_as_okx_channel, okx_instrument_type, parse_account_state},
105 },
106 http::models::OKXAccount,
107 websocket::{messages::OKXOrderMsg, parse::parse_order_msg_vec},
108};
109
110type PlaceRequestData = (ClientOrderId, TraderId, StrategyId, InstrumentId);
111type CancelRequestData = (
112 ClientOrderId,
113 TraderId,
114 StrategyId,
115 InstrumentId,
116 Option<VenueOrderId>,
117);
118type AmendRequestData = (
119 ClientOrderId,
120 TraderId,
121 StrategyId,
122 InstrumentId,
123 Option<VenueOrderId>,
124);
125type MassCancelRequestData = InstrumentId;
126
127pub static OKX_WS_QUOTA: LazyLock<Quota> =
135 LazyLock::new(|| Quota::per_second(NonZeroU32::new(3).unwrap()));
136
137pub static OKX_WS_ORDER_QUOTA: LazyLock<Quota> =
142 LazyLock::new(|| Quota::per_second(NonZeroU32::new(250).unwrap()));
143
144#[derive(Clone)]
146#[cfg_attr(
147 feature = "python",
148 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.adapters")
149)]
150pub struct OKXWebSocketClient {
151 url: String,
152 account_id: AccountId,
153 credential: Option<Credential>,
154 heartbeat: Option<u64>,
155 inner: Arc<tokio::sync::RwLock<Option<WebSocketClient>>>,
156 auth_state: Arc<tokio::sync::watch::Sender<bool>>,
157 auth_state_rx: tokio::sync::watch::Receiver<bool>,
158 rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>,
159 signal: Arc<AtomicBool>,
160 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
161 subscriptions_inst_type: Arc<DashMap<OKXWsChannel, AHashSet<OKXInstrumentType>>>,
162 subscriptions_inst_family: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
163 subscriptions_inst_id: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
164 subscriptions_bare: Arc<DashMap<OKXWsChannel, bool>>, request_id_counter: Arc<AtomicU64>,
166 pending_place_requests: Arc<DashMap<String, PlaceRequestData>>,
167 pending_cancel_requests: Arc<DashMap<String, CancelRequestData>>,
168 pending_amend_requests: Arc<DashMap<String, AmendRequestData>>,
169 pending_mass_cancel_requests: Arc<DashMap<String, MassCancelRequestData>>,
170 instruments_cache: Arc<AHashMap<Ustr, InstrumentAny>>,
171 retry_manager: Arc<RetryManager<OKXWsError>>,
172 cancellation_token: CancellationToken,
173}
174
175impl Default for OKXWebSocketClient {
176 fn default() -> Self {
177 Self::new(None, None, None, None, None, None).unwrap()
178 }
179}
180
181impl Debug for OKXWebSocketClient {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 f.debug_struct(stringify!(OKXWebSocketClient))
184 .field("url", &self.url)
185 .field(
186 "credential",
187 &self.credential.as_ref().map(|_| "<redacted>"),
188 )
189 .field("heartbeat", &self.heartbeat)
190 .finish_non_exhaustive()
191 }
192}
193
194impl OKXWebSocketClient {
195 pub fn new(
197 url: Option<String>,
198 api_key: Option<String>,
199 api_secret: Option<String>,
200 api_passphrase: Option<String>,
201 account_id: Option<AccountId>,
202 heartbeat: Option<u64>,
203 ) -> anyhow::Result<Self> {
204 let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
205 let account_id = account_id.unwrap_or(AccountId::from("OKX-master"));
206
207 let credential = match (api_key, api_secret, api_passphrase) {
208 (Some(key), Some(secret), Some(passphrase)) => {
209 Some(Credential::new(key, secret, passphrase))
210 }
211 (None, None, None) => None,
212 _ => anyhow::bail!(
213 "`api_key`, `api_secret`, `api_passphrase` credentials must be provided together"
214 ),
215 };
216
217 let signal = Arc::new(AtomicBool::new(false));
218 let subscriptions_inst_type = Arc::new(DashMap::new());
219 let subscriptions_inst_family = Arc::new(DashMap::new());
220 let subscriptions_inst_id = Arc::new(DashMap::new());
221 let subscriptions_bare = Arc::new(DashMap::new());
222 let (auth_tx, auth_rx) = tokio::sync::watch::channel(false);
223
224 Ok(Self {
225 url,
226 account_id,
227 credential,
228 heartbeat,
229 inner: Arc::new(tokio::sync::RwLock::new(None)),
230 auth_state: Arc::new(auth_tx),
231 auth_state_rx: auth_rx,
232 rx: None,
233 signal,
234 task_handle: None,
235 subscriptions_inst_type,
236 subscriptions_inst_family,
237 subscriptions_inst_id,
238 subscriptions_bare,
239 request_id_counter: Arc::new(AtomicU64::new(1)),
240 pending_place_requests: Arc::new(DashMap::new()),
241 pending_cancel_requests: Arc::new(DashMap::new()),
242 pending_amend_requests: Arc::new(DashMap::new()),
243 pending_mass_cancel_requests: Arc::new(DashMap::new()),
244 instruments_cache: Arc::new(AHashMap::new()),
245 retry_manager: Arc::new(create_websocket_retry_manager()?),
246 cancellation_token: CancellationToken::new(),
247 })
248 }
249
250 pub fn with_credentials(
252 url: Option<String>,
253 api_key: Option<String>,
254 api_secret: Option<String>,
255 api_passphrase: Option<String>,
256 account_id: Option<AccountId>,
257 heartbeat: Option<u64>,
258 ) -> anyhow::Result<Self> {
259 let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
260 let api_key = api_key.unwrap_or(get_env_var("OKX_API_KEY")?);
261 let api_secret = api_secret.unwrap_or(get_env_var("OKX_API_SECRET")?);
262 let api_passphrase = api_passphrase.unwrap_or(get_env_var("OKX_API_PASSPHRASE")?);
263
264 Self::new(
265 Some(url),
266 Some(api_key),
267 Some(api_secret),
268 Some(api_passphrase),
269 account_id,
270 heartbeat,
271 )
272 }
273
274 pub fn from_env() -> anyhow::Result<Self> {
276 let url = get_env_var("OKX_WS_URL")?;
277 let api_key = get_env_var("OKX_API_KEY")?;
278 let api_secret = get_env_var("OKX_API_SECRET")?;
279 let api_passphrase = get_env_var("OKX_API_PASSPHRASE")?;
280
281 Self::new(
282 Some(url),
283 Some(api_key),
284 Some(api_secret),
285 Some(api_passphrase),
286 None,
287 None,
288 )
289 }
290
291 pub fn cancel_all_requests(&self) {
293 self.cancellation_token.cancel();
294 }
295
296 pub fn cancellation_token(&self) -> &CancellationToken {
298 &self.cancellation_token
299 }
300
301 pub fn url(&self) -> &str {
303 self.url.as_str()
304 }
305
306 pub fn api_key(&self) -> Option<&str> {
308 self.credential.clone().map(|c| c.api_key.as_str())
309 }
310
311 pub fn is_active(&self) -> bool {
314 match self.inner.try_read() {
316 Ok(guard) => match &*guard {
317 Some(inner) => inner.is_active(),
318 None => false,
319 },
320 Err(_) => false, }
322 }
323
324 pub fn is_closed(&self) -> bool {
326 match self.inner.try_read() {
328 Ok(guard) => match &*guard {
329 Some(inner) => inner.is_closed(),
330 None => true,
331 },
332 Err(_) => true, }
334 }
335
336 pub fn initialize_instruments_cache(&mut self, instruments: Vec<InstrumentAny>) {
338 let mut instruments_cache: AHashMap<Ustr, InstrumentAny> = AHashMap::new();
339 for inst in instruments {
340 instruments_cache.insert(inst.symbol().inner(), inst.clone());
341 }
342
343 self.instruments_cache = Arc::new(instruments_cache)
344 }
345
346 pub async fn connect(&mut self) -> anyhow::Result<()> {
352 let (message_handler, reader) = channel_message_handler();
353
354 let config = WebSocketConfig {
355 url: self.url.clone(),
356 headers: vec![(USER_AGENT.to_string(), NAUTILUS_USER_AGENT.to_string())],
357 heartbeat: self.heartbeat,
358 heartbeat_msg: None,
359 message_handler: Some(message_handler),
360 ping_handler: None,
361 reconnect_timeout_ms: Some(5_000),
362 reconnect_delay_initial_ms: None, reconnect_delay_max_ms: None, reconnect_backoff_factor: None, reconnect_jitter_ms: None, };
367 let keyed_quotas = vec![
369 ("subscription".to_string(), *OKX_WS_QUOTA),
370 ("order".to_string(), *OKX_WS_ORDER_QUOTA),
371 ("cancel".to_string(), *OKX_WS_ORDER_QUOTA),
372 ("amend".to_string(), *OKX_WS_ORDER_QUOTA),
373 ];
374
375 let client = WebSocketClient::connect(
376 config,
377 None, keyed_quotas,
379 Some(*OKX_WS_QUOTA), )
381 .await?;
382
383 {
385 let mut inner_guard = self.inner.write().await;
386 *inner_guard = Some(client);
387 }
388
389 let account_id = self.account_id;
390 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<NautilusWsMessage>();
391
392 self.rx = Some(Arc::new(rx));
393 let signal = self.signal.clone();
394 let pending_place_requests = self.pending_place_requests.clone();
395 let pending_cancel_requests = self.pending_cancel_requests.clone();
396 let pending_amend_requests = self.pending_amend_requests.clone();
397 let pending_mass_cancel_requests = self.pending_mass_cancel_requests.clone();
398 let auth_state = self.auth_state.clone();
399
400 let instruments_cache = self.instruments_cache.clone();
401 let inner_client = self.inner.clone();
402 let credential_clone = self.credential.clone();
403 let subscriptions_inst_type = self.subscriptions_inst_type.clone();
404 let subscriptions_inst_family = self.subscriptions_inst_family.clone();
405 let subscriptions_inst_id = self.subscriptions_inst_id.clone();
406 let subscriptions_bare = self.subscriptions_bare.clone();
407 let auth_state_clone = auth_state.clone();
408 let stream_handle = get_runtime().spawn(async move {
409 let mut handler = OKXWsMessageHandler::new(
410 account_id,
411 instruments_cache,
412 reader,
413 signal,
414 tx,
415 pending_place_requests,
416 pending_cancel_requests,
417 pending_amend_requests,
418 pending_mass_cancel_requests,
419 auth_state,
420 );
421
422 loop {
424 match handler.next().await {
425 Some(NautilusWsMessage::Reconnected) => {
426 tracing::info!("Handling WebSocket reconnection");
427
428 let inner_guard = inner_client.read().await;
430 if let Some(cred) = &credential_clone
431 && let Some(client) = &*inner_guard {
432 let timestamp = SystemTime::now()
433 .duration_since(SystemTime::UNIX_EPOCH)
434 .expect("System time should be after UNIX epoch")
435 .as_secs()
436 .to_string();
437 let signature = cred.sign(×tamp, "GET", "/users/self/verify", "");
438
439 let auth_message = OKXAuthentication {
440 op: "login",
441 args: vec![OKXAuthenticationArg {
442 api_key: cred.api_key.to_string(),
443 passphrase: cred.api_passphrase.clone(),
444 timestamp,
445 sign: signature,
446 }],
447 };
448
449 if let Err(e) = client.send_text(serde_json::to_string(&auth_message).unwrap(), None).await {
450 tracing::error!("Failed to send re-authentication request: {e}");
451 } else {
453 tracing::info!("Sent re-authentication request, waiting for response before resubscribing");
454
455 let mut auth_rx = auth_state_clone.subscribe();
457 match tokio::time::timeout(Duration::from_secs(5), auth_rx.wait_for(|&auth| auth)).await {
458 Ok(Ok(_)) => {
459 tracing::info!("Authentication successful after reconnect, proceeding with resubscription");
460 }
463 Ok(Err(e)) => {
464 tracing::error!("Auth watch channel error after reconnect: {e}");
465 }
467 Err(_) => {
468 tracing::error!("Timeout waiting for authentication after reconnect");
469 }
471 }
472 }
473 }
474
475 let inner_guard = inner_client.read().await;
478 if let Some(client) = &*inner_guard {
479 let mut inst_type_args = Vec::new();
481 for entry in subscriptions_inst_type.iter() {
482 let (channel, inst_types) = entry.pair();
483 for inst_type in inst_types.iter() {
484 inst_type_args.push(OKXSubscriptionArg {
485 channel: channel.clone(),
486 inst_type: Some(*inst_type),
487 inst_family: None,
488 inst_id: None,
489 });
490 }
491 }
492 if !inst_type_args.is_empty() {
493 let sub_request = OKXSubscription {
494 op: OKXWsOperation::Subscribe,
495 args: inst_type_args,
496 };
497 if let Err(e) = client.send_text(serde_json::to_string(&sub_request).unwrap(), None).await {
498 tracing::error!("Failed to re-subscribe inst_type channels: {e}");
499 }
500 }
501
502 let mut inst_family_args = Vec::new();
504 for entry in subscriptions_inst_family.iter() {
505 let (channel, inst_families) = entry.pair();
506 for inst_family in inst_families.iter() {
507 inst_family_args.push(OKXSubscriptionArg {
508 channel: channel.clone(),
509 inst_type: None,
510 inst_family: Some(*inst_family),
511 inst_id: None,
512 });
513 }
514 }
515 if !inst_family_args.is_empty() {
516 let sub_request = OKXSubscription {
517 op: OKXWsOperation::Subscribe,
518 args: inst_family_args,
519 };
520 if let Err(e) = client.send_text(serde_json::to_string(&sub_request).unwrap(), None).await {
521 tracing::error!("Failed to re-subscribe inst_family channels: {e}");
522 }
523 }
524
525 let mut inst_id_args = Vec::new();
527 for entry in subscriptions_inst_id.iter() {
528 let (channel, inst_ids) = entry.pair();
529 for inst_id in inst_ids.iter() {
530 inst_id_args.push(OKXSubscriptionArg {
531 channel: channel.clone(),
532 inst_type: None,
533 inst_family: None,
534 inst_id: Some(*inst_id),
535 });
536 }
537 }
538 if !inst_id_args.is_empty() {
539 let sub_request = OKXSubscription {
540 op: OKXWsOperation::Subscribe,
541 args: inst_id_args,
542 };
543 if let Err(e) = client.send_text(serde_json::to_string(&sub_request).unwrap(), None).await {
544 tracing::error!("Failed to re-subscribe inst_id channels: {e}");
545 }
546 }
547
548 let mut bare_args = Vec::new();
550 for entry in subscriptions_bare.iter() {
551 let channel = entry.key();
552 bare_args.push(OKXSubscriptionArg {
553 channel: channel.clone(),
554 inst_type: None,
555 inst_family: None,
556 inst_id: None,
557 });
558 }
559 if !bare_args.is_empty() {
560 let sub_request = OKXSubscription {
561 op: OKXWsOperation::Subscribe,
562 args: bare_args,
563 };
564 if let Err(e) = client.send_text(serde_json::to_string(&sub_request).unwrap(), None).await {
565 tracing::error!("Failed to re-subscribe bare channels: {e}");
566 }
567 }
568
569 tracing::info!("Completed re-subscription after reconnect");
570 }
571 }
572 Some(msg) => {
573 if handler.tx.send(msg).is_err() {
575 tracing::error!("Failed to send message through channel: receiver dropped");
576 break;
577 }
578
579 }
580 None => {
581 if handler.is_stopped() {
583 tracing::debug!("Stop signal received, ending message processing");
584 break;
585 }
586 tracing::warn!("WebSocket stream ended unexpectedly");
588 break;
589 }
590 }
591 }
592 });
593
594 self.task_handle = Some(Arc::new(stream_handle));
595
596 if self.credential.is_some() {
597 if self.auth_state.send(false).is_err() {
598 tracing::error!("Failed to reset auth state, receiver dropped.");
599 };
600 self.authenticate().await?;
601 }
602
603 Ok(())
604 }
605
606 async fn authenticate(&self) -> Result<(), Error> {
608 let credential = match &self.credential {
609 Some(credential) => credential,
610 None => {
611 panic!("API credentials not available to authenticate");
612 }
613 };
614
615 let timestamp = SystemTime::now()
616 .duration_since(SystemTime::UNIX_EPOCH)
617 .expect("System time should be after UNIX epoch")
618 .as_secs()
619 .to_string();
620 let signature = credential.sign(×tamp, "GET", "/users/self/verify", "");
621
622 let auth_message = OKXAuthentication {
623 op: "login",
624 args: vec![OKXAuthenticationArg {
625 api_key: credential.api_key.to_string(),
626 passphrase: credential.api_passphrase.clone(),
627 timestamp,
628 sign: signature,
629 }],
630 };
631
632 {
633 let inner_guard = self.inner.read().await;
634 if let Some(inner) = &*inner_guard {
635 if let Err(e) = inner
636 .send_text(serde_json::to_string(&auth_message).unwrap(), None)
637 .await
638 {
639 tracing::error!("Error sending auth message: {e:?}");
640 return Err(Error::Io(std::io::Error::other(e.to_string())));
641 }
642 } else {
643 log::error!("Cannot authenticate: not connected");
644 return Err(Error::ConnectionClosed);
645 }
646 }
647
648 let mut rx = self.auth_state_rx.clone();
650 match tokio::time::timeout(Duration::from_secs(10), rx.wait_for(|&auth| auth)).await {
651 Ok(Ok(_)) => {
652 tracing::info!("Authentication confirmed by client");
653 Ok(())
654 }
655 Ok(Err(e)) => {
656 tracing::error!("Authentication watch channel closed unexpectedly: {e}");
657 Err(Error::Io(std::io::Error::other(
658 "Authentication watch channel closed",
659 )))
660 }
661 Err(_) => {
662 tracing::error!("Timeout waiting for authentication response");
663 Err(Error::Io(std::io::Error::other(
664 "Timeout waiting for authentication",
665 )))
666 }
667 }
668 }
669
670 pub fn stream(&mut self) -> impl Stream<Item = NautilusWsMessage> + 'static {
678 let rx = self
679 .rx
680 .take()
681 .expect("Data stream receiver already taken or not connected");
682 let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
683 async_stream::stream! {
684 while let Some(data) = rx.recv().await {
685 yield data;
686 }
687 }
688 }
689
690 pub async fn wait_until_active(&self, timeout_secs: f64) -> Result<(), OKXWsError> {
696 let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
697
698 tokio::time::timeout(timeout, async {
699 while !self.is_active() {
700 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
701 }
702 })
703 .await
704 .map_err(|_| {
705 OKXWsError::ClientError(format!(
706 "WebSocket connection timeout after {timeout_secs} seconds"
707 ))
708 })?;
709
710 Ok(())
711 }
712
713 pub async fn close(&mut self) -> Result<(), Error> {
715 log::debug!("Starting close process");
716
717 self.signal.store(true, Ordering::Relaxed);
718
719 {
720 let inner_guard = self.inner.read().await;
721 if let Some(inner) = &*inner_guard {
722 log::debug!("Disconnecting websocket");
723
724 match tokio::time::timeout(Duration::from_secs(3), inner.disconnect()).await {
725 Ok(()) => log::debug!("Websocket disconnected successfully"),
726 Err(_) => {
727 log::warn!(
728 "Timeout waiting for websocket disconnect, continuing with cleanup"
729 )
730 }
731 }
732 } else {
733 log::debug!("No active connection to disconnect");
734 }
735 }
736
737 if let Some(stream_handle) = self.task_handle.take() {
739 match Arc::try_unwrap(stream_handle) {
740 Ok(handle) => {
741 log::debug!("Waiting for stream handle to complete");
742 match tokio::time::timeout(Duration::from_secs(2), handle).await {
743 Ok(Ok(())) => log::debug!("Stream handle completed successfully"),
744 Ok(Err(e)) => log::error!("Stream handle encountered an error: {e:?}"),
745 Err(_) => {
746 log::warn!(
747 "Timeout waiting for stream handle, task may still be running"
748 );
749 }
751 }
752 }
753 Err(arc_handle) => {
754 log::debug!(
755 "Cannot take ownership of stream handle - other references exist, aborting task"
756 );
757 arc_handle.abort();
758 }
759 }
760 } else {
761 log::debug!("No stream handle to await");
762 }
763
764 log::debug!("Close process completed");
765
766 Ok(())
767 }
768
769 pub fn get_subscriptions(&self, instrument_id: InstrumentId) -> Vec<OKXWsChannel> {
771 let symbol = instrument_id.symbol.inner();
772 let mut channels = Vec::new();
773
774 for entry in self.subscriptions_inst_id.iter() {
775 let (channel, instruments) = entry.pair();
776 if instruments.contains(&symbol) {
777 channels.push(channel.clone());
778 }
779 }
780
781 channels
782 }
783
784 fn generate_unique_request_id(&self) -> String {
785 self.request_id_counter
786 .fetch_add(1, Ordering::SeqCst)
787 .to_string()
788 }
789
790 #[allow(
791 clippy::result_large_err,
792 reason = "OKXWsError contains large tungstenite::Error variant"
793 )]
794 fn get_instrument_type_and_family(
795 &self,
796 symbol: Ustr,
797 ) -> Result<(OKXInstrumentType, String), OKXWsError> {
798 let instrument = self.instruments_cache.get(&symbol).ok_or_else(|| {
800 OKXWsError::ClientError(format!("Instrument not found in cache: {symbol}"))
801 })?;
802
803 let inst_type =
804 okx_instrument_type(instrument).map_err(|e| OKXWsError::ClientError(e.to_string()))?;
805
806 let inst_family = match instrument {
808 InstrumentAny::CurrencyPair(_) => symbol.as_str().to_string(),
809 InstrumentAny::CryptoPerpetual(_) => {
810 symbol
812 .as_str()
813 .strip_suffix("-SWAP")
814 .unwrap_or(symbol.as_str())
815 .to_string()
816 }
817 InstrumentAny::CryptoFuture(_) => {
818 let parts: Vec<&str> = symbol.as_str().split('-').collect();
820 if parts.len() >= 2 {
821 format!("{}-{}", parts[0], parts[1])
822 } else {
823 return Err(OKXWsError::ClientError(format!(
824 "Unable to parse futures instrument family from symbol: {symbol}",
825 )));
826 }
827 }
828 _ => {
829 return Err(OKXWsError::ClientError(format!(
830 "Unsupported instrument type for mass cancel: {instrument:?}",
831 )));
832 }
833 };
834
835 Ok((inst_type, inst_family))
836 }
837
838 async fn subscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
839 for arg in &args {
840 if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
842 self.subscriptions_bare.insert(arg.channel.clone(), true);
844 } else {
845 if let Some(inst_type) = &arg.inst_type {
847 self.subscriptions_inst_type
848 .entry(arg.channel.clone())
849 .or_default()
850 .insert(*inst_type);
851 }
852
853 if let Some(inst_family) = &arg.inst_family {
855 self.subscriptions_inst_family
856 .entry(arg.channel.clone())
857 .or_default()
858 .insert(*inst_family);
859 }
860
861 if let Some(inst_id) = &arg.inst_id {
863 self.subscriptions_inst_id
864 .entry(arg.channel.clone())
865 .or_default()
866 .insert(*inst_id);
867 }
868 }
869 }
870
871 let message = OKXSubscription {
872 op: OKXWsOperation::Subscribe,
873 args,
874 };
875
876 let json_txt =
877 serde_json::to_string(&message).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
878
879 {
880 let inner_guard = self.inner.read().await;
881 if let Some(inner) = &*inner_guard {
882 if let Err(e) = inner
883 .send_text(json_txt, Some(vec!["subscription".to_string()]))
884 .await
885 {
886 tracing::error!("Error sending message: {e:?}")
887 }
888 } else {
889 return Err(OKXWsError::ClientError(
890 "Cannot send message: not connected".to_string(),
891 ));
892 }
893 }
894
895 Ok(())
896 }
897
898 #[allow(clippy::collapsible_if)] async fn unsubscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
900 for arg in &args {
901 if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
903 self.subscriptions_bare.remove(&arg.channel);
905 } else {
906 if let Some(inst_type) = &arg.inst_type {
908 if let Some(mut entry) = self.subscriptions_inst_type.get_mut(&arg.channel) {
909 entry.remove(inst_type);
910 if entry.is_empty() {
911 drop(entry);
912 self.subscriptions_inst_type.remove(&arg.channel);
913 }
914 }
915 }
916
917 if let Some(inst_family) = &arg.inst_family {
919 if let Some(mut entry) = self.subscriptions_inst_family.get_mut(&arg.channel) {
920 entry.remove(inst_family);
921 if entry.is_empty() {
922 drop(entry);
923 self.subscriptions_inst_family.remove(&arg.channel);
924 }
925 }
926 }
927
928 if let Some(inst_id) = &arg.inst_id {
930 if let Some(mut entry) = self.subscriptions_inst_id.get_mut(&arg.channel) {
931 entry.remove(inst_id);
932 if entry.is_empty() {
933 drop(entry);
934 self.subscriptions_inst_id.remove(&arg.channel);
935 }
936 }
937 }
938 }
939 }
940
941 let message = OKXSubscription {
942 op: OKXWsOperation::Unsubscribe,
943 args,
944 };
945
946 let json_txt = serde_json::to_string(&message).expect("Must be valid JSON");
947
948 {
949 let inner_guard = self.inner.read().await;
950 if let Some(inner) = &*inner_guard {
951 if let Err(e) = inner
952 .send_text(json_txt, Some(vec!["subscription".to_string()]))
953 .await
954 {
955 tracing::error!("Error sending message: {e:?}")
956 }
957 } else {
958 log::error!("Cannot send message: not connected");
959 }
960 }
961
962 Ok(())
963 }
964
965 #[allow(dead_code)]
966 async fn resubscribe_all(&self) {
967 let mut subs_bare = Vec::new();
969 for entry in self.subscriptions_bare.iter() {
970 let channel = entry.key();
971 subs_bare.push(channel.clone());
972 }
973
974 let mut subs_inst_type = Vec::new();
975 for entry in self.subscriptions_inst_type.iter() {
976 let (channel, inst_types) = entry.pair();
977 if !inst_types.is_empty() {
978 subs_inst_type.push((channel.clone(), inst_types.clone()));
979 }
980 }
981
982 let mut subs_inst_family = Vec::new();
983 for entry in self.subscriptions_inst_family.iter() {
984 let (channel, inst_families) = entry.pair();
985 if !inst_families.is_empty() {
986 subs_inst_family.push((channel.clone(), inst_families.clone()));
987 }
988 }
989
990 let mut subs_inst_id = Vec::new();
991 for entry in self.subscriptions_inst_id.iter() {
992 let (channel, inst_ids) = entry.pair();
993 if !inst_ids.is_empty() {
994 subs_inst_id.push((channel.clone(), inst_ids.clone()));
995 }
996 }
997
998 for (channel, inst_types) in subs_inst_type {
1000 if inst_types.is_empty() {
1001 continue;
1002 }
1003
1004 tracing::debug!("Resubscribing: channel={channel}, instrument_types={inst_types:?}");
1005
1006 for inst_type in inst_types {
1007 let arg = OKXSubscriptionArg {
1008 channel: channel.clone(),
1009 inst_type: Some(inst_type),
1010 inst_family: None,
1011 inst_id: None,
1012 };
1013
1014 if let Err(e) = self.subscribe(vec![arg]).await {
1015 tracing::error!(
1016 "Failed to resubscribe to channel {channel} with instrument type: {e}"
1017 );
1018 }
1019 }
1020 }
1021
1022 for (channel, inst_families) in subs_inst_family {
1024 if inst_families.is_empty() {
1025 continue;
1026 }
1027
1028 tracing::debug!(
1029 "Resubscribing: channel={channel}, instrument_families={inst_families:?}"
1030 );
1031
1032 for inst_family in inst_families {
1033 let arg = OKXSubscriptionArg {
1034 channel: channel.clone(),
1035 inst_type: None,
1036 inst_family: Some(inst_family),
1037 inst_id: None,
1038 };
1039
1040 if let Err(e) = self.subscribe(vec![arg]).await {
1041 tracing::error!(
1042 "Failed to resubscribe to channel {channel} with instrument family: {e}"
1043 );
1044 }
1045 }
1046 }
1047
1048 for (channel, inst_ids) in subs_inst_id {
1050 if inst_ids.is_empty() {
1051 continue;
1052 }
1053
1054 tracing::debug!("Resubscribing: channel={channel}, instrument_ids={inst_ids:?}");
1055
1056 for inst_id in inst_ids {
1057 let arg = OKXSubscriptionArg {
1058 channel: channel.clone(),
1059 inst_type: None,
1060 inst_family: None,
1061 inst_id: Some(inst_id),
1062 };
1063
1064 if let Err(e) = self.subscribe(vec![arg]).await {
1065 tracing::error!(
1066 "Failed to resubscribe to channel {channel} with instrument ID: {e}"
1067 );
1068 }
1069 }
1070 }
1071
1072 for channel in subs_bare {
1074 tracing::debug!("Resubscribing to bare channel: {channel}");
1075
1076 let arg = OKXSubscriptionArg {
1077 channel,
1078 inst_type: None,
1079 inst_family: None,
1080 inst_id: None,
1081 };
1082
1083 if let Err(e) = self.subscribe(vec![arg]).await {
1084 tracing::error!("Failed to resubscribe to bare channel: {e}");
1085 }
1086 }
1087 }
1088
1089 pub async fn subscribe_instruments(
1097 &self,
1098 instrument_type: OKXInstrumentType,
1099 ) -> Result<(), OKXWsError> {
1100 let arg = OKXSubscriptionArg {
1101 channel: OKXWsChannel::Instruments,
1102 inst_type: Some(instrument_type),
1103 inst_family: None,
1104 inst_id: None,
1105 };
1106 self.subscribe(vec![arg]).await
1107 }
1108
1109 pub async fn subscribe_instrument(
1117 &self,
1118 instrument_id: InstrumentId,
1119 ) -> Result<(), OKXWsError> {
1120 let arg = OKXSubscriptionArg {
1121 channel: OKXWsChannel::Instruments,
1122 inst_type: None,
1123 inst_family: None,
1124 inst_id: Some(instrument_id.symbol.inner()),
1125 };
1126 self.subscribe(vec![arg]).await
1127 }
1128
1129 pub async fn subscribe_book(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1135 let arg = OKXSubscriptionArg {
1136 channel: OKXWsChannel::Books,
1137 inst_type: None,
1138 inst_family: None,
1139 inst_id: Some(instrument_id.symbol.inner()),
1140 };
1141 self.subscribe(vec![arg]).await
1142 }
1143
1144 pub async fn subscribe_book_depth5(
1152 &self,
1153 instrument_id: InstrumentId,
1154 ) -> Result<(), OKXWsError> {
1155 let arg = OKXSubscriptionArg {
1156 channel: OKXWsChannel::Books5,
1157 inst_type: None,
1158 inst_family: None,
1159 inst_id: Some(instrument_id.symbol.inner()),
1160 };
1161 self.subscribe(vec![arg]).await
1162 }
1163
1164 pub async fn subscribe_books50_l2_tbt(
1172 &self,
1173 instrument_id: InstrumentId,
1174 ) -> Result<(), OKXWsError> {
1175 let arg = OKXSubscriptionArg {
1176 channel: OKXWsChannel::Books50Tbt,
1177 inst_type: None,
1178 inst_family: None,
1179 inst_id: Some(instrument_id.symbol.inner()),
1180 };
1181 self.subscribe(vec![arg]).await
1182 }
1183
1184 pub async fn subscribe_book_l2_tbt(
1192 &self,
1193 instrument_id: InstrumentId,
1194 ) -> Result<(), OKXWsError> {
1195 let arg = OKXSubscriptionArg {
1196 channel: OKXWsChannel::BooksTbt,
1197 inst_type: None,
1198 inst_family: None,
1199 inst_id: Some(instrument_id.symbol.inner()),
1200 };
1201 self.subscribe(vec![arg]).await
1202 }
1203
1204 pub async fn subscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1212 let arg = OKXSubscriptionArg {
1214 channel: OKXWsChannel::BboTbt,
1215 inst_type: None,
1216 inst_family: None,
1217 inst_id: Some(instrument_id.symbol.inner()),
1218 };
1219 self.subscribe(vec![arg]).await
1220 }
1221
1222 pub async fn subscribe_trades(
1228 &self,
1229 instrument_id: InstrumentId,
1230 _aggregated: bool, ) -> Result<(), OKXWsError> {
1232 let channel = OKXWsChannel::Trades;
1237
1238 let arg = OKXSubscriptionArg {
1239 channel,
1240 inst_type: None,
1241 inst_family: None,
1242 inst_id: Some(instrument_id.symbol.inner()),
1243 };
1244 self.subscribe(vec![arg]).await
1245 }
1246
1247 pub async fn subscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1255 let arg = OKXSubscriptionArg {
1256 channel: OKXWsChannel::Tickers,
1257 inst_type: None,
1258 inst_family: None,
1259 inst_id: Some(instrument_id.symbol.inner()),
1260 };
1261 self.subscribe(vec![arg]).await
1262 }
1263
1264 pub async fn subscribe_mark_prices(
1272 &self,
1273 instrument_id: InstrumentId,
1274 ) -> Result<(), OKXWsError> {
1275 let arg = OKXSubscriptionArg {
1276 channel: OKXWsChannel::MarkPrice,
1277 inst_type: None,
1278 inst_family: None,
1279 inst_id: Some(instrument_id.symbol.inner()),
1280 };
1281 self.subscribe(vec![arg]).await
1282 }
1283
1284 pub async fn subscribe_index_prices(
1292 &self,
1293 instrument_id: InstrumentId,
1294 ) -> Result<(), OKXWsError> {
1295 let arg = OKXSubscriptionArg {
1296 channel: OKXWsChannel::IndexTickers,
1297 inst_type: None,
1298 inst_family: None,
1299 inst_id: Some(instrument_id.symbol.inner()),
1300 };
1301 self.subscribe(vec![arg]).await
1302 }
1303
1304 pub async fn subscribe_funding_rates(
1312 &self,
1313 instrument_id: InstrumentId,
1314 ) -> Result<(), OKXWsError> {
1315 let arg = OKXSubscriptionArg {
1316 channel: OKXWsChannel::FundingRate,
1317 inst_type: None,
1318 inst_family: None,
1319 inst_id: Some(instrument_id.symbol.inner()),
1320 };
1321 self.subscribe(vec![arg]).await
1322 }
1323
1324 pub async fn subscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1332 let channel = bar_spec_as_okx_channel(bar_type.spec())
1334 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1335
1336 let arg = OKXSubscriptionArg {
1337 channel,
1338 inst_type: None,
1339 inst_family: None,
1340 inst_id: Some(bar_type.instrument_id().symbol.inner()),
1341 };
1342 self.subscribe(vec![arg]).await
1343 }
1344
1345 pub async fn unsubscribe_instruments(
1347 &self,
1348 instrument_type: OKXInstrumentType,
1349 ) -> Result<(), OKXWsError> {
1350 let arg = OKXSubscriptionArg {
1351 channel: OKXWsChannel::Instruments,
1352 inst_type: Some(instrument_type),
1353 inst_family: None,
1354 inst_id: None,
1355 };
1356 self.unsubscribe(vec![arg]).await
1357 }
1358
1359 pub async fn unsubscribe_instrument(
1361 &self,
1362 instrument_id: InstrumentId,
1363 ) -> Result<(), OKXWsError> {
1364 let arg = OKXSubscriptionArg {
1365 channel: OKXWsChannel::Instruments,
1366 inst_type: None,
1367 inst_family: None,
1368 inst_id: Some(instrument_id.symbol.inner()),
1369 };
1370 self.unsubscribe(vec![arg]).await
1371 }
1372
1373 pub async fn unsubscribe_book(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1375 let arg = OKXSubscriptionArg {
1376 channel: OKXWsChannel::Books,
1377 inst_type: None,
1378 inst_family: None,
1379 inst_id: Some(instrument_id.symbol.inner()),
1380 };
1381 self.unsubscribe(vec![arg]).await
1382 }
1383
1384 pub async fn unsubscribe_book_depth5(
1386 &self,
1387 instrument_id: InstrumentId,
1388 ) -> Result<(), OKXWsError> {
1389 let arg = OKXSubscriptionArg {
1390 channel: OKXWsChannel::Books5,
1391 inst_type: None,
1392 inst_family: None,
1393 inst_id: Some(instrument_id.symbol.inner()),
1394 };
1395 self.unsubscribe(vec![arg]).await
1396 }
1397
1398 pub async fn unsubscribe_book50_l2_tbt(
1400 &self,
1401 instrument_id: InstrumentId,
1402 ) -> Result<(), OKXWsError> {
1403 let arg = OKXSubscriptionArg {
1404 channel: OKXWsChannel::Books50Tbt,
1405 inst_type: None,
1406 inst_family: None,
1407 inst_id: Some(instrument_id.symbol.inner()),
1408 };
1409 self.unsubscribe(vec![arg]).await
1410 }
1411
1412 pub async fn unsubscribe_book_l2_tbt(
1414 &self,
1415 instrument_id: InstrumentId,
1416 ) -> Result<(), OKXWsError> {
1417 let arg = OKXSubscriptionArg {
1418 channel: OKXWsChannel::BooksTbt,
1419 inst_type: None,
1420 inst_family: None,
1421 inst_id: Some(instrument_id.symbol.inner()),
1422 };
1423 self.unsubscribe(vec![arg]).await
1424 }
1425
1426 pub async fn unsubscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1428 let arg = OKXSubscriptionArg {
1429 channel: OKXWsChannel::BboTbt,
1430 inst_type: None,
1431 inst_family: None,
1432 inst_id: Some(instrument_id.symbol.inner()),
1433 };
1434 self.unsubscribe(vec![arg]).await
1435 }
1436
1437 pub async fn unsubscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1439 let arg = OKXSubscriptionArg {
1440 channel: OKXWsChannel::Tickers,
1441 inst_type: None,
1442 inst_family: None,
1443 inst_id: Some(instrument_id.symbol.inner()),
1444 };
1445 self.unsubscribe(vec![arg]).await
1446 }
1447
1448 pub async fn unsubscribe_mark_prices(
1450 &self,
1451 instrument_id: InstrumentId,
1452 ) -> Result<(), OKXWsError> {
1453 let arg = OKXSubscriptionArg {
1454 channel: OKXWsChannel::MarkPrice,
1455 inst_type: None,
1456 inst_family: None,
1457 inst_id: Some(instrument_id.symbol.inner()),
1458 };
1459 self.unsubscribe(vec![arg]).await
1460 }
1461
1462 pub async fn unsubscribe_index_prices(
1464 &self,
1465 instrument_id: InstrumentId,
1466 ) -> Result<(), OKXWsError> {
1467 let arg = OKXSubscriptionArg {
1468 channel: OKXWsChannel::IndexTickers,
1469 inst_type: None,
1470 inst_family: None,
1471 inst_id: Some(instrument_id.symbol.inner()),
1472 };
1473 self.unsubscribe(vec![arg]).await
1474 }
1475
1476 pub async fn unsubscribe_funding_rates(
1478 &self,
1479 instrument_id: InstrumentId,
1480 ) -> Result<(), OKXWsError> {
1481 let arg = OKXSubscriptionArg {
1482 channel: OKXWsChannel::FundingRate,
1483 inst_type: None,
1484 inst_family: None,
1485 inst_id: Some(instrument_id.symbol.inner()),
1486 };
1487 self.unsubscribe(vec![arg]).await
1488 }
1489
1490 pub async fn unsubscribe_trades(
1492 &self,
1493 instrument_id: InstrumentId,
1494 _aggregated: bool,
1495 ) -> Result<(), OKXWsError> {
1496 let channel = OKXWsChannel::Trades;
1498
1499 let arg = OKXSubscriptionArg {
1500 channel,
1501 inst_type: None,
1502 inst_family: None,
1503 inst_id: Some(instrument_id.symbol.inner()),
1504 };
1505 self.unsubscribe(vec![arg]).await
1506 }
1507
1508 pub async fn unsubscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1510 let channel = bar_spec_as_okx_channel(bar_type.spec())
1512 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1513
1514 let arg = OKXSubscriptionArg {
1515 channel,
1516 inst_type: None,
1517 inst_family: None,
1518 inst_id: Some(bar_type.instrument_id().symbol.inner()),
1519 };
1520 self.unsubscribe(vec![arg]).await
1521 }
1522
1523 pub async fn subscribe_orders(
1525 &self,
1526 instrument_type: OKXInstrumentType,
1527 ) -> Result<(), OKXWsError> {
1528 let arg = OKXSubscriptionArg {
1529 channel: OKXWsChannel::Orders,
1530 inst_type: Some(instrument_type),
1531 inst_family: None,
1532 inst_id: None,
1533 };
1534 self.subscribe(vec![arg]).await
1535 }
1536
1537 pub async fn unsubscribe_orders(
1539 &self,
1540 instrument_type: OKXInstrumentType,
1541 ) -> Result<(), OKXWsError> {
1542 let arg = OKXSubscriptionArg {
1543 channel: OKXWsChannel::Orders,
1544 inst_type: Some(instrument_type),
1545 inst_family: None,
1546 inst_id: None,
1547 };
1548 self.unsubscribe(vec![arg]).await
1549 }
1550
1551 pub async fn subscribe_fills(
1553 &self,
1554 instrument_type: OKXInstrumentType,
1555 ) -> Result<(), OKXWsError> {
1556 let arg = OKXSubscriptionArg {
1557 channel: OKXWsChannel::Fills,
1558 inst_type: Some(instrument_type),
1559 inst_family: None,
1560 inst_id: None,
1561 };
1562 self.subscribe(vec![arg]).await
1563 }
1564
1565 pub async fn unsubscribe_fills(
1567 &self,
1568 instrument_type: OKXInstrumentType,
1569 ) -> Result<(), OKXWsError> {
1570 let arg = OKXSubscriptionArg {
1571 channel: OKXWsChannel::Fills,
1572 inst_type: Some(instrument_type),
1573 inst_family: None,
1574 inst_id: None,
1575 };
1576 self.unsubscribe(vec![arg]).await
1577 }
1578
1579 pub async fn subscribe_account(&self) -> Result<(), OKXWsError> {
1581 let arg = OKXSubscriptionArg {
1582 channel: OKXWsChannel::Account,
1583 inst_type: None,
1584 inst_family: None,
1585 inst_id: None,
1586 };
1587 self.subscribe(vec![arg]).await
1588 }
1589
1590 pub async fn unsubscribe_account(&self) -> Result<(), OKXWsError> {
1592 let arg = OKXSubscriptionArg {
1593 channel: OKXWsChannel::Account,
1594 inst_type: None,
1595 inst_family: None,
1596 inst_id: None,
1597 };
1598 self.unsubscribe(vec![arg]).await
1599 }
1600
1601 async fn ws_cancel_order(
1607 &self,
1608 params: WsCancelOrderParams,
1609 request_id: Option<String>,
1610 ) -> Result<(), OKXWsError> {
1611 let request_id = request_id.unwrap_or(self.generate_unique_request_id());
1612
1613 let req = OKXWsRequest {
1614 id: Some(request_id),
1615 op: OKXWsOperation::CancelOrder,
1616 args: vec![params],
1617 exp_time: None,
1618 };
1619
1620 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1621
1622 {
1623 let inner_guard = self.inner.read().await;
1624 if let Some(inner) = &*inner_guard {
1625 if let Err(e) = inner.send_text(txt, Some(vec!["cancel".to_string()])).await {
1626 tracing::error!("Error sending message: {e:?}");
1627 }
1628 Ok(())
1629 } else {
1630 Err(OKXWsError::ClientError("Not connected".to_string()))
1631 }
1632 }
1633 }
1634
1635 async fn ws_mass_cancel_with_id(
1641 &self,
1642 args: Vec<Value>,
1643 request_id: String,
1644 ) -> Result<(), OKXWsError> {
1645 let req = OKXWsRequest {
1646 id: Some(request_id),
1647 op: OKXWsOperation::MassCancel,
1648 args,
1649 exp_time: None,
1650 };
1651
1652 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1653
1654 {
1655 let inner_guard = self.inner.read().await;
1656 if let Some(inner) = &*inner_guard {
1657 if let Err(e) = inner.send_text(txt, Some(vec!["cancel".to_string()])).await {
1658 tracing::error!("Error sending message: {e:?}");
1659 }
1660 Ok(())
1661 } else {
1662 Err(OKXWsError::ClientError("Not connected".to_string()))
1663 }
1664 }
1665 }
1666
1667 async fn ws_amend_order(
1673 &self,
1674 params: WsAmendOrderParams,
1675 request_id: Option<String>,
1676 ) -> Result<(), OKXWsError> {
1677 let request_id = request_id.unwrap_or(self.generate_unique_request_id());
1678
1679 let req = OKXWsRequest {
1680 id: Some(request_id),
1681 op: OKXWsOperation::AmendOrder,
1682 args: vec![params],
1683 exp_time: None,
1684 };
1685
1686 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1687
1688 {
1689 let inner_guard = self.inner.read().await;
1690 if let Some(inner) = &*inner_guard {
1691 if let Err(e) = inner.send_text(txt, Some(vec!["amend".to_string()])).await {
1692 tracing::error!("Error sending message: {e:?}");
1693 }
1694 Ok(())
1695 } else {
1696 Err(OKXWsError::ClientError("Not connected".to_string()))
1697 }
1698 }
1699 }
1700
1701 async fn ws_batch_place_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1707 let request_id = self.generate_unique_request_id();
1708
1709 let req = OKXWsRequest {
1710 id: Some(request_id),
1711 op: OKXWsOperation::BatchOrders,
1712 args,
1713 exp_time: None,
1714 };
1715
1716 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1717
1718 {
1719 let inner_guard = self.inner.read().await;
1720 if let Some(inner) = &*inner_guard {
1721 if let Err(e) = inner.send_text(txt, Some(vec!["order".to_string()])).await {
1722 tracing::error!("Error sending message: {e:?}");
1723 }
1724 Ok(())
1725 } else {
1726 Err(OKXWsError::ClientError("Not connected".to_string()))
1727 }
1728 }
1729 }
1730
1731 async fn ws_batch_cancel_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1737 let request_id = self.generate_unique_request_id();
1738
1739 let req = OKXWsRequest {
1740 id: Some(request_id),
1741 op: OKXWsOperation::BatchCancelOrders,
1742 args,
1743 exp_time: None,
1744 };
1745
1746 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1747
1748 {
1749 let inner_guard = self.inner.read().await;
1750 if let Some(inner) = &*inner_guard {
1751 if let Err(e) = inner.send_text(txt, Some(vec!["cancel".to_string()])).await {
1752 tracing::error!("Error sending message: {e:?}");
1753 }
1754 Ok(())
1755 } else {
1756 Err(OKXWsError::ClientError("Not connected".to_string()))
1757 }
1758 }
1759 }
1760
1761 async fn ws_batch_amend_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1767 let request_id = self.generate_unique_request_id();
1768
1769 let req = OKXWsRequest {
1770 id: Some(request_id),
1771 op: OKXWsOperation::BatchAmendOrders,
1772 args,
1773 exp_time: None,
1774 };
1775
1776 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1777
1778 {
1779 let inner_guard = self.inner.read().await;
1780 if let Some(inner) = &*inner_guard {
1781 if let Err(e) = inner.send_text(txt, Some(vec!["amend".to_string()])).await {
1782 tracing::error!("Error sending message: {e:?}");
1783 }
1784 Ok(())
1785 } else {
1786 Err(OKXWsError::ClientError("Not connected".to_string()))
1787 }
1788 }
1789 }
1790
1791 #[allow(clippy::too_many_arguments)]
1797 pub async fn submit_order(
1798 &self,
1799 trader_id: TraderId,
1800 strategy_id: StrategyId,
1801 instrument_id: InstrumentId,
1802 td_mode: OKXTradeMode,
1803 client_order_id: ClientOrderId,
1804 order_side: OrderSide,
1805 order_type: OrderType,
1806 quantity: Quantity,
1807 time_in_force: Option<TimeInForce>,
1808 price: Option<Price>,
1809 trigger_price: Option<Price>,
1810 post_only: Option<bool>,
1811 reduce_only: Option<bool>,
1812 quote_quantity: Option<bool>,
1813 position_side: Option<PositionSide>,
1814 ) -> Result<(), OKXWsError> {
1815 if !OKX_SUPPORTED_ORDER_TYPES.contains(&order_type) {
1816 return Err(OKXWsError::ClientError(format!(
1817 "Unsupported order type: {order_type:?}",
1818 )));
1819 }
1820
1821 if let Some(tif) = time_in_force
1822 && !OKX_SUPPORTED_TIME_IN_FORCE.contains(&tif)
1823 {
1824 return Err(OKXWsError::ClientError(format!(
1825 "Unsupported time in force: {tif:?}",
1826 )));
1827 }
1828
1829 let mut builder = WsPostOrderParamsBuilder::default();
1830
1831 builder.inst_id(instrument_id.symbol.as_str());
1832 builder.td_mode(td_mode);
1833 builder.cl_ord_id(client_order_id.as_str());
1834
1835 let instrument = self
1836 .instruments_cache
1837 .get(&instrument_id.symbol.inner())
1838 .ok_or_else(|| {
1839 OKXWsError::ClientError(format!("Unknown instrument {instrument_id}"))
1840 })?;
1841
1842 let instrument_type =
1843 okx_instrument_type(instrument).map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1844 let quote_currency = instrument.quote_currency();
1845
1846 match instrument_type {
1847 OKXInstrumentType::Spot => {
1848 }
1850 OKXInstrumentType::Margin => {
1851 builder.ccy(quote_currency.to_string());
1853
1854 if let Some(ro) = reduce_only
1856 && ro
1857 {
1858 builder.reduce_only(ro);
1859 }
1860 }
1861 OKXInstrumentType::Swap | OKXInstrumentType::Futures => {
1862 builder.ccy(quote_currency.to_string());
1864 }
1865 _ => {
1866 builder.ccy(quote_currency.to_string());
1868 builder.tgt_ccy(quote_currency.to_string());
1869
1870 if let Some(ro) = reduce_only
1872 && ro
1873 {
1874 builder.reduce_only(ro);
1875 }
1876 }
1877 };
1878
1879 if let Some(is_quote_quantity) = quote_quantity
1880 && is_quote_quantity
1881 {
1882 builder.tgt_ccy(quote_currency.to_string());
1883 }
1884 builder.side(OKXSide::from(order_side));
1887
1888 if let Some(pos_side) = position_side {
1889 builder.pos_side(pos_side);
1890 };
1891
1892 let okx_ord_type = if post_only.unwrap_or(false) {
1894 OKXOrderType::PostOnly
1895 } else {
1896 OKXOrderType::from(order_type)
1897 };
1898
1899 log::debug!(
1900 "Order type mapping: order_type={:?}, time_in_force={:?}, post_only={:?} -> okx_ord_type={:?}",
1901 order_type,
1902 time_in_force,
1903 post_only,
1904 okx_ord_type
1905 );
1906
1907 builder.ord_type(okx_ord_type);
1908 builder.sz(quantity.to_string());
1909
1910 if let Some(tp) = trigger_price {
1911 builder.px(tp.to_string());
1912 } else if let Some(p) = price {
1913 builder.px(p.to_string());
1914 }
1915
1916 builder.tag(OKX_NAUTILUS_BROKER_ID);
1917
1918 let params = builder
1919 .build()
1920 .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
1921
1922 log::debug!("Sending order params to OKX: {:?}", params);
1924
1925 let request_id = self.generate_unique_request_id();
1926
1927 self.pending_place_requests.insert(
1928 request_id.clone(),
1929 (client_order_id, trader_id, strategy_id, instrument_id),
1930 );
1931
1932 self.retry_manager
1933 .execute_with_retry_with_cancel(
1934 "submit_order",
1935 || {
1936 let params = params.clone();
1937 let request_id = request_id.clone();
1938 async move { self.ws_place_order(params, Some(request_id)).await }
1939 },
1940 should_retry_okx_error,
1941 create_okx_timeout_error,
1942 &self.cancellation_token,
1943 )
1944 .await
1945 }
1946
1947 #[allow(clippy::too_many_arguments)]
1953 pub async fn cancel_order(
1954 &self,
1955 trader_id: TraderId,
1956 strategy_id: StrategyId,
1957 instrument_id: InstrumentId,
1958 client_order_id: Option<ClientOrderId>,
1959 venue_order_id: Option<VenueOrderId>,
1960 ) -> Result<(), OKXWsError> {
1961 let mut builder = WsCancelOrderParamsBuilder::default();
1962 builder.inst_id(instrument_id.symbol.as_str());
1965
1966 if let Some(venue_order_id) = venue_order_id {
1967 builder.ord_id(venue_order_id.as_str());
1968 }
1969
1970 if let Some(client_order_id) = client_order_id {
1972 builder.cl_ord_id(client_order_id.as_str());
1973 }
1974
1975 let params = builder
1976 .build()
1977 .map_err(|e| OKXWsError::ClientError(format!("Build cancel params error: {e}")))?;
1978
1979 let request_id = self.generate_unique_request_id();
1980
1981 if let Some(client_order_id) = client_order_id {
1984 self.pending_cancel_requests.insert(
1985 request_id.clone(),
1986 (
1987 client_order_id,
1988 trader_id,
1989 strategy_id,
1990 instrument_id,
1991 venue_order_id,
1992 ),
1993 );
1994 }
1995
1996 self.retry_manager
1997 .execute_with_retry_with_cancel(
1998 "cancel_order",
1999 || {
2000 let params = params.clone();
2001 let request_id = request_id.clone();
2002 async move { self.ws_cancel_order(params, Some(request_id)).await }
2003 },
2004 should_retry_okx_error,
2005 create_okx_timeout_error,
2006 &self.cancellation_token,
2007 )
2008 .await
2009 }
2010
2011 async fn ws_place_order(
2017 &self,
2018 params: WsPostOrderParams,
2019 request_id: Option<String>,
2020 ) -> Result<(), OKXWsError> {
2021 let request_id = request_id.unwrap_or(self.generate_unique_request_id());
2022
2023 let req = OKXWsRequest {
2024 id: Some(request_id),
2025 op: OKXWsOperation::Order,
2026 exp_time: None,
2027 args: vec![params],
2028 };
2029
2030 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2031
2032 {
2033 let inner_guard = self.inner.read().await;
2034 if let Some(inner) = &*inner_guard {
2035 if let Err(e) = inner.send_text(txt, Some(vec!["order".to_string()])).await {
2036 tracing::error!("Error sending message: {e:?}");
2037 }
2038 Ok(())
2039 } else {
2040 Err(OKXWsError::ClientError("Not connected".to_string()))
2041 }
2042 }
2043 }
2044
2045 #[allow(clippy::too_many_arguments)]
2051 pub async fn modify_order(
2052 &self,
2053 trader_id: TraderId,
2054 strategy_id: StrategyId,
2055 instrument_id: InstrumentId,
2056 client_order_id: Option<ClientOrderId>,
2057 price: Option<Price>,
2058 quantity: Option<Quantity>,
2059 venue_order_id: Option<VenueOrderId>,
2060 ) -> Result<(), OKXWsError> {
2061 let mut builder = WsAmendOrderParamsBuilder::default();
2062
2063 builder.inst_id(instrument_id.symbol.as_str());
2064
2065 if let Some(venue_order_id) = venue_order_id {
2066 builder.ord_id(venue_order_id.as_str());
2067 }
2068
2069 if let Some(client_order_id) = client_order_id {
2070 builder.cl_ord_id(client_order_id.as_str());
2071 }
2072
2073 if let Some(price) = price {
2074 builder.new_px(price.to_string());
2075 }
2076
2077 if let Some(quantity) = quantity {
2078 builder.new_sz(quantity.to_string());
2079 }
2080
2081 let params = builder
2082 .build()
2083 .map_err(|e| OKXWsError::ClientError(format!("Build amend params error: {e}")))?;
2084
2085 let request_id = self
2087 .request_id_counter
2088 .fetch_add(1, Ordering::SeqCst)
2089 .to_string();
2090
2091 if let Some(client_order_id) = client_order_id {
2094 self.pending_amend_requests.insert(
2095 request_id.clone(),
2096 (
2097 client_order_id,
2098 trader_id,
2099 strategy_id,
2100 instrument_id,
2101 venue_order_id,
2102 ),
2103 );
2104 }
2105
2106 self.retry_manager
2107 .execute_with_retry_with_cancel(
2108 "modify_order",
2109 || {
2110 let params = params.clone();
2111 let request_id = request_id.clone();
2112 async move { self.ws_amend_order(params, Some(request_id)).await }
2113 },
2114 should_retry_okx_error,
2115 create_okx_timeout_error,
2116 &self.cancellation_token,
2117 )
2118 .await
2119 }
2120
2121 #[allow(clippy::type_complexity)]
2123 #[allow(clippy::too_many_arguments)]
2124 pub async fn batch_submit_orders(
2125 &self,
2126 orders: Vec<(
2127 OKXInstrumentType,
2128 InstrumentId,
2129 OKXTradeMode,
2130 ClientOrderId,
2131 OrderSide,
2132 Option<PositionSide>,
2133 OrderType,
2134 Quantity,
2135 Option<Price>,
2136 Option<Price>,
2137 Option<bool>,
2138 Option<bool>,
2139 )>,
2140 ) -> Result<(), OKXWsError> {
2141 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2142 for (
2143 inst_type,
2144 inst_id,
2145 td_mode,
2146 cl_ord_id,
2147 ord_side,
2148 pos_side,
2149 ord_type,
2150 qty,
2151 pr,
2152 tp,
2153 post_only,
2154 reduce_only,
2155 ) in orders
2156 {
2157 let mut builder = WsPostOrderParamsBuilder::default();
2158 builder.inst_type(inst_type);
2159 builder.inst_id(inst_id.symbol.inner());
2160 builder.td_mode(td_mode);
2161 builder.cl_ord_id(cl_ord_id.as_str());
2162 builder.side(OKXSide::from(ord_side));
2163
2164 if let Some(ps) = pos_side {
2165 builder.pos_side(OKXPositionSide::from(ps));
2166 }
2167
2168 let okx_ord_type = if post_only.unwrap_or(false) {
2169 OKXOrderType::PostOnly
2170 } else {
2171 OKXOrderType::from(ord_type)
2172 };
2173
2174 builder.ord_type(okx_ord_type);
2175 builder.sz(qty.to_string());
2176
2177 if let Some(p) = pr {
2178 builder.px(p.to_string());
2179 } else if let Some(p) = tp {
2180 builder.px(p.to_string());
2181 }
2182
2183 if let Some(ro) = reduce_only {
2184 builder.reduce_only(ro);
2185 }
2186
2187 builder.tag(OKX_NAUTILUS_BROKER_ID);
2188
2189 let params = builder
2190 .build()
2191 .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
2192 let val =
2193 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2194 args.push(val);
2195 }
2196
2197 self.ws_batch_place_orders(args).await
2198 }
2199
2200 #[allow(clippy::type_complexity)]
2202 pub async fn batch_cancel_orders(
2203 &self,
2204 orders: Vec<(
2205 OKXInstrumentType,
2206 InstrumentId,
2207 Option<ClientOrderId>,
2208 Option<String>,
2209 )>,
2210 ) -> Result<(), OKXWsError> {
2211 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2212 for (_inst_type, inst_id, cl_ord_id, ord_id) in orders {
2213 let mut builder = WsCancelOrderParamsBuilder::default();
2214 builder.inst_id(inst_id.symbol.inner());
2216
2217 if let Some(c) = cl_ord_id {
2218 builder.cl_ord_id(c.as_str());
2219 }
2220
2221 if let Some(o) = ord_id {
2222 builder.ord_id(o);
2223 }
2224
2225 let params = builder.build().map_err(|e| {
2226 OKXWsError::ClientError(format!("Build cancel batch params error: {e}"))
2227 })?;
2228 let val =
2229 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2230 args.push(val);
2231 }
2232
2233 self.ws_batch_cancel_orders(args).await
2234 }
2235
2236 pub async fn mass_cancel_orders(&self, inst_id: InstrumentId) -> Result<(), OKXWsError> {
2245 let (inst_type, inst_family) =
2246 self.get_instrument_type_and_family(inst_id.symbol.inner())?;
2247
2248 let params = WsMassCancelParams {
2249 inst_type,
2250 inst_family: Ustr::from(&inst_family),
2251 };
2252
2253 let args =
2254 vec![serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?];
2255
2256 let request_id = self.generate_unique_request_id();
2257
2258 self.pending_mass_cancel_requests
2259 .insert(request_id.clone(), inst_id);
2260
2261 self.retry_manager
2262 .execute_with_retry_with_cancel(
2263 "mass_cancel_orders",
2264 || {
2265 let args = args.clone();
2266 let request_id = request_id.clone();
2267 async move { self.ws_mass_cancel_with_id(args, request_id).await }
2268 },
2269 should_retry_okx_error,
2270 create_okx_timeout_error,
2271 &self.cancellation_token,
2272 )
2273 .await
2274 }
2275
2276 #[allow(clippy::type_complexity)]
2278 #[allow(clippy::too_many_arguments)]
2279 pub async fn batch_modify_orders(
2280 &self,
2281 orders: Vec<(
2282 OKXInstrumentType,
2283 InstrumentId,
2284 ClientOrderId,
2285 ClientOrderId,
2286 Option<Price>,
2287 Option<Quantity>,
2288 )>,
2289 ) -> Result<(), OKXWsError> {
2290 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2291 for (_inst_type, inst_id, cl_ord_id, new_cl_ord_id, pr, sz) in orders {
2292 let mut builder = WsAmendOrderParamsBuilder::default();
2293 builder.inst_id(inst_id.symbol.inner());
2295 builder.cl_ord_id(cl_ord_id.as_str());
2296 builder.new_cl_ord_id(new_cl_ord_id.as_str());
2297
2298 if let Some(p) = pr {
2299 builder.new_px(p.to_string());
2300 }
2301
2302 if let Some(q) = sz {
2303 builder.new_sz(q.to_string());
2304 }
2305
2306 let params = builder.build().map_err(|e| {
2307 OKXWsError::ClientError(format!("Build amend batch params error: {e}"))
2308 })?;
2309 let val =
2310 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2311 args.push(val);
2312 }
2313
2314 self.ws_batch_amend_orders(args).await
2315 }
2316}
2317
2318struct OKXFeedHandler {
2319 receiver: UnboundedReceiver<Message>,
2320 signal: Arc<AtomicBool>,
2321}
2322
2323impl OKXFeedHandler {
2324 pub fn new(receiver: UnboundedReceiver<Message>, signal: Arc<AtomicBool>) -> Self {
2326 Self { receiver, signal }
2327 }
2328
2329 async fn next(&mut self) -> Option<OKXWebSocketEvent> {
2331 loop {
2332 tokio::select! {
2333 msg = self.receiver.recv() => match msg {
2334 Some(msg) => match msg {
2335 Message::Text(text) => {
2336 if text == RECONNECTED {
2338 tracing::info!("Received WebSocket reconnection signal");
2339 return Some(OKXWebSocketEvent::Reconnected);
2340 }
2341 tracing::trace!("Received WebSocket message: {text}");
2342
2343 match serde_json::from_str(&text) {
2344 Ok(ws_event) => match &ws_event {
2345 OKXWebSocketEvent::Error { code, msg } => {
2346 tracing::error!("WebSocket error: {code} - {msg}");
2347 return Some(ws_event);
2348 }
2349 OKXWebSocketEvent::Login {
2350 event,
2351 code,
2352 msg,
2353 conn_id,
2354 } => {
2355 if code == "0" {
2356 tracing::info!(
2357 "Successfully authenticated with OKX WebSocket, conn_id={conn_id}"
2358 );
2359 } else {
2360 tracing::error!(
2361 "Authentication failed: {event} {code} - {msg}"
2362 );
2363 }
2364 return Some(ws_event);
2365 }
2366 OKXWebSocketEvent::Subscription {
2367 event,
2368 arg,
2369 conn_id,
2370 } => {
2371 let channel_str = serde_json::to_string(&arg.channel)
2372 .expect("Invalid OKX websocket channel")
2373 .trim_matches('"')
2374 .to_string();
2375 tracing::debug!(
2376 "{event}d: channel={channel_str}, conn_id={conn_id}"
2377 );
2378 continue;
2379 }
2380 OKXWebSocketEvent::ChannelConnCount {
2381 event: _,
2382 channel,
2383 conn_count,
2384 conn_id,
2385 } => {
2386 let channel_str = serde_json::to_string(&channel)
2387 .expect("Invalid OKX websocket channel")
2388 .trim_matches('"')
2389 .to_string();
2390 tracing::debug!(
2391 "Channel connection status: channel={channel_str}, connections={conn_count}, conn_id={conn_id}",
2392 );
2393 continue;
2394 }
2395 OKXWebSocketEvent::Data { .. } => return Some(ws_event),
2396 OKXWebSocketEvent::BookData { .. } => return Some(ws_event),
2397 OKXWebSocketEvent::OrderResponse {
2398 id,
2399 op,
2400 code,
2401 msg,
2402 data,
2403 } => {
2404 if code == "0" {
2405 tracing::debug!(
2406 "Order operation successful: id={:?}, op={op}, code={code}",
2407 id
2408 );
2409
2410 if let Some(order_data) = data.first() {
2412 let success_msg = order_data
2413 .get("sMsg")
2414 .and_then(|s| s.as_str())
2415 .unwrap_or("Order operation successful");
2416 tracing::debug!("Order success details: {success_msg}");
2417 }
2418 } else {
2419 let error_msg = data
2421 .first()
2422 .and_then(|d| d.get("sMsg"))
2423 .and_then(|s| s.as_str())
2424 .unwrap_or(msg.as_str());
2425 tracing::error!(
2426 "Order operation failed: id={id:?}, op={op}, code={code}, error={error_msg}",
2427 );
2428 }
2429 return Some(ws_event);
2430 }
2431 OKXWebSocketEvent::Reconnected => {
2432 tracing::warn!("Unexpected Reconnected event from deserialization");
2434 continue;
2435 }
2436 },
2437 Err(e) => {
2438 tracing::error!("Failed to parse message: {e}: {text}");
2439 return None;
2440 }
2441 }
2442 }
2443 Message::Binary(msg) => {
2444 tracing::debug!("Raw binary: {msg:?}");
2445 }
2446 Message::Close(_) => {
2447 tracing::debug!("Received close message");
2448 return None;
2449 }
2450 msg => {
2451 tracing::warn!("Unexpected message: {msg}");
2452 }
2453 }
2454 None => {
2455 tracing::info!("WebSocket stream closed");
2456 return None;
2457 }
2458 },
2459 _ = tokio::time::sleep(Duration::from_millis(1)) => {
2460 if self.signal.load(std::sync::atomic::Ordering::Relaxed) {
2461 tracing::debug!("Stop signal received");
2462 return None;
2463 }
2464 }
2465 }
2466 }
2467 }
2468}
2469
2470struct OKXWsMessageHandler {
2471 account_id: AccountId,
2472 handler: OKXFeedHandler,
2473 #[allow(dead_code)]
2474 tx: tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>,
2475 pending_place_requests: Arc<DashMap<String, PlaceRequestData>>,
2476 pending_cancel_requests: Arc<DashMap<String, CancelRequestData>>,
2477 pending_amend_requests: Arc<DashMap<String, AmendRequestData>>,
2478 pending_mass_cancel_requests: Arc<DashMap<String, MassCancelRequestData>>,
2479 instruments_cache: Arc<AHashMap<Ustr, InstrumentAny>>,
2480 last_account_state: Option<AccountState>,
2481 fee_cache: AHashMap<Ustr, Money>, funding_rate_cache: AHashMap<Ustr, (Ustr, u64)>, auth_state: Arc<tokio::sync::watch::Sender<bool>>,
2484}
2485
2486impl OKXWsMessageHandler {
2487 #[allow(clippy::too_many_arguments)]
2489 pub fn new(
2490 account_id: AccountId,
2491 instruments_cache: Arc<AHashMap<Ustr, InstrumentAny>>,
2492 reader: UnboundedReceiver<Message>,
2493 signal: Arc<AtomicBool>,
2494 tx: tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>,
2495 pending_place_requests: Arc<DashMap<String, PlaceRequestData>>,
2496 pending_cancel_requests: Arc<DashMap<String, CancelRequestData>>,
2497 pending_amend_requests: Arc<DashMap<String, AmendRequestData>>,
2498 pending_mass_cancel_requests: Arc<DashMap<String, MassCancelRequestData>>,
2499 auth_state: Arc<tokio::sync::watch::Sender<bool>>,
2500 ) -> Self {
2501 Self {
2502 account_id,
2503 handler: OKXFeedHandler::new(reader, signal),
2504 tx,
2505 pending_place_requests,
2506 pending_cancel_requests,
2507 pending_amend_requests,
2508 pending_mass_cancel_requests,
2509 instruments_cache,
2510 last_account_state: None,
2511 fee_cache: AHashMap::new(),
2512 funding_rate_cache: AHashMap::new(),
2513 auth_state,
2514 }
2515 }
2516
2517 fn is_stopped(&self) -> bool {
2518 self.handler
2519 .signal
2520 .load(std::sync::atomic::Ordering::Relaxed)
2521 }
2522
2523 #[allow(dead_code)]
2524 async fn run(&mut self) {
2525 while let Some(data) = self.next().await {
2526 if let Err(e) = self.tx.send(data) {
2527 tracing::error!("Error sending data: {e}");
2528 break; }
2530 }
2531 }
2532
2533 async fn next(&mut self) -> Option<NautilusWsMessage> {
2534 let clock = get_atomic_clock_realtime();
2535
2536 while let Some(event) = self.handler.next().await {
2537 let ts_init = clock.get_time_ns();
2538
2539 if let OKXWebSocketEvent::Login { code, msg, .. } = event {
2540 if code == "0" {
2541 if self.auth_state.send(true).is_err() {
2542 tracing::error!(
2543 "Failed to send authentication success signal: receiver dropped"
2544 );
2545 }
2546 } else {
2547 tracing::error!("Authentication failed: {msg}");
2548 if self.auth_state.send(false).is_err() {
2549 tracing::error!(
2550 "Failed to send authentication failure signal: receiver dropped"
2551 );
2552 }
2553 }
2554 continue; }
2556
2557 if let OKXWebSocketEvent::BookData { arg, action, data } = event {
2558 let inst = match arg.inst_id {
2559 Some(inst_id) => match self.instruments_cache.get(&inst_id) {
2560 Some(inst_ref) => inst_ref.clone(),
2561 None => continue,
2562 },
2563 None => {
2564 tracing::error!("Instrument ID missing for book data event");
2565 continue;
2566 }
2567 };
2568
2569 let instrument_id = inst.id();
2570 let price_precision = inst.price_precision();
2571 let size_precision = inst.size_precision();
2572
2573 match parse_book_msg_vec(
2574 data,
2575 &instrument_id,
2576 price_precision,
2577 size_precision,
2578 action,
2579 ts_init,
2580 ) {
2581 Ok(data) => return Some(NautilusWsMessage::Data(data)),
2582 Err(e) => {
2583 tracing::error!("Failed to parse book message: {e}");
2584 continue;
2585 }
2586 }
2587 }
2588
2589 if let OKXWebSocketEvent::OrderResponse {
2590 id,
2591 op,
2592 code,
2593 msg,
2594 data,
2595 } = event
2596 {
2597 if code == "0" {
2598 tracing::debug!(
2599 "Order operation successful: id={:?} op={op} code={code}",
2600 id
2601 );
2602
2603 if op == OKXWsOperation::MassCancel
2605 && let Some(id) = &id
2606 && let Some((_, instrument_id)) =
2607 self.pending_mass_cancel_requests.remove(id)
2608 {
2609 tracing::info!(
2610 "Mass cancel operation successful for instrument: {}",
2611 instrument_id
2612 );
2613 }
2615
2616 if let Some(data) = data.first() {
2617 let success_msg = data
2618 .get("sMsg")
2619 .and_then(|s| s.as_str())
2620 .unwrap_or("Order operation successful");
2621 tracing::debug!("Order details: {success_msg}");
2622
2623 }
2627 } else {
2628 let error_msg = data
2630 .first()
2631 .and_then(|d| d.get("sMsg"))
2632 .and_then(|s| s.as_str())
2633 .unwrap_or(&msg);
2634
2635 if let Some(data_obj) = data.first() {
2637 tracing::debug!(
2638 "Error data fields: {}",
2639 serde_json::to_string_pretty(data_obj)
2640 .unwrap_or_else(|_| "unable to serialize".to_string())
2641 );
2642 }
2643
2644 tracing::error!(
2645 "Order operation failed: id={:?} op={op} code={code} msg={msg}",
2646 id
2647 );
2648
2649 if let Some(id) = &id {
2651 match op {
2652 OKXWsOperation::Order => {
2653 if let Some((
2654 _,
2655 (client_order_id, trader_id, strategy_id, instrument_id),
2656 )) = self.pending_place_requests.remove(id)
2657 {
2658 let ts_event = clock.get_time_ns();
2659 let rejected = OrderRejected::new(
2660 trader_id,
2661 strategy_id,
2662 instrument_id,
2663 client_order_id,
2664 self.account_id,
2665 Ustr::from(error_msg), UUID4::new(),
2667 ts_event,
2668 ts_init,
2669 false, false, );
2672
2673 return Some(NautilusWsMessage::OrderRejected(rejected));
2674 }
2675 }
2676 OKXWsOperation::CancelOrder => {
2677 if let Some((
2678 _,
2679 (
2680 client_order_id,
2681 trader_id,
2682 strategy_id,
2683 instrument_id,
2684 venue_order_id,
2685 ),
2686 )) = self.pending_cancel_requests.remove(id)
2687 {
2688 let ts_event = clock.get_time_ns();
2689 let rejected = OrderCancelRejected::new(
2690 trader_id,
2691 strategy_id,
2692 instrument_id,
2693 client_order_id,
2694 Ustr::from(error_msg), UUID4::new(),
2696 ts_event,
2697 ts_init,
2698 false, venue_order_id,
2700 Some(self.account_id),
2701 );
2702
2703 return Some(NautilusWsMessage::OrderCancelRejected(rejected));
2704 }
2705 }
2706 OKXWsOperation::AmendOrder => {
2707 if let Some((
2708 _,
2709 (
2710 client_order_id,
2711 trader_id,
2712 strategy_id,
2713 instrument_id,
2714 venue_order_id,
2715 ),
2716 )) = self.pending_amend_requests.remove(id)
2717 {
2718 let ts_event = clock.get_time_ns();
2719 let rejected = OrderModifyRejected::new(
2720 trader_id,
2721 strategy_id,
2722 instrument_id,
2723 client_order_id,
2724 Ustr::from(error_msg), UUID4::new(),
2726 ts_event,
2727 ts_init,
2728 false, venue_order_id,
2730 Some(self.account_id),
2731 );
2732
2733 return Some(NautilusWsMessage::OrderModifyRejected(rejected));
2734 }
2735 }
2736 OKXWsOperation::MassCancel => {
2737 if let Some((_, instrument_id)) =
2738 self.pending_mass_cancel_requests.remove(id)
2739 {
2740 tracing::error!(
2741 "Mass cancel operation failed for {}: code={code} msg={error_msg}",
2742 instrument_id
2743 );
2744 let error = OKXWebSocketError {
2746 code: code.clone(),
2747 message: format!(
2748 "Mass cancel failed for {}: {}",
2749 instrument_id, error_msg
2750 ),
2751 conn_id: None,
2752 timestamp: clock.get_time_ns().as_u64(),
2753 };
2754 return Some(NautilusWsMessage::Error(error));
2755 } else {
2756 tracing::error!(
2757 "Mass cancel operation failed: code={code} msg={error_msg}"
2758 );
2759 }
2760 }
2761 _ => {
2762 tracing::warn!("Unhandled operation type for rejection: {op}");
2763 }
2764 }
2765 }
2766
2767 let error = OKXWebSocketError {
2769 code: code.clone(),
2770 message: error_msg.to_string(),
2771 conn_id: None, timestamp: clock.get_time_ns().as_u64(),
2773 };
2774 return Some(NautilusWsMessage::Error(error));
2775 }
2776 continue;
2777 }
2778
2779 if let OKXWebSocketEvent::Data { ref arg, ref data } = event {
2780 if arg.channel == OKXWsChannel::Account {
2781 match serde_json::from_value::<Vec<OKXAccount>>(data.clone()) {
2782 Ok(accounts) => {
2783 if let Some(account) = accounts.first() {
2784 match parse_account_state(account, self.account_id, ts_init) {
2786 Ok(account_state) => {
2787 if let Some(last_account_state) = &self.last_account_state
2789 && account_state
2790 .has_same_balances_and_margins(last_account_state)
2791 {
2792 continue; }
2794 self.last_account_state = Some(account_state.clone());
2795 return Some(NautilusWsMessage::AccountUpdate(
2796 account_state,
2797 ));
2798 }
2799 Err(e) => {
2800 tracing::error!("Failed to parse account state: {e}");
2801 }
2802 }
2803 }
2804 }
2805 Err(e) => {
2806 tracing::error!(
2807 "Failed to parse account data: {e}, raw data: {}",
2808 data
2809 );
2810 }
2811 }
2812 continue;
2813 }
2814
2815 if arg.channel == OKXWsChannel::Orders {
2816 tracing::debug!("Received orders channel message: {data}");
2817
2818 let data: Vec<OKXOrderMsg> = serde_json::from_value(data.clone()).unwrap();
2819
2820 let mut exec_reports = Vec::with_capacity(data.len());
2821
2822 for msg in data {
2823 match parse_order_msg_vec(
2824 vec![msg],
2825 self.account_id,
2826 &self.instruments_cache,
2827 &self.fee_cache,
2828 ts_init,
2829 ) {
2830 Ok(mut reports) => {
2831 for report in &reports {
2833 match report {
2834 ExecutionReport::Fill(fill_report) => {
2835 let order_id = fill_report.venue_order_id.inner();
2836 let current_fee = self
2837 .fee_cache
2838 .get(&order_id)
2839 .copied()
2840 .unwrap_or_else(|| {
2841 Money::new(0.0, fill_report.commission.currency)
2842 });
2843 let total_fee = current_fee + fill_report.commission;
2844 self.fee_cache.insert(order_id, total_fee);
2845 }
2846 ExecutionReport::Order(status_report) => {
2847 if matches!(
2848 status_report.order_status,
2849 OrderStatus::Filled,
2850 ) {
2851 self.fee_cache
2852 .remove(&status_report.venue_order_id.inner());
2853 }
2854 }
2855 }
2856 }
2857 exec_reports.append(&mut reports);
2858 }
2859 Err(e) => {
2860 tracing::error!("Failed to parse order message: {e}");
2861 continue;
2862 }
2863 }
2864 }
2865
2866 if !exec_reports.is_empty() {
2867 return Some(NautilusWsMessage::ExecutionReports(exec_reports));
2868 }
2869 }
2870
2871 let inst = match arg.inst_id.and_then(|id| self.instruments_cache.get(&id)) {
2872 Some(inst) => inst,
2873 None => {
2874 tracing::error!(
2875 "No instrument for channel {:?}, inst_id {:?}",
2876 arg.channel,
2877 arg.inst_id
2878 );
2879 continue;
2880 }
2881 };
2882 let instrument_id = inst.id();
2883 let price_precision = inst.price_precision();
2884 let size_precision = inst.size_precision();
2885
2886 match parse_ws_message_data(
2887 &arg.channel,
2888 data.clone(),
2889 &instrument_id,
2890 price_precision,
2891 size_precision,
2892 ts_init,
2893 &mut self.funding_rate_cache,
2894 ) {
2895 Ok(Some(msg)) => return Some(msg),
2896 Ok(None) => {
2897 continue;
2899 }
2900 Err(e) => {
2901 tracing::error!("Error parsing message for channel {:?}: {e}", arg.channel)
2902 }
2903 }
2904 }
2905
2906 if let OKXWebSocketEvent::Login {
2908 code, msg, conn_id, ..
2909 } = &event
2910 && code != "0"
2911 {
2912 let error = OKXWebSocketError {
2913 code: code.clone(),
2914 message: msg.clone(),
2915 conn_id: Some(conn_id.clone()),
2916 timestamp: clock.get_time_ns().as_u64(),
2917 };
2918 return Some(NautilusWsMessage::Error(error));
2919 }
2920
2921 if let OKXWebSocketEvent::Error { code, msg } = &event {
2923 let error = OKXWebSocketError {
2924 code: code.clone(),
2925 message: msg.clone(),
2926 conn_id: None,
2927 timestamp: clock.get_time_ns().as_u64(),
2928 };
2929 return Some(NautilusWsMessage::Error(error));
2930 }
2931
2932 if matches!(&event, OKXWebSocketEvent::Reconnected) {
2934 return Some(NautilusWsMessage::Reconnected);
2935 }
2936 }
2937 None }
2939}
2940
2941#[cfg(test)]
2946mod tests {
2947 use futures_util;
2948 use rstest::rstest;
2949
2950 use super::*;
2951
2952 #[rstest]
2953 fn test_timestamp_format_for_websocket_auth() {
2954 let timestamp = SystemTime::now()
2955 .duration_since(SystemTime::UNIX_EPOCH)
2956 .expect("System time should be after UNIX epoch")
2957 .as_secs()
2958 .to_string();
2959
2960 assert!(timestamp.parse::<u64>().is_ok());
2961 assert_eq!(timestamp.len(), 10);
2962 assert!(timestamp.chars().all(|c| c.is_ascii_digit()));
2963 }
2964
2965 #[rstest]
2966 fn test_new_without_credentials() {
2967 let client = OKXWebSocketClient::default();
2968 assert!(client.credential.is_none());
2969 assert_eq!(client.api_key(), None);
2970 }
2971
2972 #[rstest]
2973 fn test_new_with_credentials() {
2974 let client = OKXWebSocketClient::new(
2975 None,
2976 Some("test_key".to_string()),
2977 Some("test_secret".to_string()),
2978 Some("test_passphrase".to_string()),
2979 None,
2980 None,
2981 )
2982 .unwrap();
2983 assert!(client.credential.is_some());
2984 assert_eq!(client.api_key(), Some("test_key"));
2985 }
2986
2987 #[rstest]
2988 fn test_new_partial_credentials_fails() {
2989 let result = OKXWebSocketClient::new(
2990 None,
2991 Some("test_key".to_string()),
2992 None,
2993 Some("test_passphrase".to_string()),
2994 None,
2995 None,
2996 );
2997 assert!(result.is_err());
2998 }
2999
3000 #[rstest]
3001 fn test_request_id_generation() {
3002 let client = OKXWebSocketClient::default();
3003
3004 let initial_counter = client.request_id_counter.load(Ordering::SeqCst);
3005
3006 let id1 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
3007 let id2 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
3008
3009 assert_eq!(id1, initial_counter);
3010 assert_eq!(id2, initial_counter + 1);
3011 assert_eq!(
3012 client.request_id_counter.load(Ordering::SeqCst),
3013 initial_counter + 2
3014 );
3015 }
3016
3017 #[rstest]
3018 fn test_client_state_management() {
3019 let client = OKXWebSocketClient::default();
3020
3021 assert!(client.is_closed());
3022 assert!(!client.is_active());
3023
3024 let client_with_heartbeat =
3025 OKXWebSocketClient::new(None, None, None, None, None, Some(30)).unwrap();
3026
3027 assert!(client_with_heartbeat.heartbeat.is_some());
3028 assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
3029 }
3030
3031 #[rstest]
3032 fn test_request_cache_operations() {
3033 let client = OKXWebSocketClient::default();
3034
3035 assert_eq!(client.pending_place_requests.len(), 0);
3036 assert_eq!(client.pending_cancel_requests.len(), 0);
3037 assert_eq!(client.pending_amend_requests.len(), 0);
3038
3039 let client_order_id = ClientOrderId::from("test-order-123");
3040 let trader_id = TraderId::from("test-trader-001");
3041 let strategy_id = StrategyId::from("test-strategy-001");
3042 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
3043
3044 client.pending_place_requests.insert(
3045 "place-123".to_string(),
3046 (client_order_id, trader_id, strategy_id, instrument_id),
3047 );
3048
3049 assert_eq!(client.pending_place_requests.len(), 1);
3050 assert!(client.pending_place_requests.contains_key("place-123"));
3051
3052 let removed = client.pending_place_requests.remove("place-123");
3053 assert!(removed.is_some());
3054 assert_eq!(client.pending_place_requests.len(), 0);
3055 }
3056
3057 #[rstest]
3058 fn test_websocket_error_handling() {
3059 let clock = get_atomic_clock_realtime();
3060 let ts = clock.get_time_ns().as_u64();
3061
3062 let error = OKXWebSocketError {
3063 code: "60012".to_string(),
3064 message: "Invalid request".to_string(),
3065 conn_id: None,
3066 timestamp: ts,
3067 };
3068
3069 assert_eq!(error.code, "60012");
3070 assert_eq!(error.message, "Invalid request");
3071 assert_eq!(error.timestamp, ts);
3072
3073 let nautilus_msg = NautilusWsMessage::Error(error);
3074 match nautilus_msg {
3075 NautilusWsMessage::Error(err) => {
3076 assert_eq!(err.code, "60012");
3077 assert_eq!(err.message, "Invalid request");
3078 }
3079 _ => panic!("Expected Error variant"),
3080 }
3081 }
3082
3083 #[rstest]
3084 fn test_request_id_generation_sequence() {
3085 let client = OKXWebSocketClient::default();
3086
3087 let initial_counter = client
3088 .request_id_counter
3089 .load(std::sync::atomic::Ordering::SeqCst);
3090 let mut ids = Vec::new();
3091 for _ in 0..10 {
3092 let id = client
3093 .request_id_counter
3094 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
3095 ids.push(id);
3096 }
3097
3098 for (i, &id) in ids.iter().enumerate() {
3099 assert_eq!(id, initial_counter + i as u64);
3100 }
3101
3102 assert_eq!(
3103 client
3104 .request_id_counter
3105 .load(std::sync::atomic::Ordering::SeqCst),
3106 initial_counter + 10
3107 );
3108 }
3109
3110 #[rstest]
3111 fn test_client_state_transitions() {
3112 let client = OKXWebSocketClient::default();
3113
3114 assert!(client.is_closed());
3115 assert!(!client.is_active());
3116
3117 let client_with_heartbeat = OKXWebSocketClient::new(
3118 None,
3119 None,
3120 None,
3121 None,
3122 None,
3123 Some(30), )
3125 .unwrap();
3126
3127 assert!(client_with_heartbeat.heartbeat.is_some());
3128 assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
3129
3130 let account_id = AccountId::from("test-account-123");
3131 let client_with_account =
3132 OKXWebSocketClient::new(None, None, None, None, Some(account_id), None).unwrap();
3133
3134 assert_eq!(client_with_account.account_id, account_id);
3135 }
3136
3137 #[tokio::test]
3138 async fn test_concurrent_request_handling() {
3139 let client = Arc::new(OKXWebSocketClient::default());
3140
3141 let initial_counter = client
3142 .request_id_counter
3143 .load(std::sync::atomic::Ordering::SeqCst);
3144 let mut handles = Vec::new();
3145
3146 for i in 0..10 {
3147 let client_clone = Arc::clone(&client);
3148 let handle = tokio::spawn(async move {
3149 let client_order_id = ClientOrderId::from(format!("order-{i}").as_str());
3150 let trader_id = TraderId::from("trader-001");
3151 let strategy_id = StrategyId::from("strategy-001");
3152 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
3153
3154 let request_id = client_clone
3155 .request_id_counter
3156 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
3157 let request_id_str = request_id.to_string();
3158
3159 client_clone.pending_place_requests.insert(
3160 request_id_str.clone(),
3161 (client_order_id, trader_id, strategy_id, instrument_id),
3162 );
3163
3164 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
3166
3167 let removed = client_clone.pending_place_requests.remove(&request_id_str);
3169 assert!(removed.is_some());
3170
3171 request_id
3172 });
3173 handles.push(handle);
3174 }
3175
3176 let results: Vec<_> = futures_util::future::join_all(handles).await;
3178
3179 assert_eq!(results.len(), 10);
3180 for result in results {
3181 assert!(result.is_ok());
3182 }
3183
3184 assert_eq!(client.pending_place_requests.len(), 0);
3185
3186 let final_counter = client
3187 .request_id_counter
3188 .load(std::sync::atomic::Ordering::SeqCst);
3189 assert_eq!(final_counter, initial_counter + 10);
3190 }
3191
3192 #[rstest]
3193 fn test_websocket_error_scenarios() {
3194 let clock = get_atomic_clock_realtime();
3195 let ts = clock.get_time_ns().as_u64();
3196
3197 let error_scenarios = vec![
3198 ("60012", "Invalid request", None),
3199 ("60009", "Invalid API key", Some("conn-123".to_string())),
3200 ("60014", "Too many requests", None),
3201 ("50001", "Order not found", None),
3202 ];
3203
3204 for (code, message, conn_id) in error_scenarios {
3205 let error = OKXWebSocketError {
3206 code: code.to_string(),
3207 message: message.to_string(),
3208 conn_id: conn_id.clone(),
3209 timestamp: ts,
3210 };
3211
3212 assert_eq!(error.code, code);
3213 assert_eq!(error.message, message);
3214 assert_eq!(error.conn_id, conn_id);
3215 assert_eq!(error.timestamp, ts);
3216
3217 let nautilus_msg = NautilusWsMessage::Error(error);
3218 match nautilus_msg {
3219 NautilusWsMessage::Error(err) => {
3220 assert_eq!(err.code, code);
3221 assert_eq!(err.message, message);
3222 assert_eq!(err.conn_id, conn_id);
3223 }
3224 _ => panic!("Expected Error variant"),
3225 }
3226 }
3227 }
3228
3229 #[tokio::test]
3230 async fn test_feed_handler_reconnection_detection() {
3231 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
3232 let signal = Arc::new(AtomicBool::new(false));
3233 let mut handler = OKXFeedHandler::new(rx, signal.clone());
3234
3235 tx.send(Message::Text(RECONNECTED.to_string().into()))
3236 .unwrap();
3237
3238 let result = handler.next().await;
3239 assert!(matches!(result, Some(OKXWebSocketEvent::Reconnected)));
3240 }
3241
3242 #[tokio::test]
3243 async fn test_feed_handler_normal_message_processing() {
3244 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
3245 let signal = Arc::new(AtomicBool::new(false));
3246 let mut handler = OKXFeedHandler::new(rx, signal.clone());
3247
3248 let ping_msg = "ping";
3250 tx.send(Message::Text(ping_msg.to_string().into())).unwrap();
3251
3252 let sub_msg = r#"{
3254 "event": "subscribe",
3255 "arg": {
3256 "channel": "tickers",
3257 "instType": "SPOT"
3258 },
3259 "connId": "a4d3ae55"
3260 }"#;
3261
3262 tx.send(Message::Text(sub_msg.to_string().into())).unwrap();
3263
3264 signal.store(true, std::sync::atomic::Ordering::Relaxed);
3266
3267 let result = handler.next().await;
3269 assert!(result.is_none());
3270 }
3271
3272 #[tokio::test]
3273 async fn test_feed_handler_stop_signal() {
3274 let (_tx, rx) = tokio::sync::mpsc::unbounded_channel();
3275 let signal = Arc::new(AtomicBool::new(true)); let mut handler = OKXFeedHandler::new(rx, signal.clone());
3277
3278 let result = handler.next().await;
3279 assert!(result.is_none());
3280 }
3281
3282 #[tokio::test]
3283 async fn test_feed_handler_close_message() {
3284 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
3285 let signal = Arc::new(AtomicBool::new(false));
3286 let mut handler = OKXFeedHandler::new(rx, signal.clone());
3287
3288 tx.send(Message::Close(None)).unwrap();
3290
3291 let result = handler.next().await;
3292 assert!(result.is_none());
3293 }
3294
3295 #[tokio::test]
3296 async fn test_reconnection_message_constant() {
3297 assert_eq!(RECONNECTED, "__RECONNECTED__");
3298 }
3299
3300 #[tokio::test]
3301 async fn test_multiple_reconnection_signals() {
3302 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
3303 let signal = Arc::new(AtomicBool::new(false));
3304 let mut handler = OKXFeedHandler::new(rx, signal.clone());
3305
3306 for _ in 0..3 {
3308 tx.send(Message::Text(RECONNECTED.to_string().into()))
3309 .unwrap();
3310
3311 let result = handler.next().await;
3312 assert!(matches!(result, Some(OKXWebSocketEvent::Reconnected)));
3313 }
3314 }
3315
3316 #[tokio::test]
3317 async fn test_wait_until_active_timeout() {
3318 let client = OKXWebSocketClient::new(
3319 None,
3320 Some("test_key".to_string()),
3321 Some("test_secret".to_string()),
3322 Some("test_passphrase".to_string()),
3323 Some(AccountId::from("test-account")),
3324 None,
3325 )
3326 .unwrap();
3327
3328 let result = client.wait_until_active(0.1).await;
3330
3331 assert!(result.is_err());
3332 assert!(!client.is_active());
3333 }
3334}