1use std::{
24 fmt::Debug,
25 num::NonZeroU32,
26 sync::{
27 Arc, LazyLock,
28 atomic::{AtomicBool, AtomicU64, Ordering},
29 },
30 time::{Duration, SystemTime},
31};
32
33use ahash::{AHashMap, AHashSet};
34use dashmap::DashMap;
35use futures_util::Stream;
36use nautilus_common::runtime::get_runtime;
37use nautilus_core::{
38 UUID4, consts::NAUTILUS_USER_AGENT, env::get_env_var, time::get_atomic_clock_realtime,
39};
40use nautilus_model::{
41 data::BarType,
42 enums::{OrderSide, OrderStatus, OrderType, PositionSide, TimeInForce},
43 events::{AccountState, OrderCancelRejected, OrderModifyRejected, OrderRejected},
44 identifiers::{AccountId, ClientOrderId, InstrumentId, StrategyId, TraderId, VenueOrderId},
45 instruments::{Instrument, InstrumentAny},
46 types::{Money, Price, Quantity},
47};
48use nautilus_network::{
49 RECONNECTED,
50 ratelimiter::quota::Quota,
51 websocket::{WebSocketClient, WebSocketConfig, channel_message_handler},
52};
53use reqwest::header::USER_AGENT;
54use serde_json::Value;
55use tokio::sync::mpsc::UnboundedReceiver;
56use tokio_tungstenite::tungstenite::{Error, Message};
57use ustr::Ustr;
58
59use super::{
60 enums::{OKXWsChannel, OKXWsOperation},
61 error::OKXWsError,
62 messages::{
63 ExecutionReport, NautilusWsMessage, OKXAuthentication, OKXAuthenticationArg,
64 OKXSubscription, OKXSubscriptionArg, OKXWebSocketError, OKXWebSocketEvent, OKXWsRequest,
65 WsAmendOrderParams, WsAmendOrderParamsBuilder, WsCancelOrderParams,
66 WsCancelOrderParamsBuilder, WsPostOrderParams, WsPostOrderParamsBuilder,
67 },
68 parse::{parse_book_msg_vec, parse_ws_message_data},
69};
70use crate::{
71 common::{
72 consts::{
73 OKX_NAUTILUS_BROKER_ID, OKX_SUPPORTED_ORDER_TYPES, OKX_SUPPORTED_TIME_IN_FORCE,
74 OKX_WS_PUBLIC_URL,
75 },
76 credential::Credential,
77 enums::{OKXInstrumentType, OKXOrderType, OKXPositionSide, OKXSide, OKXTradeMode},
78 parse::{bar_spec_as_okx_channel, okx_instrument_type, parse_account_state},
79 },
80 http::models::OKXAccount,
81 websocket::{messages::OKXOrderMsg, parse::parse_order_msg_vec},
82};
83
84type PlaceRequestData = (ClientOrderId, TraderId, StrategyId, InstrumentId);
85type CancelRequestData = (
86 ClientOrderId,
87 TraderId,
88 StrategyId,
89 InstrumentId,
90 Option<VenueOrderId>,
91);
92type AmendRequestData = (
93 ClientOrderId,
94 TraderId,
95 StrategyId,
96 InstrumentId,
97 Option<VenueOrderId>,
98);
99
100pub static OKX_WS_QUOTA: LazyLock<Quota> =
108 LazyLock::new(|| Quota::per_second(NonZeroU32::new(3).unwrap()));
109
110pub static OKX_WS_ORDER_QUOTA: LazyLock<Quota> =
115 LazyLock::new(|| Quota::per_second(NonZeroU32::new(250).unwrap()));
116
117#[derive(Clone)]
119#[cfg_attr(
120 feature = "python",
121 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.adapters")
122)]
123pub struct OKXWebSocketClient {
124 url: String,
125 account_id: AccountId,
126 credential: Option<Credential>,
127 heartbeat: Option<u64>,
128 inner: Arc<tokio::sync::RwLock<Option<WebSocketClient>>>,
129 auth_state: Arc<tokio::sync::watch::Sender<bool>>,
130 auth_state_rx: tokio::sync::watch::Receiver<bool>,
131 rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>,
132 signal: Arc<AtomicBool>,
133 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
134 subscriptions_inst_type: Arc<DashMap<OKXWsChannel, AHashSet<OKXInstrumentType>>>,
135 subscriptions_inst_family: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
136 subscriptions_inst_id: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
137 subscriptions_bare: Arc<DashMap<OKXWsChannel, bool>>, request_id_counter: Arc<AtomicU64>,
139 pending_place_requests: Arc<DashMap<String, PlaceRequestData>>,
140 pending_cancel_requests: Arc<DashMap<String, CancelRequestData>>,
141 pending_amend_requests: Arc<DashMap<String, AmendRequestData>>,
142 instruments_cache: Arc<AHashMap<Ustr, InstrumentAny>>,
143}
144
145impl Default for OKXWebSocketClient {
146 fn default() -> Self {
147 Self::new(None, None, None, None, None, None).unwrap()
148 }
149}
150
151impl Debug for OKXWebSocketClient {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 f.debug_struct(stringify!(OKXWebSocketClient))
154 .field("url", &self.url)
155 .field(
156 "credential",
157 &self.credential.as_ref().map(|_| "<redacted>"),
158 )
159 .field("heartbeat", &self.heartbeat)
160 .finish_non_exhaustive()
161 }
162}
163
164impl OKXWebSocketClient {
165 pub fn new(
167 url: Option<String>,
168 api_key: Option<String>,
169 api_secret: Option<String>,
170 api_passphrase: Option<String>,
171 account_id: Option<AccountId>,
172 heartbeat: Option<u64>,
173 ) -> anyhow::Result<Self> {
174 let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
175 let account_id = account_id.unwrap_or(AccountId::from("OKX-master"));
176
177 let credential = match (api_key, api_secret, api_passphrase) {
178 (Some(key), Some(secret), Some(passphrase)) => {
179 Some(Credential::new(key, secret, passphrase))
180 }
181 (None, None, None) => None,
182 _ => anyhow::bail!(
183 "`api_key`, `api_secret`, `api_passphrase` credentials must be provided together"
184 ),
185 };
186
187 let signal = Arc::new(AtomicBool::new(false));
188 let subscriptions_inst_type = Arc::new(DashMap::new());
189 let subscriptions_inst_family = Arc::new(DashMap::new());
190 let subscriptions_inst_id = Arc::new(DashMap::new());
191 let subscriptions_bare = Arc::new(DashMap::new());
192 let (auth_tx, auth_rx) = tokio::sync::watch::channel(false);
193
194 Ok(Self {
195 url,
196 account_id,
197 credential,
198 heartbeat,
199 inner: Arc::new(tokio::sync::RwLock::new(None)),
200 auth_state: Arc::new(auth_tx),
201 auth_state_rx: auth_rx,
202 rx: None,
203 signal,
204 task_handle: None,
205 subscriptions_inst_type,
206 subscriptions_inst_family,
207 subscriptions_inst_id,
208 subscriptions_bare,
209 request_id_counter: Arc::new(AtomicU64::new(1)),
210 pending_place_requests: Arc::new(DashMap::new()),
211 pending_cancel_requests: Arc::new(DashMap::new()),
212 pending_amend_requests: Arc::new(DashMap::new()),
213 instruments_cache: Arc::new(AHashMap::new()),
214 })
215 }
216
217 pub fn with_credentials(
219 url: Option<String>,
220 api_key: Option<String>,
221 api_secret: Option<String>,
222 api_passphrase: Option<String>,
223 account_id: Option<AccountId>,
224 heartbeat: Option<u64>,
225 ) -> anyhow::Result<Self> {
226 let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
227 let api_key = api_key.unwrap_or(get_env_var("OKX_API_KEY")?);
228 let api_secret = api_secret.unwrap_or(get_env_var("OKX_API_SECRET")?);
229 let api_passphrase = api_passphrase.unwrap_or(get_env_var("OKX_API_PASSPHRASE")?);
230
231 Self::new(
232 Some(url),
233 Some(api_key),
234 Some(api_secret),
235 Some(api_passphrase),
236 account_id,
237 heartbeat,
238 )
239 }
240
241 pub fn from_env() -> anyhow::Result<Self> {
243 let url = get_env_var("OKX_WS_URL")?;
244 let api_key = get_env_var("OKX_API_KEY")?;
245 let api_secret = get_env_var("OKX_API_SECRET")?;
246 let api_passphrase = get_env_var("OKX_API_PASSPHRASE")?;
247
248 Self::new(
249 Some(url),
250 Some(api_key),
251 Some(api_secret),
252 Some(api_passphrase),
253 None,
254 None,
255 )
256 }
257
258 pub fn url(&self) -> &str {
260 self.url.as_str()
261 }
262
263 pub fn api_key(&self) -> Option<&str> {
265 self.credential.clone().map(|c| c.api_key.as_str())
266 }
267
268 pub fn is_active(&self) -> bool {
271 match self.inner.try_read() {
273 Ok(guard) => match &*guard {
274 Some(inner) => inner.is_active(),
275 None => false,
276 },
277 Err(_) => false, }
279 }
280
281 pub fn is_closed(&self) -> bool {
283 match self.inner.try_read() {
285 Ok(guard) => match &*guard {
286 Some(inner) => inner.is_closed(),
287 None => true,
288 },
289 Err(_) => true, }
291 }
292
293 pub fn initialize_instruments_cache(&mut self, instruments: Vec<InstrumentAny>) {
295 let mut instruments_cache: AHashMap<Ustr, InstrumentAny> = AHashMap::new();
296 for inst in instruments {
297 instruments_cache.insert(inst.symbol().inner(), inst.clone());
298 }
299
300 self.instruments_cache = Arc::new(instruments_cache)
301 }
302
303 pub async fn connect(&mut self) -> anyhow::Result<()> {
309 let (message_handler, reader) = channel_message_handler();
310
311 let config = WebSocketConfig {
312 url: self.url.clone(),
313 headers: vec![(USER_AGENT.to_string(), NAUTILUS_USER_AGENT.to_string())],
314 heartbeat: self.heartbeat,
315 heartbeat_msg: None,
316 message_handler: Some(message_handler),
317 ping_handler: None,
318 reconnect_timeout_ms: Some(5_000),
319 reconnect_delay_initial_ms: None, reconnect_delay_max_ms: None, reconnect_backoff_factor: None, reconnect_jitter_ms: None, };
324 let keyed_quotas = vec![
326 ("subscription".to_string(), *OKX_WS_QUOTA),
327 ("order".to_string(), *OKX_WS_ORDER_QUOTA),
328 ("cancel".to_string(), *OKX_WS_ORDER_QUOTA),
329 ("amend".to_string(), *OKX_WS_ORDER_QUOTA),
330 ];
331
332 let client = WebSocketClient::connect(
333 config,
334 None, keyed_quotas,
336 Some(*OKX_WS_QUOTA), )
338 .await?;
339
340 {
342 let mut inner_guard = self.inner.write().await;
343 *inner_guard = Some(client);
344 }
345
346 let account_id = self.account_id;
347 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<NautilusWsMessage>();
348
349 self.rx = Some(Arc::new(rx));
350 let signal = self.signal.clone();
351 let pending_place_requests = self.pending_place_requests.clone();
352 let pending_cancel_requests = self.pending_cancel_requests.clone();
353 let pending_amend_requests = self.pending_amend_requests.clone();
354 let auth_state = self.auth_state.clone();
355
356 let instruments_cache = self.instruments_cache.clone();
357 let inner_client = self.inner.clone();
358 let credential_clone = self.credential.clone();
359 let subscriptions_inst_type = self.subscriptions_inst_type.clone();
360 let subscriptions_inst_family = self.subscriptions_inst_family.clone();
361 let subscriptions_inst_id = self.subscriptions_inst_id.clone();
362 let subscriptions_bare = self.subscriptions_bare.clone();
363 let auth_state_clone = auth_state.clone();
364 let stream_handle = get_runtime().spawn(async move {
365 let mut handler = OKXWsMessageHandler::new(
366 account_id,
367 instruments_cache,
368 reader,
369 signal,
370 tx,
371 pending_place_requests,
372 pending_cancel_requests,
373 pending_amend_requests,
374 auth_state,
375 );
376
377 loop {
379 match handler.next().await {
380 Some(NautilusWsMessage::Reconnected) => {
381 tracing::info!("Handling WebSocket reconnection");
382
383 let inner_guard = inner_client.read().await;
385 if let Some(cred) = &credential_clone
386 && let Some(client) = &*inner_guard {
387 let timestamp = SystemTime::now()
388 .duration_since(SystemTime::UNIX_EPOCH)
389 .expect("System time should be after UNIX epoch")
390 .as_secs()
391 .to_string();
392 let signature = cred.sign(×tamp, "GET", "/users/self/verify", "");
393
394 let auth_message = OKXAuthentication {
395 op: "login",
396 args: vec![OKXAuthenticationArg {
397 api_key: cred.api_key.to_string(),
398 passphrase: cred.api_passphrase.clone(),
399 timestamp,
400 sign: signature,
401 }],
402 };
403
404 if let Err(e) = client.send_text(serde_json::to_string(&auth_message).unwrap(), None).await {
405 tracing::error!("Failed to send re-authentication request: {e}");
406 } else {
408 tracing::info!("Sent re-authentication request, waiting for response before resubscribing");
409
410 let mut auth_rx = auth_state_clone.subscribe();
412 match tokio::time::timeout(Duration::from_secs(5), auth_rx.wait_for(|&auth| auth)).await {
413 Ok(Ok(_)) => {
414 tracing::info!("Authentication successful after reconnect, proceeding with resubscription");
415 }
418 Ok(Err(e)) => {
419 tracing::error!("Auth watch channel error after reconnect: {e}");
420 }
422 Err(_) => {
423 tracing::error!("Timeout waiting for authentication after reconnect");
424 }
426 }
427 }
428 }
429
430 let inner_guard = inner_client.read().await;
433 if let Some(client) = &*inner_guard {
434 let mut inst_type_args = Vec::new();
436 for entry in subscriptions_inst_type.iter() {
437 let (channel, inst_types) = entry.pair();
438 for inst_type in inst_types.iter() {
439 inst_type_args.push(OKXSubscriptionArg {
440 channel: channel.clone(),
441 inst_type: Some(*inst_type),
442 inst_family: None,
443 inst_id: None,
444 });
445 }
446 }
447 if !inst_type_args.is_empty() {
448 let sub_request = OKXSubscription {
449 op: OKXWsOperation::Subscribe,
450 args: inst_type_args,
451 };
452 if let Err(e) = client.send_text(serde_json::to_string(&sub_request).unwrap(), None).await {
453 tracing::error!("Failed to re-subscribe inst_type channels: {e}");
454 }
455 }
456
457 let mut inst_family_args = Vec::new();
459 for entry in subscriptions_inst_family.iter() {
460 let (channel, inst_families) = entry.pair();
461 for inst_family in inst_families.iter() {
462 inst_family_args.push(OKXSubscriptionArg {
463 channel: channel.clone(),
464 inst_type: None,
465 inst_family: Some(*inst_family),
466 inst_id: None,
467 });
468 }
469 }
470 if !inst_family_args.is_empty() {
471 let sub_request = OKXSubscription {
472 op: OKXWsOperation::Subscribe,
473 args: inst_family_args,
474 };
475 if let Err(e) = client.send_text(serde_json::to_string(&sub_request).unwrap(), None).await {
476 tracing::error!("Failed to re-subscribe inst_family channels: {e}");
477 }
478 }
479
480 let mut inst_id_args = Vec::new();
482 for entry in subscriptions_inst_id.iter() {
483 let (channel, inst_ids) = entry.pair();
484 for inst_id in inst_ids.iter() {
485 inst_id_args.push(OKXSubscriptionArg {
486 channel: channel.clone(),
487 inst_type: None,
488 inst_family: None,
489 inst_id: Some(*inst_id),
490 });
491 }
492 }
493 if !inst_id_args.is_empty() {
494 let sub_request = OKXSubscription {
495 op: OKXWsOperation::Subscribe,
496 args: inst_id_args,
497 };
498 if let Err(e) = client.send_text(serde_json::to_string(&sub_request).unwrap(), None).await {
499 tracing::error!("Failed to re-subscribe inst_id channels: {e}");
500 }
501 }
502
503 let mut bare_args = Vec::new();
505 for entry in subscriptions_bare.iter() {
506 let channel = entry.key();
507 bare_args.push(OKXSubscriptionArg {
508 channel: channel.clone(),
509 inst_type: None,
510 inst_family: None,
511 inst_id: None,
512 });
513 }
514 if !bare_args.is_empty() {
515 let sub_request = OKXSubscription {
516 op: OKXWsOperation::Subscribe,
517 args: bare_args,
518 };
519 if let Err(e) = client.send_text(serde_json::to_string(&sub_request).unwrap(), None).await {
520 tracing::error!("Failed to re-subscribe bare channels: {e}");
521 }
522 }
523
524 tracing::info!("Completed re-subscription after reconnect");
525 }
526 }
527 Some(msg) => {
528 if handler.tx.send(msg).is_err() {
530 tracing::error!("Failed to send message through channel: receiver dropped");
531 break;
532 }
533
534 }
535 None => {
536 if handler.is_stopped() {
538 tracing::debug!("Stop signal received, ending message processing");
539 break;
540 }
541 tracing::warn!("WebSocket stream ended unexpectedly");
543 break;
544 }
545 }
546 }
547 });
548
549 self.task_handle = Some(Arc::new(stream_handle));
550
551 if self.credential.is_some() {
552 if self.auth_state.send(false).is_err() {
553 tracing::error!("Failed to reset auth state, receiver dropped.");
554 };
555 self.authenticate().await?;
556 }
557
558 Ok(())
559 }
560
561 async fn authenticate(&self) -> Result<(), Error> {
563 let credential = match &self.credential {
564 Some(credential) => credential,
565 None => {
566 panic!("API credentials not available to authenticate");
567 }
568 };
569
570 let timestamp = SystemTime::now()
571 .duration_since(SystemTime::UNIX_EPOCH)
572 .expect("System time should be after UNIX epoch")
573 .as_secs()
574 .to_string();
575 let signature = credential.sign(×tamp, "GET", "/users/self/verify", "");
576
577 let auth_message = OKXAuthentication {
578 op: "login",
579 args: vec![OKXAuthenticationArg {
580 api_key: credential.api_key.to_string(),
581 passphrase: credential.api_passphrase.clone(),
582 timestamp,
583 sign: signature,
584 }],
585 };
586
587 {
588 let inner_guard = self.inner.read().await;
589 if let Some(inner) = &*inner_guard {
590 if let Err(e) = inner
591 .send_text(serde_json::to_string(&auth_message).unwrap(), None)
592 .await
593 {
594 tracing::error!("Error sending auth message: {e:?}");
595 return Err(Error::Io(std::io::Error::other(e.to_string())));
596 }
597 } else {
598 log::error!("Cannot authenticate: not connected");
599 return Err(Error::ConnectionClosed);
600 }
601 }
602
603 let mut rx = self.auth_state_rx.clone();
605 match tokio::time::timeout(Duration::from_secs(10), rx.wait_for(|&auth| auth)).await {
606 Ok(Ok(_)) => {
607 tracing::info!("Authentication confirmed by client");
608 Ok(())
609 }
610 Ok(Err(e)) => {
611 tracing::error!("Authentication watch channel closed unexpectedly: {e}");
612 Err(Error::Io(std::io::Error::other(
613 "Authentication watch channel closed",
614 )))
615 }
616 Err(_) => {
617 tracing::error!("Timeout waiting for authentication response");
618 Err(Error::Io(std::io::Error::other(
619 "Timeout waiting for authentication",
620 )))
621 }
622 }
623 }
624
625 pub fn stream(&mut self) -> impl Stream<Item = NautilusWsMessage> + 'static {
633 let rx = self
634 .rx
635 .take()
636 .expect("Data stream receiver already taken or not connected");
637 let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
638 async_stream::stream! {
639 while let Some(data) = rx.recv().await {
640 yield data;
641 }
642 }
643 }
644
645 pub async fn wait_until_active(&self, timeout_secs: f64) -> Result<(), OKXWsError> {
651 let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
652
653 tokio::time::timeout(timeout, async {
654 while !self.is_active() {
655 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
656 }
657 })
658 .await
659 .map_err(|_| {
660 OKXWsError::ClientError(format!(
661 "WebSocket connection timeout after {timeout_secs} seconds"
662 ))
663 })?;
664
665 Ok(())
666 }
667
668 pub async fn close(&mut self) -> Result<(), Error> {
670 log::debug!("Starting close process");
671
672 self.signal.store(true, Ordering::Relaxed);
673
674 {
675 let inner_guard = self.inner.read().await;
676 if let Some(inner) = &*inner_guard {
677 log::debug!("Disconnecting websocket");
678
679 match tokio::time::timeout(Duration::from_secs(3), inner.disconnect()).await {
680 Ok(()) => log::debug!("Websocket disconnected successfully"),
681 Err(_) => {
682 log::warn!(
683 "Timeout waiting for websocket disconnect, continuing with cleanup"
684 )
685 }
686 }
687 } else {
688 log::debug!("No active connection to disconnect");
689 }
690 }
691
692 if let Some(stream_handle) = self.task_handle.take() {
694 match Arc::try_unwrap(stream_handle) {
695 Ok(handle) => {
696 log::debug!("Waiting for stream handle to complete");
697 match tokio::time::timeout(Duration::from_secs(2), handle).await {
698 Ok(Ok(())) => log::debug!("Stream handle completed successfully"),
699 Ok(Err(e)) => log::error!("Stream handle encountered an error: {e:?}"),
700 Err(_) => {
701 log::warn!(
702 "Timeout waiting for stream handle, task may still be running"
703 );
704 }
706 }
707 }
708 Err(arc_handle) => {
709 log::debug!(
710 "Cannot take ownership of stream handle - other references exist, aborting task"
711 );
712 arc_handle.abort();
713 }
714 }
715 } else {
716 log::debug!("No stream handle to await");
717 }
718
719 log::debug!("Close process completed");
720
721 Ok(())
722 }
723
724 pub fn get_subscriptions(&self, instrument_id: InstrumentId) -> Vec<OKXWsChannel> {
726 let symbol = instrument_id.symbol.inner();
727 let mut channels = Vec::new();
728
729 for entry in self.subscriptions_inst_id.iter() {
730 let (channel, instruments) = entry.pair();
731 if instruments.contains(&symbol) {
732 channels.push(channel.clone());
733 }
734 }
735
736 channels
737 }
738
739 fn generate_unique_request_id(&self) -> String {
740 self.request_id_counter
741 .fetch_add(1, Ordering::SeqCst)
742 .to_string()
743 }
744
745 async fn subscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
746 for arg in &args {
747 if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
749 self.subscriptions_bare.insert(arg.channel.clone(), true);
751 } else {
752 if let Some(inst_type) = &arg.inst_type {
754 self.subscriptions_inst_type
755 .entry(arg.channel.clone())
756 .or_default()
757 .insert(*inst_type);
758 }
759
760 if let Some(inst_family) = &arg.inst_family {
762 self.subscriptions_inst_family
763 .entry(arg.channel.clone())
764 .or_default()
765 .insert(*inst_family);
766 }
767
768 if let Some(inst_id) = &arg.inst_id {
770 self.subscriptions_inst_id
771 .entry(arg.channel.clone())
772 .or_default()
773 .insert(*inst_id);
774 }
775 }
776 }
777
778 let message = OKXSubscription {
779 op: OKXWsOperation::Subscribe,
780 args,
781 };
782
783 let json_txt =
784 serde_json::to_string(&message).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
785
786 {
787 let inner_guard = self.inner.read().await;
788 if let Some(inner) = &*inner_guard {
789 if let Err(e) = inner
790 .send_text(json_txt, Some(vec!["subscription".to_string()]))
791 .await
792 {
793 tracing::error!("Error sending message: {e:?}")
794 }
795 } else {
796 return Err(OKXWsError::ClientError(
797 "Cannot send message: not connected".to_string(),
798 ));
799 }
800 }
801
802 Ok(())
803 }
804
805 #[allow(clippy::collapsible_if)]
806 async fn unsubscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
807 for arg in &args {
808 if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
810 self.subscriptions_bare.remove(&arg.channel);
812 } else {
813 if let Some(inst_type) = &arg.inst_type {
815 if let Some(mut entry) = self.subscriptions_inst_type.get_mut(&arg.channel) {
816 entry.remove(inst_type);
817 if entry.is_empty() {
818 drop(entry);
819 self.subscriptions_inst_type.remove(&arg.channel);
820 }
821 }
822 }
823
824 if let Some(inst_family) = &arg.inst_family {
826 if let Some(mut entry) = self.subscriptions_inst_family.get_mut(&arg.channel) {
827 entry.remove(inst_family);
828 if entry.is_empty() {
829 drop(entry);
830 self.subscriptions_inst_family.remove(&arg.channel);
831 }
832 }
833 }
834
835 if let Some(inst_id) = &arg.inst_id {
837 if let Some(mut entry) = self.subscriptions_inst_id.get_mut(&arg.channel) {
838 entry.remove(inst_id);
839 if entry.is_empty() {
840 drop(entry);
841 self.subscriptions_inst_id.remove(&arg.channel);
842 }
843 }
844 }
845 }
846 }
847
848 let message = OKXSubscription {
849 op: OKXWsOperation::Unsubscribe,
850 args,
851 };
852
853 let json_txt = serde_json::to_string(&message).expect("Must be valid JSON");
854
855 {
856 let inner_guard = self.inner.read().await;
857 if let Some(inner) = &*inner_guard {
858 if let Err(e) = inner
859 .send_text(json_txt, Some(vec!["subscription".to_string()]))
860 .await
861 {
862 tracing::error!("Error sending message: {e:?}")
863 }
864 } else {
865 log::error!("Cannot send message: not connected");
866 }
867 }
868
869 Ok(())
870 }
871
872 #[allow(dead_code)]
873 async fn resubscribe_all(&self) {
874 let mut subs_bare = Vec::new();
876 for entry in self.subscriptions_bare.iter() {
877 let channel = entry.key();
878 subs_bare.push(channel.clone());
879 }
880
881 let mut subs_inst_type = Vec::new();
882 for entry in self.subscriptions_inst_type.iter() {
883 let (channel, inst_types) = entry.pair();
884 if !inst_types.is_empty() {
885 subs_inst_type.push((channel.clone(), inst_types.clone()));
886 }
887 }
888
889 let mut subs_inst_family = Vec::new();
890 for entry in self.subscriptions_inst_family.iter() {
891 let (channel, inst_families) = entry.pair();
892 if !inst_families.is_empty() {
893 subs_inst_family.push((channel.clone(), inst_families.clone()));
894 }
895 }
896
897 let mut subs_inst_id = Vec::new();
898 for entry in self.subscriptions_inst_id.iter() {
899 let (channel, inst_ids) = entry.pair();
900 if !inst_ids.is_empty() {
901 subs_inst_id.push((channel.clone(), inst_ids.clone()));
902 }
903 }
904
905 for (channel, inst_types) in subs_inst_type {
907 if inst_types.is_empty() {
908 continue;
909 }
910
911 tracing::debug!("Resubscribing: channel={channel}, instrument_types={inst_types:?}");
912
913 for inst_type in inst_types {
914 let arg = OKXSubscriptionArg {
915 channel: channel.clone(),
916 inst_type: Some(inst_type),
917 inst_family: None,
918 inst_id: None,
919 };
920
921 if let Err(e) = self.subscribe(vec![arg]).await {
922 tracing::error!(
923 "Failed to resubscribe to channel {channel} with instrument type: {e}"
924 );
925 }
926 }
927 }
928
929 for (channel, inst_families) in subs_inst_family {
931 if inst_families.is_empty() {
932 continue;
933 }
934
935 tracing::debug!(
936 "Resubscribing: channel={channel}, instrument_families={inst_families:?}"
937 );
938
939 for inst_family in inst_families {
940 let arg = OKXSubscriptionArg {
941 channel: channel.clone(),
942 inst_type: None,
943 inst_family: Some(inst_family),
944 inst_id: None,
945 };
946
947 if let Err(e) = self.subscribe(vec![arg]).await {
948 tracing::error!(
949 "Failed to resubscribe to channel {channel} with instrument family: {e}"
950 );
951 }
952 }
953 }
954
955 for (channel, inst_ids) in subs_inst_id {
957 if inst_ids.is_empty() {
958 continue;
959 }
960
961 tracing::debug!("Resubscribing: channel={channel}, instrument_ids={inst_ids:?}");
962
963 for inst_id in inst_ids {
964 let arg = OKXSubscriptionArg {
965 channel: channel.clone(),
966 inst_type: None,
967 inst_family: None,
968 inst_id: Some(inst_id),
969 };
970
971 if let Err(e) = self.subscribe(vec![arg]).await {
972 tracing::error!(
973 "Failed to resubscribe to channel {channel} with instrument ID: {e}"
974 );
975 }
976 }
977 }
978
979 for channel in subs_bare {
981 tracing::debug!("Resubscribing to bare channel: {channel}");
982
983 let arg = OKXSubscriptionArg {
984 channel,
985 inst_type: None,
986 inst_family: None,
987 inst_id: None,
988 };
989
990 if let Err(e) = self.subscribe(vec![arg]).await {
991 tracing::error!("Failed to resubscribe to bare channel: {e}");
992 }
993 }
994 }
995
996 pub async fn subscribe_instruments(
1004 &self,
1005 instrument_type: OKXInstrumentType,
1006 ) -> Result<(), OKXWsError> {
1007 let arg = OKXSubscriptionArg {
1008 channel: OKXWsChannel::Instruments,
1009 inst_type: Some(instrument_type),
1010 inst_family: None,
1011 inst_id: None,
1012 };
1013 self.subscribe(vec![arg]).await
1014 }
1015
1016 pub async fn subscribe_instrument(
1024 &self,
1025 instrument_id: InstrumentId,
1026 ) -> Result<(), OKXWsError> {
1027 let arg = OKXSubscriptionArg {
1028 channel: OKXWsChannel::Instruments,
1029 inst_type: None,
1030 inst_family: None,
1031 inst_id: Some(instrument_id.symbol.inner()),
1032 };
1033 self.subscribe(vec![arg]).await
1034 }
1035
1036 pub async fn subscribe_book(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1042 let arg = OKXSubscriptionArg {
1043 channel: OKXWsChannel::Books,
1044 inst_type: None,
1045 inst_family: None,
1046 inst_id: Some(instrument_id.symbol.inner()),
1047 };
1048 self.subscribe(vec![arg]).await
1049 }
1050
1051 pub async fn subscribe_book_depth5(
1059 &self,
1060 instrument_id: InstrumentId,
1061 ) -> Result<(), OKXWsError> {
1062 let arg = OKXSubscriptionArg {
1063 channel: OKXWsChannel::Books5,
1064 inst_type: None,
1065 inst_family: None,
1066 inst_id: Some(instrument_id.symbol.inner()),
1067 };
1068 self.subscribe(vec![arg]).await
1069 }
1070
1071 pub async fn subscribe_books50_l2_tbt(
1079 &self,
1080 instrument_id: InstrumentId,
1081 ) -> Result<(), OKXWsError> {
1082 let arg = OKXSubscriptionArg {
1083 channel: OKXWsChannel::Books50Tbt,
1084 inst_type: None,
1085 inst_family: None,
1086 inst_id: Some(instrument_id.symbol.inner()),
1087 };
1088 self.subscribe(vec![arg]).await
1089 }
1090
1091 pub async fn subscribe_book_l2_tbt(
1099 &self,
1100 instrument_id: InstrumentId,
1101 ) -> Result<(), OKXWsError> {
1102 let arg = OKXSubscriptionArg {
1103 channel: OKXWsChannel::BooksTbt,
1104 inst_type: None,
1105 inst_family: None,
1106 inst_id: Some(instrument_id.symbol.inner()),
1107 };
1108 self.subscribe(vec![arg]).await
1109 }
1110
1111 pub async fn subscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1119 let arg = OKXSubscriptionArg {
1121 channel: OKXWsChannel::BboTbt,
1122 inst_type: None,
1123 inst_family: None,
1124 inst_id: Some(instrument_id.symbol.inner()),
1125 };
1126 self.subscribe(vec![arg]).await
1127 }
1128
1129 pub async fn subscribe_trades(
1135 &self,
1136 instrument_id: InstrumentId,
1137 _aggregated: bool, ) -> Result<(), OKXWsError> {
1139 let channel = OKXWsChannel::Trades;
1144
1145 let arg = OKXSubscriptionArg {
1146 channel,
1147 inst_type: None,
1148 inst_family: None,
1149 inst_id: Some(instrument_id.symbol.inner()),
1150 };
1151 self.subscribe(vec![arg]).await
1152 }
1153
1154 pub async fn subscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1162 let arg = OKXSubscriptionArg {
1163 channel: OKXWsChannel::Tickers,
1164 inst_type: None,
1165 inst_family: None,
1166 inst_id: Some(instrument_id.symbol.inner()),
1167 };
1168 self.subscribe(vec![arg]).await
1169 }
1170
1171 pub async fn subscribe_mark_prices(
1179 &self,
1180 instrument_id: InstrumentId,
1181 ) -> Result<(), OKXWsError> {
1182 let arg = OKXSubscriptionArg {
1183 channel: OKXWsChannel::MarkPrice,
1184 inst_type: None,
1185 inst_family: None,
1186 inst_id: Some(instrument_id.symbol.inner()),
1187 };
1188 self.subscribe(vec![arg]).await
1189 }
1190
1191 pub async fn subscribe_index_prices(
1199 &self,
1200 instrument_id: InstrumentId,
1201 ) -> Result<(), OKXWsError> {
1202 let arg = OKXSubscriptionArg {
1203 channel: OKXWsChannel::IndexTickers,
1204 inst_type: None,
1205 inst_family: None,
1206 inst_id: Some(instrument_id.symbol.inner()),
1207 };
1208 self.subscribe(vec![arg]).await
1209 }
1210
1211 pub async fn subscribe_funding_rates(
1219 &self,
1220 instrument_id: InstrumentId,
1221 ) -> Result<(), OKXWsError> {
1222 let arg = OKXSubscriptionArg {
1223 channel: OKXWsChannel::FundingRate,
1224 inst_type: None,
1225 inst_family: None,
1226 inst_id: Some(instrument_id.symbol.inner()),
1227 };
1228 self.subscribe(vec![arg]).await
1229 }
1230
1231 pub async fn subscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1239 let channel = bar_spec_as_okx_channel(bar_type.spec())
1241 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1242
1243 let arg = OKXSubscriptionArg {
1244 channel,
1245 inst_type: None,
1246 inst_family: None,
1247 inst_id: Some(bar_type.instrument_id().symbol.inner()),
1248 };
1249 self.subscribe(vec![arg]).await
1250 }
1251
1252 pub async fn unsubscribe_instruments(
1254 &self,
1255 instrument_type: OKXInstrumentType,
1256 ) -> Result<(), OKXWsError> {
1257 let arg = OKXSubscriptionArg {
1258 channel: OKXWsChannel::Instruments,
1259 inst_type: Some(instrument_type),
1260 inst_family: None,
1261 inst_id: None,
1262 };
1263 self.unsubscribe(vec![arg]).await
1264 }
1265
1266 pub async fn unsubscribe_instrument(
1268 &self,
1269 instrument_id: InstrumentId,
1270 ) -> Result<(), OKXWsError> {
1271 let arg = OKXSubscriptionArg {
1272 channel: OKXWsChannel::Instruments,
1273 inst_type: None,
1274 inst_family: None,
1275 inst_id: Some(instrument_id.symbol.inner()),
1276 };
1277 self.unsubscribe(vec![arg]).await
1278 }
1279
1280 pub async fn unsubscribe_book(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1282 let arg = OKXSubscriptionArg {
1283 channel: OKXWsChannel::Books,
1284 inst_type: None,
1285 inst_family: None,
1286 inst_id: Some(instrument_id.symbol.inner()),
1287 };
1288 self.unsubscribe(vec![arg]).await
1289 }
1290
1291 pub async fn unsubscribe_book_depth5(
1293 &self,
1294 instrument_id: InstrumentId,
1295 ) -> Result<(), OKXWsError> {
1296 let arg = OKXSubscriptionArg {
1297 channel: OKXWsChannel::Books5,
1298 inst_type: None,
1299 inst_family: None,
1300 inst_id: Some(instrument_id.symbol.inner()),
1301 };
1302 self.unsubscribe(vec![arg]).await
1303 }
1304
1305 pub async fn unsubscribe_book50_l2_tbt(
1307 &self,
1308 instrument_id: InstrumentId,
1309 ) -> Result<(), OKXWsError> {
1310 let arg = OKXSubscriptionArg {
1311 channel: OKXWsChannel::Books50Tbt,
1312 inst_type: None,
1313 inst_family: None,
1314 inst_id: Some(instrument_id.symbol.inner()),
1315 };
1316 self.unsubscribe(vec![arg]).await
1317 }
1318
1319 pub async fn unsubscribe_book_l2_tbt(
1321 &self,
1322 instrument_id: InstrumentId,
1323 ) -> Result<(), OKXWsError> {
1324 let arg = OKXSubscriptionArg {
1325 channel: OKXWsChannel::BooksTbt,
1326 inst_type: None,
1327 inst_family: None,
1328 inst_id: Some(instrument_id.symbol.inner()),
1329 };
1330 self.unsubscribe(vec![arg]).await
1331 }
1332
1333 pub async fn unsubscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1335 let arg = OKXSubscriptionArg {
1336 channel: OKXWsChannel::BboTbt,
1337 inst_type: None,
1338 inst_family: None,
1339 inst_id: Some(instrument_id.symbol.inner()),
1340 };
1341 self.unsubscribe(vec![arg]).await
1342 }
1343
1344 pub async fn unsubscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1346 let arg = OKXSubscriptionArg {
1347 channel: OKXWsChannel::Tickers,
1348 inst_type: None,
1349 inst_family: None,
1350 inst_id: Some(instrument_id.symbol.inner()),
1351 };
1352 self.unsubscribe(vec![arg]).await
1353 }
1354
1355 pub async fn unsubscribe_mark_prices(
1357 &self,
1358 instrument_id: InstrumentId,
1359 ) -> Result<(), OKXWsError> {
1360 let arg = OKXSubscriptionArg {
1361 channel: OKXWsChannel::MarkPrice,
1362 inst_type: None,
1363 inst_family: None,
1364 inst_id: Some(instrument_id.symbol.inner()),
1365 };
1366 self.unsubscribe(vec![arg]).await
1367 }
1368
1369 pub async fn unsubscribe_index_prices(
1371 &self,
1372 instrument_id: InstrumentId,
1373 ) -> Result<(), OKXWsError> {
1374 let arg = OKXSubscriptionArg {
1375 channel: OKXWsChannel::IndexTickers,
1376 inst_type: None,
1377 inst_family: None,
1378 inst_id: Some(instrument_id.symbol.inner()),
1379 };
1380 self.unsubscribe(vec![arg]).await
1381 }
1382
1383 pub async fn unsubscribe_funding_rates(
1385 &self,
1386 instrument_id: InstrumentId,
1387 ) -> Result<(), OKXWsError> {
1388 let arg = OKXSubscriptionArg {
1389 channel: OKXWsChannel::FundingRate,
1390 inst_type: None,
1391 inst_family: None,
1392 inst_id: Some(instrument_id.symbol.inner()),
1393 };
1394 self.unsubscribe(vec![arg]).await
1395 }
1396
1397 pub async fn unsubscribe_trades(
1399 &self,
1400 instrument_id: InstrumentId,
1401 _aggregated: bool,
1402 ) -> Result<(), OKXWsError> {
1403 let channel = OKXWsChannel::Trades;
1405
1406 let arg = OKXSubscriptionArg {
1407 channel,
1408 inst_type: None,
1409 inst_family: None,
1410 inst_id: Some(instrument_id.symbol.inner()),
1411 };
1412 self.unsubscribe(vec![arg]).await
1413 }
1414
1415 pub async fn unsubscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1417 let channel = bar_spec_as_okx_channel(bar_type.spec())
1419 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1420
1421 let arg = OKXSubscriptionArg {
1422 channel,
1423 inst_type: None,
1424 inst_family: None,
1425 inst_id: Some(bar_type.instrument_id().symbol.inner()),
1426 };
1427 self.unsubscribe(vec![arg]).await
1428 }
1429
1430 pub async fn subscribe_orders(
1432 &self,
1433 instrument_type: OKXInstrumentType,
1434 ) -> Result<(), OKXWsError> {
1435 let arg = OKXSubscriptionArg {
1436 channel: OKXWsChannel::Orders,
1437 inst_type: Some(instrument_type),
1438 inst_family: None,
1439 inst_id: None,
1440 };
1441 self.subscribe(vec![arg]).await
1442 }
1443
1444 pub async fn unsubscribe_orders(
1446 &self,
1447 instrument_type: OKXInstrumentType,
1448 ) -> Result<(), OKXWsError> {
1449 let arg = OKXSubscriptionArg {
1450 channel: OKXWsChannel::Orders,
1451 inst_type: Some(instrument_type),
1452 inst_family: None,
1453 inst_id: None,
1454 };
1455 self.unsubscribe(vec![arg]).await
1456 }
1457
1458 pub async fn subscribe_fills(
1460 &self,
1461 instrument_type: OKXInstrumentType,
1462 ) -> Result<(), OKXWsError> {
1463 let arg = OKXSubscriptionArg {
1464 channel: OKXWsChannel::Fills,
1465 inst_type: Some(instrument_type),
1466 inst_family: None,
1467 inst_id: None,
1468 };
1469 self.subscribe(vec![arg]).await
1470 }
1471
1472 pub async fn unsubscribe_fills(
1474 &self,
1475 instrument_type: OKXInstrumentType,
1476 ) -> Result<(), OKXWsError> {
1477 let arg = OKXSubscriptionArg {
1478 channel: OKXWsChannel::Fills,
1479 inst_type: Some(instrument_type),
1480 inst_family: None,
1481 inst_id: None,
1482 };
1483 self.unsubscribe(vec![arg]).await
1484 }
1485
1486 pub async fn subscribe_account(&self) -> Result<(), OKXWsError> {
1488 let arg = OKXSubscriptionArg {
1489 channel: OKXWsChannel::Account,
1490 inst_type: None,
1491 inst_family: None,
1492 inst_id: None,
1493 };
1494 self.subscribe(vec![arg]).await
1495 }
1496
1497 pub async fn unsubscribe_account(&self) -> Result<(), OKXWsError> {
1499 let arg = OKXSubscriptionArg {
1500 channel: OKXWsChannel::Account,
1501 inst_type: None,
1502 inst_family: None,
1503 inst_id: None,
1504 };
1505 self.unsubscribe(vec![arg]).await
1506 }
1507
1508 async fn ws_cancel_order(
1514 &self,
1515 params: WsCancelOrderParams,
1516 request_id: Option<String>,
1517 ) -> Result<(), OKXWsError> {
1518 let request_id = request_id.unwrap_or(self.generate_unique_request_id());
1519
1520 let req = OKXWsRequest {
1521 id: Some(request_id),
1522 op: OKXWsOperation::CancelOrder,
1523 args: vec![params],
1524 exp_time: None,
1525 };
1526
1527 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1528
1529 {
1530 let inner_guard = self.inner.read().await;
1531 if let Some(inner) = &*inner_guard {
1532 if let Err(e) = inner.send_text(txt, Some(vec!["cancel".to_string()])).await {
1533 tracing::error!("Error sending message: {e:?}");
1534 }
1535 Ok(())
1536 } else {
1537 Err(OKXWsError::ClientError("Not connected".to_string()))
1538 }
1539 }
1540 }
1541
1542 #[allow(dead_code)] async fn ws_mass_cancel(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1549 let request_id = self
1551 .request_id_counter
1552 .fetch_add(1, Ordering::SeqCst)
1553 .to_string();
1554
1555 let req = OKXWsRequest {
1556 id: Some(request_id),
1557 op: OKXWsOperation::MassCancel,
1558 args,
1559 exp_time: None,
1560 };
1561
1562 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1563
1564 {
1565 let inner_guard = self.inner.read().await;
1566 if let Some(inner) = &*inner_guard {
1567 if let Err(e) = inner.send_text(txt, Some(vec!["cancel".to_string()])).await {
1568 tracing::error!("Error sending message: {e:?}");
1569 }
1570 Ok(())
1571 } else {
1572 Err(OKXWsError::ClientError("Not connected".to_string()))
1573 }
1574 }
1575 }
1576
1577 async fn ws_amend_order(
1583 &self,
1584 params: WsAmendOrderParams,
1585 request_id: Option<String>,
1586 ) -> Result<(), OKXWsError> {
1587 let request_id = request_id.unwrap_or(self.generate_unique_request_id());
1588
1589 let req = OKXWsRequest {
1590 id: Some(request_id),
1591 op: OKXWsOperation::AmendOrder,
1592 args: vec![params],
1593 exp_time: None,
1594 };
1595
1596 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1597
1598 {
1599 let inner_guard = self.inner.read().await;
1600 if let Some(inner) = &*inner_guard {
1601 if let Err(e) = inner.send_text(txt, Some(vec!["amend".to_string()])).await {
1602 tracing::error!("Error sending message: {e:?}");
1603 }
1604 Ok(())
1605 } else {
1606 Err(OKXWsError::ClientError("Not connected".to_string()))
1607 }
1608 }
1609 }
1610
1611 async fn ws_batch_place_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1617 let request_id = self.generate_unique_request_id();
1618
1619 let req = OKXWsRequest {
1620 id: Some(request_id),
1621 op: OKXWsOperation::BatchOrders,
1622 args,
1623 exp_time: None,
1624 };
1625
1626 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1627
1628 {
1629 let inner_guard = self.inner.read().await;
1630 if let Some(inner) = &*inner_guard {
1631 if let Err(e) = inner.send_text(txt, Some(vec!["order".to_string()])).await {
1632 tracing::error!("Error sending message: {e:?}");
1633 }
1634 Ok(())
1635 } else {
1636 Err(OKXWsError::ClientError("Not connected".to_string()))
1637 }
1638 }
1639 }
1640
1641 async fn ws_batch_cancel_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1647 let request_id = self.generate_unique_request_id();
1648
1649 let req = OKXWsRequest {
1650 id: Some(request_id),
1651 op: OKXWsOperation::BatchCancelOrders,
1652 args,
1653 exp_time: None,
1654 };
1655
1656 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1657
1658 {
1659 let inner_guard = self.inner.read().await;
1660 if let Some(inner) = &*inner_guard {
1661 if let Err(e) = inner.send_text(txt, Some(vec!["cancel".to_string()])).await {
1662 tracing::error!("Error sending message: {e:?}");
1663 }
1664 Ok(())
1665 } else {
1666 Err(OKXWsError::ClientError("Not connected".to_string()))
1667 }
1668 }
1669 }
1670
1671 async fn ws_batch_amend_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1677 let request_id = self.generate_unique_request_id();
1678
1679 let req = OKXWsRequest {
1680 id: Some(request_id),
1681 op: OKXWsOperation::BatchAmendOrders,
1682 args,
1683 exp_time: None,
1684 };
1685
1686 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1687
1688 {
1689 let inner_guard = self.inner.read().await;
1690 if let Some(inner) = &*inner_guard {
1691 if let Err(e) = inner.send_text(txt, Some(vec!["amend".to_string()])).await {
1692 tracing::error!("Error sending message: {e:?}");
1693 }
1694 Ok(())
1695 } else {
1696 Err(OKXWsError::ClientError("Not connected".to_string()))
1697 }
1698 }
1699 }
1700
1701 #[allow(clippy::too_many_arguments)]
1707 pub async fn submit_order(
1708 &self,
1709 trader_id: TraderId,
1710 strategy_id: StrategyId,
1711 instrument_id: InstrumentId,
1712 td_mode: OKXTradeMode,
1713 client_order_id: ClientOrderId,
1714 order_side: OrderSide,
1715 order_type: OrderType,
1716 quantity: Quantity,
1717 time_in_force: Option<TimeInForce>,
1718 price: Option<Price>,
1719 trigger_price: Option<Price>,
1720 post_only: Option<bool>,
1721 reduce_only: Option<bool>,
1722 quote_quantity: Option<bool>,
1723 position_side: Option<PositionSide>,
1724 ) -> Result<(), OKXWsError> {
1725 if !OKX_SUPPORTED_ORDER_TYPES.contains(&order_type) {
1726 return Err(OKXWsError::ClientError(format!(
1727 "Unsupported order type: {order_type:?}",
1728 )));
1729 }
1730
1731 if let Some(tif) = time_in_force
1732 && !OKX_SUPPORTED_TIME_IN_FORCE.contains(&tif)
1733 {
1734 return Err(OKXWsError::ClientError(format!(
1735 "Unsupported time in force: {tif:?}",
1736 )));
1737 }
1738
1739 let mut builder = WsPostOrderParamsBuilder::default();
1740
1741 builder.inst_id(instrument_id.symbol.as_str());
1742 builder.td_mode(td_mode);
1743 builder.cl_ord_id(client_order_id.as_str());
1744
1745 let instrument = self
1746 .instruments_cache
1747 .get(&instrument_id.symbol.inner())
1748 .ok_or_else(|| {
1749 OKXWsError::ClientError(format!("Unknown instrument {instrument_id}"))
1750 })?;
1751
1752 let instrument_type =
1753 okx_instrument_type(instrument).map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1754 let quote_currency = instrument.quote_currency();
1755
1756 match instrument_type {
1757 OKXInstrumentType::Spot => {
1758 }
1760 OKXInstrumentType::Margin => {
1761 builder.ccy(quote_currency.to_string());
1763
1764 if let Some(ro) = reduce_only
1766 && ro
1767 {
1768 builder.reduce_only(ro);
1769 }
1770 }
1771 OKXInstrumentType::Swap | OKXInstrumentType::Futures => {
1772 builder.ccy(quote_currency.to_string());
1774 }
1775 _ => {
1776 builder.ccy(quote_currency.to_string());
1778 builder.tgt_ccy(quote_currency.to_string());
1779
1780 if let Some(ro) = reduce_only
1782 && ro
1783 {
1784 builder.reduce_only(ro);
1785 }
1786 }
1787 };
1788
1789 if let Some(is_quote_quantity) = quote_quantity
1790 && is_quote_quantity
1791 {
1792 builder.tgt_ccy(quote_currency.to_string());
1793 }
1794 builder.side(OKXSide::from(order_side));
1797
1798 if let Some(pos_side) = position_side {
1799 builder.pos_side(pos_side);
1800 };
1801
1802 let okx_ord_type = if post_only.unwrap_or(false) {
1804 OKXOrderType::PostOnly
1805 } else {
1806 OKXOrderType::from(order_type)
1807 };
1808
1809 log::debug!(
1810 "Order type mapping: order_type={:?}, time_in_force={:?}, post_only={:?} -> okx_ord_type={:?}",
1811 order_type,
1812 time_in_force,
1813 post_only,
1814 okx_ord_type
1815 );
1816
1817 builder.ord_type(okx_ord_type);
1818 builder.sz(quantity.to_string());
1819
1820 if let Some(tp) = trigger_price {
1821 builder.px(tp.to_string());
1822 } else if let Some(p) = price {
1823 builder.px(p.to_string());
1824 }
1825
1826 builder.tag(OKX_NAUTILUS_BROKER_ID);
1827
1828 let params = builder
1829 .build()
1830 .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
1831
1832 log::debug!("Sending order params to OKX: {:?}", params);
1834
1835 let request_id = self.generate_unique_request_id();
1836
1837 self.pending_place_requests.insert(
1838 request_id.clone(),
1839 (client_order_id, trader_id, strategy_id, instrument_id),
1840 );
1841
1842 self.ws_place_order(params, Some(request_id)).await
1843 }
1844
1845 #[allow(clippy::too_many_arguments)]
1851 pub async fn cancel_order(
1852 &self,
1853 trader_id: TraderId,
1854 strategy_id: StrategyId,
1855 instrument_id: InstrumentId,
1856 client_order_id: Option<ClientOrderId>,
1857 venue_order_id: Option<VenueOrderId>,
1858 ) -> Result<(), OKXWsError> {
1859 let mut builder = WsCancelOrderParamsBuilder::default();
1860 builder.inst_id(instrument_id.symbol.as_str());
1863
1864 if let Some(venue_order_id) = venue_order_id {
1865 builder.ord_id(venue_order_id.as_str());
1866 }
1867
1868 let params = builder
1869 .build()
1870 .map_err(|e| OKXWsError::ClientError(format!("Build cancel params error: {e}")))?;
1871
1872 let request_id = self.generate_unique_request_id();
1873
1874 if let Some(client_order_id) = client_order_id {
1877 builder.cl_ord_id(client_order_id.as_str());
1878
1879 self.pending_cancel_requests.insert(
1880 request_id.clone(),
1881 (
1882 client_order_id,
1883 trader_id,
1884 strategy_id,
1885 instrument_id,
1886 venue_order_id,
1887 ),
1888 );
1889 }
1890
1891 self.ws_cancel_order(params, Some(request_id)).await
1892 }
1893
1894 async fn ws_place_order(
1900 &self,
1901 params: WsPostOrderParams,
1902 request_id: Option<String>,
1903 ) -> Result<(), OKXWsError> {
1904 let request_id = request_id.unwrap_or(self.generate_unique_request_id());
1905
1906 let req = OKXWsRequest {
1907 id: Some(request_id),
1908 op: OKXWsOperation::Order,
1909 exp_time: None,
1910 args: vec![params],
1911 };
1912
1913 let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1914
1915 {
1916 let inner_guard = self.inner.read().await;
1917 if let Some(inner) = &*inner_guard {
1918 if let Err(e) = inner.send_text(txt, Some(vec!["order".to_string()])).await {
1919 tracing::error!("Error sending message: {e:?}");
1920 }
1921 Ok(())
1922 } else {
1923 Err(OKXWsError::ClientError("Not connected".to_string()))
1924 }
1925 }
1926 }
1927
1928 #[allow(clippy::too_many_arguments)]
1934 pub async fn modify_order(
1935 &self,
1936 trader_id: TraderId,
1937 strategy_id: StrategyId,
1938 instrument_id: InstrumentId,
1939 client_order_id: Option<ClientOrderId>,
1940 price: Option<Price>,
1941 quantity: Option<Quantity>,
1942 venue_order_id: Option<VenueOrderId>,
1943 ) -> Result<(), OKXWsError> {
1944 let mut builder = WsAmendOrderParamsBuilder::default();
1945
1946 builder.inst_id(instrument_id.symbol.as_str());
1947
1948 if let Some(venue_order_id) = venue_order_id {
1949 builder.ord_id(venue_order_id.as_str());
1950 }
1951
1952 if let Some(client_order_id) = client_order_id {
1953 builder.cl_ord_id(client_order_id.as_str());
1954 }
1955
1956 if let Some(price) = price {
1957 builder.new_px(price.to_string());
1958 }
1959
1960 if let Some(quantity) = quantity {
1961 builder.new_sz(quantity.to_string());
1962 }
1963
1964 let params = builder
1965 .build()
1966 .map_err(|e| OKXWsError::ClientError(format!("Build amend params error: {e}")))?;
1967
1968 let request_id = self
1970 .request_id_counter
1971 .fetch_add(1, Ordering::SeqCst)
1972 .to_string();
1973
1974 if let Some(client_order_id) = client_order_id {
1977 self.pending_amend_requests.insert(
1978 request_id.clone(),
1979 (
1980 client_order_id,
1981 trader_id,
1982 strategy_id,
1983 instrument_id,
1984 venue_order_id,
1985 ),
1986 );
1987 }
1988
1989 self.ws_amend_order(params, Some(request_id)).await
1990 }
1991
1992 #[allow(clippy::type_complexity)]
1994 #[allow(clippy::too_many_arguments)]
1995 pub async fn batch_submit_orders(
1996 &self,
1997 orders: Vec<(
1998 OKXInstrumentType,
1999 InstrumentId,
2000 OKXTradeMode,
2001 ClientOrderId,
2002 OrderSide,
2003 Option<PositionSide>,
2004 OrderType,
2005 Quantity,
2006 Option<Price>,
2007 Option<Price>,
2008 Option<bool>,
2009 Option<bool>,
2010 )>,
2011 ) -> Result<(), OKXWsError> {
2012 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2013 for (
2014 inst_type,
2015 inst_id,
2016 td_mode,
2017 cl_ord_id,
2018 ord_side,
2019 pos_side,
2020 ord_type,
2021 qty,
2022 pr,
2023 tp,
2024 post_only,
2025 reduce_only,
2026 ) in orders
2027 {
2028 let mut builder = WsPostOrderParamsBuilder::default();
2029 builder.inst_type(inst_type);
2030 builder.inst_id(inst_id.symbol.inner());
2031 builder.td_mode(td_mode);
2032 builder.cl_ord_id(cl_ord_id.as_str());
2033 builder.side(OKXSide::from(ord_side));
2034
2035 if let Some(ps) = pos_side {
2036 builder.pos_side(OKXPositionSide::from(ps));
2037 }
2038
2039 let okx_ord_type = if post_only.unwrap_or(false) {
2040 OKXOrderType::PostOnly
2041 } else {
2042 OKXOrderType::from(ord_type)
2043 };
2044
2045 builder.ord_type(okx_ord_type);
2046 builder.sz(qty.to_string());
2047
2048 if let Some(p) = pr {
2049 builder.px(p.to_string());
2050 } else if let Some(p) = tp {
2051 builder.px(p.to_string());
2052 }
2053
2054 if let Some(ro) = reduce_only {
2055 builder.reduce_only(ro);
2056 }
2057
2058 builder.tag(OKX_NAUTILUS_BROKER_ID);
2059
2060 let params = builder
2061 .build()
2062 .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
2063 let val =
2064 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2065 args.push(val);
2066 }
2067
2068 self.ws_batch_place_orders(args).await
2069 }
2070
2071 #[allow(clippy::type_complexity)]
2073 pub async fn batch_cancel_orders(
2074 &self,
2075 orders: Vec<(
2076 OKXInstrumentType,
2077 InstrumentId,
2078 Option<ClientOrderId>,
2079 Option<String>,
2080 )>,
2081 ) -> Result<(), OKXWsError> {
2082 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2083 for (_inst_type, inst_id, cl_ord_id, ord_id) in orders {
2084 let mut builder = WsCancelOrderParamsBuilder::default();
2085 builder.inst_id(inst_id.symbol.inner());
2087
2088 if let Some(c) = cl_ord_id {
2089 builder.cl_ord_id(c.as_str());
2090 }
2091
2092 if let Some(o) = ord_id {
2093 builder.ord_id(o);
2094 }
2095
2096 let params = builder.build().map_err(|e| {
2097 OKXWsError::ClientError(format!("Build cancel batch params error: {e}"))
2098 })?;
2099 let val =
2100 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2101 args.push(val);
2102 }
2103
2104 self.ws_batch_cancel_orders(args).await
2105 }
2106
2107 #[allow(clippy::type_complexity)]
2109 #[allow(clippy::too_many_arguments)]
2110 pub async fn batch_modify_orders(
2111 &self,
2112 orders: Vec<(
2113 OKXInstrumentType,
2114 InstrumentId,
2115 ClientOrderId,
2116 ClientOrderId,
2117 Option<Price>,
2118 Option<Quantity>,
2119 )>,
2120 ) -> Result<(), OKXWsError> {
2121 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2122 for (_inst_type, inst_id, cl_ord_id, new_cl_ord_id, pr, sz) in orders {
2123 let mut builder = WsAmendOrderParamsBuilder::default();
2124 builder.inst_id(inst_id.symbol.inner());
2126 builder.cl_ord_id(cl_ord_id.as_str());
2127 builder.new_cl_ord_id(new_cl_ord_id.as_str());
2128
2129 if let Some(p) = pr {
2130 builder.new_px(p.to_string());
2131 }
2132
2133 if let Some(q) = sz {
2134 builder.new_sz(q.to_string());
2135 }
2136
2137 let params = builder.build().map_err(|e| {
2138 OKXWsError::ClientError(format!("Build amend batch params error: {e}"))
2139 })?;
2140 let val =
2141 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2142 args.push(val);
2143 }
2144
2145 self.ws_batch_amend_orders(args).await
2146 }
2147}
2148
2149struct OKXFeedHandler {
2150 receiver: UnboundedReceiver<Message>,
2151 signal: Arc<AtomicBool>,
2152}
2153
2154impl OKXFeedHandler {
2155 pub fn new(receiver: UnboundedReceiver<Message>, signal: Arc<AtomicBool>) -> Self {
2157 Self { receiver, signal }
2158 }
2159
2160 async fn next(&mut self) -> Option<OKXWebSocketEvent> {
2162 loop {
2163 tokio::select! {
2164 msg = self.receiver.recv() => match msg {
2165 Some(msg) => match msg {
2166 Message::Text(text) => {
2167 if text == RECONNECTED {
2169 tracing::info!("Received WebSocket reconnection signal");
2170 return Some(OKXWebSocketEvent::Reconnected);
2171 }
2172 tracing::trace!("Received WebSocket message: {text}");
2173
2174 match serde_json::from_str(&text) {
2175 Ok(ws_event) => match &ws_event {
2176 OKXWebSocketEvent::Error { code, msg } => {
2177 tracing::error!("WebSocket error: {code} - {msg}");
2178 return Some(ws_event);
2179 }
2180 OKXWebSocketEvent::Login {
2181 event,
2182 code,
2183 msg,
2184 conn_id,
2185 } => {
2186 if code == "0" {
2187 tracing::info!(
2188 "Successfully authenticated with OKX WebSocket, conn_id={conn_id}"
2189 );
2190 } else {
2191 tracing::error!(
2192 "Authentication failed: {event} {code} - {msg}"
2193 );
2194 }
2195 return Some(ws_event);
2196 }
2197 OKXWebSocketEvent::Subscription {
2198 event,
2199 arg,
2200 conn_id,
2201 } => {
2202 let channel_str = serde_json::to_string(&arg.channel)
2203 .expect("Invalid OKX websocket channel")
2204 .trim_matches('"')
2205 .to_string();
2206 tracing::debug!(
2207 "{event}d: channel={channel_str}, conn_id={conn_id}"
2208 );
2209 continue;
2210 }
2211 OKXWebSocketEvent::ChannelConnCount {
2212 event: _,
2213 channel,
2214 conn_count,
2215 conn_id,
2216 } => {
2217 let channel_str = serde_json::to_string(&channel)
2218 .expect("Invalid OKX websocket channel")
2219 .trim_matches('"')
2220 .to_string();
2221 tracing::debug!(
2222 "Channel connection status: channel={channel_str}, connections={conn_count}, conn_id={conn_id}",
2223 );
2224 continue;
2225 }
2226 OKXWebSocketEvent::Data { .. } => return Some(ws_event),
2227 OKXWebSocketEvent::BookData { .. } => return Some(ws_event),
2228 OKXWebSocketEvent::OrderResponse {
2229 id,
2230 op,
2231 code,
2232 msg,
2233 data,
2234 } => {
2235 if code == "0" {
2236 tracing::debug!(
2237 "Order operation successful: id={:?}, op={op}, code={code}",
2238 id
2239 );
2240
2241 if let Some(order_data) = data.first() {
2243 let success_msg = order_data
2244 .get("sMsg")
2245 .and_then(|s| s.as_str())
2246 .unwrap_or("Order operation successful");
2247 tracing::debug!("Order success details: {success_msg}");
2248 }
2249 } else {
2250 let error_msg = data
2252 .first()
2253 .and_then(|d| d.get("sMsg"))
2254 .and_then(|s| s.as_str())
2255 .unwrap_or(msg.as_str());
2256 tracing::error!(
2257 "Order operation failed: id={id:?}, op={op}, code={code}, error={error_msg}",
2258 );
2259 }
2260 return Some(ws_event);
2261 }
2262 OKXWebSocketEvent::Reconnected => {
2263 tracing::warn!("Unexpected Reconnected event from deserialization");
2265 continue;
2266 }
2267 },
2268 Err(e) => {
2269 tracing::error!("Failed to parse message: {e}: {text}");
2270 return None;
2271 }
2272 }
2273 }
2274 Message::Binary(msg) => {
2275 tracing::debug!("Raw binary: {msg:?}");
2276 }
2277 Message::Close(_) => {
2278 tracing::debug!("Received close message");
2279 return None;
2280 }
2281 msg => {
2282 tracing::warn!("Unexpected message: {msg}");
2283 }
2284 }
2285 None => {
2286 tracing::info!("WebSocket stream closed");
2287 return None;
2288 }
2289 },
2290 _ = tokio::time::sleep(Duration::from_millis(1)) => {
2291 if self.signal.load(std::sync::atomic::Ordering::Relaxed) {
2292 tracing::debug!("Stop signal received");
2293 return None;
2294 }
2295 }
2296 }
2297 }
2298 }
2299}
2300
2301struct OKXWsMessageHandler {
2302 account_id: AccountId,
2303 handler: OKXFeedHandler,
2304 #[allow(dead_code)]
2305 tx: tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>,
2306 pending_place_requests: Arc<DashMap<String, PlaceRequestData>>,
2307 pending_cancel_requests: Arc<DashMap<String, CancelRequestData>>,
2308 pending_amend_requests: Arc<DashMap<String, AmendRequestData>>,
2309 instruments_cache: Arc<AHashMap<Ustr, InstrumentAny>>,
2310 last_account_state: Option<AccountState>,
2311 fee_cache: AHashMap<Ustr, Money>, funding_rate_cache: AHashMap<Ustr, (Ustr, u64)>, auth_state: Arc<tokio::sync::watch::Sender<bool>>,
2314}
2315
2316impl OKXWsMessageHandler {
2317 #[allow(clippy::too_many_arguments)]
2319 pub fn new(
2320 account_id: AccountId,
2321 instruments_cache: Arc<AHashMap<Ustr, InstrumentAny>>,
2322 reader: UnboundedReceiver<Message>,
2323 signal: Arc<AtomicBool>,
2324 tx: tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>,
2325 pending_place_requests: Arc<DashMap<String, PlaceRequestData>>,
2326 pending_cancel_requests: Arc<DashMap<String, CancelRequestData>>,
2327 pending_amend_requests: Arc<DashMap<String, AmendRequestData>>,
2328 auth_state: Arc<tokio::sync::watch::Sender<bool>>,
2329 ) -> Self {
2330 Self {
2331 account_id,
2332 handler: OKXFeedHandler::new(reader, signal),
2333 tx,
2334 pending_place_requests,
2335 pending_cancel_requests,
2336 pending_amend_requests,
2337 instruments_cache,
2338 last_account_state: None,
2339 fee_cache: AHashMap::new(),
2340 funding_rate_cache: AHashMap::new(),
2341 auth_state,
2342 }
2343 }
2344
2345 fn is_stopped(&self) -> bool {
2346 self.handler
2347 .signal
2348 .load(std::sync::atomic::Ordering::Relaxed)
2349 }
2350
2351 #[allow(dead_code)]
2352 async fn run(&mut self) {
2353 while let Some(data) = self.next().await {
2354 if let Err(e) = self.tx.send(data) {
2355 tracing::error!("Error sending data: {e}");
2356 break; }
2358 }
2359 }
2360
2361 async fn next(&mut self) -> Option<NautilusWsMessage> {
2362 let clock = get_atomic_clock_realtime();
2363
2364 while let Some(event) = self.handler.next().await {
2365 let ts_init = clock.get_time_ns();
2366
2367 if let OKXWebSocketEvent::Login { code, msg, .. } = event {
2368 if code == "0" {
2369 if self.auth_state.send(true).is_err() {
2370 tracing::error!(
2371 "Failed to send authentication success signal: receiver dropped"
2372 );
2373 }
2374 } else {
2375 tracing::error!("Authentication failed: {msg}");
2376 if self.auth_state.send(false).is_err() {
2377 tracing::error!(
2378 "Failed to send authentication failure signal: receiver dropped"
2379 );
2380 }
2381 }
2382 continue; }
2384
2385 if let OKXWebSocketEvent::BookData { arg, action, data } = event {
2386 let inst = match arg.inst_id {
2387 Some(inst_id) => match self.instruments_cache.get(&inst_id) {
2388 Some(inst_ref) => inst_ref.clone(),
2389 None => continue,
2390 },
2391 None => {
2392 tracing::error!("Instrument ID missing for book data event");
2393 continue;
2394 }
2395 };
2396
2397 let instrument_id = inst.id();
2398 let price_precision = inst.price_precision();
2399 let size_precision = inst.size_precision();
2400
2401 match parse_book_msg_vec(
2402 data,
2403 &instrument_id,
2404 price_precision,
2405 size_precision,
2406 action,
2407 ts_init,
2408 ) {
2409 Ok(data) => return Some(NautilusWsMessage::Data(data)),
2410 Err(e) => {
2411 tracing::error!("Failed to parse book message: {e}");
2412 continue;
2413 }
2414 }
2415 }
2416
2417 if let OKXWebSocketEvent::OrderResponse {
2418 id,
2419 op,
2420 code,
2421 msg,
2422 data,
2423 } = event
2424 {
2425 if code == "0" {
2426 tracing::debug!(
2427 "Order operation successful: id={:?} op={op} code={code}",
2428 id
2429 );
2430
2431 if let Some(data) = data.first() {
2432 let success_msg = data
2433 .get("sMsg")
2434 .and_then(|s| s.as_str())
2435 .unwrap_or("Order operation successful");
2436 tracing::debug!("Order details: {success_msg}");
2437
2438 }
2442 } else {
2443 let error_msg = data
2445 .first()
2446 .and_then(|d| d.get("sMsg"))
2447 .and_then(|s| s.as_str())
2448 .unwrap_or(&msg);
2449
2450 if let Some(data_obj) = data.first() {
2452 tracing::debug!(
2453 "Error data fields: {}",
2454 serde_json::to_string_pretty(data_obj)
2455 .unwrap_or_else(|_| "unable to serialize".to_string())
2456 );
2457 }
2458
2459 tracing::error!(
2460 "Order operation failed: id={:?} op={op} code={code} msg={msg}",
2461 id
2462 );
2463
2464 if let Some(id) = &id {
2466 match op {
2467 OKXWsOperation::Order => {
2468 if let Some((
2469 _,
2470 (client_order_id, trader_id, strategy_id, instrument_id),
2471 )) = self.pending_place_requests.remove(id)
2472 {
2473 let ts_event = clock.get_time_ns();
2474 let rejected = OrderRejected::new(
2475 trader_id,
2476 strategy_id,
2477 instrument_id,
2478 client_order_id,
2479 self.account_id,
2480 Ustr::from(error_msg), UUID4::new(),
2482 ts_event,
2483 ts_init,
2484 false, false, );
2487
2488 return Some(NautilusWsMessage::OrderRejected(rejected));
2489 }
2490 }
2491 OKXWsOperation::CancelOrder => {
2492 if let Some((
2493 _,
2494 (
2495 client_order_id,
2496 trader_id,
2497 strategy_id,
2498 instrument_id,
2499 venue_order_id,
2500 ),
2501 )) = self.pending_cancel_requests.remove(id)
2502 {
2503 let ts_event = clock.get_time_ns();
2504 let rejected = OrderCancelRejected::new(
2505 trader_id,
2506 strategy_id,
2507 instrument_id,
2508 client_order_id,
2509 Ustr::from(error_msg), UUID4::new(),
2511 ts_event,
2512 ts_init,
2513 false, venue_order_id,
2515 Some(self.account_id),
2516 );
2517
2518 return Some(NautilusWsMessage::OrderCancelRejected(rejected));
2519 }
2520 }
2521 OKXWsOperation::AmendOrder => {
2522 if let Some((
2523 _,
2524 (
2525 client_order_id,
2526 trader_id,
2527 strategy_id,
2528 instrument_id,
2529 venue_order_id,
2530 ),
2531 )) = self.pending_amend_requests.remove(id)
2532 {
2533 let ts_event = clock.get_time_ns();
2534 let rejected = OrderModifyRejected::new(
2535 trader_id,
2536 strategy_id,
2537 instrument_id,
2538 client_order_id,
2539 Ustr::from(error_msg), UUID4::new(),
2541 ts_event,
2542 ts_init,
2543 false, venue_order_id,
2545 Some(self.account_id),
2546 );
2547
2548 return Some(NautilusWsMessage::OrderModifyRejected(rejected));
2549 }
2550 }
2551 _ => {
2552 tracing::warn!("Unhandled operation type for rejection: {op}");
2553 }
2554 }
2555 }
2556
2557 let error = OKXWebSocketError {
2559 code: code.clone(),
2560 message: error_msg.to_string(),
2561 conn_id: None, timestamp: clock.get_time_ns().as_u64(),
2563 };
2564 return Some(NautilusWsMessage::Error(error));
2565 }
2566 continue;
2567 }
2568
2569 if let OKXWebSocketEvent::Data { ref arg, ref data } = event {
2570 if arg.channel == OKXWsChannel::Account {
2571 match serde_json::from_value::<Vec<OKXAccount>>(data.clone()) {
2572 Ok(accounts) => {
2573 if let Some(account) = accounts.first() {
2574 match parse_account_state(account, self.account_id, ts_init) {
2576 Ok(account_state) => {
2577 if let Some(last_account_state) = &self.last_account_state
2579 && account_state
2580 .has_same_balances_and_margins(last_account_state)
2581 {
2582 continue; }
2584 self.last_account_state = Some(account_state.clone());
2585 return Some(NautilusWsMessage::AccountUpdate(
2586 account_state,
2587 ));
2588 }
2589 Err(e) => {
2590 tracing::error!("Failed to parse account state: {e}");
2591 }
2592 }
2593 }
2594 }
2595 Err(e) => {
2596 tracing::error!(
2597 "Failed to parse account data: {e}, raw data: {}",
2598 data
2599 );
2600 }
2601 }
2602 continue;
2603 }
2604
2605 if arg.channel == OKXWsChannel::Orders {
2606 tracing::debug!("Received orders channel message: {data}");
2607
2608 let data: Vec<OKXOrderMsg> = serde_json::from_value(data.clone()).unwrap();
2609
2610 let mut exec_reports = Vec::with_capacity(data.len());
2611
2612 for msg in data {
2613 match parse_order_msg_vec(
2614 vec![msg],
2615 self.account_id,
2616 &self.instruments_cache,
2617 &self.fee_cache,
2618 ts_init,
2619 ) {
2620 Ok(mut reports) => {
2621 for report in &reports {
2623 match report {
2624 ExecutionReport::Fill(fill_report) => {
2625 let order_id = fill_report.venue_order_id.inner();
2626 let current_fee = self
2627 .fee_cache
2628 .get(&order_id)
2629 .copied()
2630 .unwrap_or_else(|| {
2631 Money::new(0.0, fill_report.commission.currency)
2632 });
2633 let total_fee = current_fee + fill_report.commission;
2634 self.fee_cache.insert(order_id, total_fee);
2635 }
2636 ExecutionReport::Order(status_report) => {
2637 if matches!(
2638 status_report.order_status,
2639 OrderStatus::Filled,
2640 ) {
2641 self.fee_cache
2642 .remove(&status_report.venue_order_id.inner());
2643 }
2644 }
2645 }
2646 }
2647 exec_reports.append(&mut reports);
2648 }
2649 Err(e) => {
2650 tracing::error!("Failed to parse order message: {e}");
2651 continue;
2652 }
2653 }
2654 }
2655
2656 if !exec_reports.is_empty() {
2657 return Some(NautilusWsMessage::ExecutionReports(exec_reports));
2658 }
2659 }
2660
2661 let inst = match arg.inst_id.and_then(|id| self.instruments_cache.get(&id)) {
2662 Some(inst) => inst,
2663 None => {
2664 tracing::error!(
2665 "No instrument for channel {:?}, inst_id {:?}",
2666 arg.channel,
2667 arg.inst_id
2668 );
2669 continue;
2670 }
2671 };
2672 let instrument_id = inst.id();
2673 let price_precision = inst.price_precision();
2674 let size_precision = inst.size_precision();
2675
2676 match parse_ws_message_data(
2677 &arg.channel,
2678 data.clone(),
2679 &instrument_id,
2680 price_precision,
2681 size_precision,
2682 ts_init,
2683 &mut self.funding_rate_cache,
2684 ) {
2685 Ok(Some(msg)) => return Some(msg),
2686 Ok(None) => {
2687 continue;
2689 }
2690 Err(e) => {
2691 tracing::error!("Error parsing message for channel {:?}: {e}", arg.channel)
2692 }
2693 }
2694 }
2695
2696 if let OKXWebSocketEvent::Login {
2698 code, msg, conn_id, ..
2699 } = &event
2700 && code != "0"
2701 {
2702 let error = OKXWebSocketError {
2703 code: code.clone(),
2704 message: msg.clone(),
2705 conn_id: Some(conn_id.clone()),
2706 timestamp: clock.get_time_ns().as_u64(),
2707 };
2708 return Some(NautilusWsMessage::Error(error));
2709 }
2710
2711 if let OKXWebSocketEvent::Error { code, msg } = &event {
2713 let error = OKXWebSocketError {
2714 code: code.clone(),
2715 message: msg.clone(),
2716 conn_id: None,
2717 timestamp: clock.get_time_ns().as_u64(),
2718 };
2719 return Some(NautilusWsMessage::Error(error));
2720 }
2721
2722 if matches!(&event, OKXWebSocketEvent::Reconnected) {
2724 return Some(NautilusWsMessage::Reconnected);
2725 }
2726 }
2727 None }
2729}
2730
2731#[cfg(test)]
2736mod tests {
2737 use futures_util;
2738 use rstest::rstest;
2739
2740 use super::*;
2741
2742 #[rstest]
2743 fn test_timestamp_format_for_websocket_auth() {
2744 let timestamp = SystemTime::now()
2745 .duration_since(SystemTime::UNIX_EPOCH)
2746 .expect("System time should be after UNIX epoch")
2747 .as_secs()
2748 .to_string();
2749
2750 assert!(timestamp.parse::<u64>().is_ok());
2751 assert_eq!(timestamp.len(), 10);
2752 assert!(timestamp.chars().all(|c| c.is_ascii_digit()));
2753 }
2754
2755 #[rstest]
2756 fn test_new_without_credentials() {
2757 let client = OKXWebSocketClient::default();
2758 assert!(client.credential.is_none());
2759 assert_eq!(client.api_key(), None);
2760 }
2761
2762 #[rstest]
2763 fn test_new_with_credentials() {
2764 let client = OKXWebSocketClient::new(
2765 None,
2766 Some("test_key".to_string()),
2767 Some("test_secret".to_string()),
2768 Some("test_passphrase".to_string()),
2769 None,
2770 None,
2771 )
2772 .unwrap();
2773 assert!(client.credential.is_some());
2774 assert_eq!(client.api_key(), Some("test_key"));
2775 }
2776
2777 #[rstest]
2778 fn test_new_partial_credentials_fails() {
2779 let result = OKXWebSocketClient::new(
2780 None,
2781 Some("test_key".to_string()),
2782 None,
2783 Some("test_passphrase".to_string()),
2784 None,
2785 None,
2786 );
2787 assert!(result.is_err());
2788 }
2789
2790 #[rstest]
2791 fn test_request_id_generation() {
2792 let client = OKXWebSocketClient::default();
2793
2794 let initial_counter = client.request_id_counter.load(Ordering::SeqCst);
2795
2796 let id1 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
2797 let id2 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
2798
2799 assert_eq!(id1, initial_counter);
2800 assert_eq!(id2, initial_counter + 1);
2801 assert_eq!(
2802 client.request_id_counter.load(Ordering::SeqCst),
2803 initial_counter + 2
2804 );
2805 }
2806
2807 #[rstest]
2808 fn test_client_state_management() {
2809 let client = OKXWebSocketClient::default();
2810
2811 assert!(client.is_closed());
2812 assert!(!client.is_active());
2813
2814 let client_with_heartbeat =
2815 OKXWebSocketClient::new(None, None, None, None, None, Some(30)).unwrap();
2816
2817 assert!(client_with_heartbeat.heartbeat.is_some());
2818 assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
2819 }
2820
2821 #[rstest]
2822 fn test_request_cache_operations() {
2823 let client = OKXWebSocketClient::default();
2824
2825 assert_eq!(client.pending_place_requests.len(), 0);
2826 assert_eq!(client.pending_cancel_requests.len(), 0);
2827 assert_eq!(client.pending_amend_requests.len(), 0);
2828
2829 let client_order_id = ClientOrderId::from("test-order-123");
2830 let trader_id = TraderId::from("test-trader-001");
2831 let strategy_id = StrategyId::from("test-strategy-001");
2832 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
2833
2834 client.pending_place_requests.insert(
2835 "place-123".to_string(),
2836 (client_order_id, trader_id, strategy_id, instrument_id),
2837 );
2838
2839 assert_eq!(client.pending_place_requests.len(), 1);
2840 assert!(client.pending_place_requests.contains_key("place-123"));
2841
2842 let removed = client.pending_place_requests.remove("place-123");
2843 assert!(removed.is_some());
2844 assert_eq!(client.pending_place_requests.len(), 0);
2845 }
2846
2847 #[rstest]
2848 fn test_websocket_error_handling() {
2849 let clock = get_atomic_clock_realtime();
2850 let ts = clock.get_time_ns().as_u64();
2851
2852 let error = OKXWebSocketError {
2853 code: "60012".to_string(),
2854 message: "Invalid request".to_string(),
2855 conn_id: None,
2856 timestamp: ts,
2857 };
2858
2859 assert_eq!(error.code, "60012");
2860 assert_eq!(error.message, "Invalid request");
2861 assert_eq!(error.timestamp, ts);
2862
2863 let nautilus_msg = NautilusWsMessage::Error(error);
2864 match nautilus_msg {
2865 NautilusWsMessage::Error(err) => {
2866 assert_eq!(err.code, "60012");
2867 assert_eq!(err.message, "Invalid request");
2868 }
2869 _ => panic!("Expected Error variant"),
2870 }
2871 }
2872
2873 #[rstest]
2874 fn test_request_id_generation_sequence() {
2875 let client = OKXWebSocketClient::default();
2876
2877 let initial_counter = client
2878 .request_id_counter
2879 .load(std::sync::atomic::Ordering::SeqCst);
2880 let mut ids = Vec::new();
2881 for _ in 0..10 {
2882 let id = client
2883 .request_id_counter
2884 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2885 ids.push(id);
2886 }
2887
2888 for (i, &id) in ids.iter().enumerate() {
2889 assert_eq!(id, initial_counter + i as u64);
2890 }
2891
2892 assert_eq!(
2893 client
2894 .request_id_counter
2895 .load(std::sync::atomic::Ordering::SeqCst),
2896 initial_counter + 10
2897 );
2898 }
2899
2900 #[rstest]
2901 fn test_client_state_transitions() {
2902 let client = OKXWebSocketClient::default();
2903
2904 assert!(client.is_closed());
2905 assert!(!client.is_active());
2906
2907 let client_with_heartbeat = OKXWebSocketClient::new(
2908 None,
2909 None,
2910 None,
2911 None,
2912 None,
2913 Some(30), )
2915 .unwrap();
2916
2917 assert!(client_with_heartbeat.heartbeat.is_some());
2918 assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
2919
2920 let account_id = AccountId::from("test-account-123");
2921 let client_with_account =
2922 OKXWebSocketClient::new(None, None, None, None, Some(account_id), None).unwrap();
2923
2924 assert_eq!(client_with_account.account_id, account_id);
2925 }
2926
2927 #[tokio::test]
2928 async fn test_concurrent_request_handling() {
2929 let client = Arc::new(OKXWebSocketClient::default());
2930
2931 let initial_counter = client
2932 .request_id_counter
2933 .load(std::sync::atomic::Ordering::SeqCst);
2934 let mut handles = Vec::new();
2935
2936 for i in 0..10 {
2937 let client_clone = Arc::clone(&client);
2938 let handle = tokio::spawn(async move {
2939 let client_order_id = ClientOrderId::from(format!("order-{i}").as_str());
2940 let trader_id = TraderId::from("trader-001");
2941 let strategy_id = StrategyId::from("strategy-001");
2942 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
2943
2944 let request_id = client_clone
2945 .request_id_counter
2946 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2947 let request_id_str = request_id.to_string();
2948
2949 client_clone.pending_place_requests.insert(
2950 request_id_str.clone(),
2951 (client_order_id, trader_id, strategy_id, instrument_id),
2952 );
2953
2954 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
2956
2957 let removed = client_clone.pending_place_requests.remove(&request_id_str);
2959 assert!(removed.is_some());
2960
2961 request_id
2962 });
2963 handles.push(handle);
2964 }
2965
2966 let results: Vec<_> = futures_util::future::join_all(handles).await;
2968
2969 assert_eq!(results.len(), 10);
2970 for result in results {
2971 assert!(result.is_ok());
2972 }
2973
2974 assert_eq!(client.pending_place_requests.len(), 0);
2975
2976 let final_counter = client
2977 .request_id_counter
2978 .load(std::sync::atomic::Ordering::SeqCst);
2979 assert_eq!(final_counter, initial_counter + 10);
2980 }
2981
2982 #[rstest]
2983 fn test_websocket_error_scenarios() {
2984 let clock = get_atomic_clock_realtime();
2985 let ts = clock.get_time_ns().as_u64();
2986
2987 let error_scenarios = vec![
2988 ("60012", "Invalid request", None),
2989 ("60009", "Invalid API key", Some("conn-123".to_string())),
2990 ("60014", "Too many requests", None),
2991 ("50001", "Order not found", None),
2992 ];
2993
2994 for (code, message, conn_id) in error_scenarios {
2995 let error = OKXWebSocketError {
2996 code: code.to_string(),
2997 message: message.to_string(),
2998 conn_id: conn_id.clone(),
2999 timestamp: ts,
3000 };
3001
3002 assert_eq!(error.code, code);
3003 assert_eq!(error.message, message);
3004 assert_eq!(error.conn_id, conn_id);
3005 assert_eq!(error.timestamp, ts);
3006
3007 let nautilus_msg = NautilusWsMessage::Error(error);
3008 match nautilus_msg {
3009 NautilusWsMessage::Error(err) => {
3010 assert_eq!(err.code, code);
3011 assert_eq!(err.message, message);
3012 assert_eq!(err.conn_id, conn_id);
3013 }
3014 _ => panic!("Expected Error variant"),
3015 }
3016 }
3017 }
3018
3019 #[tokio::test]
3020 async fn test_feed_handler_reconnection_detection() {
3021 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
3022 let signal = Arc::new(AtomicBool::new(false));
3023 let mut handler = OKXFeedHandler::new(rx, signal.clone());
3024
3025 tx.send(Message::Text(RECONNECTED.to_string().into()))
3026 .unwrap();
3027
3028 let result = handler.next().await;
3029 assert!(matches!(result, Some(OKXWebSocketEvent::Reconnected)));
3030 }
3031
3032 #[tokio::test]
3033 async fn test_feed_handler_normal_message_processing() {
3034 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
3035 let signal = Arc::new(AtomicBool::new(false));
3036 let mut handler = OKXFeedHandler::new(rx, signal.clone());
3037
3038 let ping_msg = "ping";
3040 tx.send(Message::Text(ping_msg.to_string().into())).unwrap();
3041
3042 let sub_msg = r#"{
3044 "event": "subscribe",
3045 "arg": {
3046 "channel": "tickers",
3047 "instType": "SPOT"
3048 },
3049 "connId": "a4d3ae55"
3050 }"#;
3051
3052 tx.send(Message::Text(sub_msg.to_string().into())).unwrap();
3053
3054 signal.store(true, std::sync::atomic::Ordering::Relaxed);
3056
3057 let result = handler.next().await;
3059 assert!(result.is_none());
3060 }
3061
3062 #[tokio::test]
3063 async fn test_feed_handler_stop_signal() {
3064 let (_tx, rx) = tokio::sync::mpsc::unbounded_channel();
3065 let signal = Arc::new(AtomicBool::new(true)); let mut handler = OKXFeedHandler::new(rx, signal.clone());
3067
3068 let result = handler.next().await;
3069 assert!(result.is_none());
3070 }
3071
3072 #[tokio::test]
3073 async fn test_feed_handler_close_message() {
3074 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
3075 let signal = Arc::new(AtomicBool::new(false));
3076 let mut handler = OKXFeedHandler::new(rx, signal.clone());
3077
3078 tx.send(Message::Close(None)).unwrap();
3080
3081 let result = handler.next().await;
3082 assert!(result.is_none());
3083 }
3084
3085 #[tokio::test]
3086 async fn test_reconnection_message_constant() {
3087 assert_eq!(RECONNECTED, "__RECONNECTED__");
3088 }
3089
3090 #[tokio::test]
3091 async fn test_multiple_reconnection_signals() {
3092 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
3093 let signal = Arc::new(AtomicBool::new(false));
3094 let mut handler = OKXFeedHandler::new(rx, signal.clone());
3095
3096 for _ in 0..3 {
3098 tx.send(Message::Text(RECONNECTED.to_string().into()))
3099 .unwrap();
3100
3101 let result = handler.next().await;
3102 assert!(matches!(result, Some(OKXWebSocketEvent::Reconnected)));
3103 }
3104 }
3105
3106 #[tokio::test]
3107 async fn test_wait_until_active_timeout() {
3108 let client = OKXWebSocketClient::new(
3109 None,
3110 Some("test_key".to_string()),
3111 Some("test_secret".to_string()),
3112 Some("test_passphrase".to_string()),
3113 Some(AccountId::from("test-account")),
3114 None,
3115 )
3116 .unwrap();
3117
3118 let result = client.wait_until_active(0.1).await;
3120
3121 assert!(result.is_err());
3122 assert!(!client.is_active());
3123 }
3124}