nautilus_network/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance WebSocket client implementation with automatic reconnection
17//! with exponential backoff and state management.
18
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED).
21//! - Synchronized reconnection with backoff.
22//! - Split read/write architecture.
23//! - Python callback integration.
24//!
25//! **Design**:
26//! - Single reader, multiple writer model.
27//! - Read half runs in dedicated task.
28//! - Write half runs in dedicated task connected with channel.
29//! - Controller task manages lifecycle.
30
31use std::{
32    fmt::Debug,
33    sync::{
34        Arc,
35        atomic::{AtomicU8, Ordering},
36    },
37    time::Duration,
38};
39
40use futures_util::{
41    SinkExt, StreamExt,
42    stream::{SplitSink, SplitStream},
43};
44use http::HeaderName;
45use nautilus_core::CleanDrop;
46use nautilus_cryptography::providers::install_cryptographic_provider;
47#[cfg(feature = "turmoil")]
48use tokio_tungstenite::client_async;
49#[cfg(not(feature = "turmoil"))]
50use tokio_tungstenite::connect_async;
51use tokio_tungstenite::{
52    MaybeTlsStream, WebSocketStream,
53    tungstenite::{Error, Message, client::IntoClientRequest, http::HeaderValue},
54};
55
56#[cfg(feature = "turmoil")]
57use crate::net::TcpConnector;
58use crate::{
59    RECONNECTED,
60    backoff::ExponentialBackoff,
61    error::SendError,
62    logging::{log_task_aborted, log_task_started, log_task_stopped},
63    mode::ConnectionMode,
64    ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
65};
66
67pub const TEXT_PING: &str = "ping";
68pub const TEXT_PONG: &str = "pong";
69
70// Connection timing constants
71const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
72const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
73const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
74const SEND_OPERATION_CHECK_INTERVAL_MS: u64 = 1;
75
76#[cfg(not(feature = "turmoil"))]
77type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>, Message>;
78#[cfg(not(feature = "turmoil"))]
79pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>>;
80
81#[cfg(feature = "turmoil")]
82type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<crate::net::TcpStream>>, Message>;
83#[cfg(feature = "turmoil")]
84pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<crate::net::TcpStream>>>;
85
86/// Function type for handling WebSocket messages.
87///
88/// When provided, the client will spawn an internal task to read messages and pass them
89/// to this handler. This enables automatic reconnection where the client can replace the
90/// reader internally.
91///
92/// When `None`, the client returns a `MessageReader` stream (via `connect_stream`) that
93/// the caller owns and reads from directly. This disables automatic reconnection because
94/// the reader cannot be replaced - the caller must manually reconnect.
95pub type MessageHandler = Arc<dyn Fn(Message) + Send + Sync>;
96
97/// Function type for handling WebSocket ping messages.
98pub type PingHandler = Arc<dyn Fn(Vec<u8>) + Send + Sync>;
99
100/// Creates a channel-based message handler.
101///
102/// Returns a tuple containing the message handler and a receiver for messages.
103#[must_use]
104pub fn channel_message_handler() -> (
105    MessageHandler,
106    tokio::sync::mpsc::UnboundedReceiver<Message>,
107) {
108    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
109    let handler = Arc::new(move |msg: Message| {
110        if let Err(e) = tx.send(msg) {
111            tracing::debug!("Failed to send message to channel: {e}");
112        }
113    });
114    (handler, rx)
115}
116
117/// Configuration for WebSocket client connections.
118///
119/// # Connection Modes
120///
121/// The `message_handler` field determines the connection mode:
122///
123/// ## Handler Mode (`message_handler: Some(...)`)
124/// - Use with [`WebSocketClient::connect`].
125/// - Client spawns internal task to read messages and call handler.
126/// - **Supports automatic reconnection** with exponential backoff.
127/// - Reconnection config fields (`reconnect_*`) are active.
128/// - Best for long-lived connections, Python bindings, callback-based APIs.
129///
130/// ## Stream Mode (`message_handler: None`)
131/// - Use with [`WebSocketClient::connect_stream`].
132/// - Returns a [`MessageReader`] stream for the caller to read from.
133/// - **Does NOT support automatic reconnection** (reader owned by caller).
134/// - Reconnection config fields are ignored.
135/// - On disconnect, client transitions to CLOSED state and caller must manually reconnect.
136#[cfg_attr(
137    feature = "python",
138    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
139)]
140pub struct WebSocketConfig {
141    /// The URL to connect to.
142    pub url: String,
143    /// The default headers.
144    pub headers: Vec<(String, String)>,
145    /// The function to handle incoming messages.
146    ///
147    /// - `Some(handler)`: Handler mode with automatic reconnection (use with `connect`).
148    /// - `None`: Stream mode without automatic reconnection (use with `connect_stream`).
149    ///
150    /// See [`WebSocketConfig`] docs for detailed explanation of modes.
151    pub message_handler: Option<MessageHandler>,
152    /// The optional heartbeat interval (seconds).
153    pub heartbeat: Option<u64>,
154    /// The optional heartbeat message.
155    pub heartbeat_msg: Option<String>,
156    /// The handler for incoming pings.
157    pub ping_handler: Option<PingHandler>,
158    /// The timeout (milliseconds) for reconnection attempts.
159    ///
160    /// **Note**: Only applies to handler mode. Ignored in stream mode.
161    pub reconnect_timeout_ms: Option<u64>,
162    /// The initial reconnection delay (milliseconds) for reconnects.
163    ///
164    /// **Note**: Only applies to handler mode. Ignored in stream mode.
165    pub reconnect_delay_initial_ms: Option<u64>,
166    /// The maximum reconnect delay (milliseconds) for exponential backoff.
167    ///
168    /// **Note**: Only applies to handler mode. Ignored in stream mode.
169    pub reconnect_delay_max_ms: Option<u64>,
170    /// The exponential backoff factor for reconnection delays.
171    ///
172    /// **Note**: Only applies to handler mode. Ignored in stream mode.
173    pub reconnect_backoff_factor: Option<f64>,
174    /// The maximum jitter (milliseconds) added to reconnection delays.
175    ///
176    /// **Note**: Only applies to handler mode. Ignored in stream mode.
177    pub reconnect_jitter_ms: Option<u64>,
178}
179
180impl Debug for WebSocketConfig {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        f.debug_struct(stringify!(WebSocketConfig))
183            .field("url", &self.url)
184            .field("headers", &self.headers)
185            .field(
186                "message_handler",
187                &self.message_handler.as_ref().map(|_| "<function>"),
188            )
189            .field("heartbeat", &self.heartbeat)
190            .field("heartbeat_msg", &self.heartbeat_msg)
191            .field(
192                "ping_handler",
193                &self.ping_handler.as_ref().map(|_| "<function>"),
194            )
195            .field("reconnect_timeout_ms", &self.reconnect_timeout_ms)
196            .field(
197                "reconnect_delay_initial_ms",
198                &self.reconnect_delay_initial_ms,
199            )
200            .field("reconnect_delay_max_ms", &self.reconnect_delay_max_ms)
201            .field("reconnect_backoff_factor", &self.reconnect_backoff_factor)
202            .field("reconnect_jitter_ms", &self.reconnect_jitter_ms)
203            .finish()
204    }
205}
206
207impl Clone for WebSocketConfig {
208    fn clone(&self) -> Self {
209        Self {
210            url: self.url.clone(),
211            headers: self.headers.clone(),
212            message_handler: self.message_handler.clone(),
213            heartbeat: self.heartbeat,
214            heartbeat_msg: self.heartbeat_msg.clone(),
215            ping_handler: self.ping_handler.clone(),
216            reconnect_timeout_ms: self.reconnect_timeout_ms,
217            reconnect_delay_initial_ms: self.reconnect_delay_initial_ms,
218            reconnect_delay_max_ms: self.reconnect_delay_max_ms,
219            reconnect_backoff_factor: self.reconnect_backoff_factor,
220            reconnect_jitter_ms: self.reconnect_jitter_ms,
221        }
222    }
223}
224
225/// Represents a command for the writer task.
226#[derive(Debug)]
227pub(crate) enum WriterCommand {
228    /// Update the writer reference with a new one after reconnection.
229    Update(MessageWriter),
230    /// Send message to the server.
231    Send(Message),
232}
233
234/// `WebSocketClient` connects to a websocket server to read and send messages.
235///
236/// The client is opinionated about how messages are read and written. It
237/// assumes that data can only have one reader but multiple writers.
238///
239/// The client splits the connection into read and write halves. It moves
240/// the read half into a tokio task which keeps receiving messages from the
241/// server and calls a handler - a Python function that takes the data
242/// as its parameter. It stores the write half in the struct wrapped
243/// with an Arc Mutex. This way the client struct can be used to write
244/// data to the server from multiple scopes/tasks.
245///
246/// The client also maintains a heartbeat if given a duration in seconds.
247/// It's preferable to set the duration slightly lower - heartbeat more
248/// frequently - than the required amount.
249struct WebSocketClientInner {
250    config: WebSocketConfig,
251    read_task: Option<tokio::task::JoinHandle<()>>,
252    write_task: tokio::task::JoinHandle<()>,
253    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
254    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
255    connection_mode: Arc<AtomicU8>,
256    reconnect_timeout: Duration,
257    backoff: ExponentialBackoff,
258    /// True if this is a stream-based client (created via `connect_stream`).
259    /// Stream-based clients disable auto-reconnect because the reader is
260    /// owned by the caller and cannot be replaced during reconnection.
261    is_stream_mode: bool,
262}
263
264impl WebSocketClientInner {
265    /// Create an inner websocket client with an existing writer.
266    pub async fn new_with_writer(
267        config: WebSocketConfig,
268        writer: MessageWriter,
269    ) -> Result<Self, Error> {
270        install_cryptographic_provider();
271
272        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
273
274        // Note: We don't spawn a read task here since the reader is handled externally
275        let read_task = None;
276
277        let backoff = ExponentialBackoff::new(
278            Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
279            Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
280            config.reconnect_backoff_factor.unwrap_or(1.5),
281            config.reconnect_jitter_ms.unwrap_or(100),
282            true, // immediate-first
283        )
284        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
285
286        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
287        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
288
289        let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
290            Some(Self::spawn_heartbeat_task(
291                connection_mode.clone(),
292                heartbeat_interval,
293                config.heartbeat_msg.clone(),
294                writer_tx.clone(),
295            ))
296        } else {
297            None
298        };
299
300        Ok(Self {
301            config: config.clone(),
302            writer_tx,
303            connection_mode,
304            reconnect_timeout: Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10000)),
305            heartbeat_task,
306            read_task,
307            write_task,
308            backoff,
309            is_stream_mode: true,
310        })
311    }
312
313    /// Create an inner websocket client.
314    pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
315        install_cryptographic_provider();
316
317        let WebSocketConfig {
318            url,
319            message_handler,
320            heartbeat,
321            headers,
322            heartbeat_msg,
323            ping_handler,
324            reconnect_timeout_ms,
325            reconnect_delay_initial_ms,
326            reconnect_delay_max_ms,
327            reconnect_backoff_factor,
328            reconnect_jitter_ms,
329        } = &config;
330        let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
331
332        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
333
334        let read_task = if message_handler.is_some() {
335            Some(Self::spawn_message_handler_task(
336                connection_mode.clone(),
337                reader,
338                message_handler.as_ref(),
339                ping_handler.as_ref(),
340            ))
341        } else {
342            None
343        };
344
345        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
346        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
347
348        // Optionally spawn a heartbeat task to periodically ping server
349        let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
350            Self::spawn_heartbeat_task(
351                connection_mode.clone(),
352                *heartbeat_secs,
353                heartbeat_msg.clone(),
354                writer_tx.clone(),
355            )
356        });
357
358        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
359        let backoff = ExponentialBackoff::new(
360            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
361            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
362            reconnect_backoff_factor.unwrap_or(1.5),
363            reconnect_jitter_ms.unwrap_or(100),
364            true, // immediate-first
365        )
366        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
367
368        Ok(Self {
369            config,
370            read_task,
371            write_task,
372            writer_tx,
373            heartbeat_task,
374            connection_mode,
375            reconnect_timeout,
376            backoff,
377            is_stream_mode: false,
378        })
379    }
380
381    /// Connects with the server creating a tokio-tungstenite websocket stream.
382    /// Production version that uses `connect_async` convenience helper.
383    #[inline]
384    #[cfg(not(feature = "turmoil"))]
385    pub async fn connect_with_server(
386        url: &str,
387        headers: Vec<(String, String)>,
388    ) -> Result<(MessageWriter, MessageReader), Error> {
389        let mut request = url.into_client_request()?;
390        let req_headers = request.headers_mut();
391
392        let mut header_names: Vec<HeaderName> = Vec::new();
393        for (key, val) in headers {
394            let header_value = HeaderValue::from_str(&val)?;
395            let header_name: HeaderName = key.parse()?;
396            header_names.push(header_name.clone());
397            req_headers.insert(header_name, header_value);
398        }
399
400        connect_async(request).await.map(|resp| resp.0.split())
401    }
402
403    /// Connects with the server creating a tokio-tungstenite websocket stream.
404    /// Turmoil version that uses the lower-level `client_async` API with injected stream.
405    #[inline]
406    #[cfg(feature = "turmoil")]
407    pub async fn connect_with_server(
408        url: &str,
409        headers: Vec<(String, String)>,
410    ) -> Result<(MessageWriter, MessageReader), Error> {
411        use rustls::ClientConfig;
412        use tokio_rustls::TlsConnector;
413
414        let mut request = url.into_client_request()?;
415        let req_headers = request.headers_mut();
416
417        let mut header_names: Vec<HeaderName> = Vec::new();
418        for (key, val) in headers {
419            let header_value = HeaderValue::from_str(&val)?;
420            let header_name: HeaderName = key.parse()?;
421            header_names.push(header_name.clone());
422            req_headers.insert(header_name, header_value);
423        }
424
425        let uri = request.uri();
426        let scheme = uri.scheme_str().unwrap_or("ws");
427        let host = uri.host().ok_or_else(|| {
428            Error::Url(tokio_tungstenite::tungstenite::error::UrlError::NoHostName)
429        })?;
430
431        // Determine port: use explicit port if specified, otherwise default based on scheme
432        let port = uri
433            .port_u16()
434            .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
435
436        let addr = format!("{host}:{port}");
437
438        // Use the connector to get a turmoil-compatible stream
439        let connector = crate::net::RealTcpConnector;
440        let tcp_stream = connector.connect(&addr).await?;
441
442        // Wrap stream appropriately based on scheme
443        let maybe_tls_stream = if scheme == "wss" {
444            // Build TLS config with webpki roots
445            let mut root_store = rustls::RootCertStore::empty();
446            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
447
448            let config = ClientConfig::builder()
449                .with_root_certificates(root_store)
450                .with_no_client_auth();
451
452            let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
453            let domain =
454                rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|e| {
455                    Error::Io(std::io::Error::new(
456                        std::io::ErrorKind::InvalidInput,
457                        format!("Invalid DNS name: {e}"),
458                    ))
459                })?;
460
461            let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
462            MaybeTlsStream::Rustls(tls_stream)
463        } else {
464            MaybeTlsStream::Plain(tcp_stream)
465        };
466
467        // Use client_async with the stream (plain or TLS)
468        client_async(request, maybe_tls_stream)
469            .await
470            .map(|resp| resp.0.split())
471    }
472
473    /// Reconnect with server.
474    ///
475    /// Make a new connection with server. Use the new read and write halves
476    /// to update self writer and read and heartbeat tasks.
477    ///
478    /// For stream-based clients (created via `connect_stream`), reconnection is disabled
479    /// because the reader is owned by the caller and cannot be replaced. Stream users
480    /// should handle disconnections by creating a new connection.
481    pub async fn reconnect(&mut self) -> Result<(), Error> {
482        tracing::debug!("Reconnecting");
483
484        if self.is_stream_mode {
485            tracing::warn!(
486                "Auto-reconnect disabled for stream-based WebSocket client; \
487                stream users must manually reconnect by creating a new connection"
488            );
489            // Transition to CLOSED state to stop reconnection attempts
490            self.connection_mode
491                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
492            return Ok(());
493        }
494
495        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
496            tracing::debug!("Reconnect aborted due to disconnect state");
497            return Ok(());
498        }
499
500        tokio::time::timeout(self.reconnect_timeout, async {
501            // Attempt to connect; abort early if a disconnect was requested
502            let (new_writer, reader) =
503                Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
504
505            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
506                tracing::debug!("Reconnect aborted mid-flight (after connect)");
507                return Ok(());
508            }
509
510            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
511                tracing::error!("{e}");
512            }
513
514            // Delay before closing connection
515            tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
516
517            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
518                tracing::debug!("Reconnect aborted mid-flight (after delay)");
519                return Ok(());
520            }
521
522            if let Some(ref read_task) = self.read_task.take()
523                && !read_task.is_finished()
524            {
525                read_task.abort();
526                log_task_aborted("read");
527            }
528
529            // Atomically transition from Reconnect to Active
530            // This prevents race condition where disconnect could be requested between check and store
531            if self
532                .connection_mode
533                .compare_exchange(
534                    ConnectionMode::Reconnect.as_u8(),
535                    ConnectionMode::Active.as_u8(),
536                    Ordering::SeqCst,
537                    Ordering::SeqCst,
538                )
539                .is_err()
540            {
541                tracing::debug!("Reconnect aborted (state changed during reconnect)");
542                return Ok(());
543            }
544
545            self.read_task = if self.config.message_handler.is_some() {
546                Some(Self::spawn_message_handler_task(
547                    self.connection_mode.clone(),
548                    reader,
549                    self.config.message_handler.as_ref(),
550                    self.config.ping_handler.as_ref(),
551                ))
552            } else {
553                None
554            };
555
556            tracing::debug!("Reconnect succeeded");
557            Ok(())
558        })
559        .await
560        .map_err(|_| {
561            Error::Io(std::io::Error::new(
562                std::io::ErrorKind::TimedOut,
563                format!(
564                    "reconnection timed out after {}s",
565                    self.reconnect_timeout.as_secs_f64()
566                ),
567            ))
568        })?
569    }
570
571    /// Check if the client is still connected.
572    ///
573    /// The client is connected if the read task has not finished. It is expected
574    /// that in case of any failure client or server side. The read task will be
575    /// shutdown or will receive a `Close` frame which will finish it. There
576    /// might be some delay between the connection being closed and the client
577    /// detecting.
578    #[inline]
579    #[must_use]
580    pub fn is_alive(&self) -> bool {
581        match &self.read_task {
582            Some(read_task) => !read_task.is_finished(),
583            None => true, // Stream is being used directly
584        }
585    }
586
587    fn spawn_message_handler_task(
588        connection_state: Arc<AtomicU8>,
589        mut reader: MessageReader,
590        message_handler: Option<&MessageHandler>,
591        ping_handler: Option<&PingHandler>,
592    ) -> tokio::task::JoinHandle<()> {
593        tracing::debug!("Started message handler task 'read'");
594
595        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
596
597        // Clone Arc handlers for the async task
598        let message_handler = message_handler.cloned();
599        let ping_handler = ping_handler.cloned();
600
601        tokio::task::spawn(async move {
602            loop {
603                if !ConnectionMode::from_atomic(&connection_state).is_active() {
604                    break;
605                }
606
607                match tokio::time::timeout(check_interval, reader.next()).await {
608                    Ok(Some(Ok(Message::Binary(data)))) => {
609                        tracing::trace!("Received message <binary> {} bytes", data.len());
610                        if let Some(ref handler) = message_handler {
611                            handler(Message::Binary(data));
612                        }
613                    }
614                    Ok(Some(Ok(Message::Text(data)))) => {
615                        tracing::trace!("Received message: {data}");
616                        if let Some(ref handler) = message_handler {
617                            handler(Message::Text(data));
618                        }
619                    }
620                    Ok(Some(Ok(Message::Ping(ping_data)))) => {
621                        tracing::trace!("Received ping: {ping_data:?}");
622                        if let Some(ref handler) = ping_handler {
623                            handler(ping_data.to_vec());
624                        }
625                    }
626                    Ok(Some(Ok(Message::Pong(_)))) => {
627                        tracing::trace!("Received pong");
628                    }
629                    Ok(Some(Ok(Message::Close(_)))) => {
630                        tracing::debug!("Received close message - terminating");
631                        break;
632                    }
633                    Ok(Some(Ok(_))) => (),
634                    Ok(Some(Err(e))) => {
635                        tracing::error!("Received error message - terminating: {e}");
636                        break;
637                    }
638                    Ok(None) => {
639                        tracing::debug!("No message received - terminating");
640                        break;
641                    }
642                    Err(_) => {
643                        // Timeout - continue loop and check connection mode
644                        continue;
645                    }
646                }
647            }
648        })
649    }
650
651    fn spawn_write_task(
652        connection_state: Arc<AtomicU8>,
653        writer: MessageWriter,
654        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
655    ) -> tokio::task::JoinHandle<()> {
656        log_task_started("write");
657
658        // Interval between checking the connection mode
659        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
660
661        tokio::task::spawn(async move {
662            let mut active_writer = writer;
663
664            loop {
665                match ConnectionMode::from_atomic(&connection_state) {
666                    ConnectionMode::Disconnect => {
667                        // Attempt to close the writer gracefully before exiting,
668                        // we ignore any error as the writer may already be closed.
669                        _ = tokio::time::timeout(
670                            Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
671                            active_writer.close(),
672                        )
673                        .await;
674                        break;
675                    }
676                    ConnectionMode::Closed => break,
677                    _ => {}
678                }
679
680                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
681                    Ok(Some(msg)) => {
682                        // Re-check connection mode after receiving a message
683                        let mode = ConnectionMode::from_atomic(&connection_state);
684                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
685                            break;
686                        }
687
688                        match msg {
689                            WriterCommand::Update(new_writer) => {
690                                tracing::debug!("Received new writer");
691
692                                // Delay before closing connection
693                                tokio::time::sleep(Duration::from_millis(100)).await;
694
695                                // Attempt to close the writer gracefully on update,
696                                // we ignore any error as the writer may already be closed.
697                                _ = tokio::time::timeout(
698                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
699                                    active_writer.close(),
700                                )
701                                .await;
702
703                                active_writer = new_writer;
704                                tracing::debug!("Updated writer");
705                            }
706                            _ if mode.is_reconnect() => {
707                                tracing::warn!("Skipping message while reconnecting, {msg:?}");
708                                continue;
709                            }
710                            WriterCommand::Send(msg) => {
711                                if let Err(e) = active_writer.send(msg).await {
712                                    tracing::error!("Failed to send message: {e}");
713                                    // Mode is active so trigger reconnection
714                                    tracing::warn!("Writer triggering reconnect");
715                                    connection_state
716                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
717                                }
718                            }
719                        }
720                    }
721                    Ok(None) => {
722                        // Channel closed - writer task should terminate
723                        tracing::debug!("Writer channel closed, terminating writer task");
724                        break;
725                    }
726                    Err(_) => {
727                        // Timeout - just continue the loop
728                        continue;
729                    }
730                }
731            }
732
733            // Attempt to close the writer gracefully before exiting,
734            // we ignore any error as the writer may already be closed.
735            _ = tokio::time::timeout(
736                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
737                active_writer.close(),
738            )
739            .await;
740
741            log_task_stopped("write");
742        })
743    }
744
745    fn spawn_heartbeat_task(
746        connection_state: Arc<AtomicU8>,
747        heartbeat_secs: u64,
748        message: Option<String>,
749        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
750    ) -> tokio::task::JoinHandle<()> {
751        log_task_started("heartbeat");
752
753        tokio::task::spawn(async move {
754            let interval = Duration::from_secs(heartbeat_secs);
755
756            loop {
757                tokio::time::sleep(interval).await;
758
759                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
760                    ConnectionMode::Active => {
761                        let msg = match &message {
762                            Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
763                            None => WriterCommand::Send(Message::Ping(vec![].into())),
764                        };
765
766                        match writer_tx.send(msg) {
767                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
768                            Err(e) => {
769                                tracing::error!("Failed to send heartbeat to writer task: {e}");
770                            }
771                        }
772                    }
773                    ConnectionMode::Reconnect => continue,
774                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
775                }
776            }
777
778            log_task_stopped("heartbeat");
779        })
780    }
781}
782
783impl Drop for WebSocketClientInner {
784    fn drop(&mut self) {
785        // Delegate to explicit cleanup handler
786        self.clean_drop();
787    }
788}
789
790/// Cleanup on drop: aborts background tasks and clears handlers to break reference cycles.
791impl CleanDrop for WebSocketClientInner {
792    fn clean_drop(&mut self) {
793        if let Some(ref read_task) = self.read_task.take()
794            && !read_task.is_finished()
795        {
796            read_task.abort();
797            log_task_aborted("read");
798        }
799
800        if !self.write_task.is_finished() {
801            self.write_task.abort();
802            log_task_aborted("write");
803        }
804
805        if let Some(ref handle) = self.heartbeat_task.take()
806            && !handle.is_finished()
807        {
808            handle.abort();
809            log_task_aborted("heartbeat");
810        }
811
812        // Clear handlers to break potential reference cycles
813        self.config.message_handler = None;
814        self.config.ping_handler = None;
815    }
816}
817
818/// WebSocket client with automatic reconnection.
819///
820/// Handles connection state, callbacks, and rate limiting.
821/// See module docs for architecture details.
822#[cfg_attr(
823    feature = "python",
824    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
825)]
826pub struct WebSocketClient {
827    pub(crate) controller_task: tokio::task::JoinHandle<()>,
828    pub(crate) connection_mode: Arc<AtomicU8>,
829    pub(crate) reconnect_timeout: Duration,
830    pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
831    pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
832}
833
834impl Debug for WebSocketClient {
835    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
836        f.debug_struct(stringify!(WebSocketClient)).finish()
837    }
838}
839
840impl WebSocketClient {
841    /// Creates a websocket client in **stream mode** that returns a [`MessageReader`].
842    ///
843    /// Returns a stream that the caller owns and reads from directly. Automatic reconnection
844    /// is **disabled** because the reader cannot be replaced internally. On disconnection, the
845    /// client transitions to CLOSED state and the caller must manually reconnect by calling
846    /// `connect_stream` again.
847    ///
848    /// Use stream mode when you need custom reconnection logic, direct control over message
849    /// reading, or fine-grained backpressure handling.
850    ///
851    /// See [`WebSocketConfig`] documentation for comparison with handler mode.
852    ///
853    /// # Errors
854    ///
855    /// Returns an error if the connection cannot be established.
856    #[allow(clippy::too_many_arguments)]
857    pub async fn connect_stream(
858        config: WebSocketConfig,
859        keyed_quotas: Vec<(String, Quota)>,
860        default_quota: Option<Quota>,
861        post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
862    ) -> Result<(MessageReader, Self), Error> {
863        install_cryptographic_provider();
864
865        // Create a single connection and split it, respecting configured headers
866        let (writer, reader) =
867            WebSocketClientInner::connect_with_server(&config.url, config.headers.clone()).await?;
868
869        // Create inner without connecting (we'll provide the writer)
870        let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
871
872        let connection_mode = inner.connection_mode.clone();
873        let reconnect_timeout = inner.reconnect_timeout;
874        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
875        let writer_tx = inner.writer_tx.clone();
876
877        let controller_task =
878            Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnect);
879
880        Ok((
881            reader,
882            Self {
883                controller_task,
884                connection_mode,
885                reconnect_timeout,
886                rate_limiter,
887                writer_tx,
888            },
889        ))
890    }
891
892    /// Creates a websocket client in **handler mode** with automatic reconnection.
893    ///
894    /// Requires `config.message_handler` to be set. The handler is called for each incoming
895    /// message on an internal task. Automatic reconnection is **enabled** with exponential
896    /// backoff. On disconnection, the client automatically attempts to reconnect and replaces
897    /// the internal reader (the handler continues working seamlessly).
898    ///
899    /// Use handler mode for simplified connection management, automatic reconnection, Python
900    /// bindings, or callback-based message handling.
901    ///
902    /// See [`WebSocketConfig`] documentation for comparison with stream mode.
903    ///
904    /// # Errors
905    ///
906    /// Returns an error if the connection cannot be established or if
907    /// `config.message_handler` is `None` (use `connect_stream` instead).
908    pub async fn connect(
909        config: WebSocketConfig,
910        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
911        keyed_quotas: Vec<(String, Quota)>,
912        default_quota: Option<Quota>,
913    ) -> Result<Self, Error> {
914        tracing::debug!("Connecting");
915        let inner = WebSocketClientInner::connect_url(config).await?;
916        let connection_mode = inner.connection_mode.clone();
917        let writer_tx = inner.writer_tx.clone();
918        let reconnect_timeout = inner.reconnect_timeout;
919
920        let controller_task =
921            Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnection);
922
923        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
924
925        Ok(Self {
926            controller_task,
927            connection_mode,
928            reconnect_timeout,
929            rate_limiter,
930            writer_tx,
931        })
932    }
933
934    /// Returns the current connection mode.
935    #[must_use]
936    pub fn connection_mode(&self) -> ConnectionMode {
937        ConnectionMode::from_atomic(&self.connection_mode)
938    }
939
940    /// Check if the client connection is active.
941    ///
942    /// Returns `true` if the client is connected and has not been signalled to disconnect.
943    /// The client will automatically retry connection based on its configuration.
944    #[inline]
945    #[must_use]
946    pub fn is_active(&self) -> bool {
947        self.connection_mode().is_active()
948    }
949
950    /// Check if the client is disconnected.
951    #[must_use]
952    pub fn is_disconnected(&self) -> bool {
953        self.controller_task.is_finished()
954    }
955
956    /// Check if the client is reconnecting.
957    ///
958    /// Returns `true` if the client lost connection and is attempting to reestablish it.
959    /// The client will automatically retry connection based on its configuration.
960    #[inline]
961    #[must_use]
962    pub fn is_reconnecting(&self) -> bool {
963        self.connection_mode().is_reconnect()
964    }
965
966    /// Check if the client is disconnecting.
967    ///
968    /// Returns `true` if the client is in disconnect mode.
969    #[inline]
970    #[must_use]
971    pub fn is_disconnecting(&self) -> bool {
972        self.connection_mode().is_disconnect()
973    }
974
975    /// Check if the client is closed.
976    ///
977    /// Returns `true` if the client has been explicitly disconnected or reached
978    /// maximum reconnection attempts. In this state, the client cannot be reused
979    /// and a new client must be created for further connections.
980    #[inline]
981    #[must_use]
982    pub fn is_closed(&self) -> bool {
983        self.connection_mode().is_closed()
984    }
985
986    /// Wait for the client to become active before sending.
987    ///
988    /// Returns an error if the client is closed, disconnecting, or if the wait times out.
989    async fn wait_for_active(&self) -> Result<(), SendError> {
990        if self.is_closed() {
991            return Err(SendError::Closed);
992        }
993
994        let timeout = self.reconnect_timeout;
995        let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
996
997        if !self.is_active() {
998            tracing::debug!("Waiting for client to become ACTIVE before sending...");
999
1000            let inner = tokio::time::timeout(timeout, async {
1001                loop {
1002                    if self.is_active() {
1003                        return Ok(());
1004                    }
1005                    if matches!(
1006                        self.connection_mode(),
1007                        ConnectionMode::Disconnect | ConnectionMode::Closed
1008                    ) {
1009                        return Err(());
1010                    }
1011                    tokio::time::sleep(check_interval).await;
1012                }
1013            })
1014            .await
1015            .map_err(|_| SendError::Timeout)?;
1016            inner.map_err(|()| SendError::Closed)?;
1017        }
1018
1019        Ok(())
1020    }
1021
1022    /// Set disconnect mode to true.
1023    ///
1024    /// Controller task will periodically check the disconnect mode
1025    /// and shutdown the client if it is alive
1026    pub async fn disconnect(&self) {
1027        tracing::debug!("Disconnecting");
1028        self.connection_mode
1029            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1030
1031        if let Ok(()) =
1032            tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1033                while !self.is_disconnected() {
1034                    tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS))
1035                        .await;
1036                }
1037
1038                if !self.controller_task.is_finished() {
1039                    self.controller_task.abort();
1040                    log_task_aborted("controller");
1041                }
1042            })
1043            .await
1044        {
1045            tracing::debug!("Controller task finished");
1046        } else {
1047            tracing::error!("Timeout waiting for controller task to finish");
1048            if !self.controller_task.is_finished() {
1049                self.controller_task.abort();
1050                log_task_aborted("controller");
1051            }
1052        }
1053    }
1054
1055    /// Sends the given text `data` to the server.
1056    ///
1057    /// # Errors
1058    ///
1059    /// Returns a websocket error if unable to send.
1060    #[allow(unused_variables)]
1061    pub async fn send_text(
1062        &self,
1063        data: String,
1064        keys: Option<Vec<String>>,
1065    ) -> Result<(), SendError> {
1066        self.rate_limiter.await_keys_ready(keys).await;
1067        self.wait_for_active().await?;
1068
1069        tracing::trace!("Sending text: {data:?}");
1070
1071        let msg = Message::Text(data.into());
1072        self.writer_tx
1073            .send(WriterCommand::Send(msg))
1074            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1075    }
1076
1077    /// Sends a pong frame back to the server.
1078    ///
1079    /// # Errors
1080    ///
1081    /// Returns a websocket error if unable to send.
1082    pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1083        self.wait_for_active().await?;
1084
1085        tracing::trace!("Sending pong frame ({} bytes)", data.len());
1086
1087        let msg = Message::Pong(data.into());
1088        self.writer_tx
1089            .send(WriterCommand::Send(msg))
1090            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1091    }
1092
1093    /// Sends the given bytes `data` to the server.
1094    ///
1095    /// # Errors
1096    ///
1097    /// Returns a websocket error if unable to send.
1098    #[allow(unused_variables)]
1099    pub async fn send_bytes(
1100        &self,
1101        data: Vec<u8>,
1102        keys: Option<Vec<String>>,
1103    ) -> Result<(), SendError> {
1104        self.rate_limiter.await_keys_ready(keys).await;
1105        self.wait_for_active().await?;
1106
1107        tracing::trace!("Sending bytes: {data:?}");
1108
1109        let msg = Message::Binary(data.into());
1110        self.writer_tx
1111            .send(WriterCommand::Send(msg))
1112            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1113    }
1114
1115    /// Sends a close message to the server.
1116    ///
1117    /// # Errors
1118    ///
1119    /// Returns a websocket error if unable to send.
1120    pub async fn send_close_message(&self) -> Result<(), SendError> {
1121        self.wait_for_active().await?;
1122
1123        let msg = Message::Close(None);
1124        self.writer_tx
1125            .send(WriterCommand::Send(msg))
1126            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1127    }
1128
1129    fn spawn_controller_task(
1130        mut inner: WebSocketClientInner,
1131        connection_mode: Arc<AtomicU8>,
1132        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1133    ) -> tokio::task::JoinHandle<()> {
1134        tokio::task::spawn(async move {
1135            log_task_started("controller");
1136
1137            let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1138
1139            loop {
1140                tokio::time::sleep(check_interval).await;
1141                let mut mode = ConnectionMode::from_atomic(&connection_mode);
1142
1143                if mode.is_disconnect() {
1144                    tracing::debug!("Disconnecting");
1145
1146                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1147                    if tokio::time::timeout(timeout, async {
1148                        // Delay awaiting graceful shutdown
1149                        tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1150
1151                        if let Some(task) = &inner.read_task
1152                            && !task.is_finished()
1153                        {
1154                            task.abort();
1155                            log_task_aborted("read");
1156                        }
1157
1158                        if let Some(task) = &inner.heartbeat_task
1159                            && !task.is_finished()
1160                        {
1161                            task.abort();
1162                            log_task_aborted("heartbeat");
1163                        }
1164                    })
1165                    .await
1166                    .is_err()
1167                    {
1168                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
1169                    }
1170
1171                    tracing::debug!("Closed");
1172                    break; // Controller finished
1173                }
1174
1175                if mode.is_active() && !inner.is_alive() {
1176                    if connection_mode
1177                        .compare_exchange(
1178                            ConnectionMode::Active.as_u8(),
1179                            ConnectionMode::Reconnect.as_u8(),
1180                            Ordering::SeqCst,
1181                            Ordering::SeqCst,
1182                        )
1183                        .is_ok()
1184                    {
1185                        tracing::debug!("Detected dead read task, transitioning to RECONNECT");
1186                    }
1187                    mode = ConnectionMode::from_atomic(&connection_mode);
1188                }
1189
1190                if mode.is_reconnect() {
1191                    match inner.reconnect().await {
1192                        Ok(()) => {
1193                            inner.backoff.reset();
1194
1195                            // Only invoke callbacks if not in disconnect state
1196                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
1197                                if let Some(ref handler) = inner.config.message_handler {
1198                                    let reconnected_msg =
1199                                        Message::Text(RECONNECTED.to_string().into());
1200                                    handler(reconnected_msg);
1201                                    tracing::debug!("Sent reconnected message to handler");
1202                                }
1203
1204                                // TODO: Retain this legacy callback for use from Python
1205                                if let Some(ref callback) = post_reconnection {
1206                                    callback();
1207                                    tracing::debug!("Called `post_reconnection` handler");
1208                                }
1209
1210                                tracing::debug!("Reconnected successfully");
1211                            } else {
1212                                tracing::debug!(
1213                                    "Skipping post_reconnection handlers due to disconnect state"
1214                                );
1215                            }
1216                        }
1217                        Err(e) => {
1218                            let duration = inner.backoff.next_duration();
1219                            tracing::warn!("Reconnect attempt failed: {e}");
1220                            if !duration.is_zero() {
1221                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
1222                            }
1223                            tokio::time::sleep(duration).await;
1224                        }
1225                    }
1226                }
1227            }
1228            inner
1229                .connection_mode
1230                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1231
1232            log_task_stopped("controller");
1233        })
1234    }
1235}
1236
1237// Abort controller task on drop to clean up background tasks
1238impl Drop for WebSocketClient {
1239    fn drop(&mut self) {
1240        if !self.controller_task.is_finished() {
1241            self.controller_task.abort();
1242            log_task_aborted("controller");
1243        }
1244    }
1245}
1246
1247////////////////////////////////////////////////////////////////////////////////
1248// Tests
1249////////////////////////////////////////////////////////////////////////////////
1250
1251#[cfg(test)]
1252#[cfg(not(feature = "turmoil"))]
1253#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
1254mod tests {
1255    use std::{num::NonZeroU32, sync::Arc};
1256
1257    use futures_util::{SinkExt, StreamExt};
1258    use tokio::{
1259        net::TcpListener,
1260        task::{self, JoinHandle},
1261    };
1262    use tokio_tungstenite::{
1263        accept_hdr_async,
1264        tungstenite::{
1265            handshake::server::{self, Callback},
1266            http::HeaderValue,
1267        },
1268    };
1269
1270    use crate::{
1271        ratelimiter::quota::Quota,
1272        websocket::{WebSocketClient, WebSocketConfig},
1273    };
1274
1275    struct TestServer {
1276        task: JoinHandle<()>,
1277        port: u16,
1278    }
1279
1280    #[derive(Debug, Clone)]
1281    struct TestCallback {
1282        key: String,
1283        value: HeaderValue,
1284    }
1285
1286    impl Callback for TestCallback {
1287        fn on_request(
1288            self,
1289            request: &server::Request,
1290            response: server::Response,
1291        ) -> Result<server::Response, server::ErrorResponse> {
1292            let _ = response;
1293            let value = request.headers().get(&self.key);
1294            assert!(value.is_some());
1295
1296            if let Some(value) = request.headers().get(&self.key) {
1297                assert_eq!(value, self.value);
1298            }
1299
1300            Ok(response)
1301        }
1302    }
1303
1304    impl TestServer {
1305        async fn setup() -> Self {
1306            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1307            let port = TcpListener::local_addr(&server).unwrap().port();
1308
1309            let header_key = "test".to_string();
1310            let header_value = "test".to_string();
1311
1312            let test_call_back = TestCallback {
1313                key: header_key,
1314                value: HeaderValue::from_str(&header_value).unwrap(),
1315            };
1316
1317            let task = task::spawn(async move {
1318                // Keep accepting connections
1319                loop {
1320                    let (conn, _) = server.accept().await.unwrap();
1321                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1322                        .await
1323                        .unwrap();
1324
1325                    task::spawn(async move {
1326                        while let Some(Ok(msg)) = websocket.next().await {
1327                            match msg {
1328                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1329                                    if txt == "close-now" =>
1330                                {
1331                                    tracing::debug!("Forcibly closing from server side");
1332                                    // This sends a close frame, then stops reading
1333                                    let _ = websocket.close(None).await;
1334                                    break;
1335                                }
1336                                // Echo text/binary frames
1337                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1338                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1339                                    if websocket.send(msg).await.is_err() {
1340                                        break;
1341                                    }
1342                                }
1343                                // If the client closes, we also break
1344                                tokio_tungstenite::tungstenite::protocol::Message::Close(
1345                                    _frame,
1346                                ) => {
1347                                    let _ = websocket.close(None).await;
1348                                    break;
1349                                }
1350                                // Ignore pings/pongs
1351                                _ => {}
1352                            }
1353                        }
1354                    });
1355                }
1356            });
1357
1358            Self { task, port }
1359        }
1360    }
1361
1362    impl Drop for TestServer {
1363        fn drop(&mut self) {
1364            self.task.abort();
1365        }
1366    }
1367
1368    async fn setup_test_client(port: u16) -> WebSocketClient {
1369        let config = WebSocketConfig {
1370            url: format!("ws://127.0.0.1:{port}"),
1371            headers: vec![("test".into(), "test".into())],
1372            message_handler: None,
1373            heartbeat: None,
1374            heartbeat_msg: None,
1375            ping_handler: None,
1376            reconnect_timeout_ms: None,
1377            reconnect_delay_initial_ms: None,
1378            reconnect_backoff_factor: None,
1379            reconnect_delay_max_ms: None,
1380            reconnect_jitter_ms: None,
1381        };
1382        WebSocketClient::connect(config, None, vec![], None)
1383            .await
1384            .expect("Failed to connect")
1385    }
1386
1387    #[tokio::test]
1388    async fn test_websocket_basic() {
1389        let server = TestServer::setup().await;
1390        let client = setup_test_client(server.port).await;
1391
1392        assert!(!client.is_disconnected());
1393
1394        client.disconnect().await;
1395        assert!(client.is_disconnected());
1396    }
1397
1398    #[tokio::test]
1399    async fn test_websocket_heartbeat() {
1400        let server = TestServer::setup().await;
1401        let client = setup_test_client(server.port).await;
1402
1403        // Wait ~3s => server should see multiple "ping"
1404        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1405
1406        // Cleanup
1407        client.disconnect().await;
1408        assert!(client.is_disconnected());
1409    }
1410
1411    #[tokio::test]
1412    async fn test_websocket_reconnect_exhausted() {
1413        let config = WebSocketConfig {
1414            url: "ws://127.0.0.1:9997".into(), // <-- No server
1415            headers: vec![],
1416            message_handler: None,
1417            heartbeat: None,
1418            heartbeat_msg: None,
1419            ping_handler: None,
1420            reconnect_timeout_ms: None,
1421            reconnect_delay_initial_ms: None,
1422            reconnect_backoff_factor: None,
1423            reconnect_delay_max_ms: None,
1424            reconnect_jitter_ms: None,
1425        };
1426        let res = WebSocketClient::connect(config, None, vec![], None).await;
1427        assert!(res.is_err(), "Should fail quickly with no server");
1428    }
1429
1430    #[tokio::test]
1431    async fn test_websocket_forced_close_reconnect() {
1432        let server = TestServer::setup().await;
1433        let client = setup_test_client(server.port).await;
1434
1435        // 1) Send normal message
1436        client.send_text("Hello".into(), None).await.unwrap();
1437
1438        // 2) Trigger forced close from server
1439        client.send_text("close-now".into(), None).await.unwrap();
1440
1441        // 3) Wait a bit => read loop sees close => reconnect
1442        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1443
1444        // Confirm not disconnected
1445        assert!(!client.is_disconnected());
1446
1447        // Cleanup
1448        client.disconnect().await;
1449        assert!(client.is_disconnected());
1450    }
1451
1452    #[tokio::test]
1453    async fn test_rate_limiter() {
1454        let server = TestServer::setup().await;
1455        let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1456
1457        let config = WebSocketConfig {
1458            url: format!("ws://127.0.0.1:{}", server.port),
1459            headers: vec![("test".into(), "test".into())],
1460            message_handler: None,
1461            heartbeat: None,
1462            heartbeat_msg: None,
1463            ping_handler: None,
1464            reconnect_timeout_ms: None,
1465            reconnect_delay_initial_ms: None,
1466            reconnect_backoff_factor: None,
1467            reconnect_delay_max_ms: None,
1468            reconnect_jitter_ms: None,
1469        };
1470
1471        let client = WebSocketClient::connect(config, None, vec![("default".into(), quota)], None)
1472            .await
1473            .unwrap();
1474
1475        // First 2 should succeed
1476        client.send_text("test1".into(), None).await.unwrap();
1477        client.send_text("test2".into(), None).await.unwrap();
1478
1479        // Third should error
1480        client.send_text("test3".into(), None).await.unwrap();
1481
1482        // Cleanup
1483        client.disconnect().await;
1484        assert!(client.is_disconnected());
1485    }
1486
1487    #[tokio::test]
1488    async fn test_concurrent_writers() {
1489        let server = TestServer::setup().await;
1490        let client = Arc::new(setup_test_client(server.port).await);
1491
1492        let mut handles = vec![];
1493        for i in 0..10 {
1494            let client = client.clone();
1495            handles.push(task::spawn(async move {
1496                client.send_text(format!("test{i}"), None).await.unwrap();
1497            }));
1498        }
1499
1500        for handle in handles {
1501            handle.await.unwrap();
1502        }
1503
1504        // Cleanup
1505        client.disconnect().await;
1506        assert!(client.is_disconnected());
1507    }
1508}
1509
1510////////////////////////////////////////////////////////////////////////////////
1511// Tests
1512////////////////////////////////////////////////////////////////////////////////
1513
1514#[cfg(test)]
1515#[cfg(not(feature = "turmoil"))]
1516mod rust_tests {
1517    use futures_util::StreamExt;
1518    use rstest::rstest;
1519    use tokio::{
1520        net::TcpListener,
1521        task,
1522        time::{Duration, sleep},
1523    };
1524    use tokio_tungstenite::accept_async;
1525
1526    use super::*;
1527
1528    #[rstest]
1529    #[tokio::test]
1530    async fn test_reconnect_then_disconnect() {
1531        // Bind an ephemeral port
1532        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1533        let port = listener.local_addr().unwrap().port();
1534
1535        // Server task: accept one ws connection then close it
1536        let server = task::spawn(async move {
1537            let (stream, _) = listener.accept().await.unwrap();
1538            let ws = accept_async(stream).await.unwrap();
1539            drop(ws);
1540            // Keep alive briefly
1541            sleep(Duration::from_secs(1)).await;
1542        });
1543
1544        // Build a channel-based message handler for incoming messages (unused here)
1545        let (handler, _rx) = channel_message_handler();
1546
1547        // Configure client with short reconnect backoff
1548        let config = WebSocketConfig {
1549            url: format!("ws://127.0.0.1:{port}"),
1550            headers: vec![],
1551            message_handler: Some(handler),
1552            heartbeat: None,
1553            heartbeat_msg: None,
1554            ping_handler: None,
1555            reconnect_timeout_ms: Some(1_000),
1556            reconnect_delay_initial_ms: Some(50),
1557            reconnect_delay_max_ms: Some(100),
1558            reconnect_backoff_factor: Some(1.0),
1559            reconnect_jitter_ms: Some(0),
1560        };
1561
1562        // Connect the client
1563        let client = WebSocketClient::connect(config, None, vec![], None)
1564            .await
1565            .unwrap();
1566
1567        // Allow server to drop connection and client to detect
1568        sleep(Duration::from_millis(100)).await;
1569        // Now immediately disconnect the client
1570        client.disconnect().await;
1571        assert!(client.is_disconnected());
1572        server.abort();
1573    }
1574
1575    #[rstest]
1576    #[tokio::test]
1577    async fn test_reconnect_state_flips_when_reader_stops() {
1578        // Bind an ephemeral port and accept a single websocket connection which we drop.
1579        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1580        let port = listener.local_addr().unwrap().port();
1581
1582        let server = task::spawn(async move {
1583            if let Ok((stream, _)) = listener.accept().await
1584                && let Ok(ws) = accept_async(stream).await
1585            {
1586                drop(ws);
1587            }
1588            sleep(Duration::from_millis(50)).await;
1589        });
1590
1591        let (handler, _rx) = channel_message_handler();
1592
1593        let config = WebSocketConfig {
1594            url: format!("ws://127.0.0.1:{port}"),
1595            headers: vec![],
1596            message_handler: Some(handler),
1597            heartbeat: None,
1598            heartbeat_msg: None,
1599            ping_handler: None,
1600            reconnect_timeout_ms: Some(1_000),
1601            reconnect_delay_initial_ms: Some(50),
1602            reconnect_delay_max_ms: Some(100),
1603            reconnect_backoff_factor: Some(1.0),
1604            reconnect_jitter_ms: Some(0),
1605        };
1606
1607        let client = WebSocketClient::connect(config, None, vec![], None)
1608            .await
1609            .unwrap();
1610
1611        tokio::time::timeout(Duration::from_secs(2), async {
1612            loop {
1613                if client.is_reconnecting() {
1614                    break;
1615                }
1616                tokio::time::sleep(Duration::from_millis(10)).await;
1617            }
1618        })
1619        .await
1620        .expect("client did not enter RECONNECT state");
1621
1622        client.disconnect().await;
1623        server.abort();
1624    }
1625
1626    #[rstest]
1627    #[tokio::test]
1628    async fn test_stream_mode_disables_auto_reconnect() {
1629        // Test that stream-based clients (created via connect_stream) set is_stream_mode flag
1630        // and that reconnect() transitions to CLOSED state for stream mode
1631        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1632        let port = listener.local_addr().unwrap().port();
1633
1634        let server = task::spawn(async move {
1635            if let Ok((stream, _)) = listener.accept().await
1636                && let Ok(_ws) = accept_async(stream).await
1637            {
1638                // Keep connection alive briefly
1639                sleep(Duration::from_millis(100)).await;
1640            }
1641        });
1642
1643        let config = WebSocketConfig {
1644            url: format!("ws://127.0.0.1:{port}"),
1645            headers: vec![],
1646            message_handler: None, // Stream mode - no handler
1647            heartbeat: None,
1648            heartbeat_msg: None,
1649            ping_handler: None,
1650            reconnect_timeout_ms: Some(1_000),
1651            reconnect_delay_initial_ms: Some(50),
1652            reconnect_delay_max_ms: Some(100),
1653            reconnect_backoff_factor: Some(1.0),
1654            reconnect_jitter_ms: Some(0),
1655        };
1656
1657        // Create stream-based client
1658        let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
1659            .await
1660            .unwrap();
1661
1662        // Note: We can't easily test the reconnect behavior from the outside since
1663        // the inner client is private. The key fix is that WebSocketClientInner
1664        // now has is_stream_mode=true for connect_stream, and reconnect() will
1665        // transition to CLOSED state instead of creating a new reader that gets dropped.
1666        // This is tested implicitly by the fact that stream users won't get stuck
1667        // in an infinite reconnect loop.
1668
1669        server.abort();
1670    }
1671
1672    #[rstest]
1673    #[tokio::test]
1674    async fn test_message_handler_mode_allows_auto_reconnect() {
1675        // Test that regular clients (with message handler) can auto-reconnect
1676        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1677        let port = listener.local_addr().unwrap().port();
1678
1679        let server = task::spawn(async move {
1680            // Accept first connection and close it
1681            if let Ok((stream, _)) = listener.accept().await
1682                && let Ok(ws) = accept_async(stream).await
1683            {
1684                drop(ws);
1685            }
1686            sleep(Duration::from_millis(50)).await;
1687        });
1688
1689        let (handler, _rx) = channel_message_handler();
1690
1691        let config = WebSocketConfig {
1692            url: format!("ws://127.0.0.1:{port}"),
1693            headers: vec![],
1694            message_handler: Some(handler), // Has message handler
1695            heartbeat: None,
1696            heartbeat_msg: None,
1697            ping_handler: None,
1698            reconnect_timeout_ms: Some(1_000),
1699            reconnect_delay_initial_ms: Some(50),
1700            reconnect_delay_max_ms: Some(100),
1701            reconnect_backoff_factor: Some(1.0),
1702            reconnect_jitter_ms: Some(0),
1703        };
1704
1705        let client = WebSocketClient::connect(config, None, vec![], None)
1706            .await
1707            .unwrap();
1708
1709        // Wait for the connection to be dropped and reconnection to be attempted
1710        tokio::time::timeout(Duration::from_secs(2), async {
1711            loop {
1712                if client.is_reconnecting() || client.is_closed() {
1713                    break;
1714                }
1715                tokio::time::sleep(Duration::from_millis(10)).await;
1716            }
1717        })
1718        .await
1719        .expect("client should attempt reconnection or close");
1720
1721        // Should either be reconnecting or closed (depending on timing)
1722        // The important thing is it's not staying active forever
1723        assert!(
1724            client.is_reconnecting() || client.is_closed(),
1725            "Client with message handler should attempt reconnection"
1726        );
1727
1728        client.disconnect().await;
1729        server.abort();
1730    }
1731
1732    #[rstest]
1733    #[tokio::test]
1734    async fn test_handler_mode_reconnect_with_new_connection() {
1735        // Test that handler mode successfully reconnects and messages continue flowing
1736        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1737        let port = listener.local_addr().unwrap().port();
1738
1739        let server = task::spawn(async move {
1740            // First connection - accept and immediately close
1741            if let Ok((stream, _)) = listener.accept().await
1742                && let Ok(ws) = accept_async(stream).await
1743            {
1744                drop(ws);
1745            }
1746
1747            // Small delay to let client detect disconnection
1748            sleep(Duration::from_millis(100)).await;
1749
1750            // Second connection - accept, send a message, then keep alive
1751            if let Ok((stream, _)) = listener.accept().await
1752                && let Ok(mut ws) = accept_async(stream).await
1753            {
1754                use futures_util::SinkExt;
1755                let _ = ws
1756                    .send(Message::Text("reconnected".to_string().into()))
1757                    .await;
1758                sleep(Duration::from_secs(1)).await;
1759            }
1760        });
1761
1762        let (handler, mut rx) = channel_message_handler();
1763
1764        let config = WebSocketConfig {
1765            url: format!("ws://127.0.0.1:{port}"),
1766            headers: vec![],
1767            message_handler: Some(handler),
1768            heartbeat: None,
1769            heartbeat_msg: None,
1770            ping_handler: None,
1771            reconnect_timeout_ms: Some(2_000),
1772            reconnect_delay_initial_ms: Some(50),
1773            reconnect_delay_max_ms: Some(200),
1774            reconnect_backoff_factor: Some(1.5),
1775            reconnect_jitter_ms: Some(10),
1776        };
1777
1778        let client = WebSocketClient::connect(config, None, vec![], None)
1779            .await
1780            .unwrap();
1781
1782        // Wait for reconnection to happen and message to arrive
1783        let result = tokio::time::timeout(Duration::from_secs(5), async {
1784            loop {
1785                if let Ok(msg) = rx.try_recv()
1786                    && matches!(msg, Message::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
1787                {
1788                    return true;
1789                }
1790                tokio::time::sleep(Duration::from_millis(10)).await;
1791            }
1792        })
1793        .await;
1794
1795        assert!(
1796            result.is_ok(),
1797            "Should receive message after reconnection within timeout"
1798        );
1799
1800        client.disconnect().await;
1801        server.abort();
1802    }
1803
1804    #[rstest]
1805    #[tokio::test]
1806    async fn test_stream_mode_no_auto_reconnect() {
1807        // Test that stream mode does not automatically reconnect when connection is lost
1808        // The caller owns the reader and is responsible for detecting disconnection
1809        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1810        let port = listener.local_addr().unwrap().port();
1811
1812        let server = task::spawn(async move {
1813            // Accept connection and send one message, then close
1814            if let Ok((stream, _)) = listener.accept().await
1815                && let Ok(mut ws) = accept_async(stream).await
1816            {
1817                use futures_util::SinkExt;
1818                let _ = ws.send(Message::Text("hello".to_string().into())).await;
1819                sleep(Duration::from_millis(50)).await;
1820                // Connection closes when ws is dropped
1821            }
1822        });
1823
1824        let config = WebSocketConfig {
1825            url: format!("ws://127.0.0.1:{port}"),
1826            headers: vec![],
1827            message_handler: None, // Stream mode
1828            heartbeat: None,
1829            heartbeat_msg: None,
1830            ping_handler: None,
1831            reconnect_timeout_ms: Some(1_000),
1832            reconnect_delay_initial_ms: Some(50),
1833            reconnect_delay_max_ms: Some(100),
1834            reconnect_backoff_factor: Some(1.0),
1835            reconnect_jitter_ms: Some(0),
1836        };
1837
1838        let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
1839            .await
1840            .unwrap();
1841
1842        // Initially active
1843        assert!(client.is_active(), "Client should start as active");
1844
1845        // Read the hello message
1846        let msg = reader.next().await;
1847        assert!(
1848            matches!(msg, Some(Ok(Message::Text(ref text))) if AsRef::<str>::as_ref(text) == "hello"),
1849            "Should receive initial message"
1850        );
1851
1852        // Read until connection closes (reader will return None or error)
1853        while let Some(msg) = reader.next().await {
1854            if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
1855                break;
1856            }
1857        }
1858
1859        // In stream mode, the controller cannot detect disconnection (reader is owned by caller)
1860        // The client remains ACTIVE - it's the caller's responsibility to call disconnect()
1861        sleep(Duration::from_millis(200)).await;
1862
1863        // Client should still be ACTIVE (not RECONNECTING or CLOSED)
1864        // This is correct behavior - stream mode doesn't auto-detect disconnection
1865        assert!(
1866            client.is_active() || client.is_closed(),
1867            "Stream mode client stays ACTIVE (caller owns reader) or caller disconnected"
1868        );
1869        assert!(
1870            !client.is_reconnecting(),
1871            "Stream mode client should never attempt reconnection"
1872        );
1873
1874        client.disconnect().await;
1875        server.abort();
1876    }
1877
1878    #[rstest]
1879    #[tokio::test]
1880    async fn test_send_timeout_uses_configured_reconnect_timeout() {
1881        // Test that send operations respect the configured reconnect_timeout.
1882        // When a client is stuck in RECONNECT longer than the timeout, sends should fail with Timeout.
1883        use nautilus_common::testing::wait_until_async;
1884
1885        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1886        let port = listener.local_addr().unwrap().port();
1887
1888        let server = task::spawn(async move {
1889            // Accept first connection and immediately close it
1890            if let Ok((stream, _)) = listener.accept().await
1891                && let Ok(ws) = accept_async(stream).await
1892            {
1893                drop(ws);
1894            }
1895            // Don't accept second connection - client will be stuck in RECONNECT
1896            sleep(Duration::from_secs(60)).await;
1897        });
1898
1899        let (handler, _rx) = channel_message_handler();
1900
1901        // Configure with SHORT 2s reconnect timeout
1902        let config = WebSocketConfig {
1903            url: format!("ws://127.0.0.1:{port}"),
1904            headers: vec![],
1905            message_handler: Some(handler),
1906            heartbeat: None,
1907            heartbeat_msg: None,
1908            ping_handler: None,
1909            reconnect_timeout_ms: Some(2_000), // 2s timeout
1910            reconnect_delay_initial_ms: Some(50),
1911            reconnect_delay_max_ms: Some(100),
1912            reconnect_backoff_factor: Some(1.0),
1913            reconnect_jitter_ms: Some(0),
1914        };
1915
1916        let client = WebSocketClient::connect(config, None, vec![], None)
1917            .await
1918            .unwrap();
1919
1920        // Wait for client to enter RECONNECT state
1921        wait_until_async(
1922            || async { client.is_reconnecting() },
1923            Duration::from_secs(3),
1924        )
1925        .await;
1926
1927        // Attempt send while stuck in RECONNECT - should timeout after 2s (configured timeout)
1928        let start = std::time::Instant::now();
1929        let send_result = client.send_text("test".to_string(), None).await;
1930        let elapsed = start.elapsed();
1931
1932        assert!(
1933            send_result.is_err(),
1934            "Send should fail when client stuck in RECONNECT"
1935        );
1936        assert!(
1937            matches!(send_result, Err(crate::error::SendError::Timeout)),
1938            "Send should return Timeout error, got: {:?}",
1939            send_result
1940        );
1941        // Verify timeout respects configured value (2s), but don't check upper bound
1942        // as CI scheduler jitter can cause legitimate delays beyond the timeout
1943        assert!(
1944            elapsed >= Duration::from_millis(1800),
1945            "Send should timeout after at least 2s (configured timeout), took {:?}",
1946            elapsed
1947        );
1948
1949        client.disconnect().await;
1950        server.abort();
1951    }
1952
1953    #[rstest]
1954    #[tokio::test]
1955    async fn test_send_waits_during_reconnection() {
1956        // Test that send operations wait for reconnection to complete (up to timeout)
1957        use nautilus_common::testing::wait_until_async;
1958
1959        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1960        let port = listener.local_addr().unwrap().port();
1961
1962        let server = task::spawn(async move {
1963            // First connection - accept and immediately close
1964            if let Ok((stream, _)) = listener.accept().await
1965                && let Ok(ws) = accept_async(stream).await
1966            {
1967                drop(ws);
1968            }
1969
1970            // Wait a bit before accepting second connection
1971            sleep(Duration::from_millis(500)).await;
1972
1973            // Second connection - accept and keep alive
1974            if let Ok((stream, _)) = listener.accept().await
1975                && let Ok(mut ws) = accept_async(stream).await
1976            {
1977                // Echo messages
1978                while let Some(Ok(msg)) = ws.next().await {
1979                    if ws.send(msg).await.is_err() {
1980                        break;
1981                    }
1982                }
1983            }
1984        });
1985
1986        let (handler, _rx) = channel_message_handler();
1987
1988        let config = WebSocketConfig {
1989            url: format!("ws://127.0.0.1:{port}"),
1990            headers: vec![],
1991            message_handler: Some(handler),
1992            heartbeat: None,
1993            heartbeat_msg: None,
1994            ping_handler: None,
1995            reconnect_timeout_ms: Some(5_000), // 5s timeout - enough for reconnect
1996            reconnect_delay_initial_ms: Some(100),
1997            reconnect_delay_max_ms: Some(200),
1998            reconnect_backoff_factor: Some(1.0),
1999            reconnect_jitter_ms: Some(0),
2000        };
2001
2002        let client = WebSocketClient::connect(config, None, vec![], None)
2003            .await
2004            .unwrap();
2005
2006        // Wait for reconnection to trigger
2007        wait_until_async(
2008            || async { client.is_reconnecting() },
2009            Duration::from_secs(2),
2010        )
2011        .await;
2012
2013        // Try to send while reconnecting - should wait and succeed after reconnect
2014        let send_result = tokio::time::timeout(
2015            Duration::from_secs(3),
2016            client.send_text("test_message".to_string(), None),
2017        )
2018        .await;
2019
2020        assert!(
2021            send_result.is_ok() && send_result.unwrap().is_ok(),
2022            "Send should succeed after waiting for reconnection"
2023        );
2024
2025        client.disconnect().await;
2026        server.abort();
2027    }
2028
2029    #[rstest]
2030    #[tokio::test]
2031    async fn test_rate_limiter_before_active_wait() {
2032        // Test that rate limiting happens BEFORE active state check.
2033        // This prevents race conditions where connection state changes during rate limit wait.
2034        // We verify this by: (1) exhausting rate limit, (2) ensuring client is RECONNECTING,
2035        // (3) sending again and confirming it waits for rate limit THEN reconnection.
2036        use std::{num::NonZeroU32, sync::Arc};
2037
2038        use nautilus_common::testing::wait_until_async;
2039
2040        use crate::ratelimiter::quota::Quota;
2041
2042        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2043        let port = listener.local_addr().unwrap().port();
2044
2045        let server = task::spawn(async move {
2046            // First connection - accept and close after receiving one message
2047            if let Ok((stream, _)) = listener.accept().await
2048                && let Ok(mut ws) = accept_async(stream).await
2049            {
2050                // Receive first message then close
2051                if let Some(Ok(_)) = ws.next().await {
2052                    drop(ws);
2053                }
2054            }
2055
2056            // Wait before accepting reconnection
2057            sleep(Duration::from_millis(500)).await;
2058
2059            // Second connection - accept and keep alive
2060            if let Ok((stream, _)) = listener.accept().await
2061                && let Ok(mut ws) = accept_async(stream).await
2062            {
2063                while let Some(Ok(msg)) = ws.next().await {
2064                    if ws.send(msg).await.is_err() {
2065                        break;
2066                    }
2067                }
2068            }
2069        });
2070
2071        let (handler, _rx) = channel_message_handler();
2072
2073        let config = WebSocketConfig {
2074            url: format!("ws://127.0.0.1:{port}"),
2075            headers: vec![],
2076            message_handler: Some(handler),
2077            heartbeat: None,
2078            heartbeat_msg: None,
2079            ping_handler: None,
2080            reconnect_timeout_ms: Some(5_000),
2081            reconnect_delay_initial_ms: Some(50),
2082            reconnect_delay_max_ms: Some(100),
2083            reconnect_backoff_factor: Some(1.0),
2084            reconnect_jitter_ms: Some(0),
2085        };
2086
2087        // Very restrictive rate limit: 1 request per second, burst of 1
2088        let quota =
2089            Quota::per_second(NonZeroU32::new(1).unwrap()).allow_burst(NonZeroU32::new(1).unwrap());
2090
2091        let client = Arc::new(
2092            WebSocketClient::connect(config, None, vec![("test_key".to_string(), quota)], None)
2093                .await
2094                .unwrap(),
2095        );
2096
2097        // First send exhausts burst capacity and triggers connection close
2098        client
2099            .send_text("msg1".to_string(), Some(vec!["test_key".to_string()]))
2100            .await
2101            .unwrap();
2102
2103        // Wait for client to enter RECONNECT state
2104        wait_until_async(
2105            || async { client.is_reconnecting() },
2106            Duration::from_secs(2),
2107        )
2108        .await;
2109
2110        // Second send: will hit rate limit (~1s) THEN wait for reconnection (~0.5s)
2111        let start = std::time::Instant::now();
2112        let send_result = client
2113            .send_text("msg2".to_string(), Some(vec!["test_key".to_string()]))
2114            .await;
2115        let elapsed = start.elapsed();
2116
2117        // Should succeed after both rate limit AND reconnection
2118        assert!(
2119            send_result.is_ok(),
2120            "Send should succeed after rate limit + reconnection, got: {:?}",
2121            send_result
2122        );
2123        // Total wait should be at least rate limit time (~1s)
2124        // The reconnection completes while rate limiting or after
2125        // Use 850ms threshold to account for timing jitter in CI
2126        assert!(
2127            elapsed >= Duration::from_millis(850),
2128            "Should wait for rate limit (~1s), waited {:?}",
2129            elapsed
2130        );
2131
2132        client.disconnect().await;
2133        server.abort();
2134    }
2135
2136    #[rstest]
2137    #[tokio::test]
2138    async fn test_disconnect_during_reconnect_exits_cleanly() {
2139        // Test CAS race condition: disconnect called during reconnection
2140        // Should exit cleanly without spawning new tasks
2141        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2142        let port = listener.local_addr().unwrap().port();
2143
2144        let server = task::spawn(async move {
2145            // Accept first connection and immediately close
2146            if let Ok((stream, _)) = listener.accept().await
2147                && let Ok(ws) = accept_async(stream).await
2148            {
2149                drop(ws);
2150            }
2151            // Don't accept second connection - let reconnect hang
2152            sleep(Duration::from_secs(60)).await;
2153        });
2154
2155        let (handler, _rx) = channel_message_handler();
2156
2157        let config = WebSocketConfig {
2158            url: format!("ws://127.0.0.1:{port}"),
2159            headers: vec![],
2160            message_handler: Some(handler),
2161            heartbeat: None,
2162            heartbeat_msg: None,
2163            ping_handler: None,
2164            reconnect_timeout_ms: Some(2_000), // 2s timeout - shorter than disconnect timeout
2165            reconnect_delay_initial_ms: Some(100),
2166            reconnect_delay_max_ms: Some(200),
2167            reconnect_backoff_factor: Some(1.0),
2168            reconnect_jitter_ms: Some(0),
2169        };
2170
2171        let client = WebSocketClient::connect(config, None, vec![], None)
2172            .await
2173            .unwrap();
2174
2175        // Wait for reconnection to start
2176        tokio::time::timeout(Duration::from_secs(2), async {
2177            while !client.is_reconnecting() {
2178                sleep(Duration::from_millis(10)).await;
2179            }
2180        })
2181        .await
2182        .expect("Client should enter RECONNECT state");
2183
2184        // Disconnect while reconnecting
2185        client.disconnect().await;
2186
2187        // Should be cleanly closed
2188        assert!(
2189            client.is_disconnected(),
2190            "Client should be cleanly disconnected"
2191        );
2192
2193        server.abort();
2194    }
2195}