nautilus_network/websocket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! WebSocket client implementation with automatic reconnection.
17//!
18//! This module contains the core WebSocket client implementation including:
19//! - Connection management with automatic reconnection.
20//! - Split read/write architecture with separate tasks.
21//! - Unbounded channels on latency-sensitive paths.
22//! - Heartbeat support.
23//! - Rate limiting integration.
24
25use std::{
26    collections::VecDeque,
27    fmt::Debug,
28    sync::{
29        Arc,
30        atomic::{AtomicU8, Ordering},
31    },
32    time::Duration,
33};
34
35use futures_util::{SinkExt, StreamExt};
36use http::HeaderName;
37use nautilus_core::CleanDrop;
38use nautilus_cryptography::providers::install_cryptographic_provider;
39#[cfg(feature = "turmoil")]
40use tokio_tungstenite::MaybeTlsStream;
41#[cfg(feature = "turmoil")]
42use tokio_tungstenite::client_async;
43#[cfg(not(feature = "turmoil"))]
44use tokio_tungstenite::connect_async_with_config;
45use tokio_tungstenite::tungstenite::{
46    Error, Message, client::IntoClientRequest, http::HeaderValue,
47};
48
49use super::{
50    config::WebSocketConfig,
51    consts::{
52        CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
53        GRACEFUL_SHUTDOWN_TIMEOUT_SECS, SEND_OPERATION_CHECK_INTERVAL_MS,
54    },
55    types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
56};
57#[cfg(feature = "turmoil")]
58use crate::net::TcpConnector;
59use crate::{
60    RECONNECTED,
61    backoff::ExponentialBackoff,
62    error::SendError,
63    logging::{log_task_aborted, log_task_started, log_task_stopped},
64    mode::ConnectionMode,
65    ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
66};
67
68/// `WebSocketClient` connects to a websocket server to read and send messages.
69///
70/// The client is opinionated about how messages are read and written. It
71/// assumes that data can only have one reader but multiple writers.
72///
73/// The client splits the connection into read and write halves. It moves
74/// the read half into a tokio task which keeps receiving messages from the
75/// server and calls a handler - a Python function that takes the data
76/// as its parameter. It stores the write half in the struct wrapped
77/// with an Arc Mutex. This way the client struct can be used to write
78/// data to the server from multiple scopes/tasks.
79///
80/// The client also maintains a heartbeat if given a duration in seconds.
81/// It's preferable to set the duration slightly lower - heartbeat more
82/// frequently - than the required amount.
83pub struct WebSocketClientInner {
84    config: WebSocketConfig,
85    read_task: Option<tokio::task::JoinHandle<()>>,
86    write_task: tokio::task::JoinHandle<()>,
87    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
88    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
89    connection_mode: Arc<AtomicU8>,
90    reconnect_timeout: Duration,
91    backoff: ExponentialBackoff,
92    /// True if this is a stream-based client (created via `connect_stream`).
93    /// Stream-based clients disable auto-reconnect because the reader is
94    /// owned by the caller and cannot be replaced during reconnection.
95    is_stream_mode: bool,
96    /// Maximum number of reconnection attempts before giving up (None = unlimited).
97    reconnect_max_attempts: Option<u32>,
98    /// Current count of consecutive reconnection attempts.
99    reconnection_attempt_count: u32,
100}
101
102impl WebSocketClientInner {
103    /// Create an inner websocket client with an existing writer.
104    ///
105    /// # Errors
106    ///
107    /// Returns an error if the exponential backoff configuration is invalid.
108    pub async fn new_with_writer(
109        config: WebSocketConfig,
110        writer: MessageWriter,
111    ) -> Result<Self, Error> {
112        install_cryptographic_provider();
113
114        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
115
116        // Note: We don't spawn a read task here since the reader is handled externally
117        let read_task = None;
118
119        let backoff = ExponentialBackoff::new(
120            Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
121            Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
122            config.reconnect_backoff_factor.unwrap_or(1.5),
123            config.reconnect_jitter_ms.unwrap_or(100),
124            true, // immediate-first
125        )
126        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
127
128        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
129        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
130
131        let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
132            Some(Self::spawn_heartbeat_task(
133                connection_mode.clone(),
134                heartbeat_interval,
135                config.heartbeat_msg.clone(),
136                writer_tx.clone(),
137            ))
138        } else {
139            None
140        };
141
142        Ok(Self {
143            config: config.clone(),
144            writer_tx,
145            connection_mode,
146            reconnect_timeout: Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10000)),
147            heartbeat_task,
148            read_task,
149            write_task,
150            backoff,
151            is_stream_mode: true,
152            reconnect_max_attempts: config.reconnect_max_attempts,
153            reconnection_attempt_count: 0,
154        })
155    }
156
157    /// Create an inner websocket client.
158    ///
159    /// # Errors
160    ///
161    /// Returns an error if:
162    /// - The connection to the server fails.
163    /// - The exponential backoff configuration is invalid.
164    pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
165        install_cryptographic_provider();
166
167        let WebSocketConfig {
168            url,
169            message_handler,
170            heartbeat,
171            headers,
172            heartbeat_msg,
173            ping_handler,
174            reconnect_timeout_ms,
175            reconnect_delay_initial_ms,
176            reconnect_delay_max_ms,
177            reconnect_backoff_factor,
178            reconnect_jitter_ms,
179            reconnect_max_attempts,
180        } = &config;
181
182        // Capture whether we're in stream mode before moving config
183        let is_stream_mode = message_handler.is_none();
184        let reconnect_max_attempts = *reconnect_max_attempts;
185
186        let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
187
188        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
189
190        let read_task = if message_handler.is_some() {
191            Some(Self::spawn_message_handler_task(
192                connection_mode.clone(),
193                reader,
194                message_handler.as_ref(),
195                ping_handler.as_ref(),
196            ))
197        } else {
198            None
199        };
200
201        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
202        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
203
204        // Optionally spawn a heartbeat task to periodically ping server
205        let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
206            Self::spawn_heartbeat_task(
207                connection_mode.clone(),
208                *heartbeat_secs,
209                heartbeat_msg.clone(),
210                writer_tx.clone(),
211            )
212        });
213
214        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
215        let backoff = ExponentialBackoff::new(
216            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
217            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
218            reconnect_backoff_factor.unwrap_or(1.5),
219            reconnect_jitter_ms.unwrap_or(100),
220            true, // immediate-first
221        )
222        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
223
224        Ok(Self {
225            config,
226            read_task,
227            write_task,
228            writer_tx,
229            heartbeat_task,
230            connection_mode,
231            reconnect_timeout,
232            backoff,
233            // Set stream mode when no message handler (reader not managed by client)
234            is_stream_mode,
235            reconnect_max_attempts,
236            reconnection_attempt_count: 0,
237        })
238    }
239
240    /// Connects with the server creating a tokio-tungstenite websocket stream.
241    /// Production version that uses `connect_async_with_config` convenience helper.
242    ///
243    /// # Errors
244    ///
245    /// Returns an error if:
246    /// - The URL cannot be parsed into a valid client request.
247    /// - Header values are invalid.
248    /// - The WebSocket connection fails.
249    #[inline]
250    #[cfg(not(feature = "turmoil"))]
251    pub async fn connect_with_server(
252        url: &str,
253        headers: Vec<(String, String)>,
254    ) -> Result<(MessageWriter, MessageReader), Error> {
255        let mut request = url.into_client_request()?;
256        let req_headers = request.headers_mut();
257
258        let mut header_names: Vec<HeaderName> = Vec::new();
259        for (key, val) in headers {
260            let header_value = HeaderValue::from_str(&val)?;
261            let header_name: HeaderName = key.parse()?;
262            header_names.push(header_name.clone());
263            req_headers.insert(header_name, header_value);
264        }
265
266        connect_async_with_config(request, None, true)
267            .await
268            .map(|resp| resp.0.split())
269    }
270
271    /// Connects with the server creating a tokio-tungstenite websocket stream.
272    /// Turmoil version that uses the lower-level `client_async` API with injected stream.
273    ///
274    /// # Errors
275    ///
276    /// Returns an error if:
277    /// - The URL cannot be parsed into a valid client request.
278    /// - The URL is missing a hostname.
279    /// - Header values are invalid.
280    /// - The TCP connection fails.
281    /// - TLS setup fails (for wss:// URLs).
282    /// - The WebSocket handshake fails.
283    #[inline]
284    #[cfg(feature = "turmoil")]
285    pub async fn connect_with_server(
286        url: &str,
287        headers: Vec<(String, String)>,
288    ) -> Result<(MessageWriter, MessageReader), Error> {
289        use rustls::ClientConfig;
290        use tokio_rustls::TlsConnector;
291
292        let mut request = url.into_client_request()?;
293        let req_headers = request.headers_mut();
294
295        let mut header_names: Vec<HeaderName> = Vec::new();
296        for (key, val) in headers {
297            let header_value = HeaderValue::from_str(&val)?;
298            let header_name: HeaderName = key.parse()?;
299            header_names.push(header_name.clone());
300            req_headers.insert(header_name, header_value);
301        }
302
303        let uri = request.uri();
304        let scheme = uri.scheme_str().unwrap_or("ws");
305        let host = uri.host().ok_or_else(|| {
306            Error::Url(tokio_tungstenite::tungstenite::error::UrlError::NoHostName)
307        })?;
308
309        // Determine port: use explicit port if specified, otherwise default based on scheme
310        let port = uri
311            .port_u16()
312            .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
313
314        let addr = format!("{host}:{port}");
315
316        // Use the connector to get a turmoil-compatible stream
317        let connector = crate::net::RealTcpConnector;
318        let tcp_stream = connector.connect(&addr).await?;
319        if let Err(e) = tcp_stream.set_nodelay(true) {
320            tracing::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
321        }
322
323        // Wrap stream appropriately based on scheme
324        let maybe_tls_stream = if scheme == "wss" {
325            // Build TLS config with webpki roots
326            let mut root_store = rustls::RootCertStore::empty();
327            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
328
329            let config = ClientConfig::builder()
330                .with_root_certificates(root_store)
331                .with_no_client_auth();
332
333            let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
334            let domain =
335                rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|e| {
336                    Error::Io(std::io::Error::new(
337                        std::io::ErrorKind::InvalidInput,
338                        format!("Invalid DNS name: {e}"),
339                    ))
340                })?;
341
342            let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
343            MaybeTlsStream::Rustls(tls_stream)
344        } else {
345            MaybeTlsStream::Plain(tcp_stream)
346        };
347
348        // Use client_async with the stream (plain or TLS)
349        client_async(request, maybe_tls_stream)
350            .await
351            .map(|resp| resp.0.split())
352    }
353
354    /// Reconnect with server.
355    ///
356    /// Make a new connection with server. Use the new read and write halves
357    /// to update self writer and read and heartbeat tasks.
358    ///
359    /// For stream-based clients (created via `connect_stream`), reconnection is disabled
360    /// because the reader is owned by the caller and cannot be replaced. Stream users
361    /// should handle disconnections by creating a new connection.
362    ///
363    /// # Errors
364    ///
365    /// Returns an error if:
366    /// - The reconnection attempt times out.
367    /// - The connection to the server fails.
368    pub async fn reconnect(&mut self) -> Result<(), Error> {
369        tracing::debug!("Reconnecting");
370
371        if self.is_stream_mode {
372            tracing::warn!(
373                "Auto-reconnect disabled for stream-based WebSocket client; \
374                stream users must manually reconnect by creating a new connection"
375            );
376            // Transition to CLOSED state to stop reconnection attempts
377            self.connection_mode
378                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
379            return Ok(());
380        }
381
382        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
383            tracing::debug!("Reconnect aborted due to disconnect state");
384            return Ok(());
385        }
386
387        tokio::time::timeout(self.reconnect_timeout, async {
388            // Attempt to connect; abort early if a disconnect was requested
389            let (new_writer, reader) =
390                Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
391
392            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
393                tracing::debug!("Reconnect aborted mid-flight (after connect)");
394                return Ok(());
395            }
396
397            // Use a oneshot channel to synchronize with the writer task.
398            // We must verify that the buffer was successfully drained before transitioning to ACTIVE
399            // to prevent silent message loss if the new connection drops immediately.
400            let (tx, rx) = tokio::sync::oneshot::channel();
401            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
402                tracing::error!("{e}");
403                return Err(Error::Io(std::io::Error::new(
404                    std::io::ErrorKind::BrokenPipe,
405                    format!("Failed to send update command: {e}"),
406                )));
407            }
408
409            // Wait for writer to confirm it has drained the buffer
410            match rx.await {
411                Ok(true) => tracing::debug!("Writer confirmed buffer drain success"),
412                Ok(false) => {
413                    tracing::warn!("Writer failed to drain buffer, aborting reconnect");
414                    // Return error to trigger retry logic in controller
415                    return Err(Error::Io(std::io::Error::other(
416                        "Failed to drain reconnection buffer",
417                    )));
418                }
419                Err(e) => {
420                    tracing::error!("Writer dropped update channel: {e}");
421                    return Err(Error::Io(std::io::Error::new(
422                        std::io::ErrorKind::BrokenPipe,
423                        "Writer task dropped response channel",
424                    )));
425                }
426            }
427
428            // Delay before closing connection
429            tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
430
431            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
432                tracing::debug!("Reconnect aborted mid-flight (after delay)");
433                return Ok(());
434            }
435
436            if let Some(ref read_task) = self.read_task.take()
437                && !read_task.is_finished()
438            {
439                read_task.abort();
440                log_task_aborted("read");
441            }
442
443            // Atomically transition from Reconnect to Active
444            // This prevents race condition where disconnect could be requested between check and store
445            if self
446                .connection_mode
447                .compare_exchange(
448                    ConnectionMode::Reconnect.as_u8(),
449                    ConnectionMode::Active.as_u8(),
450                    Ordering::SeqCst,
451                    Ordering::SeqCst,
452                )
453                .is_err()
454            {
455                tracing::debug!("Reconnect aborted (state changed during reconnect)");
456                return Ok(());
457            }
458
459            self.read_task = if self.config.message_handler.is_some() {
460                Some(Self::spawn_message_handler_task(
461                    self.connection_mode.clone(),
462                    reader,
463                    self.config.message_handler.as_ref(),
464                    self.config.ping_handler.as_ref(),
465                ))
466            } else {
467                None
468            };
469
470            tracing::debug!("Reconnect succeeded");
471            Ok(())
472        })
473        .await
474        .map_err(|_| {
475            Error::Io(std::io::Error::new(
476                std::io::ErrorKind::TimedOut,
477                format!(
478                    "reconnection timed out after {}s",
479                    self.reconnect_timeout.as_secs_f64()
480                ),
481            ))
482        })?
483    }
484
485    /// Check if the client is still connected.
486    ///
487    /// The client is connected if the read task has not finished. It is expected
488    /// that in case of any failure client or server side. The read task will be
489    /// shutdown or will receive a `Close` frame which will finish it. There
490    /// might be some delay between the connection being closed and the client
491    /// detecting.
492    #[inline]
493    #[must_use]
494    pub fn is_alive(&self) -> bool {
495        match &self.read_task {
496            Some(read_task) => !read_task.is_finished(),
497            None => true, // Stream is being used directly
498        }
499    }
500
501    fn spawn_message_handler_task(
502        connection_state: Arc<AtomicU8>,
503        mut reader: MessageReader,
504        message_handler: Option<&MessageHandler>,
505        ping_handler: Option<&PingHandler>,
506    ) -> tokio::task::JoinHandle<()> {
507        tracing::debug!("Started message handler task 'read'");
508
509        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
510
511        // Clone Arc handlers for the async task
512        let message_handler = message_handler.cloned();
513        let ping_handler = ping_handler.cloned();
514
515        tokio::task::spawn(async move {
516            loop {
517                if !ConnectionMode::from_atomic(&connection_state).is_active() {
518                    break;
519                }
520
521                match tokio::time::timeout(check_interval, reader.next()).await {
522                    Ok(Some(Ok(Message::Binary(data)))) => {
523                        tracing::trace!("Received message <binary> {} bytes", data.len());
524                        if let Some(ref handler) = message_handler {
525                            handler(Message::Binary(data));
526                        }
527                    }
528                    Ok(Some(Ok(Message::Text(data)))) => {
529                        tracing::trace!("Received message: {data}");
530                        if let Some(ref handler) = message_handler {
531                            handler(Message::Text(data));
532                        }
533                    }
534                    Ok(Some(Ok(Message::Ping(ping_data)))) => {
535                        tracing::trace!("Received ping: {ping_data:?}");
536                        if let Some(ref handler) = ping_handler {
537                            handler(ping_data.to_vec());
538                        }
539                    }
540                    Ok(Some(Ok(Message::Pong(_)))) => {
541                        tracing::trace!("Received pong");
542                    }
543                    Ok(Some(Ok(Message::Close(_)))) => {
544                        tracing::debug!("Received close message - terminating");
545                        break;
546                    }
547                    Ok(Some(Ok(_))) => (),
548                    Ok(Some(Err(e))) => {
549                        tracing::error!("Received error message - terminating: {e}");
550                        break;
551                    }
552                    Ok(None) => {
553                        tracing::debug!("No message received - terminating");
554                        break;
555                    }
556                    Err(_) => {
557                        // Timeout - continue loop and check connection mode
558                        continue;
559                    }
560                }
561            }
562        })
563    }
564
565    /// Attempts to send all buffered messages after reconnection.
566    ///
567    /// Returns `true` if a send error occurred (caller should trigger reconnection).
568    /// Messages remain in buffer if send fails, preserving them for the next reconnection attempt.
569    async fn drain_reconnect_buffer(
570        buffer: &mut VecDeque<Message>,
571        writer: &mut MessageWriter,
572    ) -> bool {
573        if buffer.is_empty() {
574            return false;
575        }
576
577        let initial_buffer_len = buffer.len();
578        tracing::info!(
579            "Sending {} buffered messages after reconnection",
580            initial_buffer_len
581        );
582
583        let mut send_error_occurred = false;
584
585        while let Some(buffered_msg) = buffer.front() {
586            // Clone message before attempting send (to keep in buffer if send fails)
587            let msg_to_send = buffered_msg.clone();
588
589            if let Err(e) = writer.send(msg_to_send).await {
590                tracing::error!(
591                    "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
592                    buffer.len()
593                );
594                send_error_occurred = true;
595                break; // Stop processing buffer, remaining messages preserved for next reconnection
596            }
597
598            // Only remove from buffer after successful send
599            buffer.pop_front();
600        }
601
602        if buffer.is_empty() {
603            tracing::info!(
604                "Successfully sent all {} buffered messages",
605                initial_buffer_len
606            );
607        }
608
609        send_error_occurred
610    }
611
612    fn spawn_write_task(
613        connection_state: Arc<AtomicU8>,
614        writer: MessageWriter,
615        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
616    ) -> tokio::task::JoinHandle<()> {
617        log_task_started("write");
618
619        // Interval between checking the connection mode
620        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
621
622        tokio::task::spawn(async move {
623            let mut active_writer = writer;
624            // Buffer for messages received during reconnection
625            // VecDeque for efficient pop_front() operations
626            let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
627
628            loop {
629                match ConnectionMode::from_atomic(&connection_state) {
630                    ConnectionMode::Disconnect => {
631                        // Log any buffered messages that will be lost
632                        if !reconnect_buffer.is_empty() {
633                            tracing::warn!(
634                                "Discarding {} buffered messages due to disconnect",
635                                reconnect_buffer.len()
636                            );
637                            reconnect_buffer.clear();
638                        }
639
640                        // Attempt to close the writer gracefully before exiting,
641                        // we ignore any error as the writer may already be closed.
642                        _ = tokio::time::timeout(
643                            Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
644                            active_writer.close(),
645                        )
646                        .await;
647                        break;
648                    }
649                    ConnectionMode::Closed => {
650                        // Log any buffered messages that will be lost
651                        if !reconnect_buffer.is_empty() {
652                            tracing::warn!(
653                                "Discarding {} buffered messages due to closed connection",
654                                reconnect_buffer.len()
655                            );
656                            reconnect_buffer.clear();
657                        }
658                        break;
659                    }
660                    _ => {}
661                }
662
663                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
664                    Ok(Some(msg)) => {
665                        // Re-check connection mode after receiving a message
666                        let mode = ConnectionMode::from_atomic(&connection_state);
667                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
668                            break;
669                        }
670
671                        match msg {
672                            WriterCommand::Update(new_writer, tx) => {
673                                tracing::debug!("Received new writer");
674
675                                // Delay before closing connection
676                                tokio::time::sleep(Duration::from_millis(100)).await;
677
678                                // Attempt to close the writer gracefully on update,
679                                // we ignore any error as the writer may already be closed.
680                                _ = tokio::time::timeout(
681                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
682                                    active_writer.close(),
683                                )
684                                .await;
685
686                                active_writer = new_writer;
687                                tracing::debug!("Updated writer");
688
689                                let send_error = Self::drain_reconnect_buffer(
690                                    &mut reconnect_buffer,
691                                    &mut active_writer,
692                                )
693                                .await;
694
695                                if let Err(e) = tx.send(!send_error) {
696                                    tracing::error!(
697                                        "Failed to report drain status to controller: {e:?}"
698                                    );
699                                }
700                            }
701                            WriterCommand::Send(msg) if mode.is_reconnect() => {
702                                // Buffer messages during reconnection instead of dropping them
703                                tracing::debug!(
704                                    "Buffering message during reconnection (buffer size: {})",
705                                    reconnect_buffer.len() + 1
706                                );
707                                reconnect_buffer.push_back(msg);
708                            }
709                            WriterCommand::Send(msg) => {
710                                if let Err(e) = active_writer.send(msg.clone()).await {
711                                    tracing::error!("Failed to send message: {e}");
712                                    tracing::warn!("Writer triggering reconnect");
713                                    reconnect_buffer.push_back(msg);
714                                    connection_state
715                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
716                                }
717                            }
718                        }
719                    }
720                    Ok(None) => {
721                        // Channel closed - writer task should terminate
722                        tracing::debug!("Writer channel closed, terminating writer task");
723                        break;
724                    }
725                    Err(_) => {
726                        // Timeout - just continue the loop
727                        continue;
728                    }
729                }
730            }
731
732            // Attempt to close the writer gracefully before exiting,
733            // we ignore any error as the writer may already be closed.
734            _ = tokio::time::timeout(
735                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
736                active_writer.close(),
737            )
738            .await;
739
740            log_task_stopped("write");
741        })
742    }
743
744    fn spawn_heartbeat_task(
745        connection_state: Arc<AtomicU8>,
746        heartbeat_secs: u64,
747        message: Option<String>,
748        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
749    ) -> tokio::task::JoinHandle<()> {
750        log_task_started("heartbeat");
751
752        tokio::task::spawn(async move {
753            let interval = Duration::from_secs(heartbeat_secs);
754
755            loop {
756                tokio::time::sleep(interval).await;
757
758                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
759                    ConnectionMode::Active => {
760                        let msg = match &message {
761                            Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
762                            None => WriterCommand::Send(Message::Ping(vec![].into())),
763                        };
764
765                        match writer_tx.send(msg) {
766                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
767                            Err(e) => {
768                                tracing::error!("Failed to send heartbeat to writer task: {e}");
769                            }
770                        }
771                    }
772                    ConnectionMode::Reconnect => continue,
773                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
774                }
775            }
776
777            log_task_stopped("heartbeat");
778        })
779    }
780}
781
782impl Drop for WebSocketClientInner {
783    fn drop(&mut self) {
784        // Delegate to explicit cleanup handler
785        self.clean_drop();
786    }
787}
788
789/// Cleanup on drop: aborts background tasks and clears handlers to break reference cycles.
790impl CleanDrop for WebSocketClientInner {
791    fn clean_drop(&mut self) {
792        if let Some(ref read_task) = self.read_task.take()
793            && !read_task.is_finished()
794        {
795            read_task.abort();
796            log_task_aborted("read");
797        }
798
799        if !self.write_task.is_finished() {
800            self.write_task.abort();
801            log_task_aborted("write");
802        }
803
804        if let Some(ref handle) = self.heartbeat_task.take()
805            && !handle.is_finished()
806        {
807            handle.abort();
808            log_task_aborted("heartbeat");
809        }
810
811        // Clear handlers to break potential reference cycles
812        self.config.message_handler = None;
813        self.config.ping_handler = None;
814    }
815}
816
817impl Debug for WebSocketClientInner {
818    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
819        f.debug_struct("WebSocketClientInner")
820            .field("config", &self.config)
821            .field(
822                "connection_mode",
823                &ConnectionMode::from_atomic(&self.connection_mode),
824            )
825            .field("reconnect_timeout", &self.reconnect_timeout)
826            .field("is_stream_mode", &self.is_stream_mode)
827            .finish()
828    }
829}
830
831/// WebSocket client with automatic reconnection.
832///
833/// Handles connection state, callbacks, and rate limiting.
834/// See module docs for architecture details.
835#[cfg_attr(
836    feature = "python",
837    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
838)]
839pub struct WebSocketClient {
840    pub(crate) controller_task: tokio::task::JoinHandle<()>,
841    pub(crate) connection_mode: Arc<AtomicU8>,
842    pub(crate) reconnect_timeout: Duration,
843    pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
844    pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
845}
846
847impl Debug for WebSocketClient {
848    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
849        f.debug_struct(stringify!(WebSocketClient)).finish()
850    }
851}
852
853impl WebSocketClient {
854    /// Creates a websocket client in **stream mode** that returns a [`MessageReader`].
855    ///
856    /// Returns a stream that the caller owns and reads from directly. Automatic reconnection
857    /// is **disabled** because the reader cannot be replaced internally. On disconnection, the
858    /// client transitions to CLOSED state and the caller must manually reconnect by calling
859    /// `connect_stream` again.
860    ///
861    /// Use stream mode when you need custom reconnection logic, direct control over message
862    /// reading, or fine-grained backpressure handling.
863    ///
864    /// See [`WebSocketConfig`] documentation for comparison with handler mode.
865    ///
866    /// # Errors
867    ///
868    /// Returns an error if the connection cannot be established.
869    #[allow(clippy::too_many_arguments)]
870    pub async fn connect_stream(
871        config: WebSocketConfig,
872        keyed_quotas: Vec<(String, Quota)>,
873        default_quota: Option<Quota>,
874        post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
875    ) -> Result<(MessageReader, Self), Error> {
876        install_cryptographic_provider();
877
878        // Create a single connection and split it, respecting configured headers
879        let (writer, reader) =
880            WebSocketClientInner::connect_with_server(&config.url, config.headers.clone()).await?;
881
882        // Create inner without connecting (we'll provide the writer)
883        let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
884
885        let connection_mode = inner.connection_mode.clone();
886        let reconnect_timeout = inner.reconnect_timeout;
887        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
888        let writer_tx = inner.writer_tx.clone();
889
890        let controller_task =
891            Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnect);
892
893        Ok((
894            reader,
895            Self {
896                controller_task,
897                connection_mode,
898                reconnect_timeout,
899                rate_limiter,
900                writer_tx,
901            },
902        ))
903    }
904
905    /// Creates a websocket client in **handler mode** with automatic reconnection.
906    ///
907    /// Requires `config.message_handler` to be set. The handler is called for each incoming
908    /// message on an internal task. Automatic reconnection is **enabled** with exponential
909    /// backoff. On disconnection, the client automatically attempts to reconnect and replaces
910    /// the internal reader (the handler continues working seamlessly).
911    ///
912    /// Use handler mode for simplified connection management, automatic reconnection, Python
913    /// bindings, or callback-based message handling.
914    ///
915    /// See [`WebSocketConfig`] documentation for comparison with stream mode.
916    ///
917    /// # Errors
918    ///
919    /// Returns an error if the connection cannot be established or if
920    /// `config.message_handler` is `None` (use `connect_stream` instead).
921    pub async fn connect(
922        config: WebSocketConfig,
923        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
924        keyed_quotas: Vec<(String, Quota)>,
925        default_quota: Option<Quota>,
926    ) -> Result<Self, Error> {
927        // Validate that handler mode has a message handler
928        if config.message_handler.is_none() {
929            return Err(Error::Io(std::io::Error::new(
930                std::io::ErrorKind::InvalidInput,
931                "Handler mode requires config.message_handler to be set. Use connect_stream() for stream mode without a handler.",
932            )));
933        }
934
935        tracing::debug!("Connecting");
936        let inner = WebSocketClientInner::connect_url(config).await?;
937        let connection_mode = inner.connection_mode.clone();
938        let writer_tx = inner.writer_tx.clone();
939        let reconnect_timeout = inner.reconnect_timeout;
940
941        let controller_task =
942            Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnection);
943
944        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
945
946        Ok(Self {
947            controller_task,
948            connection_mode,
949            reconnect_timeout,
950            rate_limiter,
951            writer_tx,
952        })
953    }
954
955    /// Returns the current connection mode.
956    #[must_use]
957    pub fn connection_mode(&self) -> ConnectionMode {
958        ConnectionMode::from_atomic(&self.connection_mode)
959    }
960
961    /// Returns a clone of the connection mode atomic for external state tracking.
962    ///
963    /// This allows adapter clients to track connection state across reconnections
964    /// without message-passing delays.
965    #[must_use]
966    pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
967        Arc::clone(&self.connection_mode)
968    }
969
970    /// Check if the client connection is active.
971    ///
972    /// Returns `true` if the client is connected and has not been signalled to disconnect.
973    /// The client will automatically retry connection based on its configuration.
974    #[inline]
975    #[must_use]
976    pub fn is_active(&self) -> bool {
977        self.connection_mode().is_active()
978    }
979
980    /// Check if the client is disconnected.
981    #[must_use]
982    pub fn is_disconnected(&self) -> bool {
983        self.controller_task.is_finished()
984    }
985
986    /// Check if the client is reconnecting.
987    ///
988    /// Returns `true` if the client lost connection and is attempting to reestablish it.
989    /// The client will automatically retry connection based on its configuration.
990    #[inline]
991    #[must_use]
992    pub fn is_reconnecting(&self) -> bool {
993        self.connection_mode().is_reconnect()
994    }
995
996    /// Check if the client is disconnecting.
997    ///
998    /// Returns `true` if the client is in disconnect mode.
999    #[inline]
1000    #[must_use]
1001    pub fn is_disconnecting(&self) -> bool {
1002        self.connection_mode().is_disconnect()
1003    }
1004
1005    /// Check if the client is closed.
1006    ///
1007    /// Returns `true` if the client has been explicitly disconnected or reached
1008    /// maximum reconnection attempts. In this state, the client cannot be reused
1009    /// and a new client must be created for further connections.
1010    #[inline]
1011    #[must_use]
1012    pub fn is_closed(&self) -> bool {
1013        self.connection_mode().is_closed()
1014    }
1015
1016    /// Wait for the client to become active before sending.
1017    ///
1018    /// Returns an error if the client is closed, disconnecting, or if the wait times out.
1019    async fn wait_for_active(&self) -> Result<(), SendError> {
1020        if self.is_closed() {
1021            return Err(SendError::Closed);
1022        }
1023
1024        let timeout = self.reconnect_timeout;
1025        let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
1026
1027        if !self.is_active() {
1028            tracing::debug!("Waiting for client to become ACTIVE before sending...");
1029
1030            let inner = tokio::time::timeout(timeout, async {
1031                loop {
1032                    if self.is_active() {
1033                        return Ok(());
1034                    }
1035                    if matches!(
1036                        self.connection_mode(),
1037                        ConnectionMode::Disconnect | ConnectionMode::Closed
1038                    ) {
1039                        return Err(());
1040                    }
1041                    tokio::time::sleep(check_interval).await;
1042                }
1043            })
1044            .await
1045            .map_err(|_| SendError::Timeout)?;
1046            inner.map_err(|()| SendError::Closed)?;
1047        }
1048
1049        Ok(())
1050    }
1051
1052    /// Set disconnect mode to true.
1053    ///
1054    /// Controller task will periodically check the disconnect mode
1055    /// and shutdown the client if it is alive
1056    pub async fn disconnect(&self) {
1057        tracing::debug!("Disconnecting");
1058        self.connection_mode
1059            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1060
1061        if let Ok(()) =
1062            tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1063                while !self.is_disconnected() {
1064                    tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS))
1065                        .await;
1066                }
1067
1068                if !self.controller_task.is_finished() {
1069                    self.controller_task.abort();
1070                    log_task_aborted("controller");
1071                }
1072            })
1073            .await
1074        {
1075            tracing::debug!("Controller task finished");
1076        } else {
1077            tracing::error!("Timeout waiting for controller task to finish");
1078            if !self.controller_task.is_finished() {
1079                self.controller_task.abort();
1080                log_task_aborted("controller");
1081            }
1082        }
1083    }
1084
1085    /// Sends the given text `data` to the server.
1086    ///
1087    /// # Errors
1088    ///
1089    /// Returns a websocket error if unable to send.
1090    #[allow(unused_variables)]
1091    pub async fn send_text(
1092        &self,
1093        data: String,
1094        keys: Option<Vec<String>>,
1095    ) -> Result<(), SendError> {
1096        // Check connection state before rate limiting to fail fast
1097        if self.is_closed() || self.is_disconnecting() {
1098            return Err(SendError::Closed);
1099        }
1100
1101        self.rate_limiter.await_keys_ready(keys).await;
1102        self.wait_for_active().await?;
1103
1104        tracing::trace!("Sending text: {data:?}");
1105
1106        let msg = Message::Text(data.into());
1107        self.writer_tx
1108            .send(WriterCommand::Send(msg))
1109            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1110    }
1111
1112    /// Sends a pong frame back to the server.
1113    ///
1114    /// # Errors
1115    ///
1116    /// Returns a websocket error if unable to send.
1117    pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1118        self.wait_for_active().await?;
1119
1120        tracing::trace!("Sending pong frame ({} bytes)", data.len());
1121
1122        let msg = Message::Pong(data.into());
1123        self.writer_tx
1124            .send(WriterCommand::Send(msg))
1125            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1126    }
1127
1128    /// Sends the given bytes `data` to the server.
1129    ///
1130    /// # Errors
1131    ///
1132    /// Returns a websocket error if unable to send.
1133    #[allow(unused_variables)]
1134    pub async fn send_bytes(
1135        &self,
1136        data: Vec<u8>,
1137        keys: Option<Vec<String>>,
1138    ) -> Result<(), SendError> {
1139        // Check connection state before rate limiting to fail fast
1140        if self.is_closed() || self.is_disconnecting() {
1141            return Err(SendError::Closed);
1142        }
1143
1144        self.rate_limiter.await_keys_ready(keys).await;
1145        self.wait_for_active().await?;
1146
1147        tracing::trace!("Sending bytes: {data:?}");
1148
1149        let msg = Message::Binary(data.into());
1150        self.writer_tx
1151            .send(WriterCommand::Send(msg))
1152            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1153    }
1154
1155    /// Sends a close message to the server.
1156    ///
1157    /// # Errors
1158    ///
1159    /// Returns a websocket error if unable to send.
1160    pub async fn send_close_message(&self) -> Result<(), SendError> {
1161        self.wait_for_active().await?;
1162
1163        let msg = Message::Close(None);
1164        self.writer_tx
1165            .send(WriterCommand::Send(msg))
1166            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1167    }
1168
1169    fn spawn_controller_task(
1170        mut inner: WebSocketClientInner,
1171        connection_mode: Arc<AtomicU8>,
1172        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1173    ) -> tokio::task::JoinHandle<()> {
1174        tokio::task::spawn(async move {
1175            log_task_started("controller");
1176
1177            let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1178
1179            loop {
1180                tokio::time::sleep(check_interval).await;
1181                let mut mode = ConnectionMode::from_atomic(&connection_mode);
1182
1183                if mode.is_disconnect() {
1184                    tracing::debug!("Disconnecting");
1185
1186                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1187                    if tokio::time::timeout(timeout, async {
1188                        // Delay awaiting graceful shutdown
1189                        tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1190
1191                        if let Some(task) = &inner.read_task
1192                            && !task.is_finished()
1193                        {
1194                            task.abort();
1195                            log_task_aborted("read");
1196                        }
1197
1198                        if let Some(task) = &inner.heartbeat_task
1199                            && !task.is_finished()
1200                        {
1201                            task.abort();
1202                            log_task_aborted("heartbeat");
1203                        }
1204                    })
1205                    .await
1206                    .is_err()
1207                    {
1208                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
1209                    }
1210
1211                    tracing::debug!("Closed");
1212                    break; // Controller finished
1213                }
1214
1215                if mode.is_active() && !inner.is_alive() {
1216                    if connection_mode
1217                        .compare_exchange(
1218                            ConnectionMode::Active.as_u8(),
1219                            ConnectionMode::Reconnect.as_u8(),
1220                            Ordering::SeqCst,
1221                            Ordering::SeqCst,
1222                        )
1223                        .is_ok()
1224                    {
1225                        tracing::debug!("Detected dead read task, transitioning to RECONNECT");
1226                    }
1227                    mode = ConnectionMode::from_atomic(&connection_mode);
1228                }
1229
1230                if mode.is_reconnect() {
1231                    // Check if max reconnection attempts exceeded
1232                    if let Some(max_attempts) = inner.reconnect_max_attempts
1233                        && inner.reconnection_attempt_count >= max_attempts
1234                    {
1235                        tracing::error!(
1236                            "Max reconnection attempts ({}) exceeded, transitioning to CLOSED",
1237                            max_attempts
1238                        );
1239                        connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1240                        break;
1241                    }
1242
1243                    inner.reconnection_attempt_count += 1;
1244                    tracing::debug!(
1245                        "Reconnection attempt {} of {}",
1246                        inner.reconnection_attempt_count,
1247                        inner
1248                            .reconnect_max_attempts
1249                            .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1250                    );
1251
1252                    match inner.reconnect().await {
1253                        Ok(()) => {
1254                            inner.backoff.reset();
1255                            inner.reconnection_attempt_count = 0; // Reset counter on success
1256
1257                            // Only invoke callbacks if not in disconnect state
1258                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
1259                                if let Some(ref handler) = inner.config.message_handler {
1260                                    let reconnected_msg =
1261                                        Message::Text(RECONNECTED.to_string().into());
1262                                    handler(reconnected_msg);
1263                                    tracing::debug!("Sent reconnected message to handler");
1264                                }
1265
1266                                // TODO: Retain this legacy callback for use from Python
1267                                if let Some(ref callback) = post_reconnection {
1268                                    callback();
1269                                    tracing::debug!("Called `post_reconnection` handler");
1270                                }
1271
1272                                tracing::debug!("Reconnected successfully");
1273                            } else {
1274                                tracing::debug!(
1275                                    "Skipping post_reconnection handlers due to disconnect state"
1276                                );
1277                            }
1278                        }
1279                        Err(e) => {
1280                            let duration = inner.backoff.next_duration();
1281                            tracing::warn!(
1282                                "Reconnect attempt {} failed: {e}",
1283                                inner.reconnection_attempt_count
1284                            );
1285                            if !duration.is_zero() {
1286                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
1287                            }
1288                            tokio::time::sleep(duration).await;
1289                        }
1290                    }
1291                }
1292            }
1293            inner
1294                .connection_mode
1295                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1296
1297            log_task_stopped("controller");
1298        })
1299    }
1300}
1301
1302// Abort controller task on drop to clean up background tasks
1303impl Drop for WebSocketClient {
1304    fn drop(&mut self) {
1305        if !self.controller_task.is_finished() {
1306            self.controller_task.abort();
1307            log_task_aborted("controller");
1308        }
1309    }
1310}
1311
1312////////////////////////////////////////////////////////////////////////////////
1313// Tests
1314////////////////////////////////////////////////////////////////////////////////
1315
1316#[cfg(test)]
1317#[cfg(not(feature = "turmoil"))]
1318#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
1319mod tests {
1320    use std::{num::NonZeroU32, sync::Arc};
1321
1322    use futures_util::{SinkExt, StreamExt};
1323    use tokio::{
1324        net::TcpListener,
1325        task::{self, JoinHandle},
1326    };
1327    use tokio_tungstenite::{
1328        accept_hdr_async,
1329        tungstenite::{
1330            handshake::server::{self, Callback},
1331            http::HeaderValue,
1332        },
1333    };
1334
1335    use crate::{
1336        ratelimiter::quota::Quota,
1337        websocket::{WebSocketClient, WebSocketConfig},
1338    };
1339
1340    struct TestServer {
1341        task: JoinHandle<()>,
1342        port: u16,
1343    }
1344
1345    #[derive(Debug, Clone)]
1346    struct TestCallback {
1347        key: String,
1348        value: HeaderValue,
1349    }
1350
1351    impl Callback for TestCallback {
1352        #[allow(clippy::panic_in_result_fn)]
1353        fn on_request(
1354            self,
1355            request: &server::Request,
1356            response: server::Response,
1357        ) -> Result<server::Response, server::ErrorResponse> {
1358            let _ = response;
1359            let value = request.headers().get(&self.key);
1360            assert!(value.is_some());
1361
1362            if let Some(value) = request.headers().get(&self.key) {
1363                assert_eq!(value, self.value);
1364            }
1365
1366            Ok(response)
1367        }
1368    }
1369
1370    impl TestServer {
1371        async fn setup() -> Self {
1372            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1373            let port = TcpListener::local_addr(&server).unwrap().port();
1374
1375            let header_key = "test".to_string();
1376            let header_value = "test".to_string();
1377
1378            let test_call_back = TestCallback {
1379                key: header_key,
1380                value: HeaderValue::from_str(&header_value).unwrap(),
1381            };
1382
1383            let task = task::spawn(async move {
1384                // Keep accepting connections
1385                loop {
1386                    let (conn, _) = server.accept().await.unwrap();
1387                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1388                        .await
1389                        .unwrap();
1390
1391                    task::spawn(async move {
1392                        while let Some(Ok(msg)) = websocket.next().await {
1393                            match msg {
1394                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1395                                    if txt == "close-now" =>
1396                                {
1397                                    tracing::debug!("Forcibly closing from server side");
1398                                    // This sends a close frame, then stops reading
1399                                    let _ = websocket.close(None).await;
1400                                    break;
1401                                }
1402                                // Echo text/binary frames
1403                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1404                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1405                                    if websocket.send(msg).await.is_err() {
1406                                        break;
1407                                    }
1408                                }
1409                                // If the client closes, we also break
1410                                tokio_tungstenite::tungstenite::protocol::Message::Close(
1411                                    _frame,
1412                                ) => {
1413                                    let _ = websocket.close(None).await;
1414                                    break;
1415                                }
1416                                // Ignore pings/pongs
1417                                _ => {}
1418                            }
1419                        }
1420                    });
1421                }
1422            });
1423
1424            Self { task, port }
1425        }
1426    }
1427
1428    impl Drop for TestServer {
1429        fn drop(&mut self) {
1430            self.task.abort();
1431        }
1432    }
1433
1434    async fn setup_test_client(port: u16) -> WebSocketClient {
1435        let config = WebSocketConfig {
1436            url: format!("ws://127.0.0.1:{port}"),
1437            headers: vec![("test".into(), "test".into())],
1438            message_handler: Some(Arc::new(|_| {})),
1439            heartbeat: None,
1440            heartbeat_msg: None,
1441            ping_handler: None,
1442            reconnect_timeout_ms: None,
1443            reconnect_delay_initial_ms: None,
1444            reconnect_backoff_factor: None,
1445            reconnect_delay_max_ms: None,
1446            reconnect_jitter_ms: None,
1447            reconnect_max_attempts: None,
1448        };
1449        WebSocketClient::connect(config, None, vec![], None)
1450            .await
1451            .expect("Failed to connect")
1452    }
1453
1454    #[tokio::test]
1455    async fn test_websocket_basic() {
1456        let server = TestServer::setup().await;
1457        let client = setup_test_client(server.port).await;
1458
1459        assert!(!client.is_disconnected());
1460
1461        client.disconnect().await;
1462        assert!(client.is_disconnected());
1463    }
1464
1465    #[tokio::test]
1466    async fn test_websocket_heartbeat() {
1467        let server = TestServer::setup().await;
1468        let client = setup_test_client(server.port).await;
1469
1470        // Wait ~3s => server should see multiple "ping"
1471        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1472
1473        // Cleanup
1474        client.disconnect().await;
1475        assert!(client.is_disconnected());
1476    }
1477
1478    #[tokio::test]
1479    async fn test_websocket_reconnect_exhausted() {
1480        let config = WebSocketConfig {
1481            url: "ws://127.0.0.1:9997".into(), // <-- No server
1482            headers: vec![],
1483            message_handler: Some(Arc::new(|_| {})),
1484            heartbeat: None,
1485            heartbeat_msg: None,
1486            ping_handler: None,
1487            reconnect_timeout_ms: None,
1488            reconnect_delay_initial_ms: None,
1489            reconnect_backoff_factor: None,
1490            reconnect_delay_max_ms: None,
1491            reconnect_jitter_ms: None,
1492            reconnect_max_attempts: None,
1493        };
1494        let res = WebSocketClient::connect(config, None, vec![], None).await;
1495        assert!(res.is_err(), "Should fail quickly with no server");
1496    }
1497
1498    #[tokio::test]
1499    async fn test_websocket_forced_close_reconnect() {
1500        let server = TestServer::setup().await;
1501        let client = setup_test_client(server.port).await;
1502
1503        // 1) Send normal message
1504        client.send_text("Hello".into(), None).await.unwrap();
1505
1506        // 2) Trigger forced close from server
1507        client.send_text("close-now".into(), None).await.unwrap();
1508
1509        // 3) Wait a bit => read loop sees close => reconnect
1510        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1511
1512        // Confirm not disconnected
1513        assert!(!client.is_disconnected());
1514
1515        // Cleanup
1516        client.disconnect().await;
1517        assert!(client.is_disconnected());
1518    }
1519
1520    #[tokio::test]
1521    async fn test_rate_limiter() {
1522        let server = TestServer::setup().await;
1523        let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1524
1525        let config = WebSocketConfig {
1526            url: format!("ws://127.0.0.1:{}", server.port),
1527            headers: vec![("test".into(), "test".into())],
1528            message_handler: Some(Arc::new(|_| {})),
1529            heartbeat: None,
1530            heartbeat_msg: None,
1531            ping_handler: None,
1532            reconnect_timeout_ms: None,
1533            reconnect_delay_initial_ms: None,
1534            reconnect_backoff_factor: None,
1535            reconnect_delay_max_ms: None,
1536            reconnect_jitter_ms: None,
1537            reconnect_max_attempts: None,
1538        };
1539
1540        let client = WebSocketClient::connect(config, None, vec![("default".into(), quota)], None)
1541            .await
1542            .unwrap();
1543
1544        // First 2 should succeed
1545        client.send_text("test1".into(), None).await.unwrap();
1546        client.send_text("test2".into(), None).await.unwrap();
1547
1548        // Third should error
1549        client.send_text("test3".into(), None).await.unwrap();
1550
1551        // Cleanup
1552        client.disconnect().await;
1553        assert!(client.is_disconnected());
1554    }
1555
1556    #[tokio::test]
1557    async fn test_concurrent_writers() {
1558        let server = TestServer::setup().await;
1559        let client = Arc::new(setup_test_client(server.port).await);
1560
1561        let mut handles = vec![];
1562        for i in 0..10 {
1563            let client = client.clone();
1564            handles.push(task::spawn(async move {
1565                client.send_text(format!("test{i}"), None).await.unwrap();
1566            }));
1567        }
1568
1569        for handle in handles {
1570            handle.await.unwrap();
1571        }
1572
1573        // Cleanup
1574        client.disconnect().await;
1575        assert!(client.is_disconnected());
1576    }
1577}
1578
1579////////////////////////////////////////////////////////////////////////////////
1580// Tests
1581////////////////////////////////////////////////////////////////////////////////
1582
1583#[cfg(test)]
1584#[cfg(not(feature = "turmoil"))]
1585mod rust_tests {
1586    use futures_util::StreamExt;
1587    use rstest::rstest;
1588    use tokio::{
1589        net::TcpListener,
1590        task,
1591        time::{Duration, sleep},
1592    };
1593    use tokio_tungstenite::accept_async;
1594
1595    use super::*;
1596    use crate::websocket::types::channel_message_handler;
1597
1598    #[rstest]
1599    #[tokio::test]
1600    async fn test_reconnect_then_disconnect() {
1601        // Bind an ephemeral port
1602        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1603        let port = listener.local_addr().unwrap().port();
1604
1605        // Server task: accept one ws connection then close it
1606        let server = task::spawn(async move {
1607            let (stream, _) = listener.accept().await.unwrap();
1608            let ws = accept_async(stream).await.unwrap();
1609            drop(ws);
1610            // Keep alive briefly
1611            sleep(Duration::from_secs(1)).await;
1612        });
1613
1614        // Build a channel-based message handler for incoming messages (unused here)
1615        let (handler, _rx) = channel_message_handler();
1616
1617        // Configure client with short reconnect backoff
1618        let config = WebSocketConfig {
1619            url: format!("ws://127.0.0.1:{port}"),
1620            headers: vec![],
1621            message_handler: Some(handler),
1622            heartbeat: None,
1623            heartbeat_msg: None,
1624            ping_handler: None,
1625            reconnect_timeout_ms: Some(1_000),
1626            reconnect_delay_initial_ms: Some(50),
1627            reconnect_delay_max_ms: Some(100),
1628            reconnect_backoff_factor: Some(1.0),
1629            reconnect_jitter_ms: Some(0),
1630            reconnect_max_attempts: None,
1631        };
1632
1633        // Connect the client
1634        let client = WebSocketClient::connect(config, None, vec![], None)
1635            .await
1636            .unwrap();
1637
1638        // Allow server to drop connection and client to detect
1639        sleep(Duration::from_millis(100)).await;
1640        // Now immediately disconnect the client
1641        client.disconnect().await;
1642        assert!(client.is_disconnected());
1643        server.abort();
1644    }
1645
1646    #[rstest]
1647    #[tokio::test]
1648    async fn test_reconnect_state_flips_when_reader_stops() {
1649        // Bind an ephemeral port and accept a single websocket connection which we drop.
1650        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1651        let port = listener.local_addr().unwrap().port();
1652
1653        let server = task::spawn(async move {
1654            if let Ok((stream, _)) = listener.accept().await
1655                && let Ok(ws) = accept_async(stream).await
1656            {
1657                drop(ws);
1658            }
1659            sleep(Duration::from_millis(50)).await;
1660        });
1661
1662        let (handler, _rx) = channel_message_handler();
1663
1664        let config = WebSocketConfig {
1665            url: format!("ws://127.0.0.1:{port}"),
1666            headers: vec![],
1667            message_handler: Some(handler),
1668            heartbeat: None,
1669            heartbeat_msg: None,
1670            ping_handler: None,
1671            reconnect_timeout_ms: Some(1_000),
1672            reconnect_delay_initial_ms: Some(50),
1673            reconnect_delay_max_ms: Some(100),
1674            reconnect_backoff_factor: Some(1.0),
1675            reconnect_jitter_ms: Some(0),
1676            reconnect_max_attempts: None,
1677        };
1678
1679        let client = WebSocketClient::connect(config, None, vec![], None)
1680            .await
1681            .unwrap();
1682
1683        tokio::time::timeout(Duration::from_secs(2), async {
1684            loop {
1685                if client.is_reconnecting() {
1686                    break;
1687                }
1688                tokio::time::sleep(Duration::from_millis(10)).await;
1689            }
1690        })
1691        .await
1692        .expect("client did not enter RECONNECT state");
1693
1694        client.disconnect().await;
1695        server.abort();
1696    }
1697
1698    #[rstest]
1699    #[tokio::test]
1700    async fn test_stream_mode_disables_auto_reconnect() {
1701        // Test that stream-based clients (created via connect_stream) set is_stream_mode flag
1702        // and that reconnect() transitions to CLOSED state for stream mode
1703        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1704        let port = listener.local_addr().unwrap().port();
1705
1706        let server = task::spawn(async move {
1707            if let Ok((stream, _)) = listener.accept().await
1708                && let Ok(_ws) = accept_async(stream).await
1709            {
1710                // Keep connection alive briefly
1711                sleep(Duration::from_millis(100)).await;
1712            }
1713        });
1714
1715        let config = WebSocketConfig {
1716            url: format!("ws://127.0.0.1:{port}"),
1717            headers: vec![],
1718            message_handler: None, // Stream mode - no handler
1719            heartbeat: None,
1720            heartbeat_msg: None,
1721            ping_handler: None,
1722            reconnect_timeout_ms: Some(1_000),
1723            reconnect_delay_initial_ms: Some(50),
1724            reconnect_delay_max_ms: Some(100),
1725            reconnect_backoff_factor: Some(1.0),
1726            reconnect_jitter_ms: Some(0),
1727            reconnect_max_attempts: None,
1728        };
1729
1730        // Create stream-based client
1731        let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
1732            .await
1733            .unwrap();
1734
1735        // Note: We can't easily test the reconnect behavior from the outside since
1736        // the inner client is private. The key fix is that WebSocketClientInner
1737        // now has is_stream_mode=true for connect_stream, and reconnect() will
1738        // transition to CLOSED state instead of creating a new reader that gets dropped.
1739        // This is tested implicitly by the fact that stream users won't get stuck
1740        // in an infinite reconnect loop.
1741
1742        server.abort();
1743    }
1744
1745    #[rstest]
1746    #[tokio::test]
1747    async fn test_message_handler_mode_allows_auto_reconnect() {
1748        // Test that regular clients (with message handler) can auto-reconnect
1749        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1750        let port = listener.local_addr().unwrap().port();
1751
1752        let server = task::spawn(async move {
1753            // Accept first connection and close it
1754            if let Ok((stream, _)) = listener.accept().await
1755                && let Ok(ws) = accept_async(stream).await
1756            {
1757                drop(ws);
1758            }
1759            sleep(Duration::from_millis(50)).await;
1760        });
1761
1762        let (handler, _rx) = channel_message_handler();
1763
1764        let config = WebSocketConfig {
1765            url: format!("ws://127.0.0.1:{port}"),
1766            headers: vec![],
1767            message_handler: Some(handler), // Has message handler
1768            heartbeat: None,
1769            heartbeat_msg: None,
1770            ping_handler: None,
1771            reconnect_timeout_ms: Some(1_000),
1772            reconnect_delay_initial_ms: Some(50),
1773            reconnect_delay_max_ms: Some(100),
1774            reconnect_backoff_factor: Some(1.0),
1775            reconnect_jitter_ms: Some(0),
1776            reconnect_max_attempts: None,
1777        };
1778
1779        let client = WebSocketClient::connect(config, None, vec![], None)
1780            .await
1781            .unwrap();
1782
1783        // Wait for the connection to be dropped and reconnection to be attempted
1784        tokio::time::timeout(Duration::from_secs(2), async {
1785            loop {
1786                if client.is_reconnecting() || client.is_closed() {
1787                    break;
1788                }
1789                tokio::time::sleep(Duration::from_millis(10)).await;
1790            }
1791        })
1792        .await
1793        .expect("client should attempt reconnection or close");
1794
1795        // Should either be reconnecting or closed (depending on timing)
1796        // The important thing is it's not staying active forever
1797        assert!(
1798            client.is_reconnecting() || client.is_closed(),
1799            "Client with message handler should attempt reconnection"
1800        );
1801
1802        client.disconnect().await;
1803        server.abort();
1804    }
1805
1806    #[rstest]
1807    #[tokio::test]
1808    async fn test_handler_mode_reconnect_with_new_connection() {
1809        // Test that handler mode successfully reconnects and messages continue flowing
1810        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1811        let port = listener.local_addr().unwrap().port();
1812
1813        let server = task::spawn(async move {
1814            // First connection - accept and immediately close
1815            if let Ok((stream, _)) = listener.accept().await
1816                && let Ok(ws) = accept_async(stream).await
1817            {
1818                drop(ws);
1819            }
1820
1821            // Small delay to let client detect disconnection
1822            sleep(Duration::from_millis(100)).await;
1823
1824            // Second connection - accept, send a message, then keep alive
1825            if let Ok((stream, _)) = listener.accept().await
1826                && let Ok(mut ws) = accept_async(stream).await
1827            {
1828                use futures_util::SinkExt;
1829                let _ = ws
1830                    .send(Message::Text("reconnected".to_string().into()))
1831                    .await;
1832                sleep(Duration::from_secs(1)).await;
1833            }
1834        });
1835
1836        let (handler, mut rx) = channel_message_handler();
1837
1838        let config = WebSocketConfig {
1839            url: format!("ws://127.0.0.1:{port}"),
1840            headers: vec![],
1841            message_handler: Some(handler),
1842            heartbeat: None,
1843            heartbeat_msg: None,
1844            ping_handler: None,
1845            reconnect_timeout_ms: Some(2_000),
1846            reconnect_delay_initial_ms: Some(50),
1847            reconnect_delay_max_ms: Some(200),
1848            reconnect_backoff_factor: Some(1.5),
1849            reconnect_jitter_ms: Some(10),
1850            reconnect_max_attempts: None,
1851        };
1852
1853        let client = WebSocketClient::connect(config, None, vec![], None)
1854            .await
1855            .unwrap();
1856
1857        // Wait for reconnection to happen and message to arrive
1858        let result = tokio::time::timeout(Duration::from_secs(5), async {
1859            loop {
1860                if let Ok(msg) = rx.try_recv()
1861                    && matches!(msg, Message::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
1862                {
1863                    return true;
1864                }
1865                tokio::time::sleep(Duration::from_millis(10)).await;
1866            }
1867        })
1868        .await;
1869
1870        assert!(
1871            result.is_ok(),
1872            "Should receive message after reconnection within timeout"
1873        );
1874
1875        client.disconnect().await;
1876        server.abort();
1877    }
1878
1879    #[rstest]
1880    #[tokio::test]
1881    async fn test_stream_mode_no_auto_reconnect() {
1882        // Test that stream mode does not automatically reconnect when connection is lost
1883        // The caller owns the reader and is responsible for detecting disconnection
1884        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1885        let port = listener.local_addr().unwrap().port();
1886
1887        let server = task::spawn(async move {
1888            // Accept connection and send one message, then close
1889            if let Ok((stream, _)) = listener.accept().await
1890                && let Ok(mut ws) = accept_async(stream).await
1891            {
1892                use futures_util::SinkExt;
1893                let _ = ws.send(Message::Text("hello".to_string().into())).await;
1894                sleep(Duration::from_millis(50)).await;
1895                // Connection closes when ws is dropped
1896            }
1897        });
1898
1899        let config = WebSocketConfig {
1900            url: format!("ws://127.0.0.1:{port}"),
1901            headers: vec![],
1902            message_handler: None, // Stream mode
1903            heartbeat: None,
1904            heartbeat_msg: None,
1905            ping_handler: None,
1906            reconnect_timeout_ms: Some(1_000),
1907            reconnect_delay_initial_ms: Some(50),
1908            reconnect_delay_max_ms: Some(100),
1909            reconnect_backoff_factor: Some(1.0),
1910            reconnect_jitter_ms: Some(0),
1911            reconnect_max_attempts: None,
1912        };
1913
1914        let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
1915            .await
1916            .unwrap();
1917
1918        // Initially active
1919        assert!(client.is_active(), "Client should start as active");
1920
1921        // Read the hello message
1922        let msg = reader.next().await;
1923        assert!(
1924            matches!(msg, Some(Ok(Message::Text(ref text))) if AsRef::<str>::as_ref(text) == "hello"),
1925            "Should receive initial message"
1926        );
1927
1928        // Read until connection closes (reader will return None or error)
1929        while let Some(msg) = reader.next().await {
1930            if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
1931                break;
1932            }
1933        }
1934
1935        // In stream mode, the controller cannot detect disconnection (reader is owned by caller)
1936        // The client remains ACTIVE - it's the caller's responsibility to call disconnect()
1937        sleep(Duration::from_millis(200)).await;
1938
1939        // Client should still be ACTIVE (not RECONNECTING or CLOSED)
1940        // This is correct behavior - stream mode doesn't auto-detect disconnection
1941        assert!(
1942            client.is_active() || client.is_closed(),
1943            "Stream mode client stays ACTIVE (caller owns reader) or caller disconnected"
1944        );
1945        assert!(
1946            !client.is_reconnecting(),
1947            "Stream mode client should never attempt reconnection"
1948        );
1949
1950        client.disconnect().await;
1951        server.abort();
1952    }
1953
1954    #[rstest]
1955    #[tokio::test]
1956    async fn test_send_timeout_uses_configured_reconnect_timeout() {
1957        // Test that send operations respect the configured reconnect_timeout.
1958        // When a client is stuck in RECONNECT longer than the timeout, sends should fail with Timeout.
1959        use nautilus_common::testing::wait_until_async;
1960
1961        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1962        let port = listener.local_addr().unwrap().port();
1963
1964        let server = task::spawn(async move {
1965            // Accept first connection and immediately close it
1966            if let Ok((stream, _)) = listener.accept().await
1967                && let Ok(ws) = accept_async(stream).await
1968            {
1969                drop(ws);
1970            }
1971            // Don't accept second connection - client will be stuck in RECONNECT
1972            sleep(Duration::from_secs(60)).await;
1973        });
1974
1975        let (handler, _rx) = channel_message_handler();
1976
1977        // Configure with SHORT 2s reconnect timeout
1978        let config = WebSocketConfig {
1979            url: format!("ws://127.0.0.1:{port}"),
1980            headers: vec![],
1981            message_handler: Some(handler),
1982            heartbeat: None,
1983            heartbeat_msg: None,
1984            ping_handler: None,
1985            reconnect_timeout_ms: Some(2_000), // 2s timeout
1986            reconnect_delay_initial_ms: Some(50),
1987            reconnect_delay_max_ms: Some(100),
1988            reconnect_backoff_factor: Some(1.0),
1989            reconnect_jitter_ms: Some(0),
1990            reconnect_max_attempts: None,
1991        };
1992
1993        let client = WebSocketClient::connect(config, None, vec![], None)
1994            .await
1995            .unwrap();
1996
1997        // Wait for client to enter RECONNECT state
1998        wait_until_async(
1999            || async { client.is_reconnecting() },
2000            Duration::from_secs(3),
2001        )
2002        .await;
2003
2004        // Attempt send while stuck in RECONNECT - should timeout after 2s (configured timeout)
2005        let start = std::time::Instant::now();
2006        let send_result = client.send_text("test".to_string(), None).await;
2007        let elapsed = start.elapsed();
2008
2009        assert!(
2010            send_result.is_err(),
2011            "Send should fail when client stuck in RECONNECT"
2012        );
2013        assert!(
2014            matches!(send_result, Err(crate::error::SendError::Timeout)),
2015            "Send should return Timeout error, was: {:?}",
2016            send_result
2017        );
2018        // Verify timeout respects configured value (2s), but don't check upper bound
2019        // as CI scheduler jitter can cause legitimate delays beyond the timeout
2020        assert!(
2021            elapsed >= Duration::from_millis(1800),
2022            "Send should timeout after at least 2s (configured timeout), took {:?}",
2023            elapsed
2024        );
2025
2026        client.disconnect().await;
2027        server.abort();
2028    }
2029
2030    #[rstest]
2031    #[tokio::test]
2032    async fn test_send_waits_during_reconnection() {
2033        // Test that send operations wait for reconnection to complete (up to timeout)
2034        use nautilus_common::testing::wait_until_async;
2035
2036        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2037        let port = listener.local_addr().unwrap().port();
2038
2039        let server = task::spawn(async move {
2040            // First connection - accept and immediately close
2041            if let Ok((stream, _)) = listener.accept().await
2042                && let Ok(ws) = accept_async(stream).await
2043            {
2044                drop(ws);
2045            }
2046
2047            // Wait a bit before accepting second connection
2048            sleep(Duration::from_millis(500)).await;
2049
2050            // Second connection - accept and keep alive
2051            if let Ok((stream, _)) = listener.accept().await
2052                && let Ok(mut ws) = accept_async(stream).await
2053            {
2054                // Echo messages
2055                while let Some(Ok(msg)) = ws.next().await {
2056                    if ws.send(msg).await.is_err() {
2057                        break;
2058                    }
2059                }
2060            }
2061        });
2062
2063        let (handler, _rx) = channel_message_handler();
2064
2065        let config = WebSocketConfig {
2066            url: format!("ws://127.0.0.1:{port}"),
2067            headers: vec![],
2068            message_handler: Some(handler),
2069            heartbeat: None,
2070            heartbeat_msg: None,
2071            ping_handler: None,
2072            reconnect_timeout_ms: Some(5_000), // 5s timeout - enough for reconnect
2073            reconnect_delay_initial_ms: Some(100),
2074            reconnect_delay_max_ms: Some(200),
2075            reconnect_backoff_factor: Some(1.0),
2076            reconnect_jitter_ms: Some(0),
2077            reconnect_max_attempts: None,
2078        };
2079
2080        let client = WebSocketClient::connect(config, None, vec![], None)
2081            .await
2082            .unwrap();
2083
2084        // Wait for reconnection to trigger
2085        wait_until_async(
2086            || async { client.is_reconnecting() },
2087            Duration::from_secs(2),
2088        )
2089        .await;
2090
2091        // Try to send while reconnecting - should wait and succeed after reconnect
2092        let send_result = tokio::time::timeout(
2093            Duration::from_secs(3),
2094            client.send_text("test_message".to_string(), None),
2095        )
2096        .await;
2097
2098        assert!(
2099            send_result.is_ok() && send_result.unwrap().is_ok(),
2100            "Send should succeed after waiting for reconnection"
2101        );
2102
2103        client.disconnect().await;
2104        server.abort();
2105    }
2106
2107    #[rstest]
2108    #[tokio::test]
2109    async fn test_rate_limiter_before_active_wait() {
2110        // Test that rate limiting happens BEFORE active state check.
2111        // This prevents race conditions where connection state changes during rate limit wait.
2112        // We verify this by: (1) exhausting rate limit, (2) ensuring client is RECONNECTING,
2113        // (3) sending again and confirming it waits for rate limit THEN reconnection.
2114        use std::{num::NonZeroU32, sync::Arc};
2115
2116        use nautilus_common::testing::wait_until_async;
2117
2118        use crate::ratelimiter::quota::Quota;
2119
2120        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2121        let port = listener.local_addr().unwrap().port();
2122
2123        let server = task::spawn(async move {
2124            // First connection - accept and close after receiving one message
2125            if let Ok((stream, _)) = listener.accept().await
2126                && let Ok(mut ws) = accept_async(stream).await
2127            {
2128                // Receive first message then close
2129                if let Some(Ok(_)) = ws.next().await {
2130                    drop(ws);
2131                }
2132            }
2133
2134            // Wait before accepting reconnection
2135            sleep(Duration::from_millis(500)).await;
2136
2137            // Second connection - accept and keep alive
2138            if let Ok((stream, _)) = listener.accept().await
2139                && let Ok(mut ws) = accept_async(stream).await
2140            {
2141                while let Some(Ok(msg)) = ws.next().await {
2142                    if ws.send(msg).await.is_err() {
2143                        break;
2144                    }
2145                }
2146            }
2147        });
2148
2149        let (handler, _rx) = channel_message_handler();
2150
2151        let config = WebSocketConfig {
2152            url: format!("ws://127.0.0.1:{port}"),
2153            headers: vec![],
2154            message_handler: Some(handler),
2155            heartbeat: None,
2156            heartbeat_msg: None,
2157            ping_handler: None,
2158            reconnect_timeout_ms: Some(5_000),
2159            reconnect_delay_initial_ms: Some(50),
2160            reconnect_delay_max_ms: Some(100),
2161            reconnect_backoff_factor: Some(1.0),
2162            reconnect_jitter_ms: Some(0),
2163            reconnect_max_attempts: None,
2164        };
2165
2166        // Very restrictive rate limit: 1 request per second, burst of 1
2167        let quota =
2168            Quota::per_second(NonZeroU32::new(1).unwrap()).allow_burst(NonZeroU32::new(1).unwrap());
2169
2170        let client = Arc::new(
2171            WebSocketClient::connect(config, None, vec![("test_key".to_string(), quota)], None)
2172                .await
2173                .unwrap(),
2174        );
2175
2176        // First send exhausts burst capacity and triggers connection close
2177        client
2178            .send_text("msg1".to_string(), Some(vec!["test_key".to_string()]))
2179            .await
2180            .unwrap();
2181
2182        // Wait for client to enter RECONNECT state
2183        wait_until_async(
2184            || async { client.is_reconnecting() },
2185            Duration::from_secs(2),
2186        )
2187        .await;
2188
2189        // Second send: will hit rate limit (~1s) THEN wait for reconnection (~0.5s)
2190        let start = std::time::Instant::now();
2191        let send_result = client
2192            .send_text("msg2".to_string(), Some(vec!["test_key".to_string()]))
2193            .await;
2194        let elapsed = start.elapsed();
2195
2196        // Should succeed after both rate limit AND reconnection
2197        assert!(
2198            send_result.is_ok(),
2199            "Send should succeed after rate limit + reconnection, was: {:?}",
2200            send_result
2201        );
2202        // Total wait should be at least rate limit time (~1s)
2203        // The reconnection completes while rate limiting or after
2204        // Use 850ms threshold to account for timing jitter in CI
2205        assert!(
2206            elapsed >= Duration::from_millis(850),
2207            "Should wait for rate limit (~1s), waited {:?}",
2208            elapsed
2209        );
2210
2211        client.disconnect().await;
2212        server.abort();
2213    }
2214
2215    #[rstest]
2216    #[tokio::test]
2217    async fn test_disconnect_during_reconnect_exits_cleanly() {
2218        // Test CAS race condition: disconnect called during reconnection
2219        // Should exit cleanly without spawning new tasks
2220        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2221        let port = listener.local_addr().unwrap().port();
2222
2223        let server = task::spawn(async move {
2224            // Accept first connection and immediately close
2225            if let Ok((stream, _)) = listener.accept().await
2226                && let Ok(ws) = accept_async(stream).await
2227            {
2228                drop(ws);
2229            }
2230            // Don't accept second connection - let reconnect hang
2231            sleep(Duration::from_secs(60)).await;
2232        });
2233
2234        let (handler, _rx) = channel_message_handler();
2235
2236        let config = WebSocketConfig {
2237            url: format!("ws://127.0.0.1:{port}"),
2238            headers: vec![],
2239            message_handler: Some(handler),
2240            heartbeat: None,
2241            heartbeat_msg: None,
2242            ping_handler: None,
2243            reconnect_timeout_ms: Some(2_000), // 2s timeout - shorter than disconnect timeout
2244            reconnect_delay_initial_ms: Some(100),
2245            reconnect_delay_max_ms: Some(200),
2246            reconnect_backoff_factor: Some(1.0),
2247            reconnect_jitter_ms: Some(0),
2248            reconnect_max_attempts: None,
2249        };
2250
2251        let client = WebSocketClient::connect(config, None, vec![], None)
2252            .await
2253            .unwrap();
2254
2255        // Wait for reconnection to start
2256        tokio::time::timeout(Duration::from_secs(2), async {
2257            while !client.is_reconnecting() {
2258                sleep(Duration::from_millis(10)).await;
2259            }
2260        })
2261        .await
2262        .expect("Client should enter RECONNECT state");
2263
2264        // Disconnect while reconnecting
2265        client.disconnect().await;
2266
2267        // Should be cleanly closed
2268        assert!(
2269            client.is_disconnected(),
2270            "Client should be cleanly disconnected"
2271        );
2272
2273        server.abort();
2274    }
2275
2276    #[rstest]
2277    #[tokio::test]
2278    async fn test_send_fails_fast_when_closed_before_rate_limit() {
2279        // Test that send operations check connection state BEFORE rate limiting,
2280        // preventing unnecessary delays when the connection is already closed.
2281        use std::{num::NonZeroU32, sync::Arc};
2282
2283        use nautilus_common::testing::wait_until_async;
2284
2285        use crate::ratelimiter::quota::Quota;
2286
2287        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2288        let port = listener.local_addr().unwrap().port();
2289
2290        let server = task::spawn(async move {
2291            // Accept connection and immediately close
2292            if let Ok((stream, _)) = listener.accept().await
2293                && let Ok(ws) = accept_async(stream).await
2294            {
2295                drop(ws);
2296            }
2297            sleep(Duration::from_secs(60)).await;
2298        });
2299
2300        let (handler, _rx) = channel_message_handler();
2301
2302        let config = WebSocketConfig {
2303            url: format!("ws://127.0.0.1:{port}"),
2304            headers: vec![],
2305            message_handler: Some(handler),
2306            heartbeat: None,
2307            heartbeat_msg: None,
2308            ping_handler: None,
2309            reconnect_timeout_ms: Some(5_000),
2310            reconnect_delay_initial_ms: Some(50),
2311            reconnect_delay_max_ms: Some(100),
2312            reconnect_backoff_factor: Some(1.0),
2313            reconnect_jitter_ms: Some(0),
2314            reconnect_max_attempts: None,
2315        };
2316
2317        // Very restrictive rate limit: 1 request per 10 seconds
2318        // This ensures that if we wait for rate limit, the test will timeout
2319        let quota = Quota::with_period(Duration::from_secs(10))
2320            .unwrap()
2321            .allow_burst(NonZeroU32::new(1).unwrap());
2322
2323        let client = Arc::new(
2324            WebSocketClient::connect(config, None, vec![("test_key".to_string(), quota)], None)
2325                .await
2326                .unwrap(),
2327        );
2328
2329        // Wait for disconnection
2330        wait_until_async(
2331            || async { client.is_reconnecting() || client.is_closed() },
2332            Duration::from_secs(2),
2333        )
2334        .await;
2335
2336        // Explicitly disconnect to move away from ACTIVE state
2337        client.disconnect().await;
2338        assert!(
2339            !client.is_active(),
2340            "Client should not be active after disconnect"
2341        );
2342
2343        // Attempt send - should fail IMMEDIATELY without waiting for rate limit
2344        let start = std::time::Instant::now();
2345        let result = client
2346            .send_text("test".to_string(), Some(vec!["test_key".to_string()]))
2347            .await;
2348        let elapsed = start.elapsed();
2349
2350        // Should fail with Closed error
2351        assert!(result.is_err(), "Send should fail when client is closed");
2352        assert!(
2353            matches!(result, Err(crate::error::SendError::Closed)),
2354            "Send should return Closed error, was: {:?}",
2355            result
2356        );
2357
2358        // Should fail FAST (< 100ms) without waiting for rate limit (10s)
2359        assert!(
2360            elapsed < Duration::from_millis(100),
2361            "Send should fail fast without rate limiting, took {:?}",
2362            elapsed
2363        );
2364
2365        server.abort();
2366    }
2367
2368    #[rstest]
2369    #[tokio::test]
2370    async fn test_connect_rejects_config_without_message_handler() {
2371        // Test that connect() properly rejects configs without a message handler
2372        // to prevent zombie connections that appear alive but never detect disconnections
2373
2374        let config = WebSocketConfig {
2375            url: "ws://127.0.0.1:9999".to_string(),
2376            headers: vec![],
2377            message_handler: None, // No handler provided
2378            heartbeat: None,
2379            heartbeat_msg: None,
2380            ping_handler: None,
2381            reconnect_timeout_ms: Some(1_000),
2382            reconnect_delay_initial_ms: Some(100),
2383            reconnect_delay_max_ms: Some(500),
2384            reconnect_backoff_factor: Some(1.5),
2385            reconnect_jitter_ms: Some(0),
2386            reconnect_max_attempts: None,
2387        };
2388
2389        let result = WebSocketClient::connect(config, None, vec![], None).await;
2390
2391        assert!(
2392            result.is_err(),
2393            "connect() should reject configs without message_handler"
2394        );
2395
2396        let err = result.unwrap_err();
2397        let err_msg = err.to_string();
2398        assert!(
2399            err_msg.contains("Handler mode requires config.message_handler"),
2400            "Error should mention missing message_handler, was: {err_msg}"
2401        );
2402    }
2403
2404    #[rstest]
2405    #[tokio::test]
2406    async fn test_client_without_handler_sets_stream_mode() {
2407        // Test that if a client is somehow created without a handler,
2408        // it properly sets is_stream_mode=true to prevent zombie connections
2409
2410        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2411        let port = listener.local_addr().unwrap().port();
2412
2413        let server = task::spawn(async move {
2414            // Accept and immediately close to simulate server disconnect
2415            if let Ok((stream, _)) = listener.accept().await
2416                && let Ok(ws) = accept_async(stream).await
2417            {
2418                drop(ws); // Drop connection immediately
2419            }
2420        });
2421
2422        let config = WebSocketConfig {
2423            url: format!("ws://127.0.0.1:{port}"),
2424            headers: vec![],
2425            message_handler: None, // No handler
2426            heartbeat: None,
2427            heartbeat_msg: None,
2428            ping_handler: None,
2429            reconnect_timeout_ms: Some(1_000),
2430            reconnect_delay_initial_ms: Some(100),
2431            reconnect_delay_max_ms: Some(500),
2432            reconnect_backoff_factor: Some(1.5),
2433            reconnect_jitter_ms: Some(0),
2434            reconnect_max_attempts: None,
2435        };
2436
2437        // Create client directly via connect_url to bypass validation
2438        let inner = WebSocketClientInner::connect_url(config).await.unwrap();
2439
2440        // Verify is_stream_mode is true when no handler
2441        assert!(
2442            inner.is_stream_mode,
2443            "Client without handler should have is_stream_mode=true"
2444        );
2445
2446        // Verify that when stream mode is enabled, reconnection is disabled
2447        // (documented behavior - stream mode clients close instead of reconnecting)
2448
2449        server.abort();
2450    }
2451}