nautilus_network/websocket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! WebSocket client implementation with automatic reconnection.
17//!
18//! This module contains the core WebSocket client implementation including:
19//! - Connection management with automatic reconnection.
20//! - Split read/write architecture with separate tasks.
21//! - Unbounded channels on latency-sensitive paths.
22//! - Heartbeat support.
23//! - Rate limiting integration.
24
25use std::{
26    collections::VecDeque,
27    fmt::Debug,
28    sync::{
29        Arc,
30        atomic::{AtomicU8, Ordering},
31    },
32    time::Duration,
33};
34
35use futures_util::{SinkExt, StreamExt};
36use http::HeaderName;
37use nautilus_core::CleanDrop;
38use nautilus_cryptography::providers::install_cryptographic_provider;
39#[cfg(feature = "turmoil")]
40use tokio_tungstenite::MaybeTlsStream;
41#[cfg(feature = "turmoil")]
42use tokio_tungstenite::client_async;
43#[cfg(not(feature = "turmoil"))]
44use tokio_tungstenite::connect_async_with_config;
45use tokio_tungstenite::tungstenite::{
46    Error, Message, client::IntoClientRequest, http::HeaderValue,
47};
48
49use super::{
50    config::WebSocketConfig,
51    consts::{
52        CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
53        GRACEFUL_SHUTDOWN_TIMEOUT_SECS, SEND_OPERATION_CHECK_INTERVAL_MS,
54    },
55    types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
56};
57#[cfg(feature = "turmoil")]
58use crate::net::TcpConnector;
59use crate::{
60    RECONNECTED,
61    backoff::ExponentialBackoff,
62    error::SendError,
63    logging::{log_task_aborted, log_task_started, log_task_stopped},
64    mode::ConnectionMode,
65    ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
66};
67
68/// `WebSocketClient` connects to a websocket server to read and send messages.
69///
70/// The client is opinionated about how messages are read and written. It
71/// assumes that data can only have one reader but multiple writers.
72///
73/// The client splits the connection into read and write halves. It moves
74/// the read half into a tokio task which keeps receiving messages from the
75/// server and calls a handler - a Python function that takes the data
76/// as its parameter. It stores the write half in the struct wrapped
77/// with an Arc Mutex. This way the client struct can be used to write
78/// data to the server from multiple scopes/tasks.
79///
80/// The client also maintains a heartbeat if given a duration in seconds.
81/// It's preferable to set the duration slightly lower - heartbeat more
82/// frequently - than the required amount.
83pub struct WebSocketClientInner {
84    config: WebSocketConfig,
85    /// The function to handle incoming messages (stored separately from config).
86    message_handler: Option<MessageHandler>,
87    /// The handler for incoming pings (stored separately from config).
88    ping_handler: Option<PingHandler>,
89    read_task: Option<tokio::task::JoinHandle<()>>,
90    write_task: tokio::task::JoinHandle<()>,
91    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
92    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
93    connection_mode: Arc<AtomicU8>,
94    reconnect_timeout: Duration,
95    backoff: ExponentialBackoff,
96    /// True if this is a stream-based client (created via `connect_stream`).
97    /// Stream-based clients disable auto-reconnect because the reader is
98    /// owned by the caller and cannot be replaced during reconnection.
99    is_stream_mode: bool,
100    /// Maximum number of reconnection attempts before giving up (None = unlimited).
101    reconnect_max_attempts: Option<u32>,
102    /// Current count of consecutive reconnection attempts.
103    reconnection_attempt_count: u32,
104}
105
106impl WebSocketClientInner {
107    /// Create an inner websocket client with an existing writer.
108    ///
109    /// This is used for stream mode where the reader is owned by the caller.
110    ///
111    /// # Errors
112    ///
113    /// Returns an error if the exponential backoff configuration is invalid.
114    pub async fn new_with_writer(
115        config: WebSocketConfig,
116        writer: MessageWriter,
117    ) -> Result<Self, Error> {
118        install_cryptographic_provider();
119
120        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
121
122        // Note: We don't spawn a read task here since the reader is handled externally
123        let read_task = None;
124
125        let backoff = ExponentialBackoff::new(
126            Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
127            Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
128            config.reconnect_backoff_factor.unwrap_or(1.5),
129            config.reconnect_jitter_ms.unwrap_or(100),
130            true, // immediate-first
131        )
132        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
133
134        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
135        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
136
137        let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
138            Some(Self::spawn_heartbeat_task(
139                connection_mode.clone(),
140                heartbeat_interval,
141                config.heartbeat_msg.clone(),
142                writer_tx.clone(),
143            ))
144        } else {
145            None
146        };
147
148        let reconnect_max_attempts = config.reconnect_max_attempts;
149        let reconnect_timeout = Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10000));
150
151        Ok(Self {
152            config,
153            message_handler: None, // Stream mode has no handler
154            ping_handler: None,
155            writer_tx,
156            connection_mode,
157            reconnect_timeout,
158            heartbeat_task,
159            read_task,
160            write_task,
161            backoff,
162            is_stream_mode: true,
163            reconnect_max_attempts,
164            reconnection_attempt_count: 0,
165        })
166    }
167
168    /// Create an inner websocket client.
169    ///
170    /// # Errors
171    ///
172    /// Returns an error if:
173    /// - The connection to the server fails.
174    /// - The exponential backoff configuration is invalid.
175    pub async fn connect_url(
176        config: WebSocketConfig,
177        message_handler: Option<MessageHandler>,
178        ping_handler: Option<PingHandler>,
179    ) -> Result<Self, Error> {
180        install_cryptographic_provider();
181
182        // Capture whether we're in stream mode before moving config
183        let is_stream_mode = message_handler.is_none();
184        let reconnect_max_attempts = config.reconnect_max_attempts;
185
186        let (writer, reader) =
187            Self::connect_with_server(&config.url, config.headers.clone()).await?;
188
189        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
190
191        let read_task = if message_handler.is_some() {
192            Some(Self::spawn_message_handler_task(
193                connection_mode.clone(),
194                reader,
195                message_handler.as_ref(),
196                ping_handler.as_ref(),
197            ))
198        } else {
199            None
200        };
201
202        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
203        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
204
205        // Optionally spawn a heartbeat task to periodically ping server
206        let heartbeat_task = config.heartbeat.map(|heartbeat_secs| {
207            Self::spawn_heartbeat_task(
208                connection_mode.clone(),
209                heartbeat_secs,
210                config.heartbeat_msg.clone(),
211                writer_tx.clone(),
212            )
213        });
214
215        let reconnect_timeout =
216            Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10_000));
217        let backoff = ExponentialBackoff::new(
218            Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
219            Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
220            config.reconnect_backoff_factor.unwrap_or(1.5),
221            config.reconnect_jitter_ms.unwrap_or(100),
222            true, // immediate-first
223        )
224        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
225
226        Ok(Self {
227            config,
228            message_handler,
229            ping_handler,
230            read_task,
231            write_task,
232            writer_tx,
233            heartbeat_task,
234            connection_mode,
235            reconnect_timeout,
236            backoff,
237            // Set stream mode when no message handler (reader not managed by client)
238            is_stream_mode,
239            reconnect_max_attempts,
240            reconnection_attempt_count: 0,
241        })
242    }
243
244    /// Connects with the server creating a tokio-tungstenite websocket stream.
245    /// Production version that uses `connect_async_with_config` convenience helper.
246    ///
247    /// # Errors
248    ///
249    /// Returns an error if:
250    /// - The URL cannot be parsed into a valid client request.
251    /// - Header values are invalid.
252    /// - The WebSocket connection fails.
253    #[inline]
254    #[cfg(not(feature = "turmoil"))]
255    pub async fn connect_with_server(
256        url: &str,
257        headers: Vec<(String, String)>,
258    ) -> Result<(MessageWriter, MessageReader), Error> {
259        let mut request = url.into_client_request()?;
260        let req_headers = request.headers_mut();
261
262        let mut header_names: Vec<HeaderName> = Vec::new();
263        for (key, val) in headers {
264            let header_value = HeaderValue::from_str(&val)?;
265            let header_name: HeaderName = key.parse()?;
266            header_names.push(header_name.clone());
267            req_headers.insert(header_name, header_value);
268        }
269
270        connect_async_with_config(request, None, true)
271            .await
272            .map(|resp| resp.0.split())
273    }
274
275    /// Connects with the server creating a tokio-tungstenite websocket stream.
276    /// Turmoil version that uses the lower-level `client_async` API with injected stream.
277    ///
278    /// # Errors
279    ///
280    /// Returns an error if:
281    /// - The URL cannot be parsed into a valid client request.
282    /// - The URL is missing a hostname.
283    /// - Header values are invalid.
284    /// - The TCP connection fails.
285    /// - TLS setup fails (for wss:// URLs).
286    /// - The WebSocket handshake fails.
287    #[inline]
288    #[cfg(feature = "turmoil")]
289    pub async fn connect_with_server(
290        url: &str,
291        headers: Vec<(String, String)>,
292    ) -> Result<(MessageWriter, MessageReader), Error> {
293        use rustls::ClientConfig;
294        use tokio_rustls::TlsConnector;
295
296        let mut request = url.into_client_request()?;
297        let req_headers = request.headers_mut();
298
299        let mut header_names: Vec<HeaderName> = Vec::new();
300        for (key, val) in headers {
301            let header_value = HeaderValue::from_str(&val)?;
302            let header_name: HeaderName = key.parse()?;
303            header_names.push(header_name.clone());
304            req_headers.insert(header_name, header_value);
305        }
306
307        let uri = request.uri();
308        let scheme = uri.scheme_str().unwrap_or("ws");
309        let host = uri.host().ok_or_else(|| {
310            Error::Url(tokio_tungstenite::tungstenite::error::UrlError::NoHostName)
311        })?;
312
313        // Determine port: use explicit port if specified, otherwise default based on scheme
314        let port = uri
315            .port_u16()
316            .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
317
318        let addr = format!("{host}:{port}");
319
320        // Use the connector to get a turmoil-compatible stream
321        let connector = crate::net::RealTcpConnector;
322        let tcp_stream = connector.connect(&addr).await?;
323        if let Err(e) = tcp_stream.set_nodelay(true) {
324            tracing::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
325        }
326
327        // Wrap stream appropriately based on scheme
328        let maybe_tls_stream = if scheme == "wss" {
329            // Build TLS config with webpki roots
330            let mut root_store = rustls::RootCertStore::empty();
331            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
332
333            let config = ClientConfig::builder()
334                .with_root_certificates(root_store)
335                .with_no_client_auth();
336
337            let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
338            let domain =
339                rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|e| {
340                    Error::Io(std::io::Error::new(
341                        std::io::ErrorKind::InvalidInput,
342                        format!("Invalid DNS name: {e}"),
343                    ))
344                })?;
345
346            let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
347            MaybeTlsStream::Rustls(tls_stream)
348        } else {
349            MaybeTlsStream::Plain(tcp_stream)
350        };
351
352        // Use client_async with the stream (plain or TLS)
353        client_async(request, maybe_tls_stream)
354            .await
355            .map(|resp| resp.0.split())
356    }
357
358    /// Reconnect with server.
359    ///
360    /// Make a new connection with server. Use the new read and write halves
361    /// to update self writer and read and heartbeat tasks.
362    ///
363    /// For stream-based clients (created via `connect_stream`), reconnection is disabled
364    /// because the reader is owned by the caller and cannot be replaced. Stream users
365    /// should handle disconnections by creating a new connection.
366    ///
367    /// # Errors
368    ///
369    /// Returns an error if:
370    /// - The reconnection attempt times out.
371    /// - The connection to the server fails.
372    pub async fn reconnect(&mut self) -> Result<(), Error> {
373        tracing::debug!("Reconnecting");
374
375        if self.is_stream_mode {
376            tracing::warn!(
377                "Auto-reconnect disabled for stream-based WebSocket client; \
378                stream users must manually reconnect by creating a new connection"
379            );
380            // Transition to CLOSED state to stop reconnection attempts
381            self.connection_mode
382                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
383            return Ok(());
384        }
385
386        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
387            tracing::debug!("Reconnect aborted due to disconnect state");
388            return Ok(());
389        }
390
391        tokio::time::timeout(self.reconnect_timeout, async {
392            // Attempt to connect; abort early if a disconnect was requested
393            let (new_writer, reader) =
394                Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
395
396            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
397                tracing::debug!("Reconnect aborted mid-flight (after connect)");
398                return Ok(());
399            }
400
401            // Use a oneshot channel to synchronize with the writer task.
402            // We must verify that the buffer was successfully drained before transitioning to ACTIVE
403            // to prevent silent message loss if the new connection drops immediately.
404            let (tx, rx) = tokio::sync::oneshot::channel();
405            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
406                tracing::error!("{e}");
407                return Err(Error::Io(std::io::Error::new(
408                    std::io::ErrorKind::BrokenPipe,
409                    format!("Failed to send update command: {e}"),
410                )));
411            }
412
413            // Wait for writer to confirm it has drained the buffer
414            match rx.await {
415                Ok(true) => tracing::debug!("Writer confirmed buffer drain success"),
416                Ok(false) => {
417                    tracing::warn!("Writer failed to drain buffer, aborting reconnect");
418                    // Return error to trigger retry logic in controller
419                    return Err(Error::Io(std::io::Error::other(
420                        "Failed to drain reconnection buffer",
421                    )));
422                }
423                Err(e) => {
424                    tracing::error!("Writer dropped update channel: {e}");
425                    return Err(Error::Io(std::io::Error::new(
426                        std::io::ErrorKind::BrokenPipe,
427                        "Writer task dropped response channel",
428                    )));
429                }
430            }
431
432            // Delay before closing connection
433            tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
434
435            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
436                tracing::debug!("Reconnect aborted mid-flight (after delay)");
437                return Ok(());
438            }
439
440            if let Some(ref read_task) = self.read_task.take()
441                && !read_task.is_finished()
442            {
443                read_task.abort();
444                log_task_aborted("read");
445            }
446
447            // Atomically transition from Reconnect to Active
448            // This prevents race condition where disconnect could be requested between check and store
449            if self
450                .connection_mode
451                .compare_exchange(
452                    ConnectionMode::Reconnect.as_u8(),
453                    ConnectionMode::Active.as_u8(),
454                    Ordering::SeqCst,
455                    Ordering::SeqCst,
456                )
457                .is_err()
458            {
459                tracing::debug!("Reconnect aborted (state changed during reconnect)");
460                return Ok(());
461            }
462
463            self.read_task = if self.message_handler.is_some() {
464                Some(Self::spawn_message_handler_task(
465                    self.connection_mode.clone(),
466                    reader,
467                    self.message_handler.as_ref(),
468                    self.ping_handler.as_ref(),
469                ))
470            } else {
471                None
472            };
473
474            tracing::debug!("Reconnect succeeded");
475            Ok(())
476        })
477        .await
478        .map_err(|_| {
479            Error::Io(std::io::Error::new(
480                std::io::ErrorKind::TimedOut,
481                format!(
482                    "reconnection timed out after {}s",
483                    self.reconnect_timeout.as_secs_f64()
484                ),
485            ))
486        })?
487    }
488
489    /// Check if the client is still connected.
490    ///
491    /// The client is connected if the read task has not finished. It is expected
492    /// that in case of any failure client or server side. The read task will be
493    /// shutdown or will receive a `Close` frame which will finish it. There
494    /// might be some delay between the connection being closed and the client
495    /// detecting.
496    #[inline]
497    #[must_use]
498    pub fn is_alive(&self) -> bool {
499        match &self.read_task {
500            Some(read_task) => !read_task.is_finished(),
501            None => true, // Stream is being used directly
502        }
503    }
504
505    fn spawn_message_handler_task(
506        connection_state: Arc<AtomicU8>,
507        mut reader: MessageReader,
508        message_handler: Option<&MessageHandler>,
509        ping_handler: Option<&PingHandler>,
510    ) -> tokio::task::JoinHandle<()> {
511        tracing::debug!("Started message handler task 'read'");
512
513        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
514
515        // Clone Arc handlers for the async task
516        let message_handler = message_handler.cloned();
517        let ping_handler = ping_handler.cloned();
518
519        tokio::task::spawn(async move {
520            loop {
521                if !ConnectionMode::from_atomic(&connection_state).is_active() {
522                    break;
523                }
524
525                match tokio::time::timeout(check_interval, reader.next()).await {
526                    Ok(Some(Ok(Message::Binary(data)))) => {
527                        tracing::trace!("Received message <binary> {} bytes", data.len());
528                        if let Some(ref handler) = message_handler {
529                            handler(Message::Binary(data));
530                        }
531                    }
532                    Ok(Some(Ok(Message::Text(data)))) => {
533                        tracing::trace!("Received message: {data}");
534                        if let Some(ref handler) = message_handler {
535                            handler(Message::Text(data));
536                        }
537                    }
538                    Ok(Some(Ok(Message::Ping(ping_data)))) => {
539                        tracing::trace!("Received ping: {ping_data:?}");
540                        if let Some(ref handler) = ping_handler {
541                            handler(ping_data.to_vec());
542                        }
543                    }
544                    Ok(Some(Ok(Message::Pong(_)))) => {
545                        tracing::trace!("Received pong");
546                    }
547                    Ok(Some(Ok(Message::Close(_)))) => {
548                        tracing::debug!("Received close message - terminating");
549                        break;
550                    }
551                    Ok(Some(Ok(_))) => (),
552                    Ok(Some(Err(e))) => {
553                        tracing::error!("Received error message - terminating: {e}");
554                        break;
555                    }
556                    Ok(None) => {
557                        tracing::debug!("No message received - terminating");
558                        break;
559                    }
560                    Err(_) => {
561                        // Timeout - continue loop and check connection mode
562                        continue;
563                    }
564                }
565            }
566        })
567    }
568
569    /// Attempts to send all buffered messages after reconnection.
570    ///
571    /// Returns `true` if a send error occurred (caller should trigger reconnection).
572    /// Messages remain in buffer if send fails, preserving them for the next reconnection attempt.
573    async fn drain_reconnect_buffer(
574        buffer: &mut VecDeque<Message>,
575        writer: &mut MessageWriter,
576    ) -> bool {
577        if buffer.is_empty() {
578            return false;
579        }
580
581        let initial_buffer_len = buffer.len();
582        tracing::info!(
583            "Sending {} buffered messages after reconnection",
584            initial_buffer_len
585        );
586
587        let mut send_error_occurred = false;
588
589        while let Some(buffered_msg) = buffer.front() {
590            // Clone message before attempting send (to keep in buffer if send fails)
591            let msg_to_send = buffered_msg.clone();
592
593            if let Err(e) = writer.send(msg_to_send).await {
594                tracing::error!(
595                    "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
596                    buffer.len()
597                );
598                send_error_occurred = true;
599                break; // Stop processing buffer, remaining messages preserved for next reconnection
600            }
601
602            // Only remove from buffer after successful send
603            buffer.pop_front();
604        }
605
606        if buffer.is_empty() {
607            tracing::info!(
608                "Successfully sent all {} buffered messages",
609                initial_buffer_len
610            );
611        }
612
613        send_error_occurred
614    }
615
616    fn spawn_write_task(
617        connection_state: Arc<AtomicU8>,
618        writer: MessageWriter,
619        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
620    ) -> tokio::task::JoinHandle<()> {
621        log_task_started("write");
622
623        // Interval between checking the connection mode
624        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
625
626        tokio::task::spawn(async move {
627            let mut active_writer = writer;
628            // Buffer for messages received during reconnection
629            // VecDeque for efficient pop_front() operations
630            let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
631
632            loop {
633                match ConnectionMode::from_atomic(&connection_state) {
634                    ConnectionMode::Disconnect => {
635                        // Log any buffered messages that will be lost
636                        if !reconnect_buffer.is_empty() {
637                            tracing::warn!(
638                                "Discarding {} buffered messages due to disconnect",
639                                reconnect_buffer.len()
640                            );
641                            reconnect_buffer.clear();
642                        }
643
644                        // Attempt to close the writer gracefully before exiting,
645                        // we ignore any error as the writer may already be closed.
646                        _ = tokio::time::timeout(
647                            Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
648                            active_writer.close(),
649                        )
650                        .await;
651                        break;
652                    }
653                    ConnectionMode::Closed => {
654                        // Log any buffered messages that will be lost
655                        if !reconnect_buffer.is_empty() {
656                            tracing::warn!(
657                                "Discarding {} buffered messages due to closed connection",
658                                reconnect_buffer.len()
659                            );
660                            reconnect_buffer.clear();
661                        }
662                        break;
663                    }
664                    _ => {}
665                }
666
667                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
668                    Ok(Some(msg)) => {
669                        // Re-check connection mode after receiving a message
670                        let mode = ConnectionMode::from_atomic(&connection_state);
671                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
672                            break;
673                        }
674
675                        match msg {
676                            WriterCommand::Update(new_writer, tx) => {
677                                tracing::debug!("Received new writer");
678
679                                // Delay before closing connection
680                                tokio::time::sleep(Duration::from_millis(100)).await;
681
682                                // Attempt to close the writer gracefully on update,
683                                // we ignore any error as the writer may already be closed.
684                                _ = tokio::time::timeout(
685                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
686                                    active_writer.close(),
687                                )
688                                .await;
689
690                                active_writer = new_writer;
691                                tracing::debug!("Updated writer");
692
693                                let send_error = Self::drain_reconnect_buffer(
694                                    &mut reconnect_buffer,
695                                    &mut active_writer,
696                                )
697                                .await;
698
699                                if let Err(e) = tx.send(!send_error) {
700                                    tracing::error!(
701                                        "Failed to report drain status to controller: {e:?}"
702                                    );
703                                }
704                            }
705                            WriterCommand::Send(msg) if mode.is_reconnect() => {
706                                // Buffer messages during reconnection instead of dropping them
707                                tracing::debug!(
708                                    "Buffering message during reconnection (buffer size: {})",
709                                    reconnect_buffer.len() + 1
710                                );
711                                reconnect_buffer.push_back(msg);
712                            }
713                            WriterCommand::Send(msg) => {
714                                if let Err(e) = active_writer.send(msg.clone()).await {
715                                    tracing::error!("Failed to send message: {e}");
716                                    tracing::warn!("Writer triggering reconnect");
717                                    reconnect_buffer.push_back(msg);
718                                    connection_state
719                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
720                                }
721                            }
722                        }
723                    }
724                    Ok(None) => {
725                        // Channel closed - writer task should terminate
726                        tracing::debug!("Writer channel closed, terminating writer task");
727                        break;
728                    }
729                    Err(_) => {
730                        // Timeout - just continue the loop
731                        continue;
732                    }
733                }
734            }
735
736            // Attempt to close the writer gracefully before exiting,
737            // we ignore any error as the writer may already be closed.
738            _ = tokio::time::timeout(
739                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
740                active_writer.close(),
741            )
742            .await;
743
744            log_task_stopped("write");
745        })
746    }
747
748    fn spawn_heartbeat_task(
749        connection_state: Arc<AtomicU8>,
750        heartbeat_secs: u64,
751        message: Option<String>,
752        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
753    ) -> tokio::task::JoinHandle<()> {
754        log_task_started("heartbeat");
755
756        tokio::task::spawn(async move {
757            let interval = Duration::from_secs(heartbeat_secs);
758
759            loop {
760                tokio::time::sleep(interval).await;
761
762                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
763                    ConnectionMode::Active => {
764                        let msg = match &message {
765                            Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
766                            None => WriterCommand::Send(Message::Ping(vec![].into())),
767                        };
768
769                        match writer_tx.send(msg) {
770                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
771                            Err(e) => {
772                                tracing::error!("Failed to send heartbeat to writer task: {e}");
773                            }
774                        }
775                    }
776                    ConnectionMode::Reconnect => continue,
777                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
778                }
779            }
780
781            log_task_stopped("heartbeat");
782        })
783    }
784}
785
786impl Drop for WebSocketClientInner {
787    fn drop(&mut self) {
788        // Delegate to explicit cleanup handler
789        self.clean_drop();
790    }
791}
792
793/// Cleanup on drop: aborts background tasks and clears handlers to break reference cycles.
794impl CleanDrop for WebSocketClientInner {
795    fn clean_drop(&mut self) {
796        if let Some(ref read_task) = self.read_task.take()
797            && !read_task.is_finished()
798        {
799            read_task.abort();
800            log_task_aborted("read");
801        }
802
803        if !self.write_task.is_finished() {
804            self.write_task.abort();
805            log_task_aborted("write");
806        }
807
808        if let Some(ref handle) = self.heartbeat_task.take()
809            && !handle.is_finished()
810        {
811            handle.abort();
812            log_task_aborted("heartbeat");
813        }
814
815        // Clear handlers to break potential reference cycles
816        self.message_handler = None;
817        self.ping_handler = None;
818    }
819}
820
821impl Debug for WebSocketClientInner {
822    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
823        f.debug_struct("WebSocketClientInner")
824            .field("config", &self.config)
825            .field(
826                "connection_mode",
827                &ConnectionMode::from_atomic(&self.connection_mode),
828            )
829            .field("reconnect_timeout", &self.reconnect_timeout)
830            .field("is_stream_mode", &self.is_stream_mode)
831            .finish()
832    }
833}
834
835/// WebSocket client with automatic reconnection.
836///
837/// Handles connection state, callbacks, and rate limiting.
838/// See module docs for architecture details.
839#[cfg_attr(
840    feature = "python",
841    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
842)]
843pub struct WebSocketClient {
844    pub(crate) controller_task: tokio::task::JoinHandle<()>,
845    pub(crate) connection_mode: Arc<AtomicU8>,
846    pub(crate) reconnect_timeout: Duration,
847    pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
848    pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
849}
850
851impl Debug for WebSocketClient {
852    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
853        f.debug_struct(stringify!(WebSocketClient)).finish()
854    }
855}
856
857impl WebSocketClient {
858    /// Creates a websocket client in **stream mode** that returns a [`MessageReader`].
859    ///
860    /// Returns a stream that the caller owns and reads from directly. Automatic reconnection
861    /// is **disabled** because the reader cannot be replaced internally. On disconnection, the
862    /// client transitions to CLOSED state and the caller must manually reconnect by calling
863    /// `connect_stream` again.
864    ///
865    /// Use stream mode when you need custom reconnection logic, direct control over message
866    /// reading, or fine-grained backpressure handling.
867    ///
868    /// See [`WebSocketConfig`] documentation for comparison with handler mode.
869    ///
870    /// # Errors
871    ///
872    /// Returns an error if the connection cannot be established.
873    #[allow(clippy::too_many_arguments)]
874    pub async fn connect_stream(
875        config: WebSocketConfig,
876        keyed_quotas: Vec<(String, Quota)>,
877        default_quota: Option<Quota>,
878        post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
879    ) -> Result<(MessageReader, Self), Error> {
880        install_cryptographic_provider();
881
882        // Create a single connection and split it, respecting configured headers
883        let (writer, reader) =
884            WebSocketClientInner::connect_with_server(&config.url, config.headers.clone()).await?;
885
886        // Create inner without connecting (we'll provide the writer)
887        let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
888
889        let connection_mode = inner.connection_mode.clone();
890        let reconnect_timeout = inner.reconnect_timeout;
891        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
892        let writer_tx = inner.writer_tx.clone();
893
894        let controller_task =
895            Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnect);
896
897        Ok((
898            reader,
899            Self {
900                controller_task,
901                connection_mode,
902                reconnect_timeout,
903                rate_limiter,
904                writer_tx,
905            },
906        ))
907    }
908
909    /// Creates a websocket client in **handler mode** with automatic reconnection.
910    ///
911    /// The handler is called for each incoming message on an internal task.
912    /// Automatic reconnection is **enabled** with exponential backoff. On disconnection,
913    /// the client automatically attempts to reconnect and replaces the internal reader
914    /// (the handler continues working seamlessly).
915    ///
916    /// Use handler mode for simplified connection management, automatic reconnection, Python
917    /// bindings, or callback-based message handling.
918    ///
919    /// See [`WebSocketConfig`] documentation for comparison with stream mode.
920    ///
921    /// # Errors
922    ///
923    /// Returns an error if:
924    /// - The connection cannot be established.
925    /// - `message_handler` is `None` (use `connect_stream` instead).
926    pub async fn connect(
927        config: WebSocketConfig,
928        message_handler: Option<MessageHandler>,
929        ping_handler: Option<PingHandler>,
930        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
931        keyed_quotas: Vec<(String, Quota)>,
932        default_quota: Option<Quota>,
933    ) -> Result<Self, Error> {
934        // Validate that handler mode has a message handler
935        if message_handler.is_none() {
936            return Err(Error::Io(std::io::Error::new(
937                std::io::ErrorKind::InvalidInput,
938                "Handler mode requires message_handler to be set. Use connect_stream() for stream mode without a handler.",
939            )));
940        }
941
942        tracing::debug!("Connecting");
943        let inner =
944            WebSocketClientInner::connect_url(config, message_handler, ping_handler).await?;
945        let connection_mode = inner.connection_mode.clone();
946        let writer_tx = inner.writer_tx.clone();
947        let reconnect_timeout = inner.reconnect_timeout;
948
949        let controller_task =
950            Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnection);
951
952        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
953
954        Ok(Self {
955            controller_task,
956            connection_mode,
957            reconnect_timeout,
958            rate_limiter,
959            writer_tx,
960        })
961    }
962
963    /// Returns the current connection mode.
964    #[must_use]
965    pub fn connection_mode(&self) -> ConnectionMode {
966        ConnectionMode::from_atomic(&self.connection_mode)
967    }
968
969    /// Returns a clone of the connection mode atomic for external state tracking.
970    ///
971    /// This allows adapter clients to track connection state across reconnections
972    /// without message-passing delays.
973    #[must_use]
974    pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
975        Arc::clone(&self.connection_mode)
976    }
977
978    /// Check if the client connection is active.
979    ///
980    /// Returns `true` if the client is connected and has not been signalled to disconnect.
981    /// The client will automatically retry connection based on its configuration.
982    #[inline]
983    #[must_use]
984    pub fn is_active(&self) -> bool {
985        self.connection_mode().is_active()
986    }
987
988    /// Check if the client is disconnected.
989    #[must_use]
990    pub fn is_disconnected(&self) -> bool {
991        self.controller_task.is_finished()
992    }
993
994    /// Check if the client is reconnecting.
995    ///
996    /// Returns `true` if the client lost connection and is attempting to reestablish it.
997    /// The client will automatically retry connection based on its configuration.
998    #[inline]
999    #[must_use]
1000    pub fn is_reconnecting(&self) -> bool {
1001        self.connection_mode().is_reconnect()
1002    }
1003
1004    /// Check if the client is disconnecting.
1005    ///
1006    /// Returns `true` if the client is in disconnect mode.
1007    #[inline]
1008    #[must_use]
1009    pub fn is_disconnecting(&self) -> bool {
1010        self.connection_mode().is_disconnect()
1011    }
1012
1013    /// Check if the client is closed.
1014    ///
1015    /// Returns `true` if the client has been explicitly disconnected or reached
1016    /// maximum reconnection attempts. In this state, the client cannot be reused
1017    /// and a new client must be created for further connections.
1018    #[inline]
1019    #[must_use]
1020    pub fn is_closed(&self) -> bool {
1021        self.connection_mode().is_closed()
1022    }
1023
1024    /// Wait for the client to become active before sending.
1025    ///
1026    /// Returns an error if the client is closed, disconnecting, or if the wait times out.
1027    async fn wait_for_active(&self) -> Result<(), SendError> {
1028        if self.is_closed() {
1029            return Err(SendError::Closed);
1030        }
1031
1032        let timeout = self.reconnect_timeout;
1033        let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
1034
1035        if !self.is_active() {
1036            tracing::debug!("Waiting for client to become ACTIVE before sending...");
1037
1038            let inner = tokio::time::timeout(timeout, async {
1039                loop {
1040                    if self.is_active() {
1041                        return Ok(());
1042                    }
1043                    if matches!(
1044                        self.connection_mode(),
1045                        ConnectionMode::Disconnect | ConnectionMode::Closed
1046                    ) {
1047                        return Err(());
1048                    }
1049                    tokio::time::sleep(check_interval).await;
1050                }
1051            })
1052            .await
1053            .map_err(|_| SendError::Timeout)?;
1054            inner.map_err(|()| SendError::Closed)?;
1055        }
1056
1057        Ok(())
1058    }
1059
1060    /// Set disconnect mode to true.
1061    ///
1062    /// Controller task will periodically check the disconnect mode
1063    /// and shutdown the client if it is alive
1064    pub async fn disconnect(&self) {
1065        tracing::debug!("Disconnecting");
1066        self.connection_mode
1067            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1068
1069        if let Ok(()) =
1070            tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1071                while !self.is_disconnected() {
1072                    tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS))
1073                        .await;
1074                }
1075
1076                if !self.controller_task.is_finished() {
1077                    self.controller_task.abort();
1078                    log_task_aborted("controller");
1079                }
1080            })
1081            .await
1082        {
1083            tracing::debug!("Controller task finished");
1084        } else {
1085            tracing::error!("Timeout waiting for controller task to finish");
1086            if !self.controller_task.is_finished() {
1087                self.controller_task.abort();
1088                log_task_aborted("controller");
1089            }
1090        }
1091    }
1092
1093    /// Sends the given text `data` to the server.
1094    ///
1095    /// # Errors
1096    ///
1097    /// Returns a websocket error if unable to send.
1098    #[allow(unused_variables)]
1099    pub async fn send_text(
1100        &self,
1101        data: String,
1102        keys: Option<Vec<String>>,
1103    ) -> Result<(), SendError> {
1104        // Check connection state before rate limiting to fail fast
1105        if self.is_closed() || self.is_disconnecting() {
1106            return Err(SendError::Closed);
1107        }
1108
1109        self.rate_limiter.await_keys_ready(keys).await;
1110        self.wait_for_active().await?;
1111
1112        tracing::trace!("Sending text: {data:?}");
1113
1114        let msg = Message::Text(data.into());
1115        self.writer_tx
1116            .send(WriterCommand::Send(msg))
1117            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1118    }
1119
1120    /// Sends a pong frame back to the server.
1121    ///
1122    /// # Errors
1123    ///
1124    /// Returns a websocket error if unable to send.
1125    pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1126        self.wait_for_active().await?;
1127
1128        tracing::trace!("Sending pong frame ({} bytes)", data.len());
1129
1130        let msg = Message::Pong(data.into());
1131        self.writer_tx
1132            .send(WriterCommand::Send(msg))
1133            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1134    }
1135
1136    /// Sends the given bytes `data` to the server.
1137    ///
1138    /// # Errors
1139    ///
1140    /// Returns a websocket error if unable to send.
1141    #[allow(unused_variables)]
1142    pub async fn send_bytes(
1143        &self,
1144        data: Vec<u8>,
1145        keys: Option<Vec<String>>,
1146    ) -> Result<(), SendError> {
1147        // Check connection state before rate limiting to fail fast
1148        if self.is_closed() || self.is_disconnecting() {
1149            return Err(SendError::Closed);
1150        }
1151
1152        self.rate_limiter.await_keys_ready(keys).await;
1153        self.wait_for_active().await?;
1154
1155        tracing::trace!("Sending bytes: {data:?}");
1156
1157        let msg = Message::Binary(data.into());
1158        self.writer_tx
1159            .send(WriterCommand::Send(msg))
1160            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1161    }
1162
1163    /// Sends a close message to the server.
1164    ///
1165    /// # Errors
1166    ///
1167    /// Returns a websocket error if unable to send.
1168    pub async fn send_close_message(&self) -> Result<(), SendError> {
1169        self.wait_for_active().await?;
1170
1171        let msg = Message::Close(None);
1172        self.writer_tx
1173            .send(WriterCommand::Send(msg))
1174            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1175    }
1176
1177    fn spawn_controller_task(
1178        mut inner: WebSocketClientInner,
1179        connection_mode: Arc<AtomicU8>,
1180        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1181    ) -> tokio::task::JoinHandle<()> {
1182        tokio::task::spawn(async move {
1183            log_task_started("controller");
1184
1185            let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1186
1187            loop {
1188                tokio::time::sleep(check_interval).await;
1189                let mut mode = ConnectionMode::from_atomic(&connection_mode);
1190
1191                if mode.is_disconnect() {
1192                    tracing::debug!("Disconnecting");
1193
1194                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1195                    if tokio::time::timeout(timeout, async {
1196                        // Delay awaiting graceful shutdown
1197                        tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1198
1199                        if let Some(task) = &inner.read_task
1200                            && !task.is_finished()
1201                        {
1202                            task.abort();
1203                            log_task_aborted("read");
1204                        }
1205
1206                        if let Some(task) = &inner.heartbeat_task
1207                            && !task.is_finished()
1208                        {
1209                            task.abort();
1210                            log_task_aborted("heartbeat");
1211                        }
1212                    })
1213                    .await
1214                    .is_err()
1215                    {
1216                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
1217                    }
1218
1219                    tracing::debug!("Closed");
1220                    break; // Controller finished
1221                }
1222
1223                if mode.is_closed() {
1224                    tracing::debug!("Connection closed");
1225                    break;
1226                }
1227
1228                if mode.is_active() && !inner.is_alive() {
1229                    if connection_mode
1230                        .compare_exchange(
1231                            ConnectionMode::Active.as_u8(),
1232                            ConnectionMode::Reconnect.as_u8(),
1233                            Ordering::SeqCst,
1234                            Ordering::SeqCst,
1235                        )
1236                        .is_ok()
1237                    {
1238                        tracing::debug!("Detected dead read task, transitioning to RECONNECT");
1239                    }
1240                    mode = ConnectionMode::from_atomic(&connection_mode);
1241                }
1242
1243                if mode.is_reconnect() {
1244                    // Check if max reconnection attempts exceeded
1245                    if let Some(max_attempts) = inner.reconnect_max_attempts
1246                        && inner.reconnection_attempt_count >= max_attempts
1247                    {
1248                        tracing::error!(
1249                            "Max reconnection attempts ({}) exceeded, transitioning to CLOSED",
1250                            max_attempts
1251                        );
1252                        connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1253                        break;
1254                    }
1255
1256                    inner.reconnection_attempt_count += 1;
1257                    tracing::debug!(
1258                        "Reconnection attempt {} of {}",
1259                        inner.reconnection_attempt_count,
1260                        inner
1261                            .reconnect_max_attempts
1262                            .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1263                    );
1264
1265                    match inner.reconnect().await {
1266                        Ok(()) => {
1267                            inner.backoff.reset();
1268                            inner.reconnection_attempt_count = 0; // Reset counter on success
1269
1270                            // Only invoke callbacks if not in disconnect state
1271                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
1272                                if let Some(ref handler) = inner.message_handler {
1273                                    let reconnected_msg =
1274                                        Message::Text(RECONNECTED.to_string().into());
1275                                    handler(reconnected_msg);
1276                                    tracing::debug!("Sent reconnected message to handler");
1277                                }
1278
1279                                // TODO: Retain this legacy callback for use from Python
1280                                if let Some(ref callback) = post_reconnection {
1281                                    callback();
1282                                    tracing::debug!("Called `post_reconnection` handler");
1283                                }
1284
1285                                tracing::debug!("Reconnected successfully");
1286                            } else {
1287                                tracing::debug!(
1288                                    "Skipping post_reconnection handlers due to disconnect state"
1289                                );
1290                            }
1291                        }
1292                        Err(e) => {
1293                            let duration = inner.backoff.next_duration();
1294                            tracing::warn!(
1295                                "Reconnect attempt {} failed: {e}",
1296                                inner.reconnection_attempt_count
1297                            );
1298                            if !duration.is_zero() {
1299                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
1300                            }
1301                            tokio::time::sleep(duration).await;
1302                        }
1303                    }
1304                }
1305            }
1306            inner
1307                .connection_mode
1308                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1309
1310            log_task_stopped("controller");
1311        })
1312    }
1313}
1314
1315// Abort controller task on drop to clean up background tasks
1316impl Drop for WebSocketClient {
1317    fn drop(&mut self) {
1318        if !self.controller_task.is_finished() {
1319            self.controller_task.abort();
1320            log_task_aborted("controller");
1321        }
1322    }
1323}
1324
1325#[cfg(test)]
1326#[cfg(not(feature = "turmoil"))]
1327#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
1328mod tests {
1329    use std::{num::NonZeroU32, sync::Arc};
1330
1331    use futures_util::{SinkExt, StreamExt};
1332    use tokio::{
1333        net::TcpListener,
1334        task::{self, JoinHandle},
1335    };
1336    use tokio_tungstenite::{
1337        accept_hdr_async,
1338        tungstenite::{
1339            handshake::server::{self, Callback},
1340            http::HeaderValue,
1341        },
1342    };
1343
1344    use crate::{
1345        ratelimiter::quota::Quota,
1346        websocket::{WebSocketClient, WebSocketConfig},
1347    };
1348
1349    struct TestServer {
1350        task: JoinHandle<()>,
1351        port: u16,
1352    }
1353
1354    #[derive(Debug, Clone)]
1355    struct TestCallback {
1356        key: String,
1357        value: HeaderValue,
1358    }
1359
1360    impl Callback for TestCallback {
1361        #[allow(clippy::panic_in_result_fn)]
1362        fn on_request(
1363            self,
1364            request: &server::Request,
1365            response: server::Response,
1366        ) -> Result<server::Response, server::ErrorResponse> {
1367            let _ = response;
1368            let value = request.headers().get(&self.key);
1369            assert!(value.is_some());
1370
1371            if let Some(value) = request.headers().get(&self.key) {
1372                assert_eq!(value, self.value);
1373            }
1374
1375            Ok(response)
1376        }
1377    }
1378
1379    impl TestServer {
1380        async fn setup() -> Self {
1381            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1382            let port = TcpListener::local_addr(&server).unwrap().port();
1383
1384            let header_key = "test".to_string();
1385            let header_value = "test".to_string();
1386
1387            let test_call_back = TestCallback {
1388                key: header_key,
1389                value: HeaderValue::from_str(&header_value).unwrap(),
1390            };
1391
1392            let task = task::spawn(async move {
1393                // Keep accepting connections
1394                loop {
1395                    let (conn, _) = server.accept().await.unwrap();
1396                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1397                        .await
1398                        .unwrap();
1399
1400                    task::spawn(async move {
1401                        while let Some(Ok(msg)) = websocket.next().await {
1402                            match msg {
1403                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1404                                    if txt == "close-now" =>
1405                                {
1406                                    tracing::debug!("Forcibly closing from server side");
1407                                    // This sends a close frame, then stops reading
1408                                    let _ = websocket.close(None).await;
1409                                    break;
1410                                }
1411                                // Echo text/binary frames
1412                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1413                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1414                                    if websocket.send(msg).await.is_err() {
1415                                        break;
1416                                    }
1417                                }
1418                                // If the client closes, we also break
1419                                tokio_tungstenite::tungstenite::protocol::Message::Close(
1420                                    _frame,
1421                                ) => {
1422                                    let _ = websocket.close(None).await;
1423                                    break;
1424                                }
1425                                // Ignore pings/pongs
1426                                _ => {}
1427                            }
1428                        }
1429                    });
1430                }
1431            });
1432
1433            Self { task, port }
1434        }
1435    }
1436
1437    impl Drop for TestServer {
1438        fn drop(&mut self) {
1439            self.task.abort();
1440        }
1441    }
1442
1443    async fn setup_test_client(port: u16) -> WebSocketClient {
1444        let config = WebSocketConfig {
1445            url: format!("ws://127.0.0.1:{port}"),
1446            headers: vec![("test".into(), "test".into())],
1447            heartbeat: None,
1448            heartbeat_msg: None,
1449            reconnect_timeout_ms: None,
1450            reconnect_delay_initial_ms: None,
1451            reconnect_backoff_factor: None,
1452            reconnect_delay_max_ms: None,
1453            reconnect_jitter_ms: None,
1454            reconnect_max_attempts: None,
1455        };
1456        WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1457            .await
1458            .expect("Failed to connect")
1459    }
1460
1461    #[tokio::test]
1462    async fn test_websocket_basic() {
1463        let server = TestServer::setup().await;
1464        let client = setup_test_client(server.port).await;
1465
1466        assert!(!client.is_disconnected());
1467
1468        client.disconnect().await;
1469        assert!(client.is_disconnected());
1470    }
1471
1472    #[tokio::test]
1473    async fn test_websocket_heartbeat() {
1474        let server = TestServer::setup().await;
1475        let client = setup_test_client(server.port).await;
1476
1477        // Wait ~3s => server should see multiple "ping"
1478        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1479
1480        // Cleanup
1481        client.disconnect().await;
1482        assert!(client.is_disconnected());
1483    }
1484
1485    #[tokio::test]
1486    async fn test_websocket_reconnect_exhausted() {
1487        let config = WebSocketConfig {
1488            url: "ws://127.0.0.1:9997".into(), // <-- No server
1489            headers: vec![],
1490            heartbeat: None,
1491            heartbeat_msg: None,
1492            reconnect_timeout_ms: None,
1493            reconnect_delay_initial_ms: None,
1494            reconnect_backoff_factor: None,
1495            reconnect_delay_max_ms: None,
1496            reconnect_jitter_ms: None,
1497            reconnect_max_attempts: None,
1498        };
1499        let res =
1500            WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1501                .await;
1502        assert!(res.is_err(), "Should fail quickly with no server");
1503    }
1504
1505    #[tokio::test]
1506    async fn test_websocket_forced_close_reconnect() {
1507        let server = TestServer::setup().await;
1508        let client = setup_test_client(server.port).await;
1509
1510        // 1) Send normal message
1511        client.send_text("Hello".into(), None).await.unwrap();
1512
1513        // 2) Trigger forced close from server
1514        client.send_text("close-now".into(), None).await.unwrap();
1515
1516        // 3) Wait a bit => read loop sees close => reconnect
1517        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1518
1519        // Confirm not disconnected
1520        assert!(!client.is_disconnected());
1521
1522        // Cleanup
1523        client.disconnect().await;
1524        assert!(client.is_disconnected());
1525    }
1526
1527    #[tokio::test]
1528    async fn test_rate_limiter() {
1529        let server = TestServer::setup().await;
1530        let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1531
1532        let config = WebSocketConfig {
1533            url: format!("ws://127.0.0.1:{}", server.port),
1534            headers: vec![("test".into(), "test".into())],
1535            heartbeat: None,
1536            heartbeat_msg: None,
1537            reconnect_timeout_ms: None,
1538            reconnect_delay_initial_ms: None,
1539            reconnect_backoff_factor: None,
1540            reconnect_delay_max_ms: None,
1541            reconnect_jitter_ms: None,
1542            reconnect_max_attempts: None,
1543        };
1544
1545        let client = WebSocketClient::connect(
1546            config,
1547            Some(Arc::new(|_| {})),
1548            None,
1549            None,
1550            vec![("default".into(), quota)],
1551            None,
1552        )
1553        .await
1554        .unwrap();
1555
1556        // First 2 should succeed
1557        client.send_text("test1".into(), None).await.unwrap();
1558        client.send_text("test2".into(), None).await.unwrap();
1559
1560        // Third should error
1561        client.send_text("test3".into(), None).await.unwrap();
1562
1563        // Cleanup
1564        client.disconnect().await;
1565        assert!(client.is_disconnected());
1566    }
1567
1568    #[tokio::test]
1569    async fn test_concurrent_writers() {
1570        let server = TestServer::setup().await;
1571        let client = Arc::new(setup_test_client(server.port).await);
1572
1573        let mut handles = vec![];
1574        for i in 0..10 {
1575            let client = client.clone();
1576            handles.push(task::spawn(async move {
1577                client.send_text(format!("test{i}"), None).await.unwrap();
1578            }));
1579        }
1580
1581        for handle in handles {
1582            handle.await.unwrap();
1583        }
1584
1585        // Cleanup
1586        client.disconnect().await;
1587        assert!(client.is_disconnected());
1588    }
1589}
1590
1591#[cfg(test)]
1592#[cfg(not(feature = "turmoil"))]
1593mod rust_tests {
1594    use futures_util::StreamExt;
1595    use rstest::rstest;
1596    use tokio::{
1597        net::TcpListener,
1598        task,
1599        time::{Duration, sleep},
1600    };
1601    use tokio_tungstenite::accept_async;
1602
1603    use super::*;
1604    use crate::websocket::types::channel_message_handler;
1605
1606    #[rstest]
1607    #[tokio::test]
1608    async fn test_reconnect_then_disconnect() {
1609        // Bind an ephemeral port
1610        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1611        let port = listener.local_addr().unwrap().port();
1612
1613        // Server task: accept one ws connection then close it
1614        let server = task::spawn(async move {
1615            let (stream, _) = listener.accept().await.unwrap();
1616            let ws = accept_async(stream).await.unwrap();
1617            drop(ws);
1618            // Keep alive briefly
1619            sleep(Duration::from_secs(1)).await;
1620        });
1621
1622        // Build a channel-based message handler for incoming messages (unused here)
1623        let (handler, _rx) = channel_message_handler();
1624
1625        // Configure client with short reconnect backoff
1626        let config = WebSocketConfig {
1627            url: format!("ws://127.0.0.1:{port}"),
1628            headers: vec![],
1629            heartbeat: None,
1630            heartbeat_msg: None,
1631            reconnect_timeout_ms: Some(1_000),
1632            reconnect_delay_initial_ms: Some(50),
1633            reconnect_delay_max_ms: Some(100),
1634            reconnect_backoff_factor: Some(1.0),
1635            reconnect_jitter_ms: Some(0),
1636            reconnect_max_attempts: None,
1637        };
1638
1639        // Connect the client
1640        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1641            .await
1642            .unwrap();
1643
1644        // Allow server to drop connection and client to detect
1645        sleep(Duration::from_millis(100)).await;
1646        // Now immediately disconnect the client
1647        client.disconnect().await;
1648        assert!(client.is_disconnected());
1649        server.abort();
1650    }
1651
1652    #[rstest]
1653    #[tokio::test]
1654    async fn test_reconnect_state_flips_when_reader_stops() {
1655        // Bind an ephemeral port and accept a single websocket connection which we drop.
1656        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1657        let port = listener.local_addr().unwrap().port();
1658
1659        let server = task::spawn(async move {
1660            if let Ok((stream, _)) = listener.accept().await
1661                && let Ok(ws) = accept_async(stream).await
1662            {
1663                drop(ws);
1664            }
1665            sleep(Duration::from_millis(50)).await;
1666        });
1667
1668        let (handler, _rx) = channel_message_handler();
1669
1670        let config = WebSocketConfig {
1671            url: format!("ws://127.0.0.1:{port}"),
1672            headers: vec![],
1673            heartbeat: None,
1674            heartbeat_msg: None,
1675            reconnect_timeout_ms: Some(1_000),
1676            reconnect_delay_initial_ms: Some(50),
1677            reconnect_delay_max_ms: Some(100),
1678            reconnect_backoff_factor: Some(1.0),
1679            reconnect_jitter_ms: Some(0),
1680            reconnect_max_attempts: None,
1681        };
1682
1683        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1684            .await
1685            .unwrap();
1686
1687        tokio::time::timeout(Duration::from_secs(2), async {
1688            loop {
1689                if client.is_reconnecting() {
1690                    break;
1691                }
1692                tokio::time::sleep(Duration::from_millis(10)).await;
1693            }
1694        })
1695        .await
1696        .expect("client did not enter RECONNECT state");
1697
1698        client.disconnect().await;
1699        server.abort();
1700    }
1701
1702    #[rstest]
1703    #[tokio::test]
1704    async fn test_stream_mode_disables_auto_reconnect() {
1705        // Test that stream-based clients (created via connect_stream) set is_stream_mode flag
1706        // and that reconnect() transitions to CLOSED state for stream mode
1707        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1708        let port = listener.local_addr().unwrap().port();
1709
1710        let server = task::spawn(async move {
1711            if let Ok((stream, _)) = listener.accept().await
1712                && let Ok(_ws) = accept_async(stream).await
1713            {
1714                // Keep connection alive briefly
1715                sleep(Duration::from_millis(100)).await;
1716            }
1717        });
1718
1719        let config = WebSocketConfig {
1720            url: format!("ws://127.0.0.1:{port}"),
1721            headers: vec![],
1722            heartbeat: None,
1723            heartbeat_msg: None,
1724            reconnect_timeout_ms: Some(1_000),
1725            reconnect_delay_initial_ms: Some(50),
1726            reconnect_delay_max_ms: Some(100),
1727            reconnect_backoff_factor: Some(1.0),
1728            reconnect_jitter_ms: Some(0),
1729            reconnect_max_attempts: None,
1730        };
1731
1732        let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
1733            .await
1734            .unwrap();
1735
1736        // Note: We can't easily test the reconnect behavior from the outside since
1737        // the inner client is private. The key fix is that WebSocketClientInner
1738        // now has is_stream_mode=true for connect_stream, and reconnect() will
1739        // transition to CLOSED state instead of creating a new reader that gets dropped.
1740        // This is tested implicitly by the fact that stream users won't get stuck
1741        // in an infinite reconnect loop.
1742
1743        server.abort();
1744    }
1745
1746    #[rstest]
1747    #[tokio::test]
1748    async fn test_message_handler_mode_allows_auto_reconnect() {
1749        // Test that regular clients (with message handler) can auto-reconnect
1750        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1751        let port = listener.local_addr().unwrap().port();
1752
1753        let server = task::spawn(async move {
1754            // Accept first connection and close it
1755            if let Ok((stream, _)) = listener.accept().await
1756                && let Ok(ws) = accept_async(stream).await
1757            {
1758                drop(ws);
1759            }
1760            sleep(Duration::from_millis(50)).await;
1761        });
1762
1763        let (handler, _rx) = channel_message_handler();
1764
1765        let config = WebSocketConfig {
1766            url: format!("ws://127.0.0.1:{port}"),
1767            headers: vec![],
1768            heartbeat: None,
1769            heartbeat_msg: None,
1770            reconnect_timeout_ms: Some(1_000),
1771            reconnect_delay_initial_ms: Some(50),
1772            reconnect_delay_max_ms: Some(100),
1773            reconnect_backoff_factor: Some(1.0),
1774            reconnect_jitter_ms: Some(0),
1775            reconnect_max_attempts: None,
1776        };
1777
1778        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1779            .await
1780            .unwrap();
1781
1782        // Wait for the connection to be dropped and reconnection to be attempted
1783        tokio::time::timeout(Duration::from_secs(2), async {
1784            loop {
1785                if client.is_reconnecting() || client.is_closed() {
1786                    break;
1787                }
1788                tokio::time::sleep(Duration::from_millis(10)).await;
1789            }
1790        })
1791        .await
1792        .expect("client should attempt reconnection or close");
1793
1794        // Should either be reconnecting or closed (depending on timing)
1795        // The important thing is it's not staying active forever
1796        assert!(
1797            client.is_reconnecting() || client.is_closed(),
1798            "Client with message handler should attempt reconnection"
1799        );
1800
1801        client.disconnect().await;
1802        server.abort();
1803    }
1804
1805    #[rstest]
1806    #[tokio::test]
1807    async fn test_handler_mode_reconnect_with_new_connection() {
1808        // Test that handler mode successfully reconnects and messages continue flowing
1809        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1810        let port = listener.local_addr().unwrap().port();
1811
1812        let server = task::spawn(async move {
1813            // First connection - accept and immediately close
1814            if let Ok((stream, _)) = listener.accept().await
1815                && let Ok(ws) = accept_async(stream).await
1816            {
1817                drop(ws);
1818            }
1819
1820            // Small delay to let client detect disconnection
1821            sleep(Duration::from_millis(100)).await;
1822
1823            // Second connection - accept, send a message, then keep alive
1824            if let Ok((stream, _)) = listener.accept().await
1825                && let Ok(mut ws) = accept_async(stream).await
1826            {
1827                use futures_util::SinkExt;
1828                let _ = ws
1829                    .send(Message::Text("reconnected".to_string().into()))
1830                    .await;
1831                sleep(Duration::from_secs(1)).await;
1832            }
1833        });
1834
1835        let (handler, mut rx) = channel_message_handler();
1836
1837        let config = WebSocketConfig {
1838            url: format!("ws://127.0.0.1:{port}"),
1839            headers: vec![],
1840            heartbeat: None,
1841            heartbeat_msg: None,
1842            reconnect_timeout_ms: Some(2_000),
1843            reconnect_delay_initial_ms: Some(50),
1844            reconnect_delay_max_ms: Some(200),
1845            reconnect_backoff_factor: Some(1.5),
1846            reconnect_jitter_ms: Some(10),
1847            reconnect_max_attempts: None,
1848        };
1849
1850        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1851            .await
1852            .unwrap();
1853
1854        // Wait for reconnection to happen and message to arrive
1855        let result = tokio::time::timeout(Duration::from_secs(5), async {
1856            loop {
1857                if let Ok(msg) = rx.try_recv()
1858                    && matches!(msg, Message::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
1859                {
1860                    return true;
1861                }
1862                tokio::time::sleep(Duration::from_millis(10)).await;
1863            }
1864        })
1865        .await;
1866
1867        assert!(
1868            result.is_ok(),
1869            "Should receive message after reconnection within timeout"
1870        );
1871
1872        client.disconnect().await;
1873        server.abort();
1874    }
1875
1876    #[rstest]
1877    #[tokio::test]
1878    async fn test_stream_mode_no_auto_reconnect() {
1879        // Test that stream mode does not automatically reconnect when connection is lost
1880        // The caller owns the reader and is responsible for detecting disconnection
1881        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1882        let port = listener.local_addr().unwrap().port();
1883
1884        let server = task::spawn(async move {
1885            // Accept connection and send one message, then close
1886            if let Ok((stream, _)) = listener.accept().await
1887                && let Ok(mut ws) = accept_async(stream).await
1888            {
1889                use futures_util::SinkExt;
1890                let _ = ws.send(Message::Text("hello".to_string().into())).await;
1891                sleep(Duration::from_millis(50)).await;
1892                // Connection closes when ws is dropped
1893            }
1894        });
1895
1896        let config = WebSocketConfig {
1897            url: format!("ws://127.0.0.1:{port}"),
1898            headers: vec![],
1899            heartbeat: None,
1900            heartbeat_msg: None,
1901            reconnect_timeout_ms: Some(1_000),
1902            reconnect_delay_initial_ms: Some(50),
1903            reconnect_delay_max_ms: Some(100),
1904            reconnect_backoff_factor: Some(1.0),
1905            reconnect_jitter_ms: Some(0),
1906            reconnect_max_attempts: None,
1907        };
1908
1909        let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
1910            .await
1911            .unwrap();
1912
1913        // Initially active
1914        assert!(client.is_active(), "Client should start as active");
1915
1916        // Read the hello message
1917        let msg = reader.next().await;
1918        assert!(
1919            matches!(msg, Some(Ok(Message::Text(ref text))) if AsRef::<str>::as_ref(text) == "hello"),
1920            "Should receive initial message"
1921        );
1922
1923        // Read until connection closes (reader will return None or error)
1924        while let Some(msg) = reader.next().await {
1925            if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
1926                break;
1927            }
1928        }
1929
1930        // In stream mode, the controller cannot detect disconnection (reader is owned by caller)
1931        // The client remains ACTIVE - it's the caller's responsibility to call disconnect()
1932        sleep(Duration::from_millis(200)).await;
1933
1934        // Client should still be ACTIVE (not RECONNECTING or CLOSED)
1935        // This is correct behavior - stream mode doesn't auto-detect disconnection
1936        assert!(
1937            client.is_active() || client.is_closed(),
1938            "Stream mode client stays ACTIVE (caller owns reader) or caller disconnected"
1939        );
1940        assert!(
1941            !client.is_reconnecting(),
1942            "Stream mode client should never attempt reconnection"
1943        );
1944
1945        client.disconnect().await;
1946        server.abort();
1947    }
1948
1949    #[rstest]
1950    #[tokio::test]
1951    async fn test_send_timeout_uses_configured_reconnect_timeout() {
1952        // Test that send operations respect the configured reconnect_timeout.
1953        // When a client is stuck in RECONNECT longer than the timeout, sends should fail with Timeout.
1954        use nautilus_common::testing::wait_until_async;
1955
1956        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1957        let port = listener.local_addr().unwrap().port();
1958
1959        let server = task::spawn(async move {
1960            // Accept first connection and immediately close it
1961            if let Ok((stream, _)) = listener.accept().await
1962                && let Ok(ws) = accept_async(stream).await
1963            {
1964                drop(ws);
1965            }
1966            // Don't accept second connection - client will be stuck in RECONNECT
1967            sleep(Duration::from_secs(60)).await;
1968        });
1969
1970        let (handler, _rx) = channel_message_handler();
1971
1972        // Configure with SHORT 2s reconnect timeout
1973        let config = WebSocketConfig {
1974            url: format!("ws://127.0.0.1:{port}"),
1975            headers: vec![],
1976            heartbeat: None,
1977            heartbeat_msg: None,
1978            reconnect_timeout_ms: Some(2_000), // 2s timeout
1979            reconnect_delay_initial_ms: Some(50),
1980            reconnect_delay_max_ms: Some(100),
1981            reconnect_backoff_factor: Some(1.0),
1982            reconnect_jitter_ms: Some(0),
1983            reconnect_max_attempts: None,
1984        };
1985
1986        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1987            .await
1988            .unwrap();
1989
1990        // Wait for client to enter RECONNECT state
1991        wait_until_async(
1992            || async { client.is_reconnecting() },
1993            Duration::from_secs(3),
1994        )
1995        .await;
1996
1997        // Attempt send while stuck in RECONNECT - should timeout after 2s (configured timeout)
1998        let start = std::time::Instant::now();
1999        let send_result = client.send_text("test".to_string(), None).await;
2000        let elapsed = start.elapsed();
2001
2002        assert!(
2003            send_result.is_err(),
2004            "Send should fail when client stuck in RECONNECT"
2005        );
2006        assert!(
2007            matches!(send_result, Err(crate::error::SendError::Timeout)),
2008            "Send should return Timeout error, was: {send_result:?}"
2009        );
2010        // Verify timeout respects configured value (2s), but don't check upper bound
2011        // as CI scheduler jitter can cause legitimate delays beyond the timeout
2012        assert!(
2013            elapsed >= Duration::from_millis(1800),
2014            "Send should timeout after at least 2s (configured timeout), took {elapsed:?}"
2015        );
2016
2017        client.disconnect().await;
2018        server.abort();
2019    }
2020
2021    #[rstest]
2022    #[tokio::test]
2023    async fn test_send_waits_during_reconnection() {
2024        // Test that send operations wait for reconnection to complete (up to timeout)
2025        use nautilus_common::testing::wait_until_async;
2026
2027        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2028        let port = listener.local_addr().unwrap().port();
2029
2030        let server = task::spawn(async move {
2031            // First connection - accept and immediately close
2032            if let Ok((stream, _)) = listener.accept().await
2033                && let Ok(ws) = accept_async(stream).await
2034            {
2035                drop(ws);
2036            }
2037
2038            // Wait a bit before accepting second connection
2039            sleep(Duration::from_millis(500)).await;
2040
2041            // Second connection - accept and keep alive
2042            if let Ok((stream, _)) = listener.accept().await
2043                && let Ok(mut ws) = accept_async(stream).await
2044            {
2045                // Echo messages
2046                while let Some(Ok(msg)) = ws.next().await {
2047                    if ws.send(msg).await.is_err() {
2048                        break;
2049                    }
2050                }
2051            }
2052        });
2053
2054        let (handler, _rx) = channel_message_handler();
2055
2056        let config = WebSocketConfig {
2057            url: format!("ws://127.0.0.1:{port}"),
2058            headers: vec![],
2059            heartbeat: None,
2060            heartbeat_msg: None,
2061            reconnect_timeout_ms: Some(5_000), // 5s timeout - enough for reconnect
2062            reconnect_delay_initial_ms: Some(100),
2063            reconnect_delay_max_ms: Some(200),
2064            reconnect_backoff_factor: Some(1.0),
2065            reconnect_jitter_ms: Some(0),
2066            reconnect_max_attempts: None,
2067        };
2068
2069        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2070            .await
2071            .unwrap();
2072
2073        // Wait for reconnection to trigger
2074        wait_until_async(
2075            || async { client.is_reconnecting() },
2076            Duration::from_secs(2),
2077        )
2078        .await;
2079
2080        // Try to send while reconnecting - should wait and succeed after reconnect
2081        let send_result = tokio::time::timeout(
2082            Duration::from_secs(3),
2083            client.send_text("test_message".to_string(), None),
2084        )
2085        .await;
2086
2087        assert!(
2088            send_result.is_ok() && send_result.unwrap().is_ok(),
2089            "Send should succeed after waiting for reconnection"
2090        );
2091
2092        client.disconnect().await;
2093        server.abort();
2094    }
2095
2096    #[rstest]
2097    #[tokio::test]
2098    async fn test_rate_limiter_before_active_wait() {
2099        // Test that rate limiting happens BEFORE active state check.
2100        // This prevents race conditions where connection state changes during rate limit wait.
2101        // We verify this by: (1) exhausting rate limit, (2) ensuring client is RECONNECTING,
2102        // (3) sending again and confirming it waits for rate limit THEN reconnection.
2103        use std::{num::NonZeroU32, sync::Arc};
2104
2105        use nautilus_common::testing::wait_until_async;
2106
2107        use crate::ratelimiter::quota::Quota;
2108
2109        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2110        let port = listener.local_addr().unwrap().port();
2111
2112        let server = task::spawn(async move {
2113            // First connection - accept and close after receiving one message
2114            if let Ok((stream, _)) = listener.accept().await
2115                && let Ok(mut ws) = accept_async(stream).await
2116            {
2117                // Receive first message then close
2118                if let Some(Ok(_)) = ws.next().await {
2119                    drop(ws);
2120                }
2121            }
2122
2123            // Wait before accepting reconnection
2124            sleep(Duration::from_millis(500)).await;
2125
2126            // Second connection - accept and keep alive
2127            if let Ok((stream, _)) = listener.accept().await
2128                && let Ok(mut ws) = accept_async(stream).await
2129            {
2130                while let Some(Ok(msg)) = ws.next().await {
2131                    if ws.send(msg).await.is_err() {
2132                        break;
2133                    }
2134                }
2135            }
2136        });
2137
2138        let (handler, _rx) = channel_message_handler();
2139
2140        let config = WebSocketConfig {
2141            url: format!("ws://127.0.0.1:{port}"),
2142            headers: vec![],
2143            heartbeat: None,
2144            heartbeat_msg: None,
2145            reconnect_timeout_ms: Some(5_000),
2146            reconnect_delay_initial_ms: Some(50),
2147            reconnect_delay_max_ms: Some(100),
2148            reconnect_backoff_factor: Some(1.0),
2149            reconnect_jitter_ms: Some(0),
2150            reconnect_max_attempts: None,
2151        };
2152
2153        // Very restrictive rate limit: 1 request per second, burst of 1
2154        let quota =
2155            Quota::per_second(NonZeroU32::new(1).unwrap()).allow_burst(NonZeroU32::new(1).unwrap());
2156
2157        let client = Arc::new(
2158            WebSocketClient::connect(
2159                config,
2160                Some(handler),
2161                None,
2162                None,
2163                vec![("test_key".to_string(), quota)],
2164                None,
2165            )
2166            .await
2167            .unwrap(),
2168        );
2169
2170        // First send exhausts burst capacity and triggers connection close
2171        client
2172            .send_text("msg1".to_string(), Some(vec!["test_key".to_string()]))
2173            .await
2174            .unwrap();
2175
2176        // Wait for client to enter RECONNECT state
2177        wait_until_async(
2178            || async { client.is_reconnecting() },
2179            Duration::from_secs(2),
2180        )
2181        .await;
2182
2183        // Second send: will hit rate limit (~1s) THEN wait for reconnection (~0.5s)
2184        let start = std::time::Instant::now();
2185        let send_result = client
2186            .send_text("msg2".to_string(), Some(vec!["test_key".to_string()]))
2187            .await;
2188        let elapsed = start.elapsed();
2189
2190        // Should succeed after both rate limit AND reconnection
2191        assert!(
2192            send_result.is_ok(),
2193            "Send should succeed after rate limit + reconnection, was: {send_result:?}"
2194        );
2195        // Total wait should be at least rate limit time (~1s)
2196        // The reconnection completes while rate limiting or after
2197        // Use 850ms threshold to account for timing jitter in CI
2198        assert!(
2199            elapsed >= Duration::from_millis(850),
2200            "Should wait for rate limit (~1s), waited {elapsed:?}"
2201        );
2202
2203        client.disconnect().await;
2204        server.abort();
2205    }
2206
2207    #[rstest]
2208    #[tokio::test]
2209    async fn test_disconnect_during_reconnect_exits_cleanly() {
2210        // Test CAS race condition: disconnect called during reconnection
2211        // Should exit cleanly without spawning new tasks
2212        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2213        let port = listener.local_addr().unwrap().port();
2214
2215        let server = task::spawn(async move {
2216            // Accept first connection and immediately close
2217            if let Ok((stream, _)) = listener.accept().await
2218                && let Ok(ws) = accept_async(stream).await
2219            {
2220                drop(ws);
2221            }
2222            // Don't accept second connection - let reconnect hang
2223            sleep(Duration::from_secs(60)).await;
2224        });
2225
2226        let (handler, _rx) = channel_message_handler();
2227
2228        let config = WebSocketConfig {
2229            url: format!("ws://127.0.0.1:{port}"),
2230            headers: vec![],
2231            heartbeat: None,
2232            heartbeat_msg: None,
2233            reconnect_timeout_ms: Some(2_000), // 2s timeout - shorter than disconnect timeout
2234            reconnect_delay_initial_ms: Some(100),
2235            reconnect_delay_max_ms: Some(200),
2236            reconnect_backoff_factor: Some(1.0),
2237            reconnect_jitter_ms: Some(0),
2238            reconnect_max_attempts: None,
2239        };
2240
2241        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2242            .await
2243            .unwrap();
2244
2245        // Wait for reconnection to start
2246        tokio::time::timeout(Duration::from_secs(2), async {
2247            while !client.is_reconnecting() {
2248                sleep(Duration::from_millis(10)).await;
2249            }
2250        })
2251        .await
2252        .expect("Client should enter RECONNECT state");
2253
2254        // Disconnect while reconnecting
2255        client.disconnect().await;
2256
2257        // Should be cleanly closed
2258        assert!(
2259            client.is_disconnected(),
2260            "Client should be cleanly disconnected"
2261        );
2262
2263        server.abort();
2264    }
2265
2266    #[rstest]
2267    #[tokio::test]
2268    async fn test_send_fails_fast_when_closed_before_rate_limit() {
2269        // Test that send operations check connection state BEFORE rate limiting,
2270        // preventing unnecessary delays when the connection is already closed.
2271        use std::{num::NonZeroU32, sync::Arc};
2272
2273        use nautilus_common::testing::wait_until_async;
2274
2275        use crate::ratelimiter::quota::Quota;
2276
2277        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2278        let port = listener.local_addr().unwrap().port();
2279
2280        let server = task::spawn(async move {
2281            // Accept connection and immediately close
2282            if let Ok((stream, _)) = listener.accept().await
2283                && let Ok(ws) = accept_async(stream).await
2284            {
2285                drop(ws);
2286            }
2287            sleep(Duration::from_secs(60)).await;
2288        });
2289
2290        let (handler, _rx) = channel_message_handler();
2291
2292        let config = WebSocketConfig {
2293            url: format!("ws://127.0.0.1:{port}"),
2294            headers: vec![],
2295            heartbeat: None,
2296            heartbeat_msg: None,
2297            reconnect_timeout_ms: Some(5_000),
2298            reconnect_delay_initial_ms: Some(50),
2299            reconnect_delay_max_ms: Some(100),
2300            reconnect_backoff_factor: Some(1.0),
2301            reconnect_jitter_ms: Some(0),
2302            reconnect_max_attempts: None,
2303        };
2304
2305        // Very restrictive rate limit: 1 request per 10 seconds
2306        // This ensures that if we wait for rate limit, the test will timeout
2307        let quota = Quota::with_period(Duration::from_secs(10))
2308            .unwrap()
2309            .allow_burst(NonZeroU32::new(1).unwrap());
2310
2311        let client = Arc::new(
2312            WebSocketClient::connect(
2313                config,
2314                Some(handler),
2315                None,
2316                None,
2317                vec![("test_key".to_string(), quota)],
2318                None,
2319            )
2320            .await
2321            .unwrap(),
2322        );
2323
2324        // Wait for disconnection
2325        wait_until_async(
2326            || async { client.is_reconnecting() || client.is_closed() },
2327            Duration::from_secs(2),
2328        )
2329        .await;
2330
2331        // Explicitly disconnect to move away from ACTIVE state
2332        client.disconnect().await;
2333        assert!(
2334            !client.is_active(),
2335            "Client should not be active after disconnect"
2336        );
2337
2338        // Attempt send - should fail IMMEDIATELY without waiting for rate limit
2339        let start = std::time::Instant::now();
2340        let result = client
2341            .send_text("test".to_string(), Some(vec!["test_key".to_string()]))
2342            .await;
2343        let elapsed = start.elapsed();
2344
2345        // Should fail with Closed error
2346        assert!(result.is_err(), "Send should fail when client is closed");
2347        assert!(
2348            matches!(result, Err(crate::error::SendError::Closed)),
2349            "Send should return Closed error, was: {result:?}"
2350        );
2351
2352        // Should fail FAST (< 100ms) without waiting for rate limit (10s)
2353        assert!(
2354            elapsed < Duration::from_millis(100),
2355            "Send should fail fast without rate limiting, took {elapsed:?}"
2356        );
2357
2358        server.abort();
2359    }
2360
2361    #[rstest]
2362    #[tokio::test]
2363    async fn test_connect_rejects_none_message_handler() {
2364        // Test that connect() properly rejects None message_handler
2365        // to prevent zombie connections that appear alive but never detect disconnections
2366
2367        let config = WebSocketConfig {
2368            url: "ws://127.0.0.1:9999".to_string(),
2369            headers: vec![],
2370            heartbeat: None,
2371            heartbeat_msg: None,
2372            reconnect_timeout_ms: Some(1_000),
2373            reconnect_delay_initial_ms: Some(100),
2374            reconnect_delay_max_ms: Some(500),
2375            reconnect_backoff_factor: Some(1.5),
2376            reconnect_jitter_ms: Some(0),
2377            reconnect_max_attempts: None,
2378        };
2379
2380        // Pass None for message_handler - should be rejected
2381        let result = WebSocketClient::connect(config, None, None, None, vec![], None).await;
2382
2383        assert!(
2384            result.is_err(),
2385            "connect() should reject None message_handler"
2386        );
2387
2388        let err = result.unwrap_err();
2389        let err_msg = err.to_string();
2390        assert!(
2391            err_msg.contains("Handler mode requires message_handler"),
2392            "Error should mention missing message_handler, was: {err_msg}"
2393        );
2394    }
2395
2396    #[rstest]
2397    #[tokio::test]
2398    async fn test_client_without_handler_sets_stream_mode() {
2399        // Test that if a client is created without a handler via connect_url,
2400        // it properly sets is_stream_mode=true to prevent zombie connections
2401
2402        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2403        let port = listener.local_addr().unwrap().port();
2404
2405        let server = task::spawn(async move {
2406            // Accept and immediately close to simulate server disconnect
2407            if let Ok((stream, _)) = listener.accept().await
2408                && let Ok(ws) = accept_async(stream).await
2409            {
2410                drop(ws); // Drop connection immediately
2411            }
2412        });
2413
2414        let config = WebSocketConfig {
2415            url: format!("ws://127.0.0.1:{port}"),
2416            headers: vec![],
2417            heartbeat: None,
2418            heartbeat_msg: None,
2419            reconnect_timeout_ms: Some(1_000),
2420            reconnect_delay_initial_ms: Some(100),
2421            reconnect_delay_max_ms: Some(500),
2422            reconnect_backoff_factor: Some(1.5),
2423            reconnect_jitter_ms: Some(0),
2424            reconnect_max_attempts: None,
2425        };
2426
2427        // Create client directly via connect_url with no handler (stream mode)
2428        let inner = WebSocketClientInner::connect_url(config, None, None)
2429            .await
2430            .unwrap();
2431
2432        // Verify is_stream_mode is true when no handler
2433        assert!(
2434            inner.is_stream_mode,
2435            "Client without handler should have is_stream_mode=true"
2436        );
2437
2438        // Verify that when stream mode is enabled, reconnection is disabled
2439        // (documented behavior - stream mode clients close instead of reconnecting)
2440
2441        server.abort();
2442    }
2443}