nautilus_okx/websocket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Provides the WebSocket client integration for the [OKX](https://okx.com) WebSocket API.
17//!
18//! This module defines and implements a strongly-typed [`OKXWebSocketClient`] for
19//! connecting to OKX WebSocket streams. It handles authentication (when credentials
20//! are provided), manages subscriptions to market data and account update channels,
21//! and parses incoming messages into structured Nautilus domain objects.
22
23use std::{
24    fmt::Debug,
25    num::NonZeroU32,
26    sync::{
27        Arc, LazyLock,
28        atomic::{AtomicBool, AtomicU64, Ordering},
29    },
30    time::{Duration, SystemTime},
31};
32
33use ahash::{AHashMap, AHashSet};
34use dashmap::DashMap;
35use futures_util::Stream;
36use nautilus_common::runtime::get_runtime;
37use nautilus_core::{
38    UUID4, consts::NAUTILUS_USER_AGENT, env::get_env_var, time::get_atomic_clock_realtime,
39};
40use nautilus_model::{
41    data::BarType,
42    enums::{OrderSide, OrderStatus, OrderType, PositionSide, TimeInForce},
43    events::{AccountState, OrderCancelRejected, OrderModifyRejected, OrderRejected},
44    identifiers::{AccountId, ClientOrderId, InstrumentId, StrategyId, TraderId, VenueOrderId},
45    instruments::{Instrument, InstrumentAny},
46    types::{Money, Price, Quantity},
47};
48use nautilus_network::{
49    RECONNECTED,
50    ratelimiter::quota::Quota,
51    websocket::{WebSocketClient, WebSocketConfig, channel_message_handler},
52};
53use reqwest::header::USER_AGENT;
54use serde_json::Value;
55use tokio::sync::mpsc::UnboundedReceiver;
56use tokio_tungstenite::tungstenite::{Error, Message};
57use ustr::Ustr;
58
59use super::{
60    enums::{OKXWsChannel, OKXWsOperation},
61    error::OKXWsError,
62    messages::{
63        ExecutionReport, NautilusWsMessage, OKXAuthentication, OKXAuthenticationArg,
64        OKXSubscription, OKXSubscriptionArg, OKXWebSocketError, OKXWebSocketEvent, OKXWsRequest,
65        WsAmendOrderParams, WsAmendOrderParamsBuilder, WsCancelOrderParams,
66        WsCancelOrderParamsBuilder, WsPostOrderParams, WsPostOrderParamsBuilder,
67    },
68    parse::{parse_book_msg_vec, parse_ws_message_data},
69};
70use crate::{
71    common::{
72        consts::{
73            OKX_NAUTILUS_BROKER_ID, OKX_SUPPORTED_ORDER_TYPES, OKX_SUPPORTED_TIME_IN_FORCE,
74            OKX_WS_PUBLIC_URL,
75        },
76        credential::Credential,
77        enums::{OKXInstrumentType, OKXOrderType, OKXPositionSide, OKXSide, OKXTradeMode},
78        parse::{bar_spec_as_okx_channel, okx_instrument_type, parse_account_state},
79    },
80    http::models::OKXAccount,
81    websocket::{messages::OKXOrderMsg, parse::parse_order_msg_vec},
82};
83
84type PlaceRequestData = (ClientOrderId, TraderId, StrategyId, InstrumentId);
85type CancelRequestData = (
86    ClientOrderId,
87    TraderId,
88    StrategyId,
89    InstrumentId,
90    Option<VenueOrderId>,
91);
92type AmendRequestData = (
93    ClientOrderId,
94    TraderId,
95    StrategyId,
96    InstrumentId,
97    Option<VenueOrderId>,
98);
99
100/// Default OKX WebSocket rate limit: 3 requests per second.
101///
102/// - Connection limit: 3 requests per second (per IP).
103/// - Subscription requests: 480 'subscribe/unsubscribe/login' requests per connection per hour.
104/// - 30 WebSocket connections max per specific channel per sub-account.
105///
106/// We use 3 requests per second as the base limit to respect the connection rate limit.
107pub static OKX_WS_QUOTA: LazyLock<Quota> =
108    LazyLock::new(|| Quota::per_second(NonZeroU32::new(3).unwrap()));
109
110/// Rate limit for order-related WebSocket operations: 250 requests per second.
111///
112/// Based on OKX documentation for sub-account order limits (1000 per 2 seconds,
113/// so we use half for conservative rate limiting).
114pub static OKX_WS_ORDER_QUOTA: LazyLock<Quota> =
115    LazyLock::new(|| Quota::per_second(NonZeroU32::new(250).unwrap()));
116
117/// Provides a WebSocket client for connecting to [OKX](https://okx.com).
118#[derive(Clone)]
119#[cfg_attr(
120    feature = "python",
121    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.adapters")
122)]
123pub struct OKXWebSocketClient {
124    url: String,
125    account_id: AccountId,
126    credential: Option<Credential>,
127    heartbeat: Option<u64>,
128    inner: Arc<tokio::sync::RwLock<Option<WebSocketClient>>>,
129    auth_state: Arc<tokio::sync::watch::Sender<bool>>,
130    auth_state_rx: tokio::sync::watch::Receiver<bool>,
131    rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>,
132    signal: Arc<AtomicBool>,
133    task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
134    subscriptions_inst_type: Arc<DashMap<OKXWsChannel, AHashSet<OKXInstrumentType>>>,
135    subscriptions_inst_family: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
136    subscriptions_inst_id: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
137    subscriptions_bare: Arc<DashMap<OKXWsChannel, bool>>, // For channels without inst params (e.g., Account)
138    request_id_counter: Arc<AtomicU64>,
139    pending_place_requests: Arc<DashMap<String, PlaceRequestData>>,
140    pending_cancel_requests: Arc<DashMap<String, CancelRequestData>>,
141    pending_amend_requests: Arc<DashMap<String, AmendRequestData>>,
142    instruments_cache: Arc<AHashMap<Ustr, InstrumentAny>>,
143}
144
145impl Default for OKXWebSocketClient {
146    fn default() -> Self {
147        Self::new(None, None, None, None, None, None).unwrap()
148    }
149}
150
151impl Debug for OKXWebSocketClient {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        f.debug_struct(stringify!(OKXWebSocketClient))
154            .field("url", &self.url)
155            .field(
156                "credential",
157                &self.credential.as_ref().map(|_| "<redacted>"),
158            )
159            .field("heartbeat", &self.heartbeat)
160            .finish_non_exhaustive()
161    }
162}
163
164impl OKXWebSocketClient {
165    /// Creates a new [`OKXWebSocketClient`] instance.
166    pub fn new(
167        url: Option<String>,
168        api_key: Option<String>,
169        api_secret: Option<String>,
170        api_passphrase: Option<String>,
171        account_id: Option<AccountId>,
172        heartbeat: Option<u64>,
173    ) -> anyhow::Result<Self> {
174        let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
175        let account_id = account_id.unwrap_or(AccountId::from("OKX-master"));
176
177        let credential = match (api_key, api_secret, api_passphrase) {
178            (Some(key), Some(secret), Some(passphrase)) => {
179                Some(Credential::new(key, secret, passphrase))
180            }
181            (None, None, None) => None,
182            _ => anyhow::bail!(
183                "`api_key`, `api_secret`, `api_passphrase` credentials must be provided together"
184            ),
185        };
186
187        let signal = Arc::new(AtomicBool::new(false));
188        let subscriptions_inst_type = Arc::new(DashMap::new());
189        let subscriptions_inst_family = Arc::new(DashMap::new());
190        let subscriptions_inst_id = Arc::new(DashMap::new());
191        let subscriptions_bare = Arc::new(DashMap::new());
192        let (auth_tx, auth_rx) = tokio::sync::watch::channel(false);
193
194        Ok(Self {
195            url,
196            account_id,
197            credential,
198            heartbeat,
199            inner: Arc::new(tokio::sync::RwLock::new(None)),
200            auth_state: Arc::new(auth_tx),
201            auth_state_rx: auth_rx,
202            rx: None,
203            signal,
204            task_handle: None,
205            subscriptions_inst_type,
206            subscriptions_inst_family,
207            subscriptions_inst_id,
208            subscriptions_bare,
209            request_id_counter: Arc::new(AtomicU64::new(1)),
210            pending_place_requests: Arc::new(DashMap::new()),
211            pending_cancel_requests: Arc::new(DashMap::new()),
212            pending_amend_requests: Arc::new(DashMap::new()),
213            instruments_cache: Arc::new(AHashMap::new()),
214        })
215    }
216
217    /// Creates a new [`OKXWebSocketClient`] instance.
218    pub fn with_credentials(
219        url: Option<String>,
220        api_key: Option<String>,
221        api_secret: Option<String>,
222        api_passphrase: Option<String>,
223        account_id: Option<AccountId>,
224        heartbeat: Option<u64>,
225    ) -> anyhow::Result<Self> {
226        let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
227        let api_key = api_key.unwrap_or(get_env_var("OKX_API_KEY")?);
228        let api_secret = api_secret.unwrap_or(get_env_var("OKX_API_SECRET")?);
229        let api_passphrase = api_passphrase.unwrap_or(get_env_var("OKX_API_PASSPHRASE")?);
230
231        Self::new(
232            Some(url),
233            Some(api_key),
234            Some(api_secret),
235            Some(api_passphrase),
236            account_id,
237            heartbeat,
238        )
239    }
240
241    /// Creates a new authenticated [`OKXWebSocketClient`] using environment variables.
242    pub fn from_env() -> anyhow::Result<Self> {
243        let url = get_env_var("OKX_WS_URL")?;
244        let api_key = get_env_var("OKX_API_KEY")?;
245        let api_secret = get_env_var("OKX_API_SECRET")?;
246        let api_passphrase = get_env_var("OKX_API_PASSPHRASE")?;
247
248        Self::new(
249            Some(url),
250            Some(api_key),
251            Some(api_secret),
252            Some(api_passphrase),
253            None,
254            None,
255        )
256    }
257
258    /// Returns the websocket url being used by the client.
259    pub fn url(&self) -> &str {
260        self.url.as_str()
261    }
262
263    /// Returns the public API key being used by the client.
264    pub fn api_key(&self) -> Option<&str> {
265        self.credential.clone().map(|c| c.api_key.as_str())
266    }
267
268    /// Get a read lock on the inner client
269    /// Returns a value indicating whether the client is active.
270    pub fn is_active(&self) -> bool {
271        // Use try_read to avoid blocking
272        match self.inner.try_read() {
273            Ok(guard) => match &*guard {
274                Some(inner) => inner.is_active(),
275                None => false,
276            },
277            Err(_) => false, // If we can't get the lock, assume not active
278        }
279    }
280
281    /// Returns a value indicating whether the client is closed.
282    pub fn is_closed(&self) -> bool {
283        // Use try_read to avoid blocking
284        match self.inner.try_read() {
285            Ok(guard) => match &*guard {
286                Some(inner) => inner.is_closed(),
287                None => true,
288            },
289            Err(_) => true, // If we can't get the lock, assume closed
290        }
291    }
292
293    /// Initialize the instruments cache with the given `instruments`.
294    pub fn initialize_instruments_cache(&mut self, instruments: Vec<InstrumentAny>) {
295        let mut instruments_cache: AHashMap<Ustr, InstrumentAny> = AHashMap::new();
296        for inst in instruments {
297            instruments_cache.insert(inst.symbol().inner(), inst.clone());
298        }
299
300        self.instruments_cache = Arc::new(instruments_cache)
301    }
302
303    /// Connect to the OKX WebSocket server.
304    ///
305    /// # Panics
306    ///
307    /// Panics if subscription arguments fail to serialize to JSON.
308    pub async fn connect(&mut self) -> anyhow::Result<()> {
309        let (message_handler, reader) = channel_message_handler();
310
311        let config = WebSocketConfig {
312            url: self.url.clone(),
313            headers: vec![(USER_AGENT.to_string(), NAUTILUS_USER_AGENT.to_string())],
314            heartbeat: self.heartbeat,
315            heartbeat_msg: None,
316            message_handler: Some(message_handler),
317            ping_handler: None,
318            reconnect_timeout_ms: Some(5_000),
319            reconnect_delay_initial_ms: None, // Use default
320            reconnect_delay_max_ms: None,     // Use default
321            reconnect_backoff_factor: None,   // Use default
322            reconnect_jitter_ms: None,        // Use default
323        };
324        // Configure rate limits for different operation types
325        let keyed_quotas = vec![
326            ("subscription".to_string(), *OKX_WS_QUOTA),
327            ("order".to_string(), *OKX_WS_ORDER_QUOTA),
328            ("cancel".to_string(), *OKX_WS_ORDER_QUOTA),
329            ("amend".to_string(), *OKX_WS_ORDER_QUOTA),
330        ];
331
332        let client = WebSocketClient::connect(
333            config,
334            None, // post_reconnection
335            keyed_quotas,
336            Some(*OKX_WS_QUOTA), // Default quota for general operations
337        )
338        .await?;
339
340        // Set the inner client with write lock
341        {
342            let mut inner_guard = self.inner.write().await;
343            *inner_guard = Some(client);
344        }
345
346        let account_id = self.account_id;
347        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<NautilusWsMessage>();
348
349        self.rx = Some(Arc::new(rx));
350        let signal = self.signal.clone();
351        let pending_place_requests = self.pending_place_requests.clone();
352        let pending_cancel_requests = self.pending_cancel_requests.clone();
353        let pending_amend_requests = self.pending_amend_requests.clone();
354        let auth_state = self.auth_state.clone();
355
356        let instruments_cache = self.instruments_cache.clone();
357        let inner_client = self.inner.clone();
358        let credential_clone = self.credential.clone();
359        let subscriptions_inst_type = self.subscriptions_inst_type.clone();
360        let subscriptions_inst_family = self.subscriptions_inst_family.clone();
361        let subscriptions_inst_id = self.subscriptions_inst_id.clone();
362        let subscriptions_bare = self.subscriptions_bare.clone();
363        let auth_state_clone = auth_state.clone();
364        let stream_handle = get_runtime().spawn(async move {
365            let mut handler = OKXWsMessageHandler::new(
366                account_id,
367                instruments_cache,
368                reader,
369                signal,
370                tx,
371                pending_place_requests,
372                pending_cancel_requests,
373                pending_amend_requests,
374                auth_state,
375            );
376
377            // Main message loop with explicit reconnection handling
378            loop {
379                match handler.next().await {
380                    Some(NautilusWsMessage::Reconnected) => {
381                        tracing::info!("Handling WebSocket reconnection");
382
383                        // Re-authenticate if we have credentials
384                        let inner_guard = inner_client.read().await;
385                        if let Some(cred) = &credential_clone
386                            && let Some(client) = &*inner_guard {
387                                let timestamp = SystemTime::now()
388                                    .duration_since(SystemTime::UNIX_EPOCH)
389                                    .expect("System time should be after UNIX epoch")
390                                    .as_secs()
391                                    .to_string();
392                                let signature = cred.sign(&timestamp, "GET", "/users/self/verify", "");
393
394                                let auth_message = OKXAuthentication {
395                                    op: "login",
396                                    args: vec![OKXAuthenticationArg {
397                                        api_key: cred.api_key.to_string(),
398                                        passphrase: cred.api_passphrase.clone(),
399                                        timestamp,
400                                        sign: signature,
401                                    }],
402                                };
403
404                                if let Err(e) = client.send_text(serde_json::to_string(&auth_message).unwrap(), None).await {
405                                    tracing::error!("Failed to send re-authentication request: {e}");
406                                    // Even if auth fails, try to resubscribe public channels
407                                } else {
408                                    tracing::info!("Sent re-authentication request, waiting for response before resubscribing");
409
410                                    // Wait for authentication to complete (with timeout)
411                                    let mut auth_rx = auth_state_clone.subscribe();
412                                    match tokio::time::timeout(Duration::from_secs(5), auth_rx.wait_for(|&auth| auth)).await {
413                                        Ok(Ok(_)) => {
414                                            tracing::info!("Authentication successful after reconnect, proceeding with resubscription");
415                                            // Now we resubscribe after successful auth
416                                            // Fall through to resubscription logic below
417                                        }
418                                        Ok(Err(e)) => {
419                                            tracing::error!("Auth watch channel error after reconnect: {e}");
420                                            // Fall through to resubscribe public channels anyway
421                                        }
422                                        Err(_) => {
423                                            tracing::error!("Timeout waiting for authentication after reconnect");
424                                            // Fall through to resubscribe public channels anyway
425                                        }
426                                    }
427                                }
428                        }
429
430                        // Re-subscribe to all channels
431                        // TODO: Extract common resubscription logic to avoid duplication with the auth success path
432                        let inner_guard = inner_client.read().await;
433                        if let Some(client) = &*inner_guard {
434                            // Batch subscribe by instrument type
435                            let mut inst_type_args = Vec::new();
436                            for entry in subscriptions_inst_type.iter() {
437                                let (channel, inst_types) = entry.pair();
438                                for inst_type in inst_types.iter() {
439                                    inst_type_args.push(OKXSubscriptionArg {
440                                        channel: channel.clone(),
441                                        inst_type: Some(*inst_type),
442                                        inst_family: None,
443                                        inst_id: None,
444                                    });
445                                }
446                            }
447                            if !inst_type_args.is_empty() {
448                                let sub_request = OKXSubscription {
449                                    op: OKXWsOperation::Subscribe,
450                                    args: inst_type_args,
451                                };
452                                if let Err(e) = client.send_text(serde_json::to_string(&sub_request).unwrap(), None).await {
453                                    tracing::error!("Failed to re-subscribe inst_type channels: {e}");
454                                }
455                            }
456
457                            // Batch subscribe by instrument family
458                            let mut inst_family_args = Vec::new();
459                            for entry in subscriptions_inst_family.iter() {
460                                let (channel, inst_families) = entry.pair();
461                                for inst_family in inst_families.iter() {
462                                    inst_family_args.push(OKXSubscriptionArg {
463                                        channel: channel.clone(),
464                                        inst_type: None,
465                                        inst_family: Some(*inst_family),
466                                        inst_id: None,
467                                    });
468                                }
469                            }
470                            if !inst_family_args.is_empty() {
471                                let sub_request = OKXSubscription {
472                                    op: OKXWsOperation::Subscribe,
473                                    args: inst_family_args,
474                                };
475                                if let Err(e) = client.send_text(serde_json::to_string(&sub_request).unwrap(), None).await {
476                                    tracing::error!("Failed to re-subscribe inst_family channels: {e}");
477                                }
478                            }
479
480                            // Batch subscribe by instrument ID
481                            let mut inst_id_args = Vec::new();
482                            for entry in subscriptions_inst_id.iter() {
483                                let (channel, inst_ids) = entry.pair();
484                                for inst_id in inst_ids.iter() {
485                                    inst_id_args.push(OKXSubscriptionArg {
486                                        channel: channel.clone(),
487                                        inst_type: None,
488                                        inst_family: None,
489                                        inst_id: Some(*inst_id),
490                                    });
491                                }
492                            }
493                            if !inst_id_args.is_empty() {
494                                let sub_request = OKXSubscription {
495                                    op: OKXWsOperation::Subscribe,
496                                    args: inst_id_args,
497                                };
498                                if let Err(e) = client.send_text(serde_json::to_string(&sub_request).unwrap(), None).await {
499                                    tracing::error!("Failed to re-subscribe inst_id channels: {e}");
500                                }
501                            }
502
503                            // Batch subscribe bare channels
504                            let mut bare_args = Vec::new();
505                            for entry in subscriptions_bare.iter() {
506                                let channel = entry.key();
507                                bare_args.push(OKXSubscriptionArg {
508                                    channel: channel.clone(),
509                                    inst_type: None,
510                                    inst_family: None,
511                                    inst_id: None,
512                                });
513                            }
514                            if !bare_args.is_empty() {
515                                let sub_request = OKXSubscription {
516                                    op: OKXWsOperation::Subscribe,
517                                    args: bare_args,
518                                };
519                                if let Err(e) = client.send_text(serde_json::to_string(&sub_request).unwrap(), None).await {
520                                    tracing::error!("Failed to re-subscribe bare channels: {e}");
521                                }
522                            }
523
524                            tracing::info!("Completed re-subscription after reconnect");
525                        }
526                    }
527                    Some(msg) => {
528                        // Forward the message
529                        if handler.tx.send(msg).is_err() {
530                            tracing::error!("Failed to send message through channel: receiver dropped");
531                            break;
532                        }
533
534                    }
535                    None => {
536                        // Stream ended - check if it's a stop signal
537                        if handler.is_stopped() {
538                            tracing::debug!("Stop signal received, ending message processing");
539                            break;
540                        }
541                        // Otherwise it's an unexpected stream end
542                        tracing::warn!("WebSocket stream ended unexpectedly");
543                        break;
544                    }
545                }
546            }
547        });
548
549        self.task_handle = Some(Arc::new(stream_handle));
550
551        if self.credential.is_some() {
552            if self.auth_state.send(false).is_err() {
553                tracing::error!("Failed to reset auth state, receiver dropped.");
554            };
555            self.authenticate().await?;
556        }
557
558        Ok(())
559    }
560
561    /// Authenticates the WebSocket session with OKX.
562    async fn authenticate(&self) -> Result<(), Error> {
563        let credential = match &self.credential {
564            Some(credential) => credential,
565            None => {
566                panic!("API credentials not available to authenticate");
567            }
568        };
569
570        let timestamp = SystemTime::now()
571            .duration_since(SystemTime::UNIX_EPOCH)
572            .expect("System time should be after UNIX epoch")
573            .as_secs()
574            .to_string();
575        let signature = credential.sign(&timestamp, "GET", "/users/self/verify", "");
576
577        let auth_message = OKXAuthentication {
578            op: "login",
579            args: vec![OKXAuthenticationArg {
580                api_key: credential.api_key.to_string(),
581                passphrase: credential.api_passphrase.clone(),
582                timestamp,
583                sign: signature,
584            }],
585        };
586
587        {
588            let inner_guard = self.inner.read().await;
589            if let Some(inner) = &*inner_guard {
590                if let Err(e) = inner
591                    .send_text(serde_json::to_string(&auth_message).unwrap(), None)
592                    .await
593                {
594                    tracing::error!("Error sending auth message: {e:?}");
595                    return Err(Error::Io(std::io::Error::other(e.to_string())));
596                }
597            } else {
598                log::error!("Cannot authenticate: not connected");
599                return Err(Error::ConnectionClosed);
600            }
601        }
602
603        // Wait for authentication to complete
604        let mut rx = self.auth_state_rx.clone();
605        match tokio::time::timeout(Duration::from_secs(10), rx.wait_for(|&auth| auth)).await {
606            Ok(Ok(_)) => {
607                tracing::info!("Authentication confirmed by client");
608                Ok(())
609            }
610            Ok(Err(e)) => {
611                tracing::error!("Authentication watch channel closed unexpectedly: {e}");
612                Err(Error::Io(std::io::Error::other(
613                    "Authentication watch channel closed",
614                )))
615            }
616            Err(_) => {
617                tracing::error!("Timeout waiting for authentication response");
618                Err(Error::Io(std::io::Error::other(
619                    "Timeout waiting for authentication",
620                )))
621            }
622        }
623    }
624
625    /// Provides the internal data stream as a channel-based stream.
626    ///
627    /// # Panics
628    ///
629    /// This function panics if:
630    /// - The websocket is not connected.
631    /// - `stream_data` has already been called somewhere else (stream receiver is then taken).
632    pub fn stream(&mut self) -> impl Stream<Item = NautilusWsMessage> + 'static {
633        let rx = self
634            .rx
635            .take()
636            .expect("Data stream receiver already taken or not connected");
637        let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
638        async_stream::stream! {
639            while let Some(data) = rx.recv().await {
640                yield data;
641            }
642        }
643    }
644
645    /// Wait until the WebSocket connection is active.
646    ///
647    /// # Errors
648    ///
649    /// Returns an error if the connection times out.
650    pub async fn wait_until_active(&self, timeout_secs: f64) -> Result<(), OKXWsError> {
651        let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
652
653        tokio::time::timeout(timeout, async {
654            while !self.is_active() {
655                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
656            }
657        })
658        .await
659        .map_err(|_| {
660            OKXWsError::ClientError(format!(
661                "WebSocket connection timeout after {timeout_secs} seconds"
662            ))
663        })?;
664
665        Ok(())
666    }
667
668    /// Closes the client.
669    pub async fn close(&mut self) -> Result<(), Error> {
670        log::debug!("Starting close process");
671
672        self.signal.store(true, Ordering::Relaxed);
673
674        {
675            let inner_guard = self.inner.read().await;
676            if let Some(inner) = &*inner_guard {
677                log::debug!("Disconnecting websocket");
678
679                match tokio::time::timeout(Duration::from_secs(3), inner.disconnect()).await {
680                    Ok(()) => log::debug!("Websocket disconnected successfully"),
681                    Err(_) => {
682                        log::warn!(
683                            "Timeout waiting for websocket disconnect, continuing with cleanup"
684                        )
685                    }
686                }
687            } else {
688                log::debug!("No active connection to disconnect");
689            }
690        }
691
692        // Clean up stream handle with timeout
693        if let Some(stream_handle) = self.task_handle.take() {
694            match Arc::try_unwrap(stream_handle) {
695                Ok(handle) => {
696                    log::debug!("Waiting for stream handle to complete");
697                    match tokio::time::timeout(Duration::from_secs(2), handle).await {
698                        Ok(Ok(())) => log::debug!("Stream handle completed successfully"),
699                        Ok(Err(e)) => log::error!("Stream handle encountered an error: {e:?}"),
700                        Err(_) => {
701                            log::warn!(
702                                "Timeout waiting for stream handle, task may still be running"
703                            );
704                            // The task will be dropped and should clean up automatically
705                        }
706                    }
707                }
708                Err(arc_handle) => {
709                    log::debug!(
710                        "Cannot take ownership of stream handle - other references exist, aborting task"
711                    );
712                    arc_handle.abort();
713                }
714            }
715        } else {
716            log::debug!("No stream handle to await");
717        }
718
719        log::debug!("Close process completed");
720
721        Ok(())
722    }
723
724    /// Get active subscriptions for a specific instrument.
725    pub fn get_subscriptions(&self, instrument_id: InstrumentId) -> Vec<OKXWsChannel> {
726        let symbol = instrument_id.symbol.inner();
727        let mut channels = Vec::new();
728
729        for entry in self.subscriptions_inst_id.iter() {
730            let (channel, instruments) = entry.pair();
731            if instruments.contains(&symbol) {
732                channels.push(channel.clone());
733            }
734        }
735
736        channels
737    }
738
739    fn generate_unique_request_id(&self) -> String {
740        self.request_id_counter
741            .fetch_add(1, Ordering::SeqCst)
742            .to_string()
743    }
744
745    async fn subscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
746        for arg in &args {
747            // Check if this is a bare channel (no inst params)
748            if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
749                // Track bare channels like Account
750                self.subscriptions_bare.insert(arg.channel.clone(), true);
751            } else {
752                // Update instrument type subscriptions
753                if let Some(inst_type) = &arg.inst_type {
754                    self.subscriptions_inst_type
755                        .entry(arg.channel.clone())
756                        .or_default()
757                        .insert(*inst_type);
758                }
759
760                // Update instrument family subscriptions
761                if let Some(inst_family) = &arg.inst_family {
762                    self.subscriptions_inst_family
763                        .entry(arg.channel.clone())
764                        .or_default()
765                        .insert(*inst_family);
766                }
767
768                // Update instrument ID subscriptions
769                if let Some(inst_id) = &arg.inst_id {
770                    self.subscriptions_inst_id
771                        .entry(arg.channel.clone())
772                        .or_default()
773                        .insert(*inst_id);
774                }
775            }
776        }
777
778        let message = OKXSubscription {
779            op: OKXWsOperation::Subscribe,
780            args,
781        };
782
783        let json_txt =
784            serde_json::to_string(&message).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
785
786        {
787            let inner_guard = self.inner.read().await;
788            if let Some(inner) = &*inner_guard {
789                if let Err(e) = inner
790                    .send_text(json_txt, Some(vec!["subscription".to_string()]))
791                    .await
792                {
793                    tracing::error!("Error sending message: {e:?}")
794                }
795            } else {
796                return Err(OKXWsError::ClientError(
797                    "Cannot send message: not connected".to_string(),
798                ));
799            }
800        }
801
802        Ok(())
803    }
804
805    #[allow(clippy::collapsible_if)]
806    async fn unsubscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
807        for arg in &args {
808            // Check if this is a bare channel
809            if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
810                // Remove bare channel subscription
811                self.subscriptions_bare.remove(&arg.channel);
812            } else {
813                // Update instrument type subscriptions
814                if let Some(inst_type) = &arg.inst_type {
815                    if let Some(mut entry) = self.subscriptions_inst_type.get_mut(&arg.channel) {
816                        entry.remove(inst_type);
817                        if entry.is_empty() {
818                            drop(entry);
819                            self.subscriptions_inst_type.remove(&arg.channel);
820                        }
821                    }
822                }
823
824                // Update instrument family subscriptions
825                if let Some(inst_family) = &arg.inst_family {
826                    if let Some(mut entry) = self.subscriptions_inst_family.get_mut(&arg.channel) {
827                        entry.remove(inst_family);
828                        if entry.is_empty() {
829                            drop(entry);
830                            self.subscriptions_inst_family.remove(&arg.channel);
831                        }
832                    }
833                }
834
835                // Update instrument ID subscriptions
836                if let Some(inst_id) = &arg.inst_id {
837                    if let Some(mut entry) = self.subscriptions_inst_id.get_mut(&arg.channel) {
838                        entry.remove(inst_id);
839                        if entry.is_empty() {
840                            drop(entry);
841                            self.subscriptions_inst_id.remove(&arg.channel);
842                        }
843                    }
844                }
845            }
846        }
847
848        let message = OKXSubscription {
849            op: OKXWsOperation::Unsubscribe,
850            args,
851        };
852
853        let json_txt = serde_json::to_string(&message).expect("Must be valid JSON");
854
855        {
856            let inner_guard = self.inner.read().await;
857            if let Some(inner) = &*inner_guard {
858                if let Err(e) = inner
859                    .send_text(json_txt, Some(vec!["subscription".to_string()]))
860                    .await
861                {
862                    tracing::error!("Error sending message: {e:?}")
863                }
864            } else {
865                log::error!("Cannot send message: not connected");
866            }
867        }
868
869        Ok(())
870    }
871
872    #[allow(dead_code)]
873    async fn resubscribe_all(&self) {
874        // Collect bare channel subscriptions (e.g., Account)
875        let mut subs_bare = Vec::new();
876        for entry in self.subscriptions_bare.iter() {
877            let channel = entry.key();
878            subs_bare.push(channel.clone());
879        }
880
881        let mut subs_inst_type = Vec::new();
882        for entry in self.subscriptions_inst_type.iter() {
883            let (channel, inst_types) = entry.pair();
884            if !inst_types.is_empty() {
885                subs_inst_type.push((channel.clone(), inst_types.clone()));
886            }
887        }
888
889        let mut subs_inst_family = Vec::new();
890        for entry in self.subscriptions_inst_family.iter() {
891            let (channel, inst_families) = entry.pair();
892            if !inst_families.is_empty() {
893                subs_inst_family.push((channel.clone(), inst_families.clone()));
894            }
895        }
896
897        let mut subs_inst_id = Vec::new();
898        for entry in self.subscriptions_inst_id.iter() {
899            let (channel, inst_ids) = entry.pair();
900            if !inst_ids.is_empty() {
901                subs_inst_id.push((channel.clone(), inst_ids.clone()));
902            }
903        }
904
905        // Process instrument type subscriptions
906        for (channel, inst_types) in subs_inst_type {
907            if inst_types.is_empty() {
908                continue;
909            }
910
911            tracing::debug!("Resubscribing: channel={channel}, instrument_types={inst_types:?}");
912
913            for inst_type in inst_types {
914                let arg = OKXSubscriptionArg {
915                    channel: channel.clone(),
916                    inst_type: Some(inst_type),
917                    inst_family: None,
918                    inst_id: None,
919                };
920
921                if let Err(e) = self.subscribe(vec![arg]).await {
922                    tracing::error!(
923                        "Failed to resubscribe to channel {channel} with instrument type: {e}"
924                    );
925                }
926            }
927        }
928
929        // Process instrument family subscriptions
930        for (channel, inst_families) in subs_inst_family {
931            if inst_families.is_empty() {
932                continue;
933            }
934
935            tracing::debug!(
936                "Resubscribing: channel={channel}, instrument_families={inst_families:?}"
937            );
938
939            for inst_family in inst_families {
940                let arg = OKXSubscriptionArg {
941                    channel: channel.clone(),
942                    inst_type: None,
943                    inst_family: Some(inst_family),
944                    inst_id: None,
945                };
946
947                if let Err(e) = self.subscribe(vec![arg]).await {
948                    tracing::error!(
949                        "Failed to resubscribe to channel {channel} with instrument family: {e}"
950                    );
951                }
952            }
953        }
954
955        // Process instrument ID subscriptions
956        for (channel, inst_ids) in subs_inst_id {
957            if inst_ids.is_empty() {
958                continue;
959            }
960
961            tracing::debug!("Resubscribing: channel={channel}, instrument_ids={inst_ids:?}");
962
963            for inst_id in inst_ids {
964                let arg = OKXSubscriptionArg {
965                    channel: channel.clone(),
966                    inst_type: None,
967                    inst_family: None,
968                    inst_id: Some(inst_id),
969                };
970
971                if let Err(e) = self.subscribe(vec![arg]).await {
972                    tracing::error!(
973                        "Failed to resubscribe to channel {channel} with instrument ID: {e}"
974                    );
975                }
976            }
977        }
978
979        // Process bare channel subscriptions (e.g., Account)
980        for channel in subs_bare {
981            tracing::debug!("Resubscribing to bare channel: {channel}");
982
983            let arg = OKXSubscriptionArg {
984                channel,
985                inst_type: None,
986                inst_family: None,
987                inst_id: None,
988            };
989
990            if let Err(e) = self.subscribe(vec![arg]).await {
991                tracing::error!("Failed to resubscribe to bare channel: {e}");
992            }
993        }
994    }
995
996    /// Subscribes to instrument updates for a specific instrument type.
997    ///
998    /// Provides updates when instrument specifications change.
999    ///
1000    /// # References
1001    ///
1002    /// <https://www.okx.com/docs-v5/en/#public-data-websocket-instruments-channel>.
1003    pub async fn subscribe_instruments(
1004        &self,
1005        instrument_type: OKXInstrumentType,
1006    ) -> Result<(), OKXWsError> {
1007        let arg = OKXSubscriptionArg {
1008            channel: OKXWsChannel::Instruments,
1009            inst_type: Some(instrument_type),
1010            inst_family: None,
1011            inst_id: None,
1012        };
1013        self.subscribe(vec![arg]).await
1014    }
1015
1016    /// Subscribes to instrument updates for a specific instrument.
1017    ///
1018    /// Provides updates when instrument specifications change.
1019    ///
1020    /// # References
1021    ///
1022    /// <https://www.okx.com/docs-v5/en/#public-data-websocket-instruments-channel>.
1023    pub async fn subscribe_instrument(
1024        &self,
1025        instrument_id: InstrumentId,
1026    ) -> Result<(), OKXWsError> {
1027        let arg = OKXSubscriptionArg {
1028            channel: OKXWsChannel::Instruments,
1029            inst_type: None,
1030            inst_family: None,
1031            inst_id: Some(instrument_id.symbol.inner()),
1032        };
1033        self.subscribe(vec![arg]).await
1034    }
1035
1036    /// Subscribes to full order book data (400 depth levels) for an instrument.
1037    ///
1038    /// # References
1039    ///
1040    /// <https://www.okx.com/docs-v5/en/#order-book-trading-market-data-ws-order-book-channel>.
1041    pub async fn subscribe_book(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1042        let arg = OKXSubscriptionArg {
1043            channel: OKXWsChannel::Books,
1044            inst_type: None,
1045            inst_family: None,
1046            inst_id: Some(instrument_id.symbol.inner()),
1047        };
1048        self.subscribe(vec![arg]).await
1049    }
1050
1051    /// Subscribes to 5-level order book snapshot data for an instrument.
1052    ///
1053    /// Updates every 100ms when there are changes.
1054    ///
1055    /// # References
1056    ///
1057    /// <https://www.okx.com/docs-v5/en/#order-book-trading-market-data-ws-order-book-5-depth-channel>.
1058    pub async fn subscribe_book_depth5(
1059        &self,
1060        instrument_id: InstrumentId,
1061    ) -> Result<(), OKXWsError> {
1062        let arg = OKXSubscriptionArg {
1063            channel: OKXWsChannel::Books5,
1064            inst_type: None,
1065            inst_family: None,
1066            inst_id: Some(instrument_id.symbol.inner()),
1067        };
1068        self.subscribe(vec![arg]).await
1069    }
1070
1071    /// Subscribes to 50-level tick-by-tick order book data for an instrument.
1072    ///
1073    /// Provides real-time updates whenever order book changes.
1074    ///
1075    /// # References
1076    ///
1077    /// <https://www.okx.com/docs-v5/en/#order-book-trading-market-data-ws-order-book-50-depth-tbt-channel>.
1078    pub async fn subscribe_books50_l2_tbt(
1079        &self,
1080        instrument_id: InstrumentId,
1081    ) -> Result<(), OKXWsError> {
1082        let arg = OKXSubscriptionArg {
1083            channel: OKXWsChannel::Books50Tbt,
1084            inst_type: None,
1085            inst_family: None,
1086            inst_id: Some(instrument_id.symbol.inner()),
1087        };
1088        self.subscribe(vec![arg]).await
1089    }
1090
1091    /// Subscribes to tick-by-tick full depth (400 levels) order book data for an instrument.
1092    ///
1093    /// Provides real-time updates with all depth levels whenever order book changes.
1094    ///
1095    /// # References
1096    ///
1097    /// <https://www.okx.com/docs-v5/en/#order-book-trading-market-data-ws-order-book-400-depth-tbt-channel>.
1098    pub async fn subscribe_book_l2_tbt(
1099        &self,
1100        instrument_id: InstrumentId,
1101    ) -> Result<(), OKXWsError> {
1102        let arg = OKXSubscriptionArg {
1103            channel: OKXWsChannel::BooksTbt,
1104            inst_type: None,
1105            inst_family: None,
1106            inst_id: Some(instrument_id.symbol.inner()),
1107        };
1108        self.subscribe(vec![arg]).await
1109    }
1110
1111    /// Subscribes to best bid/ask quote data for an instrument.
1112    ///
1113    /// Provides tick-by-tick updates of the best bid and ask prices.
1114    ///
1115    /// # References
1116    ///
1117    /// <https://www.okx.com/docs-v5/en/#order-book-trading-market-data-ws-bbo-tbt-channel>.
1118    pub async fn subscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1119        // let (_, inst_type) = extract_okx_symbol_and_inst_type(&instrument_id);
1120        let arg = OKXSubscriptionArg {
1121            channel: OKXWsChannel::BboTbt,
1122            inst_type: None,
1123            inst_family: None,
1124            inst_id: Some(instrument_id.symbol.inner()),
1125        };
1126        self.subscribe(vec![arg]).await
1127    }
1128
1129    /// Subscribes to trade data for an instrument.
1130    ///
1131    /// # References
1132    ///
1133    /// <https://www.okx.com/docs-v5/en/#order-book-trading-market-data-ws-trades-channel>.
1134    pub async fn subscribe_trades(
1135        &self,
1136        instrument_id: InstrumentId,
1137        _aggregated: bool, // TODO: TBD?
1138    ) -> Result<(), OKXWsError> {
1139        // TODO: aggregated parameter is ignored, always uses 'trades' channel.
1140        // let (symbol, _) = extract_okx_symbol_and_inst_type(&instrument_id);
1141
1142        // Use trades channel for all instruments (trades-all not available?)
1143        let channel = OKXWsChannel::Trades;
1144
1145        let arg = OKXSubscriptionArg {
1146            channel,
1147            inst_type: None,
1148            inst_family: None,
1149            inst_id: Some(instrument_id.symbol.inner()),
1150        };
1151        self.subscribe(vec![arg]).await
1152    }
1153
1154    /// Subscribes to 24hr rolling ticker data for an instrument.
1155    ///
1156    /// Updates every 100ms with trading statistics.
1157    ///
1158    /// # References
1159    ///
1160    /// <https://www.okx.com/docs-v5/en/#order-book-trading-market-data-ws-tickers-channel>.
1161    pub async fn subscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1162        let arg = OKXSubscriptionArg {
1163            channel: OKXWsChannel::Tickers,
1164            inst_type: None,
1165            inst_family: None,
1166            inst_id: Some(instrument_id.symbol.inner()),
1167        };
1168        self.subscribe(vec![arg]).await
1169    }
1170
1171    /// Subscribes to mark price data for derivatives instruments.
1172    ///
1173    /// Updates every 200ms for perpetual swaps, or at settlement for futures.
1174    ///
1175    /// # References
1176    ///
1177    /// <https://www.okx.com/docs-v5/en/#public-data-websocket-mark-price-channel>.
1178    pub async fn subscribe_mark_prices(
1179        &self,
1180        instrument_id: InstrumentId,
1181    ) -> Result<(), OKXWsError> {
1182        let arg = OKXSubscriptionArg {
1183            channel: OKXWsChannel::MarkPrice,
1184            inst_type: None,
1185            inst_family: None,
1186            inst_id: Some(instrument_id.symbol.inner()),
1187        };
1188        self.subscribe(vec![arg]).await
1189    }
1190
1191    /// Subscribes to index price data for an instrument.
1192    ///
1193    /// Updates every second with the underlying index price.
1194    ///
1195    /// # References
1196    ///
1197    /// <https://www.okx.com/docs-v5/en/#public-data-websocket-index-tickers-channel>.
1198    pub async fn subscribe_index_prices(
1199        &self,
1200        instrument_id: InstrumentId,
1201    ) -> Result<(), OKXWsError> {
1202        let arg = OKXSubscriptionArg {
1203            channel: OKXWsChannel::IndexTickers,
1204            inst_type: None,
1205            inst_family: None,
1206            inst_id: Some(instrument_id.symbol.inner()),
1207        };
1208        self.subscribe(vec![arg]).await
1209    }
1210
1211    /// Subscribes to funding rate data for perpetual swap instruments.
1212    ///
1213    /// Updates when funding rate changes or at funding intervals.
1214    ///
1215    /// # References
1216    ///
1217    /// <https://www.okx.com/docs-v5/en/#public-data-websocket-funding-rate-channel>.
1218    pub async fn subscribe_funding_rates(
1219        &self,
1220        instrument_id: InstrumentId,
1221    ) -> Result<(), OKXWsError> {
1222        let arg = OKXSubscriptionArg {
1223            channel: OKXWsChannel::FundingRate,
1224            inst_type: None,
1225            inst_family: None,
1226            inst_id: Some(instrument_id.symbol.inner()),
1227        };
1228        self.subscribe(vec![arg]).await
1229    }
1230
1231    /// Subscribes to candlestick/bar data for an instrument.
1232    ///
1233    /// Supports various time intervals from 1s to 3M.
1234    ///
1235    /// # References
1236    ///
1237    /// <https://www.okx.com/docs-v5/en/#order-book-trading-market-data-ws-candlesticks-channel>.
1238    pub async fn subscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1239        // Use regular trade-price candlesticks which work for all instrument types
1240        let channel = bar_spec_as_okx_channel(bar_type.spec())
1241            .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1242
1243        let arg = OKXSubscriptionArg {
1244            channel,
1245            inst_type: None,
1246            inst_family: None,
1247            inst_id: Some(bar_type.instrument_id().symbol.inner()),
1248        };
1249        self.subscribe(vec![arg]).await
1250    }
1251
1252    /// Unsubscribes from instrument updates for a specific instrument type.
1253    pub async fn unsubscribe_instruments(
1254        &self,
1255        instrument_type: OKXInstrumentType,
1256    ) -> Result<(), OKXWsError> {
1257        let arg = OKXSubscriptionArg {
1258            channel: OKXWsChannel::Instruments,
1259            inst_type: Some(instrument_type),
1260            inst_family: None,
1261            inst_id: None,
1262        };
1263        self.unsubscribe(vec![arg]).await
1264    }
1265
1266    /// Unsubscribe from instrument updates for a specific instrument.
1267    pub async fn unsubscribe_instrument(
1268        &self,
1269        instrument_id: InstrumentId,
1270    ) -> Result<(), OKXWsError> {
1271        let arg = OKXSubscriptionArg {
1272            channel: OKXWsChannel::Instruments,
1273            inst_type: None,
1274            inst_family: None,
1275            inst_id: Some(instrument_id.symbol.inner()),
1276        };
1277        self.unsubscribe(vec![arg]).await
1278    }
1279
1280    /// Unsubscribe from full order book data for an instrument.
1281    pub async fn unsubscribe_book(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1282        let arg = OKXSubscriptionArg {
1283            channel: OKXWsChannel::Books,
1284            inst_type: None,
1285            inst_family: None,
1286            inst_id: Some(instrument_id.symbol.inner()),
1287        };
1288        self.unsubscribe(vec![arg]).await
1289    }
1290
1291    /// Unsubscribe from 5-level order book snapshot data for an instrument.
1292    pub async fn unsubscribe_book_depth5(
1293        &self,
1294        instrument_id: InstrumentId,
1295    ) -> Result<(), OKXWsError> {
1296        let arg = OKXSubscriptionArg {
1297            channel: OKXWsChannel::Books5,
1298            inst_type: None,
1299            inst_family: None,
1300            inst_id: Some(instrument_id.symbol.inner()),
1301        };
1302        self.unsubscribe(vec![arg]).await
1303    }
1304
1305    /// Unsubscribe from 50-level tick-by-tick order book data for an instrument.
1306    pub async fn unsubscribe_book50_l2_tbt(
1307        &self,
1308        instrument_id: InstrumentId,
1309    ) -> Result<(), OKXWsError> {
1310        let arg = OKXSubscriptionArg {
1311            channel: OKXWsChannel::Books50Tbt,
1312            inst_type: None,
1313            inst_family: None,
1314            inst_id: Some(instrument_id.symbol.inner()),
1315        };
1316        self.unsubscribe(vec![arg]).await
1317    }
1318
1319    /// Unsubscribe from tick-by-tick full depth order book data for an instrument.
1320    pub async fn unsubscribe_book_l2_tbt(
1321        &self,
1322        instrument_id: InstrumentId,
1323    ) -> Result<(), OKXWsError> {
1324        let arg = OKXSubscriptionArg {
1325            channel: OKXWsChannel::BooksTbt,
1326            inst_type: None,
1327            inst_family: None,
1328            inst_id: Some(instrument_id.symbol.inner()),
1329        };
1330        self.unsubscribe(vec![arg]).await
1331    }
1332
1333    /// Unsubscribe from best bid/ask quote data for an instrument.
1334    pub async fn unsubscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1335        let arg = OKXSubscriptionArg {
1336            channel: OKXWsChannel::BboTbt,
1337            inst_type: None,
1338            inst_family: None,
1339            inst_id: Some(instrument_id.symbol.inner()),
1340        };
1341        self.unsubscribe(vec![arg]).await
1342    }
1343
1344    /// Unsubscribe from 24hr rolling ticker data for an instrument.
1345    pub async fn unsubscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1346        let arg = OKXSubscriptionArg {
1347            channel: OKXWsChannel::Tickers,
1348            inst_type: None,
1349            inst_family: None,
1350            inst_id: Some(instrument_id.symbol.inner()),
1351        };
1352        self.unsubscribe(vec![arg]).await
1353    }
1354
1355    /// Unsubscribe from mark price data for a derivatives instrument.
1356    pub async fn unsubscribe_mark_prices(
1357        &self,
1358        instrument_id: InstrumentId,
1359    ) -> Result<(), OKXWsError> {
1360        let arg = OKXSubscriptionArg {
1361            channel: OKXWsChannel::MarkPrice,
1362            inst_type: None,
1363            inst_family: None,
1364            inst_id: Some(instrument_id.symbol.inner()),
1365        };
1366        self.unsubscribe(vec![arg]).await
1367    }
1368
1369    /// Unsubscribe from index price data for an instrument.
1370    pub async fn unsubscribe_index_prices(
1371        &self,
1372        instrument_id: InstrumentId,
1373    ) -> Result<(), OKXWsError> {
1374        let arg = OKXSubscriptionArg {
1375            channel: OKXWsChannel::IndexTickers,
1376            inst_type: None,
1377            inst_family: None,
1378            inst_id: Some(instrument_id.symbol.inner()),
1379        };
1380        self.unsubscribe(vec![arg]).await
1381    }
1382
1383    /// Unsubscribe from funding rate data for a perpetual swap instrument.
1384    pub async fn unsubscribe_funding_rates(
1385        &self,
1386        instrument_id: InstrumentId,
1387    ) -> Result<(), OKXWsError> {
1388        let arg = OKXSubscriptionArg {
1389            channel: OKXWsChannel::FundingRate,
1390            inst_type: None,
1391            inst_family: None,
1392            inst_id: Some(instrument_id.symbol.inner()),
1393        };
1394        self.unsubscribe(vec![arg]).await
1395    }
1396
1397    /// Unsubscribe from trade data for an instrument.
1398    pub async fn unsubscribe_trades(
1399        &self,
1400        instrument_id: InstrumentId,
1401        _aggregated: bool,
1402    ) -> Result<(), OKXWsError> {
1403        // Use trades channel for all instruments (trades-all not available?)
1404        let channel = OKXWsChannel::Trades;
1405
1406        let arg = OKXSubscriptionArg {
1407            channel,
1408            inst_type: None,
1409            inst_family: None,
1410            inst_id: Some(instrument_id.symbol.inner()),
1411        };
1412        self.unsubscribe(vec![arg]).await
1413    }
1414
1415    /// Unsubscribe from candlestick/bar data for an instrument.
1416    pub async fn unsubscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1417        // Use regular trade-price candlesticks which work for all instrument types
1418        let channel = bar_spec_as_okx_channel(bar_type.spec())
1419            .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1420
1421        let arg = OKXSubscriptionArg {
1422            channel,
1423            inst_type: None,
1424            inst_family: None,
1425            inst_id: Some(bar_type.instrument_id().symbol.inner()),
1426        };
1427        self.unsubscribe(vec![arg]).await
1428    }
1429
1430    /// Subscribes to order updates for the given instrument type.
1431    pub async fn subscribe_orders(
1432        &self,
1433        instrument_type: OKXInstrumentType,
1434    ) -> Result<(), OKXWsError> {
1435        let arg = OKXSubscriptionArg {
1436            channel: OKXWsChannel::Orders,
1437            inst_type: Some(instrument_type),
1438            inst_family: None,
1439            inst_id: None,
1440        };
1441        self.subscribe(vec![arg]).await
1442    }
1443
1444    /// Unsubscribes from order updates for the given instrument type.
1445    pub async fn unsubscribe_orders(
1446        &self,
1447        instrument_type: OKXInstrumentType,
1448    ) -> Result<(), OKXWsError> {
1449        let arg = OKXSubscriptionArg {
1450            channel: OKXWsChannel::Orders,
1451            inst_type: Some(instrument_type),
1452            inst_family: None,
1453            inst_id: None,
1454        };
1455        self.unsubscribe(vec![arg]).await
1456    }
1457
1458    /// Subscribes to fill updates for the given instrument type.
1459    pub async fn subscribe_fills(
1460        &self,
1461        instrument_type: OKXInstrumentType,
1462    ) -> Result<(), OKXWsError> {
1463        let arg = OKXSubscriptionArg {
1464            channel: OKXWsChannel::Fills,
1465            inst_type: Some(instrument_type),
1466            inst_family: None,
1467            inst_id: None,
1468        };
1469        self.subscribe(vec![arg]).await
1470    }
1471
1472    /// Unsubscribes from fill updates for the given instrument type.
1473    pub async fn unsubscribe_fills(
1474        &self,
1475        instrument_type: OKXInstrumentType,
1476    ) -> Result<(), OKXWsError> {
1477        let arg = OKXSubscriptionArg {
1478            channel: OKXWsChannel::Fills,
1479            inst_type: Some(instrument_type),
1480            inst_family: None,
1481            inst_id: None,
1482        };
1483        self.unsubscribe(vec![arg]).await
1484    }
1485
1486    /// Subscribes to account balance updates.
1487    pub async fn subscribe_account(&self) -> Result<(), OKXWsError> {
1488        let arg = OKXSubscriptionArg {
1489            channel: OKXWsChannel::Account,
1490            inst_type: None,
1491            inst_family: None,
1492            inst_id: None,
1493        };
1494        self.subscribe(vec![arg]).await
1495    }
1496
1497    /// Unsubscribes from account balance updates.
1498    pub async fn unsubscribe_account(&self) -> Result<(), OKXWsError> {
1499        let arg = OKXSubscriptionArg {
1500            channel: OKXWsChannel::Account,
1501            inst_type: None,
1502            inst_family: None,
1503            inst_id: None,
1504        };
1505        self.unsubscribe(vec![arg]).await
1506    }
1507
1508    /// Cancel an existing order via WebSocket.
1509    ///
1510    /// # References
1511    ///
1512    /// <https://www.okx.com/docs-v5/en/#order-book-trading-websocket-cancel-order>
1513    async fn ws_cancel_order(
1514        &self,
1515        params: WsCancelOrderParams,
1516        request_id: Option<String>,
1517    ) -> Result<(), OKXWsError> {
1518        let request_id = request_id.unwrap_or(self.generate_unique_request_id());
1519
1520        let req = OKXWsRequest {
1521            id: Some(request_id),
1522            op: OKXWsOperation::CancelOrder,
1523            args: vec![params],
1524            exp_time: None,
1525        };
1526
1527        let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1528
1529        {
1530            let inner_guard = self.inner.read().await;
1531            if let Some(inner) = &*inner_guard {
1532                if let Err(e) = inner.send_text(txt, Some(vec!["cancel".to_string()])).await {
1533                    tracing::error!("Error sending message: {e:?}");
1534                }
1535                Ok(())
1536            } else {
1537                Err(OKXWsError::ClientError("Not connected".to_string()))
1538            }
1539        }
1540    }
1541
1542    #[allow(dead_code)] // TODO: Implement for MM pending orders
1543    /// Cancel multiple orders at once via WebSocket.
1544    ///
1545    /// # References
1546    ///
1547    /// <https://www.okx.com/docs-v5/en/#order-book-trading-websocket-mass-cancel-order>
1548    async fn ws_mass_cancel(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1549        // Generate unique request ID for WebSocket message
1550        let request_id = self
1551            .request_id_counter
1552            .fetch_add(1, Ordering::SeqCst)
1553            .to_string();
1554
1555        let req = OKXWsRequest {
1556            id: Some(request_id),
1557            op: OKXWsOperation::MassCancel,
1558            args,
1559            exp_time: None,
1560        };
1561
1562        let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1563
1564        {
1565            let inner_guard = self.inner.read().await;
1566            if let Some(inner) = &*inner_guard {
1567                if let Err(e) = inner.send_text(txt, Some(vec!["cancel".to_string()])).await {
1568                    tracing::error!("Error sending message: {e:?}");
1569                }
1570                Ok(())
1571            } else {
1572                Err(OKXWsError::ClientError("Not connected".to_string()))
1573            }
1574        }
1575    }
1576
1577    /// Amend an existing order via WebSocket.
1578    ///
1579    /// # References
1580    ///
1581    /// <https://www.okx.com/docs-v5/en/#order-book-trading-websocket-amend-order>
1582    async fn ws_amend_order(
1583        &self,
1584        params: WsAmendOrderParams,
1585        request_id: Option<String>,
1586    ) -> Result<(), OKXWsError> {
1587        let request_id = request_id.unwrap_or(self.generate_unique_request_id());
1588
1589        let req = OKXWsRequest {
1590            id: Some(request_id),
1591            op: OKXWsOperation::AmendOrder,
1592            args: vec![params],
1593            exp_time: None,
1594        };
1595
1596        let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1597
1598        {
1599            let inner_guard = self.inner.read().await;
1600            if let Some(inner) = &*inner_guard {
1601                if let Err(e) = inner.send_text(txt, Some(vec!["amend".to_string()])).await {
1602                    tracing::error!("Error sending message: {e:?}");
1603                }
1604                Ok(())
1605            } else {
1606                Err(OKXWsError::ClientError("Not connected".to_string()))
1607            }
1608        }
1609    }
1610
1611    /// Place multiple orders in a single batch via WebSocket.
1612    ///
1613    /// # References
1614    ///
1615    /// <https://www.okx.com/docs-v5/en/#order-book-trading-websocket-batch-orders>
1616    async fn ws_batch_place_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1617        let request_id = self.generate_unique_request_id();
1618
1619        let req = OKXWsRequest {
1620            id: Some(request_id),
1621            op: OKXWsOperation::BatchOrders,
1622            args,
1623            exp_time: None,
1624        };
1625
1626        let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1627
1628        {
1629            let inner_guard = self.inner.read().await;
1630            if let Some(inner) = &*inner_guard {
1631                if let Err(e) = inner.send_text(txt, Some(vec!["order".to_string()])).await {
1632                    tracing::error!("Error sending message: {e:?}");
1633                }
1634                Ok(())
1635            } else {
1636                Err(OKXWsError::ClientError("Not connected".to_string()))
1637            }
1638        }
1639    }
1640
1641    /// Cancel multiple orders in a single batch via WebSocket.
1642    ///
1643    /// # References
1644    ///
1645    /// <https://www.okx.com/docs-v5/en/#order-book-trading-websocket-batch-cancel-orders>
1646    async fn ws_batch_cancel_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1647        let request_id = self.generate_unique_request_id();
1648
1649        let req = OKXWsRequest {
1650            id: Some(request_id),
1651            op: OKXWsOperation::BatchCancelOrders,
1652            args,
1653            exp_time: None,
1654        };
1655
1656        let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1657
1658        {
1659            let inner_guard = self.inner.read().await;
1660            if let Some(inner) = &*inner_guard {
1661                if let Err(e) = inner.send_text(txt, Some(vec!["cancel".to_string()])).await {
1662                    tracing::error!("Error sending message: {e:?}");
1663                }
1664                Ok(())
1665            } else {
1666                Err(OKXWsError::ClientError("Not connected".to_string()))
1667            }
1668        }
1669    }
1670
1671    /// Amend multiple orders in a single batch via WebSocket.
1672    ///
1673    /// # References
1674    ///
1675    /// <https://www.okx.com/docs-v5/en/#order-book-trading-websocket-batch-amend-orders>
1676    async fn ws_batch_amend_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
1677        let request_id = self.generate_unique_request_id();
1678
1679        let req = OKXWsRequest {
1680            id: Some(request_id),
1681            op: OKXWsOperation::BatchAmendOrders,
1682            args,
1683            exp_time: None,
1684        };
1685
1686        let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1687
1688        {
1689            let inner_guard = self.inner.read().await;
1690            if let Some(inner) = &*inner_guard {
1691                if let Err(e) = inner.send_text(txt, Some(vec!["amend".to_string()])).await {
1692                    tracing::error!("Error sending message: {e:?}");
1693                }
1694                Ok(())
1695            } else {
1696                Err(OKXWsError::ClientError("Not connected".to_string()))
1697            }
1698        }
1699    }
1700
1701    /// Submits a new order.
1702    ///
1703    /// # References
1704    ///
1705    /// <https://www.okx.com/docs-v5/en/#order-book-trading-trade-ws-place-order>.
1706    #[allow(clippy::too_many_arguments)]
1707    pub async fn submit_order(
1708        &self,
1709        trader_id: TraderId,
1710        strategy_id: StrategyId,
1711        instrument_id: InstrumentId,
1712        td_mode: OKXTradeMode,
1713        client_order_id: ClientOrderId,
1714        order_side: OrderSide,
1715        order_type: OrderType,
1716        quantity: Quantity,
1717        time_in_force: Option<TimeInForce>,
1718        price: Option<Price>,
1719        trigger_price: Option<Price>,
1720        post_only: Option<bool>,
1721        reduce_only: Option<bool>,
1722        quote_quantity: Option<bool>,
1723        position_side: Option<PositionSide>,
1724    ) -> Result<(), OKXWsError> {
1725        if !OKX_SUPPORTED_ORDER_TYPES.contains(&order_type) {
1726            return Err(OKXWsError::ClientError(format!(
1727                "Unsupported order type: {order_type:?}",
1728            )));
1729        }
1730
1731        if let Some(tif) = time_in_force
1732            && !OKX_SUPPORTED_TIME_IN_FORCE.contains(&tif)
1733        {
1734            return Err(OKXWsError::ClientError(format!(
1735                "Unsupported time in force: {tif:?}",
1736            )));
1737        }
1738
1739        let mut builder = WsPostOrderParamsBuilder::default();
1740
1741        builder.inst_id(instrument_id.symbol.as_str());
1742        builder.td_mode(td_mode);
1743        builder.cl_ord_id(client_order_id.as_str());
1744
1745        let instrument = self
1746            .instruments_cache
1747            .get(&instrument_id.symbol.inner())
1748            .ok_or_else(|| {
1749                OKXWsError::ClientError(format!("Unknown instrument {instrument_id}"))
1750            })?;
1751
1752        let instrument_type =
1753            okx_instrument_type(instrument).map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1754        let quote_currency = instrument.quote_currency();
1755
1756        match instrument_type {
1757            OKXInstrumentType::Spot => {
1758                // Defaults
1759            }
1760            OKXInstrumentType::Margin => {
1761                // MARGIN: use quote currency for margin
1762                builder.ccy(quote_currency.to_string());
1763
1764                // TODO: Consider position mode (only applicable for NET)
1765                if let Some(ro) = reduce_only
1766                    && ro
1767                {
1768                    builder.reduce_only(ro);
1769                }
1770            }
1771            OKXInstrumentType::Swap | OKXInstrumentType::Futures => {
1772                // SWAP/FUTURES: use quote currency for margin (required by OKX)
1773                builder.ccy(quote_currency.to_string());
1774            }
1775            _ => {
1776                // For other instrument types (OPTIONS, etc.), use quote currency as fallback
1777                builder.ccy(quote_currency.to_string());
1778                builder.tgt_ccy(quote_currency.to_string());
1779
1780                // TODO: Consider position mode (only applicable for NET)
1781                if let Some(ro) = reduce_only
1782                    && ro
1783                {
1784                    builder.reduce_only(ro);
1785                }
1786            }
1787        };
1788
1789        if let Some(is_quote_quantity) = quote_quantity
1790            && is_quote_quantity
1791        {
1792            builder.tgt_ccy(quote_currency.to_string());
1793        }
1794        // If is_quote_quantity is false, we don't set tgtCcy (defaults to base currency)
1795
1796        builder.side(OKXSide::from(order_side));
1797
1798        if let Some(pos_side) = position_side {
1799            builder.pos_side(pos_side);
1800        };
1801
1802        // Determine OKX order type based on order type and post_only
1803        let okx_ord_type = if post_only.unwrap_or(false) {
1804            OKXOrderType::PostOnly
1805        } else {
1806            OKXOrderType::from(order_type)
1807        };
1808
1809        log::debug!(
1810            "Order type mapping: order_type={:?}, time_in_force={:?}, post_only={:?} -> okx_ord_type={:?}",
1811            order_type,
1812            time_in_force,
1813            post_only,
1814            okx_ord_type
1815        );
1816
1817        builder.ord_type(okx_ord_type);
1818        builder.sz(quantity.to_string());
1819
1820        if let Some(tp) = trigger_price {
1821            builder.px(tp.to_string());
1822        } else if let Some(p) = price {
1823            builder.px(p.to_string());
1824        }
1825
1826        builder.tag(OKX_NAUTILUS_BROKER_ID);
1827
1828        let params = builder
1829            .build()
1830            .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
1831
1832        // TODO: Log the full order parameters being sent (for development)
1833        log::debug!("Sending order params to OKX: {:?}", params);
1834
1835        let request_id = self.generate_unique_request_id();
1836
1837        self.pending_place_requests.insert(
1838            request_id.clone(),
1839            (client_order_id, trader_id, strategy_id, instrument_id),
1840        );
1841
1842        self.ws_place_order(params, Some(request_id)).await
1843    }
1844
1845    /// Cancels an existing order.
1846    ///
1847    /// # References
1848    ///
1849    /// <https://www.okx.com/docs-v5/en/#order-book-trading-trade-ws-cancel-order>.
1850    #[allow(clippy::too_many_arguments)]
1851    pub async fn cancel_order(
1852        &self,
1853        trader_id: TraderId,
1854        strategy_id: StrategyId,
1855        instrument_id: InstrumentId,
1856        client_order_id: Option<ClientOrderId>,
1857        venue_order_id: Option<VenueOrderId>,
1858    ) -> Result<(), OKXWsError> {
1859        let mut builder = WsCancelOrderParamsBuilder::default();
1860        // Note: instType should NOT be included in cancel order requests
1861        // For WebSocket orders, use the full symbol (including SWAP/FUTURES suffix if present)
1862        builder.inst_id(instrument_id.symbol.as_str());
1863
1864        if let Some(venue_order_id) = venue_order_id {
1865            builder.ord_id(venue_order_id.as_str());
1866        }
1867
1868        let params = builder
1869            .build()
1870            .map_err(|e| OKXWsError::ClientError(format!("Build cancel params error: {e}")))?;
1871
1872        let request_id = self.generate_unique_request_id();
1873
1874        // External orders may not have a client order ID,
1875        // for now we just track those with a client order ID as pending requests.
1876        if let Some(client_order_id) = client_order_id {
1877            builder.cl_ord_id(client_order_id.as_str());
1878
1879            self.pending_cancel_requests.insert(
1880                request_id.clone(),
1881                (
1882                    client_order_id,
1883                    trader_id,
1884                    strategy_id,
1885                    instrument_id,
1886                    venue_order_id,
1887                ),
1888            );
1889        }
1890
1891        self.ws_cancel_order(params, Some(request_id)).await
1892    }
1893
1894    /// Place a new order via WebSocket.
1895    ///
1896    /// # References
1897    ///
1898    /// <https://www.okx.com/docs-v5/en/#order-book-trading-websocket-place-order>
1899    async fn ws_place_order(
1900        &self,
1901        params: WsPostOrderParams,
1902        request_id: Option<String>,
1903    ) -> Result<(), OKXWsError> {
1904        let request_id = request_id.unwrap_or(self.generate_unique_request_id());
1905
1906        let req = OKXWsRequest {
1907            id: Some(request_id),
1908            op: OKXWsOperation::Order,
1909            exp_time: None,
1910            args: vec![params],
1911        };
1912
1913        let txt = serde_json::to_string(&req).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
1914
1915        {
1916            let inner_guard = self.inner.read().await;
1917            if let Some(inner) = &*inner_guard {
1918                if let Err(e) = inner.send_text(txt, Some(vec!["order".to_string()])).await {
1919                    tracing::error!("Error sending message: {e:?}");
1920                }
1921                Ok(())
1922            } else {
1923                Err(OKXWsError::ClientError("Not connected".to_string()))
1924            }
1925        }
1926    }
1927
1928    /// Modifies an existing order.
1929    ///
1930    /// # References
1931    ///
1932    /// <https://www.okx.com/docs-v5/en/#order-book-trading-trade-ws-amend-order>.
1933    #[allow(clippy::too_many_arguments)]
1934    pub async fn modify_order(
1935        &self,
1936        trader_id: TraderId,
1937        strategy_id: StrategyId,
1938        instrument_id: InstrumentId,
1939        client_order_id: Option<ClientOrderId>,
1940        price: Option<Price>,
1941        quantity: Option<Quantity>,
1942        venue_order_id: Option<VenueOrderId>,
1943    ) -> Result<(), OKXWsError> {
1944        let mut builder = WsAmendOrderParamsBuilder::default();
1945
1946        builder.inst_id(instrument_id.symbol.as_str());
1947
1948        if let Some(venue_order_id) = venue_order_id {
1949            builder.ord_id(venue_order_id.as_str());
1950        }
1951
1952        if let Some(client_order_id) = client_order_id {
1953            builder.cl_ord_id(client_order_id.as_str());
1954        }
1955
1956        if let Some(price) = price {
1957            builder.new_px(price.to_string());
1958        }
1959
1960        if let Some(quantity) = quantity {
1961            builder.new_sz(quantity.to_string());
1962        }
1963
1964        let params = builder
1965            .build()
1966            .map_err(|e| OKXWsError::ClientError(format!("Build amend params error: {e}")))?;
1967
1968        // Generate unique request ID for WebSocket message
1969        let request_id = self
1970            .request_id_counter
1971            .fetch_add(1, Ordering::SeqCst)
1972            .to_string();
1973
1974        // External orders may not have a client order ID,
1975        // for now we just track those with a client order ID as pending requests.
1976        if let Some(client_order_id) = client_order_id {
1977            self.pending_amend_requests.insert(
1978                request_id.clone(),
1979                (
1980                    client_order_id,
1981                    trader_id,
1982                    strategy_id,
1983                    instrument_id,
1984                    venue_order_id,
1985                ),
1986            );
1987        }
1988
1989        self.ws_amend_order(params, Some(request_id)).await
1990    }
1991
1992    /// Submits multiple orders.
1993    #[allow(clippy::type_complexity)]
1994    #[allow(clippy::too_many_arguments)]
1995    pub async fn batch_submit_orders(
1996        &self,
1997        orders: Vec<(
1998            OKXInstrumentType,
1999            InstrumentId,
2000            OKXTradeMode,
2001            ClientOrderId,
2002            OrderSide,
2003            Option<PositionSide>,
2004            OrderType,
2005            Quantity,
2006            Option<Price>,
2007            Option<Price>,
2008            Option<bool>,
2009            Option<bool>,
2010        )>,
2011    ) -> Result<(), OKXWsError> {
2012        let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2013        for (
2014            inst_type,
2015            inst_id,
2016            td_mode,
2017            cl_ord_id,
2018            ord_side,
2019            pos_side,
2020            ord_type,
2021            qty,
2022            pr,
2023            tp,
2024            post_only,
2025            reduce_only,
2026        ) in orders
2027        {
2028            let mut builder = WsPostOrderParamsBuilder::default();
2029            builder.inst_type(inst_type);
2030            builder.inst_id(inst_id.symbol.inner());
2031            builder.td_mode(td_mode);
2032            builder.cl_ord_id(cl_ord_id.as_str());
2033            builder.side(OKXSide::from(ord_side));
2034
2035            if let Some(ps) = pos_side {
2036                builder.pos_side(OKXPositionSide::from(ps));
2037            }
2038
2039            let okx_ord_type = if post_only.unwrap_or(false) {
2040                OKXOrderType::PostOnly
2041            } else {
2042                OKXOrderType::from(ord_type)
2043            };
2044
2045            builder.ord_type(okx_ord_type);
2046            builder.sz(qty.to_string());
2047
2048            if let Some(p) = pr {
2049                builder.px(p.to_string());
2050            } else if let Some(p) = tp {
2051                builder.px(p.to_string());
2052            }
2053
2054            if let Some(ro) = reduce_only {
2055                builder.reduce_only(ro);
2056            }
2057
2058            builder.tag(OKX_NAUTILUS_BROKER_ID);
2059
2060            let params = builder
2061                .build()
2062                .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
2063            let val =
2064                serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2065            args.push(val);
2066        }
2067
2068        self.ws_batch_place_orders(args).await
2069    }
2070
2071    /// Cancels multiple orders.
2072    #[allow(clippy::type_complexity)]
2073    pub async fn batch_cancel_orders(
2074        &self,
2075        orders: Vec<(
2076            OKXInstrumentType,
2077            InstrumentId,
2078            Option<ClientOrderId>,
2079            Option<String>,
2080        )>,
2081    ) -> Result<(), OKXWsError> {
2082        let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2083        for (_inst_type, inst_id, cl_ord_id, ord_id) in orders {
2084            let mut builder = WsCancelOrderParamsBuilder::default();
2085            // Note: instType should NOT be included in cancel order requests
2086            builder.inst_id(inst_id.symbol.inner());
2087
2088            if let Some(c) = cl_ord_id {
2089                builder.cl_ord_id(c.as_str());
2090            }
2091
2092            if let Some(o) = ord_id {
2093                builder.ord_id(o);
2094            }
2095
2096            let params = builder.build().map_err(|e| {
2097                OKXWsError::ClientError(format!("Build cancel batch params error: {e}"))
2098            })?;
2099            let val =
2100                serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2101            args.push(val);
2102        }
2103
2104        self.ws_batch_cancel_orders(args).await
2105    }
2106
2107    /// Modifies multiple orders via WebSocket using Nautilus domain types.
2108    #[allow(clippy::type_complexity)]
2109    #[allow(clippy::too_many_arguments)]
2110    pub async fn batch_modify_orders(
2111        &self,
2112        orders: Vec<(
2113            OKXInstrumentType,
2114            InstrumentId,
2115            ClientOrderId,
2116            ClientOrderId,
2117            Option<Price>,
2118            Option<Quantity>,
2119        )>,
2120    ) -> Result<(), OKXWsError> {
2121        let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2122        for (_inst_type, inst_id, cl_ord_id, new_cl_ord_id, pr, sz) in orders {
2123            let mut builder = WsAmendOrderParamsBuilder::default();
2124            // Note: instType should NOT be included in amend order requests
2125            builder.inst_id(inst_id.symbol.inner());
2126            builder.cl_ord_id(cl_ord_id.as_str());
2127            builder.new_cl_ord_id(new_cl_ord_id.as_str());
2128
2129            if let Some(p) = pr {
2130                builder.new_px(p.to_string());
2131            }
2132
2133            if let Some(q) = sz {
2134                builder.new_sz(q.to_string());
2135            }
2136
2137            let params = builder.build().map_err(|e| {
2138                OKXWsError::ClientError(format!("Build amend batch params error: {e}"))
2139            })?;
2140            let val =
2141                serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2142            args.push(val);
2143        }
2144
2145        self.ws_batch_amend_orders(args).await
2146    }
2147}
2148
2149struct OKXFeedHandler {
2150    receiver: UnboundedReceiver<Message>,
2151    signal: Arc<AtomicBool>,
2152}
2153
2154impl OKXFeedHandler {
2155    /// Creates a new [`OKXFeedHandler`] instance.
2156    pub fn new(receiver: UnboundedReceiver<Message>, signal: Arc<AtomicBool>) -> Self {
2157        Self { receiver, signal }
2158    }
2159
2160    /// Gets the next message from the WebSocket stream.
2161    async fn next(&mut self) -> Option<OKXWebSocketEvent> {
2162        loop {
2163            tokio::select! {
2164                msg = self.receiver.recv() => match msg {
2165                    Some(msg) => match msg {
2166                        Message::Text(text) => {
2167                            // Check for reconnection signal
2168                            if text == RECONNECTED {
2169                                tracing::info!("Received WebSocket reconnection signal");
2170                                return Some(OKXWebSocketEvent::Reconnected);
2171                            }
2172                            tracing::trace!("Received WebSocket message: {text}");
2173
2174                            match serde_json::from_str(&text) {
2175                                Ok(ws_event) => match &ws_event {
2176                                    OKXWebSocketEvent::Error { code, msg } => {
2177                                        tracing::error!("WebSocket error: {code} - {msg}");
2178                                        return Some(ws_event);
2179                                    }
2180                                    OKXWebSocketEvent::Login {
2181                                        event,
2182                                        code,
2183                                        msg,
2184                                        conn_id,
2185                                    } => {
2186                                        if code == "0" {
2187                                            tracing::info!(
2188                                                "Successfully authenticated with OKX WebSocket, conn_id={conn_id}"
2189                                            );
2190                                        } else {
2191                                            tracing::error!(
2192                                                "Authentication failed: {event} {code} - {msg}"
2193                                            );
2194                                        }
2195                                        return Some(ws_event);
2196                                    }
2197                                    OKXWebSocketEvent::Subscription {
2198                                        event,
2199                                        arg,
2200                                        conn_id,
2201                                    } => {
2202                                        let channel_str = serde_json::to_string(&arg.channel)
2203                                            .expect("Invalid OKX websocket channel")
2204                                            .trim_matches('"')
2205                                            .to_string();
2206                                        tracing::debug!(
2207                                            "{event}d: channel={channel_str}, conn_id={conn_id}"
2208                                        );
2209                                        continue;
2210                                    }
2211                                    OKXWebSocketEvent::ChannelConnCount {
2212                                        event: _,
2213                                        channel,
2214                                        conn_count,
2215                                        conn_id,
2216                                    } => {
2217                                        let channel_str = serde_json::to_string(&channel)
2218                                            .expect("Invalid OKX websocket channel")
2219                                            .trim_matches('"')
2220                                            .to_string();
2221                                        tracing::debug!(
2222                                            "Channel connection status: channel={channel_str}, connections={conn_count}, conn_id={conn_id}",
2223                                        );
2224                                        continue;
2225                                    }
2226                                    OKXWebSocketEvent::Data { .. } => return Some(ws_event),
2227                                    OKXWebSocketEvent::BookData { .. } => return Some(ws_event),
2228                                    OKXWebSocketEvent::OrderResponse {
2229                                        id,
2230                                        op,
2231                                        code,
2232                                        msg,
2233                                        data,
2234                                    } => {
2235                                        if code == "0" {
2236                                            tracing::debug!(
2237                                                "Order operation successful: id={:?}, op={op}, code={code}",
2238                                                id
2239                                            );
2240
2241                                            // Extract success message
2242                                            if let Some(order_data) = data.first() {
2243                                                let success_msg = order_data
2244                                                    .get("sMsg")
2245                                                    .and_then(|s| s.as_str())
2246                                                    .unwrap_or("Order operation successful");
2247                                                tracing::debug!("Order success details: {success_msg}");
2248                                            }
2249                                        } else {
2250                                            // Extract error message
2251                                            let error_msg = data
2252                                                .first()
2253                                                .and_then(|d| d.get("sMsg"))
2254                                                .and_then(|s| s.as_str())
2255                                                .unwrap_or(msg.as_str());
2256                                            tracing::error!(
2257                                                "Order operation failed: id={id:?}, op={op}, code={code}, error={error_msg}",
2258                                            );
2259                                        }
2260                                        return Some(ws_event);
2261                                    }
2262                                    OKXWebSocketEvent::Reconnected => {
2263                                        // This shouldn't happen as we handle RECONNECTED string directly
2264                                        tracing::warn!("Unexpected Reconnected event from deserialization");
2265                                        continue;
2266                                    }
2267                                },
2268                                Err(e) => {
2269                                    tracing::error!("Failed to parse message: {e}: {text}");
2270                                    return None;
2271                                }
2272                            }
2273                        }
2274                        Message::Binary(msg) => {
2275                            tracing::debug!("Raw binary: {msg:?}");
2276                        }
2277                        Message::Close(_) => {
2278                            tracing::debug!("Received close message");
2279                            return None;
2280                        }
2281                        msg => {
2282                            tracing::warn!("Unexpected message: {msg}");
2283                        }
2284                    }
2285                    None => {
2286                        tracing::info!("WebSocket stream closed");
2287                        return None;
2288                    }
2289                },
2290                _ = tokio::time::sleep(Duration::from_millis(1)) => {
2291                    if self.signal.load(std::sync::atomic::Ordering::Relaxed) {
2292                        tracing::debug!("Stop signal received");
2293                        return None;
2294                    }
2295                }
2296            }
2297        }
2298    }
2299}
2300
2301struct OKXWsMessageHandler {
2302    account_id: AccountId,
2303    handler: OKXFeedHandler,
2304    #[allow(dead_code)]
2305    tx: tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>,
2306    pending_place_requests: Arc<DashMap<String, PlaceRequestData>>,
2307    pending_cancel_requests: Arc<DashMap<String, CancelRequestData>>,
2308    pending_amend_requests: Arc<DashMap<String, AmendRequestData>>,
2309    instruments_cache: Arc<AHashMap<Ustr, InstrumentAny>>,
2310    last_account_state: Option<AccountState>,
2311    fee_cache: AHashMap<Ustr, Money>, // Key is order ID
2312    funding_rate_cache: AHashMap<Ustr, (Ustr, u64)>, // Cache (funding_rate, funding_time) by inst_id
2313    auth_state: Arc<tokio::sync::watch::Sender<bool>>,
2314}
2315
2316impl OKXWsMessageHandler {
2317    /// Creates a new [`OKXFeedHandler`] instance.
2318    #[allow(clippy::too_many_arguments)]
2319    pub fn new(
2320        account_id: AccountId,
2321        instruments_cache: Arc<AHashMap<Ustr, InstrumentAny>>,
2322        reader: UnboundedReceiver<Message>,
2323        signal: Arc<AtomicBool>,
2324        tx: tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>,
2325        pending_place_requests: Arc<DashMap<String, PlaceRequestData>>,
2326        pending_cancel_requests: Arc<DashMap<String, CancelRequestData>>,
2327        pending_amend_requests: Arc<DashMap<String, AmendRequestData>>,
2328        auth_state: Arc<tokio::sync::watch::Sender<bool>>,
2329    ) -> Self {
2330        Self {
2331            account_id,
2332            handler: OKXFeedHandler::new(reader, signal),
2333            tx,
2334            pending_place_requests,
2335            pending_cancel_requests,
2336            pending_amend_requests,
2337            instruments_cache,
2338            last_account_state: None,
2339            fee_cache: AHashMap::new(),
2340            funding_rate_cache: AHashMap::new(),
2341            auth_state,
2342        }
2343    }
2344
2345    fn is_stopped(&self) -> bool {
2346        self.handler
2347            .signal
2348            .load(std::sync::atomic::Ordering::Relaxed)
2349    }
2350
2351    #[allow(dead_code)]
2352    async fn run(&mut self) {
2353        while let Some(data) = self.next().await {
2354            if let Err(e) = self.tx.send(data) {
2355                tracing::error!("Error sending data: {e}");
2356                break; // Stop processing on channel error for now
2357            }
2358        }
2359    }
2360
2361    async fn next(&mut self) -> Option<NautilusWsMessage> {
2362        let clock = get_atomic_clock_realtime();
2363
2364        while let Some(event) = self.handler.next().await {
2365            let ts_init = clock.get_time_ns();
2366
2367            if let OKXWebSocketEvent::Login { code, msg, .. } = event {
2368                if code == "0" {
2369                    if self.auth_state.send(true).is_err() {
2370                        tracing::error!(
2371                            "Failed to send authentication success signal: receiver dropped"
2372                        );
2373                    }
2374                } else {
2375                    tracing::error!("Authentication failed: {msg}");
2376                    if self.auth_state.send(false).is_err() {
2377                        tracing::error!(
2378                            "Failed to send authentication failure signal: receiver dropped"
2379                        );
2380                    }
2381                }
2382                continue; // Don't forward login events as Nautilus messages
2383            }
2384
2385            if let OKXWebSocketEvent::BookData { arg, action, data } = event {
2386                let inst = match arg.inst_id {
2387                    Some(inst_id) => match self.instruments_cache.get(&inst_id) {
2388                        Some(inst_ref) => inst_ref.clone(),
2389                        None => continue,
2390                    },
2391                    None => {
2392                        tracing::error!("Instrument ID missing for book data event");
2393                        continue;
2394                    }
2395                };
2396
2397                let instrument_id = inst.id();
2398                let price_precision = inst.price_precision();
2399                let size_precision = inst.size_precision();
2400
2401                match parse_book_msg_vec(
2402                    data,
2403                    &instrument_id,
2404                    price_precision,
2405                    size_precision,
2406                    action,
2407                    ts_init,
2408                ) {
2409                    Ok(data) => return Some(NautilusWsMessage::Data(data)),
2410                    Err(e) => {
2411                        tracing::error!("Failed to parse book message: {e}");
2412                        continue;
2413                    }
2414                }
2415            }
2416
2417            if let OKXWebSocketEvent::OrderResponse {
2418                id,
2419                op,
2420                code,
2421                msg,
2422                data,
2423            } = event
2424            {
2425                if code == "0" {
2426                    tracing::debug!(
2427                        "Order operation successful: id={:?} op={op} code={code}",
2428                        id
2429                    );
2430
2431                    if let Some(data) = data.first() {
2432                        let success_msg = data
2433                            .get("sMsg")
2434                            .and_then(|s| s.as_str())
2435                            .unwrap_or("Order operation successful");
2436                        tracing::debug!("Order details: {success_msg}");
2437
2438                        // Note: We rely on the orders channel subscription to provide the proper
2439                        // OrderStatusReport with correct instrument ID and full order details.
2440                        // The placement response has limited information.
2441                    }
2442                } else {
2443                    // Extract actual error message from data array, same as in the handler
2444                    let error_msg = data
2445                        .first()
2446                        .and_then(|d| d.get("sMsg"))
2447                        .and_then(|s| s.as_str())
2448                        .unwrap_or(&msg);
2449
2450                    // Debug: Check what fields are available in error data
2451                    if let Some(data_obj) = data.first() {
2452                        tracing::debug!(
2453                            "Error data fields: {}",
2454                            serde_json::to_string_pretty(data_obj)
2455                                .unwrap_or_else(|_| "unable to serialize".to_string())
2456                        );
2457                    }
2458
2459                    tracing::error!(
2460                        "Order operation failed: id={:?} op={op} code={code} msg={msg}",
2461                        id
2462                    );
2463
2464                    // Fetch pending request mapping for rejection based on operation type
2465                    if let Some(id) = &id {
2466                        match op {
2467                            OKXWsOperation::Order => {
2468                                if let Some((
2469                                    _,
2470                                    (client_order_id, trader_id, strategy_id, instrument_id),
2471                                )) = self.pending_place_requests.remove(id)
2472                                {
2473                                    let ts_event = clock.get_time_ns();
2474                                    let rejected = OrderRejected::new(
2475                                        trader_id,
2476                                        strategy_id,
2477                                        instrument_id,
2478                                        client_order_id,
2479                                        self.account_id,
2480                                        Ustr::from(error_msg), // Rejection reason from OKX
2481                                        UUID4::new(),
2482                                        ts_event,
2483                                        ts_init,
2484                                        false, // Not from reconciliation
2485                                        false, // Not due to post-only (TODO: parse error_msg)
2486                                    );
2487
2488                                    return Some(NautilusWsMessage::OrderRejected(rejected));
2489                                }
2490                            }
2491                            OKXWsOperation::CancelOrder => {
2492                                if let Some((
2493                                    _,
2494                                    (
2495                                        client_order_id,
2496                                        trader_id,
2497                                        strategy_id,
2498                                        instrument_id,
2499                                        venue_order_id,
2500                                    ),
2501                                )) = self.pending_cancel_requests.remove(id)
2502                                {
2503                                    let ts_event = clock.get_time_ns();
2504                                    let rejected = OrderCancelRejected::new(
2505                                        trader_id,
2506                                        strategy_id,
2507                                        instrument_id,
2508                                        client_order_id,
2509                                        Ustr::from(error_msg), // Rejection reason from OKX
2510                                        UUID4::new(),
2511                                        ts_event,
2512                                        ts_init,
2513                                        false, // Not from reconciliation
2514                                        venue_order_id,
2515                                        Some(self.account_id),
2516                                    );
2517
2518                                    return Some(NautilusWsMessage::OrderCancelRejected(rejected));
2519                                }
2520                            }
2521                            OKXWsOperation::AmendOrder => {
2522                                if let Some((
2523                                    _,
2524                                    (
2525                                        client_order_id,
2526                                        trader_id,
2527                                        strategy_id,
2528                                        instrument_id,
2529                                        venue_order_id,
2530                                    ),
2531                                )) = self.pending_amend_requests.remove(id)
2532                                {
2533                                    let ts_event = clock.get_time_ns();
2534                                    let rejected = OrderModifyRejected::new(
2535                                        trader_id,
2536                                        strategy_id,
2537                                        instrument_id,
2538                                        client_order_id,
2539                                        Ustr::from(error_msg), // Rejection reason from OKX
2540                                        UUID4::new(),
2541                                        ts_event,
2542                                        ts_init,
2543                                        false, // Not from reconciliation
2544                                        venue_order_id,
2545                                        Some(self.account_id),
2546                                    );
2547
2548                                    return Some(NautilusWsMessage::OrderModifyRejected(rejected));
2549                                }
2550                            }
2551                            _ => {
2552                                tracing::warn!("Unhandled operation type for rejection: {op}");
2553                            }
2554                        }
2555                    }
2556
2557                    // Fallback to error if no mapping found
2558                    let error = OKXWebSocketError {
2559                        code: code.clone(),
2560                        message: error_msg.to_string(),
2561                        conn_id: None, // Order responses don't have connection IDs
2562                        timestamp: clock.get_time_ns().as_u64(),
2563                    };
2564                    return Some(NautilusWsMessage::Error(error));
2565                }
2566                continue;
2567            }
2568
2569            if let OKXWebSocketEvent::Data { ref arg, ref data } = event {
2570                if arg.channel == OKXWsChannel::Account {
2571                    match serde_json::from_value::<Vec<OKXAccount>>(data.clone()) {
2572                        Ok(accounts) => {
2573                            if let Some(account) = accounts.first() {
2574                                // Account ID is provided from client configuration
2575                                match parse_account_state(account, self.account_id, ts_init) {
2576                                    Ok(account_state) => {
2577                                        // TODO: Optimize this account state comparison
2578                                        if let Some(last_account_state) = &self.last_account_state
2579                                            && account_state
2580                                                .has_same_balances_and_margins(last_account_state)
2581                                        {
2582                                            continue; // Nothing to update
2583                                        }
2584                                        self.last_account_state = Some(account_state.clone());
2585                                        return Some(NautilusWsMessage::AccountUpdate(
2586                                            account_state,
2587                                        ));
2588                                    }
2589                                    Err(e) => {
2590                                        tracing::error!("Failed to parse account state: {e}");
2591                                    }
2592                                }
2593                            }
2594                        }
2595                        Err(e) => {
2596                            tracing::error!(
2597                                "Failed to parse account data: {e}, raw data: {}",
2598                                data
2599                            );
2600                        }
2601                    }
2602                    continue;
2603                }
2604
2605                if arg.channel == OKXWsChannel::Orders {
2606                    tracing::debug!("Received orders channel message: {data}");
2607
2608                    let data: Vec<OKXOrderMsg> = serde_json::from_value(data.clone()).unwrap();
2609
2610                    let mut exec_reports = Vec::with_capacity(data.len());
2611
2612                    for msg in data {
2613                        match parse_order_msg_vec(
2614                            vec![msg],
2615                            self.account_id,
2616                            &self.instruments_cache,
2617                            &self.fee_cache,
2618                            ts_init,
2619                        ) {
2620                            Ok(mut reports) => {
2621                                // Update fee cache based on the new reports
2622                                for report in &reports {
2623                                    match report {
2624                                        ExecutionReport::Fill(fill_report) => {
2625                                            let order_id = fill_report.venue_order_id.inner();
2626                                            let current_fee = self
2627                                                .fee_cache
2628                                                .get(&order_id)
2629                                                .copied()
2630                                                .unwrap_or_else(|| {
2631                                                    Money::new(0.0, fill_report.commission.currency)
2632                                                });
2633                                            let total_fee = current_fee + fill_report.commission;
2634                                            self.fee_cache.insert(order_id, total_fee);
2635                                        }
2636                                        ExecutionReport::Order(status_report) => {
2637                                            if matches!(
2638                                                status_report.order_status,
2639                                                OrderStatus::Filled,
2640                                            ) {
2641                                                self.fee_cache
2642                                                    .remove(&status_report.venue_order_id.inner());
2643                                            }
2644                                        }
2645                                    }
2646                                }
2647                                exec_reports.append(&mut reports);
2648                            }
2649                            Err(e) => {
2650                                tracing::error!("Failed to parse order message: {e}");
2651                                continue;
2652                            }
2653                        }
2654                    }
2655
2656                    if !exec_reports.is_empty() {
2657                        return Some(NautilusWsMessage::ExecutionReports(exec_reports));
2658                    }
2659                }
2660
2661                let inst = match arg.inst_id.and_then(|id| self.instruments_cache.get(&id)) {
2662                    Some(inst) => inst,
2663                    None => {
2664                        tracing::error!(
2665                            "No instrument for channel {:?}, inst_id {:?}",
2666                            arg.channel,
2667                            arg.inst_id
2668                        );
2669                        continue;
2670                    }
2671                };
2672                let instrument_id = inst.id();
2673                let price_precision = inst.price_precision();
2674                let size_precision = inst.size_precision();
2675
2676                match parse_ws_message_data(
2677                    &arg.channel,
2678                    data.clone(),
2679                    &instrument_id,
2680                    price_precision,
2681                    size_precision,
2682                    ts_init,
2683                    &mut self.funding_rate_cache,
2684                ) {
2685                    Ok(Some(msg)) => return Some(msg),
2686                    Ok(None) => {
2687                        // No message to return (e.g., empty instrument payload)
2688                        continue;
2689                    }
2690                    Err(e) => {
2691                        tracing::error!("Error parsing message for channel {:?}: {e}", arg.channel)
2692                    }
2693                }
2694            }
2695
2696            // Handle login events (authentication failures)
2697            if let OKXWebSocketEvent::Login {
2698                code, msg, conn_id, ..
2699            } = &event
2700                && code != "0"
2701            {
2702                let error = OKXWebSocketError {
2703                    code: code.clone(),
2704                    message: msg.clone(),
2705                    conn_id: Some(conn_id.clone()),
2706                    timestamp: clock.get_time_ns().as_u64(),
2707                };
2708                return Some(NautilusWsMessage::Error(error));
2709            }
2710
2711            // Handle general error events
2712            if let OKXWebSocketEvent::Error { code, msg } = &event {
2713                let error = OKXWebSocketError {
2714                    code: code.clone(),
2715                    message: msg.clone(),
2716                    conn_id: None,
2717                    timestamp: clock.get_time_ns().as_u64(),
2718                };
2719                return Some(NautilusWsMessage::Error(error));
2720            }
2721
2722            // Handle reconnection signal
2723            if matches!(&event, OKXWebSocketEvent::Reconnected) {
2724                return Some(NautilusWsMessage::Reconnected);
2725            }
2726        }
2727        None // Connection closed
2728    }
2729}
2730
2731////////////////////////////////////////////////////////////////////////////////
2732// Tests
2733////////////////////////////////////////////////////////////////////////////////
2734
2735#[cfg(test)]
2736mod tests {
2737    use futures_util;
2738    use rstest::rstest;
2739
2740    use super::*;
2741
2742    #[rstest]
2743    fn test_timestamp_format_for_websocket_auth() {
2744        let timestamp = SystemTime::now()
2745            .duration_since(SystemTime::UNIX_EPOCH)
2746            .expect("System time should be after UNIX epoch")
2747            .as_secs()
2748            .to_string();
2749
2750        assert!(timestamp.parse::<u64>().is_ok());
2751        assert_eq!(timestamp.len(), 10);
2752        assert!(timestamp.chars().all(|c| c.is_ascii_digit()));
2753    }
2754
2755    #[rstest]
2756    fn test_new_without_credentials() {
2757        let client = OKXWebSocketClient::default();
2758        assert!(client.credential.is_none());
2759        assert_eq!(client.api_key(), None);
2760    }
2761
2762    #[rstest]
2763    fn test_new_with_credentials() {
2764        let client = OKXWebSocketClient::new(
2765            None,
2766            Some("test_key".to_string()),
2767            Some("test_secret".to_string()),
2768            Some("test_passphrase".to_string()),
2769            None,
2770            None,
2771        )
2772        .unwrap();
2773        assert!(client.credential.is_some());
2774        assert_eq!(client.api_key(), Some("test_key"));
2775    }
2776
2777    #[rstest]
2778    fn test_new_partial_credentials_fails() {
2779        let result = OKXWebSocketClient::new(
2780            None,
2781            Some("test_key".to_string()),
2782            None,
2783            Some("test_passphrase".to_string()),
2784            None,
2785            None,
2786        );
2787        assert!(result.is_err());
2788    }
2789
2790    #[rstest]
2791    fn test_request_id_generation() {
2792        let client = OKXWebSocketClient::default();
2793
2794        let initial_counter = client.request_id_counter.load(Ordering::SeqCst);
2795
2796        let id1 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
2797        let id2 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
2798
2799        assert_eq!(id1, initial_counter);
2800        assert_eq!(id2, initial_counter + 1);
2801        assert_eq!(
2802            client.request_id_counter.load(Ordering::SeqCst),
2803            initial_counter + 2
2804        );
2805    }
2806
2807    #[rstest]
2808    fn test_client_state_management() {
2809        let client = OKXWebSocketClient::default();
2810
2811        assert!(client.is_closed());
2812        assert!(!client.is_active());
2813
2814        let client_with_heartbeat =
2815            OKXWebSocketClient::new(None, None, None, None, None, Some(30)).unwrap();
2816
2817        assert!(client_with_heartbeat.heartbeat.is_some());
2818        assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
2819    }
2820
2821    #[rstest]
2822    fn test_request_cache_operations() {
2823        let client = OKXWebSocketClient::default();
2824
2825        assert_eq!(client.pending_place_requests.len(), 0);
2826        assert_eq!(client.pending_cancel_requests.len(), 0);
2827        assert_eq!(client.pending_amend_requests.len(), 0);
2828
2829        let client_order_id = ClientOrderId::from("test-order-123");
2830        let trader_id = TraderId::from("test-trader-001");
2831        let strategy_id = StrategyId::from("test-strategy-001");
2832        let instrument_id = InstrumentId::from("BTC-USDT.OKX");
2833
2834        client.pending_place_requests.insert(
2835            "place-123".to_string(),
2836            (client_order_id, trader_id, strategy_id, instrument_id),
2837        );
2838
2839        assert_eq!(client.pending_place_requests.len(), 1);
2840        assert!(client.pending_place_requests.contains_key("place-123"));
2841
2842        let removed = client.pending_place_requests.remove("place-123");
2843        assert!(removed.is_some());
2844        assert_eq!(client.pending_place_requests.len(), 0);
2845    }
2846
2847    #[rstest]
2848    fn test_websocket_error_handling() {
2849        let clock = get_atomic_clock_realtime();
2850        let ts = clock.get_time_ns().as_u64();
2851
2852        let error = OKXWebSocketError {
2853            code: "60012".to_string(),
2854            message: "Invalid request".to_string(),
2855            conn_id: None,
2856            timestamp: ts,
2857        };
2858
2859        assert_eq!(error.code, "60012");
2860        assert_eq!(error.message, "Invalid request");
2861        assert_eq!(error.timestamp, ts);
2862
2863        let nautilus_msg = NautilusWsMessage::Error(error);
2864        match nautilus_msg {
2865            NautilusWsMessage::Error(err) => {
2866                assert_eq!(err.code, "60012");
2867                assert_eq!(err.message, "Invalid request");
2868            }
2869            _ => panic!("Expected Error variant"),
2870        }
2871    }
2872
2873    #[rstest]
2874    fn test_request_id_generation_sequence() {
2875        let client = OKXWebSocketClient::default();
2876
2877        let initial_counter = client
2878            .request_id_counter
2879            .load(std::sync::atomic::Ordering::SeqCst);
2880        let mut ids = Vec::new();
2881        for _ in 0..10 {
2882            let id = client
2883                .request_id_counter
2884                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2885            ids.push(id);
2886        }
2887
2888        for (i, &id) in ids.iter().enumerate() {
2889            assert_eq!(id, initial_counter + i as u64);
2890        }
2891
2892        assert_eq!(
2893            client
2894                .request_id_counter
2895                .load(std::sync::atomic::Ordering::SeqCst),
2896            initial_counter + 10
2897        );
2898    }
2899
2900    #[rstest]
2901    fn test_client_state_transitions() {
2902        let client = OKXWebSocketClient::default();
2903
2904        assert!(client.is_closed());
2905        assert!(!client.is_active());
2906
2907        let client_with_heartbeat = OKXWebSocketClient::new(
2908            None,
2909            None,
2910            None,
2911            None,
2912            None,
2913            Some(30), // 30 second heartbeat
2914        )
2915        .unwrap();
2916
2917        assert!(client_with_heartbeat.heartbeat.is_some());
2918        assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
2919
2920        let account_id = AccountId::from("test-account-123");
2921        let client_with_account =
2922            OKXWebSocketClient::new(None, None, None, None, Some(account_id), None).unwrap();
2923
2924        assert_eq!(client_with_account.account_id, account_id);
2925    }
2926
2927    #[tokio::test]
2928    async fn test_concurrent_request_handling() {
2929        let client = Arc::new(OKXWebSocketClient::default());
2930
2931        let initial_counter = client
2932            .request_id_counter
2933            .load(std::sync::atomic::Ordering::SeqCst);
2934        let mut handles = Vec::new();
2935
2936        for i in 0..10 {
2937            let client_clone = Arc::clone(&client);
2938            let handle = tokio::spawn(async move {
2939                let client_order_id = ClientOrderId::from(format!("order-{i}").as_str());
2940                let trader_id = TraderId::from("trader-001");
2941                let strategy_id = StrategyId::from("strategy-001");
2942                let instrument_id = InstrumentId::from("BTC-USDT.OKX");
2943
2944                let request_id = client_clone
2945                    .request_id_counter
2946                    .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2947                let request_id_str = request_id.to_string();
2948
2949                client_clone.pending_place_requests.insert(
2950                    request_id_str.clone(),
2951                    (client_order_id, trader_id, strategy_id, instrument_id),
2952                );
2953
2954                // Simulate processing delay
2955                tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
2956
2957                // Remove from cache (simulating response processing)
2958                let removed = client_clone.pending_place_requests.remove(&request_id_str);
2959                assert!(removed.is_some());
2960
2961                request_id
2962            });
2963            handles.push(handle);
2964        }
2965
2966        // Wait for all operations to complete
2967        let results: Vec<_> = futures_util::future::join_all(handles).await;
2968
2969        assert_eq!(results.len(), 10);
2970        for result in results {
2971            assert!(result.is_ok());
2972        }
2973
2974        assert_eq!(client.pending_place_requests.len(), 0);
2975
2976        let final_counter = client
2977            .request_id_counter
2978            .load(std::sync::atomic::Ordering::SeqCst);
2979        assert_eq!(final_counter, initial_counter + 10);
2980    }
2981
2982    #[rstest]
2983    fn test_websocket_error_scenarios() {
2984        let clock = get_atomic_clock_realtime();
2985        let ts = clock.get_time_ns().as_u64();
2986
2987        let error_scenarios = vec![
2988            ("60012", "Invalid request", None),
2989            ("60009", "Invalid API key", Some("conn-123".to_string())),
2990            ("60014", "Too many requests", None),
2991            ("50001", "Order not found", None),
2992        ];
2993
2994        for (code, message, conn_id) in error_scenarios {
2995            let error = OKXWebSocketError {
2996                code: code.to_string(),
2997                message: message.to_string(),
2998                conn_id: conn_id.clone(),
2999                timestamp: ts,
3000            };
3001
3002            assert_eq!(error.code, code);
3003            assert_eq!(error.message, message);
3004            assert_eq!(error.conn_id, conn_id);
3005            assert_eq!(error.timestamp, ts);
3006
3007            let nautilus_msg = NautilusWsMessage::Error(error);
3008            match nautilus_msg {
3009                NautilusWsMessage::Error(err) => {
3010                    assert_eq!(err.code, code);
3011                    assert_eq!(err.message, message);
3012                    assert_eq!(err.conn_id, conn_id);
3013                }
3014                _ => panic!("Expected Error variant"),
3015            }
3016        }
3017    }
3018
3019    #[tokio::test]
3020    async fn test_feed_handler_reconnection_detection() {
3021        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
3022        let signal = Arc::new(AtomicBool::new(false));
3023        let mut handler = OKXFeedHandler::new(rx, signal.clone());
3024
3025        tx.send(Message::Text(RECONNECTED.to_string().into()))
3026            .unwrap();
3027
3028        let result = handler.next().await;
3029        assert!(matches!(result, Some(OKXWebSocketEvent::Reconnected)));
3030    }
3031
3032    #[tokio::test]
3033    async fn test_feed_handler_normal_message_processing() {
3034        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
3035        let signal = Arc::new(AtomicBool::new(false));
3036        let mut handler = OKXFeedHandler::new(rx, signal.clone());
3037
3038        // Send a ping message (OKX sends pings)
3039        let ping_msg = "ping";
3040        tx.send(Message::Text(ping_msg.to_string().into())).unwrap();
3041
3042        // Send a valid subscription response
3043        let sub_msg = r#"{
3044            "event": "subscribe",
3045            "arg": {
3046                "channel": "tickers",
3047                "instType": "SPOT"
3048            },
3049            "connId": "a4d3ae55"
3050        }"#;
3051
3052        tx.send(Message::Text(sub_msg.to_string().into())).unwrap();
3053
3054        // Set signal to stop the handler
3055        signal.store(true, std::sync::atomic::Ordering::Relaxed);
3056
3057        // Handler should process messages and then stop on signal
3058        let result = handler.next().await;
3059        assert!(result.is_none());
3060    }
3061
3062    #[tokio::test]
3063    async fn test_feed_handler_stop_signal() {
3064        let (_tx, rx) = tokio::sync::mpsc::unbounded_channel();
3065        let signal = Arc::new(AtomicBool::new(true)); // Signal already set
3066        let mut handler = OKXFeedHandler::new(rx, signal.clone());
3067
3068        let result = handler.next().await;
3069        assert!(result.is_none());
3070    }
3071
3072    #[tokio::test]
3073    async fn test_feed_handler_close_message() {
3074        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
3075        let signal = Arc::new(AtomicBool::new(false));
3076        let mut handler = OKXFeedHandler::new(rx, signal.clone());
3077
3078        // Send close message
3079        tx.send(Message::Close(None)).unwrap();
3080
3081        let result = handler.next().await;
3082        assert!(result.is_none());
3083    }
3084
3085    #[tokio::test]
3086    async fn test_reconnection_message_constant() {
3087        assert_eq!(RECONNECTED, "__RECONNECTED__");
3088    }
3089
3090    #[tokio::test]
3091    async fn test_multiple_reconnection_signals() {
3092        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
3093        let signal = Arc::new(AtomicBool::new(false));
3094        let mut handler = OKXFeedHandler::new(rx, signal.clone());
3095
3096        // Send multiple reconnection messages
3097        for _ in 0..3 {
3098            tx.send(Message::Text(RECONNECTED.to_string().into()))
3099                .unwrap();
3100
3101            let result = handler.next().await;
3102            assert!(matches!(result, Some(OKXWebSocketEvent::Reconnected)));
3103        }
3104    }
3105
3106    #[tokio::test]
3107    async fn test_wait_until_active_timeout() {
3108        let client = OKXWebSocketClient::new(
3109            None,
3110            Some("test_key".to_string()),
3111            Some("test_secret".to_string()),
3112            Some("test_passphrase".to_string()),
3113            Some(AccountId::from("test-account")),
3114            None,
3115        )
3116        .unwrap();
3117
3118        // Should timeout since client is not connected
3119        let result = client.wait_until_active(0.1).await;
3120
3121        assert!(result.is_err());
3122        assert!(!client.is_active());
3123    }
3124}