nautilus_network/websocket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! WebSocket client implementation with automatic reconnection.
17//!
18//! This module contains the core WebSocket client implementation including:
19//! - Connection management with automatic reconnection.
20//! - Split read/write architecture with separate tasks.
21//! - Unbounded channels on latency-sensitive paths.
22//! - Heartbeat support.
23//! - Rate limiting integration.
24
25use std::{
26    collections::VecDeque,
27    fmt::Debug,
28    sync::{
29        Arc,
30        atomic::{AtomicU8, Ordering},
31    },
32    time::Duration,
33};
34
35use futures_util::{SinkExt, StreamExt};
36use http::HeaderName;
37use nautilus_core::CleanDrop;
38use nautilus_cryptography::providers::install_cryptographic_provider;
39#[cfg(feature = "turmoil")]
40use tokio_tungstenite::MaybeTlsStream;
41#[cfg(feature = "turmoil")]
42use tokio_tungstenite::client_async;
43#[cfg(not(feature = "turmoil"))]
44use tokio_tungstenite::connect_async_with_config;
45use tokio_tungstenite::tungstenite::{
46    Error, Message, client::IntoClientRequest, http::HeaderValue,
47};
48
49use super::{
50    config::WebSocketConfig,
51    consts::{
52        CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
53        GRACEFUL_SHUTDOWN_TIMEOUT_SECS, SEND_OPERATION_CHECK_INTERVAL_MS,
54    },
55    types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
56};
57#[cfg(feature = "turmoil")]
58use crate::net::TcpConnector;
59use crate::{
60    RECONNECTED,
61    backoff::ExponentialBackoff,
62    error::SendError,
63    logging::{log_task_aborted, log_task_started, log_task_stopped},
64    mode::ConnectionMode,
65    ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
66};
67
68/// `WebSocketClient` connects to a websocket server to read and send messages.
69///
70/// The client is opinionated about how messages are read and written. It
71/// assumes that data can only have one reader but multiple writers.
72///
73/// The client splits the connection into read and write halves. It moves
74/// the read half into a tokio task which keeps receiving messages from the
75/// server and calls a handler - a Python function that takes the data
76/// as its parameter. It stores the write half in the struct wrapped
77/// with an Arc Mutex. This way the client struct can be used to write
78/// data to the server from multiple scopes/tasks.
79///
80/// The client also maintains a heartbeat if given a duration in seconds.
81/// It's preferable to set the duration slightly lower - heartbeat more
82/// frequently - than the required amount.
83pub struct WebSocketClientInner {
84    config: WebSocketConfig,
85    read_task: Option<tokio::task::JoinHandle<()>>,
86    write_task: tokio::task::JoinHandle<()>,
87    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
88    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
89    connection_mode: Arc<AtomicU8>,
90    reconnect_timeout: Duration,
91    backoff: ExponentialBackoff,
92    /// True if this is a stream-based client (created via `connect_stream`).
93    /// Stream-based clients disable auto-reconnect because the reader is
94    /// owned by the caller and cannot be replaced during reconnection.
95    is_stream_mode: bool,
96    /// Maximum number of reconnection attempts before giving up (None = unlimited).
97    reconnect_max_attempts: Option<u32>,
98    /// Current count of consecutive reconnection attempts.
99    reconnection_attempt_count: u32,
100}
101
102impl WebSocketClientInner {
103    /// Create an inner websocket client with an existing writer.
104    ///
105    /// # Errors
106    ///
107    /// Returns an error if the exponential backoff configuration is invalid.
108    pub async fn new_with_writer(
109        config: WebSocketConfig,
110        writer: MessageWriter,
111    ) -> Result<Self, Error> {
112        install_cryptographic_provider();
113
114        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
115
116        // Note: We don't spawn a read task here since the reader is handled externally
117        let read_task = None;
118
119        let backoff = ExponentialBackoff::new(
120            Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
121            Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
122            config.reconnect_backoff_factor.unwrap_or(1.5),
123            config.reconnect_jitter_ms.unwrap_or(100),
124            true, // immediate-first
125        )
126        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
127
128        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
129        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
130
131        let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
132            Some(Self::spawn_heartbeat_task(
133                connection_mode.clone(),
134                heartbeat_interval,
135                config.heartbeat_msg.clone(),
136                writer_tx.clone(),
137            ))
138        } else {
139            None
140        };
141
142        Ok(Self {
143            config: config.clone(),
144            writer_tx,
145            connection_mode,
146            reconnect_timeout: Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10000)),
147            heartbeat_task,
148            read_task,
149            write_task,
150            backoff,
151            is_stream_mode: true,
152            reconnect_max_attempts: config.reconnect_max_attempts,
153            reconnection_attempt_count: 0,
154        })
155    }
156
157    /// Create an inner websocket client.
158    ///
159    /// # Errors
160    ///
161    /// Returns an error if:
162    /// - The connection to the server fails.
163    /// - The exponential backoff configuration is invalid.
164    pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
165        install_cryptographic_provider();
166
167        let WebSocketConfig {
168            url,
169            message_handler,
170            heartbeat,
171            headers,
172            heartbeat_msg,
173            ping_handler,
174            reconnect_timeout_ms,
175            reconnect_delay_initial_ms,
176            reconnect_delay_max_ms,
177            reconnect_backoff_factor,
178            reconnect_jitter_ms,
179            reconnect_max_attempts,
180        } = &config;
181
182        // Capture whether we're in stream mode before moving config
183        let is_stream_mode = message_handler.is_none();
184        let reconnect_max_attempts = *reconnect_max_attempts;
185
186        let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
187
188        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
189
190        let read_task = if message_handler.is_some() {
191            Some(Self::spawn_message_handler_task(
192                connection_mode.clone(),
193                reader,
194                message_handler.as_ref(),
195                ping_handler.as_ref(),
196            ))
197        } else {
198            None
199        };
200
201        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
202        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
203
204        // Optionally spawn a heartbeat task to periodically ping server
205        let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
206            Self::spawn_heartbeat_task(
207                connection_mode.clone(),
208                *heartbeat_secs,
209                heartbeat_msg.clone(),
210                writer_tx.clone(),
211            )
212        });
213
214        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
215        let backoff = ExponentialBackoff::new(
216            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
217            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
218            reconnect_backoff_factor.unwrap_or(1.5),
219            reconnect_jitter_ms.unwrap_or(100),
220            true, // immediate-first
221        )
222        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
223
224        Ok(Self {
225            config,
226            read_task,
227            write_task,
228            writer_tx,
229            heartbeat_task,
230            connection_mode,
231            reconnect_timeout,
232            backoff,
233            // Set stream mode when no message handler (reader not managed by client)
234            is_stream_mode,
235            reconnect_max_attempts,
236            reconnection_attempt_count: 0,
237        })
238    }
239
240    /// Connects with the server creating a tokio-tungstenite websocket stream.
241    /// Production version that uses `connect_async_with_config` convenience helper.
242    ///
243    /// # Errors
244    ///
245    /// Returns an error if:
246    /// - The URL cannot be parsed into a valid client request.
247    /// - Header values are invalid.
248    /// - The WebSocket connection fails.
249    #[inline]
250    #[cfg(not(feature = "turmoil"))]
251    pub async fn connect_with_server(
252        url: &str,
253        headers: Vec<(String, String)>,
254    ) -> Result<(MessageWriter, MessageReader), Error> {
255        let mut request = url.into_client_request()?;
256        let req_headers = request.headers_mut();
257
258        let mut header_names: Vec<HeaderName> = Vec::new();
259        for (key, val) in headers {
260            let header_value = HeaderValue::from_str(&val)?;
261            let header_name: HeaderName = key.parse()?;
262            header_names.push(header_name.clone());
263            req_headers.insert(header_name, header_value);
264        }
265
266        connect_async_with_config(request, None, true)
267            .await
268            .map(|resp| resp.0.split())
269    }
270
271    /// Connects with the server creating a tokio-tungstenite websocket stream.
272    /// Turmoil version that uses the lower-level `client_async` API with injected stream.
273    ///
274    /// # Errors
275    ///
276    /// Returns an error if:
277    /// - The URL cannot be parsed into a valid client request.
278    /// - The URL is missing a hostname.
279    /// - Header values are invalid.
280    /// - The TCP connection fails.
281    /// - TLS setup fails (for wss:// URLs).
282    /// - The WebSocket handshake fails.
283    #[inline]
284    #[cfg(feature = "turmoil")]
285    pub async fn connect_with_server(
286        url: &str,
287        headers: Vec<(String, String)>,
288    ) -> Result<(MessageWriter, MessageReader), Error> {
289        use rustls::ClientConfig;
290        use tokio_rustls::TlsConnector;
291
292        let mut request = url.into_client_request()?;
293        let req_headers = request.headers_mut();
294
295        let mut header_names: Vec<HeaderName> = Vec::new();
296        for (key, val) in headers {
297            let header_value = HeaderValue::from_str(&val)?;
298            let header_name: HeaderName = key.parse()?;
299            header_names.push(header_name.clone());
300            req_headers.insert(header_name, header_value);
301        }
302
303        let uri = request.uri();
304        let scheme = uri.scheme_str().unwrap_or("ws");
305        let host = uri.host().ok_or_else(|| {
306            Error::Url(tokio_tungstenite::tungstenite::error::UrlError::NoHostName)
307        })?;
308
309        // Determine port: use explicit port if specified, otherwise default based on scheme
310        let port = uri
311            .port_u16()
312            .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
313
314        let addr = format!("{host}:{port}");
315
316        // Use the connector to get a turmoil-compatible stream
317        let connector = crate::net::RealTcpConnector;
318        let tcp_stream = connector.connect(&addr).await?;
319        if let Err(e) = tcp_stream.set_nodelay(true) {
320            tracing::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
321        }
322
323        // Wrap stream appropriately based on scheme
324        let maybe_tls_stream = if scheme == "wss" {
325            // Build TLS config with webpki roots
326            let mut root_store = rustls::RootCertStore::empty();
327            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
328
329            let config = ClientConfig::builder()
330                .with_root_certificates(root_store)
331                .with_no_client_auth();
332
333            let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
334            let domain =
335                rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|e| {
336                    Error::Io(std::io::Error::new(
337                        std::io::ErrorKind::InvalidInput,
338                        format!("Invalid DNS name: {e}"),
339                    ))
340                })?;
341
342            let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
343            MaybeTlsStream::Rustls(tls_stream)
344        } else {
345            MaybeTlsStream::Plain(tcp_stream)
346        };
347
348        // Use client_async with the stream (plain or TLS)
349        client_async(request, maybe_tls_stream)
350            .await
351            .map(|resp| resp.0.split())
352    }
353
354    /// Reconnect with server.
355    ///
356    /// Make a new connection with server. Use the new read and write halves
357    /// to update self writer and read and heartbeat tasks.
358    ///
359    /// For stream-based clients (created via `connect_stream`), reconnection is disabled
360    /// because the reader is owned by the caller and cannot be replaced. Stream users
361    /// should handle disconnections by creating a new connection.
362    ///
363    /// # Errors
364    ///
365    /// Returns an error if:
366    /// - The reconnection attempt times out.
367    /// - The connection to the server fails.
368    pub async fn reconnect(&mut self) -> Result<(), Error> {
369        tracing::debug!("Reconnecting");
370
371        if self.is_stream_mode {
372            tracing::warn!(
373                "Auto-reconnect disabled for stream-based WebSocket client; \
374                stream users must manually reconnect by creating a new connection"
375            );
376            // Transition to CLOSED state to stop reconnection attempts
377            self.connection_mode
378                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
379            return Ok(());
380        }
381
382        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
383            tracing::debug!("Reconnect aborted due to disconnect state");
384            return Ok(());
385        }
386
387        tokio::time::timeout(self.reconnect_timeout, async {
388            // Attempt to connect; abort early if a disconnect was requested
389            let (new_writer, reader) =
390                Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
391
392            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
393                tracing::debug!("Reconnect aborted mid-flight (after connect)");
394                return Ok(());
395            }
396
397            // Use a oneshot channel to synchronize with the writer task.
398            // We must verify that the buffer was successfully drained before transitioning to ACTIVE
399            // to prevent silent message loss if the new connection drops immediately.
400            let (tx, rx) = tokio::sync::oneshot::channel();
401            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
402                tracing::error!("{e}");
403                return Err(Error::Io(std::io::Error::new(
404                    std::io::ErrorKind::BrokenPipe,
405                    format!("Failed to send update command: {e}"),
406                )));
407            }
408
409            // Wait for writer to confirm it has drained the buffer
410            match rx.await {
411                Ok(true) => tracing::debug!("Writer confirmed buffer drain success"),
412                Ok(false) => {
413                    tracing::warn!("Writer failed to drain buffer, aborting reconnect");
414                    // Return error to trigger retry logic in controller
415                    return Err(Error::Io(std::io::Error::other(
416                        "Failed to drain reconnection buffer",
417                    )));
418                }
419                Err(e) => {
420                    tracing::error!("Writer dropped update channel: {e}");
421                    return Err(Error::Io(std::io::Error::new(
422                        std::io::ErrorKind::BrokenPipe,
423                        "Writer task dropped response channel",
424                    )));
425                }
426            }
427
428            // Delay before closing connection
429            tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
430
431            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
432                tracing::debug!("Reconnect aborted mid-flight (after delay)");
433                return Ok(());
434            }
435
436            if let Some(ref read_task) = self.read_task.take()
437                && !read_task.is_finished()
438            {
439                read_task.abort();
440                log_task_aborted("read");
441            }
442
443            // Atomically transition from Reconnect to Active
444            // This prevents race condition where disconnect could be requested between check and store
445            if self
446                .connection_mode
447                .compare_exchange(
448                    ConnectionMode::Reconnect.as_u8(),
449                    ConnectionMode::Active.as_u8(),
450                    Ordering::SeqCst,
451                    Ordering::SeqCst,
452                )
453                .is_err()
454            {
455                tracing::debug!("Reconnect aborted (state changed during reconnect)");
456                return Ok(());
457            }
458
459            self.read_task = if self.config.message_handler.is_some() {
460                Some(Self::spawn_message_handler_task(
461                    self.connection_mode.clone(),
462                    reader,
463                    self.config.message_handler.as_ref(),
464                    self.config.ping_handler.as_ref(),
465                ))
466            } else {
467                None
468            };
469
470            tracing::debug!("Reconnect succeeded");
471            Ok(())
472        })
473        .await
474        .map_err(|_| {
475            Error::Io(std::io::Error::new(
476                std::io::ErrorKind::TimedOut,
477                format!(
478                    "reconnection timed out after {}s",
479                    self.reconnect_timeout.as_secs_f64()
480                ),
481            ))
482        })?
483    }
484
485    /// Check if the client is still connected.
486    ///
487    /// The client is connected if the read task has not finished. It is expected
488    /// that in case of any failure client or server side. The read task will be
489    /// shutdown or will receive a `Close` frame which will finish it. There
490    /// might be some delay between the connection being closed and the client
491    /// detecting.
492    #[inline]
493    #[must_use]
494    pub fn is_alive(&self) -> bool {
495        match &self.read_task {
496            Some(read_task) => !read_task.is_finished(),
497            None => true, // Stream is being used directly
498        }
499    }
500
501    fn spawn_message_handler_task(
502        connection_state: Arc<AtomicU8>,
503        mut reader: MessageReader,
504        message_handler: Option<&MessageHandler>,
505        ping_handler: Option<&PingHandler>,
506    ) -> tokio::task::JoinHandle<()> {
507        tracing::debug!("Started message handler task 'read'");
508
509        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
510
511        // Clone Arc handlers for the async task
512        let message_handler = message_handler.cloned();
513        let ping_handler = ping_handler.cloned();
514
515        tokio::task::spawn(async move {
516            loop {
517                if !ConnectionMode::from_atomic(&connection_state).is_active() {
518                    break;
519                }
520
521                match tokio::time::timeout(check_interval, reader.next()).await {
522                    Ok(Some(Ok(Message::Binary(data)))) => {
523                        tracing::trace!("Received message <binary> {} bytes", data.len());
524                        if let Some(ref handler) = message_handler {
525                            handler(Message::Binary(data));
526                        }
527                    }
528                    Ok(Some(Ok(Message::Text(data)))) => {
529                        tracing::trace!("Received message: {data}");
530                        if let Some(ref handler) = message_handler {
531                            handler(Message::Text(data));
532                        }
533                    }
534                    Ok(Some(Ok(Message::Ping(ping_data)))) => {
535                        tracing::trace!("Received ping: {ping_data:?}");
536                        if let Some(ref handler) = ping_handler {
537                            handler(ping_data.to_vec());
538                        }
539                    }
540                    Ok(Some(Ok(Message::Pong(_)))) => {
541                        tracing::trace!("Received pong");
542                    }
543                    Ok(Some(Ok(Message::Close(_)))) => {
544                        tracing::debug!("Received close message - terminating");
545                        break;
546                    }
547                    Ok(Some(Ok(_))) => (),
548                    Ok(Some(Err(e))) => {
549                        tracing::error!("Received error message - terminating: {e}");
550                        break;
551                    }
552                    Ok(None) => {
553                        tracing::debug!("No message received - terminating");
554                        break;
555                    }
556                    Err(_) => {
557                        // Timeout - continue loop and check connection mode
558                        continue;
559                    }
560                }
561            }
562        })
563    }
564
565    /// Attempts to send all buffered messages after reconnection.
566    ///
567    /// Returns `true` if a send error occurred (caller should trigger reconnection).
568    /// Messages remain in buffer if send fails, preserving them for the next reconnection attempt.
569    async fn drain_reconnect_buffer(
570        buffer: &mut VecDeque<Message>,
571        writer: &mut MessageWriter,
572    ) -> bool {
573        if buffer.is_empty() {
574            return false;
575        }
576
577        let initial_buffer_len = buffer.len();
578        tracing::info!(
579            "Sending {} buffered messages after reconnection",
580            initial_buffer_len
581        );
582
583        let mut send_error_occurred = false;
584
585        while let Some(buffered_msg) = buffer.front() {
586            // Clone message before attempting send (to keep in buffer if send fails)
587            let msg_to_send = buffered_msg.clone();
588
589            if let Err(e) = writer.send(msg_to_send).await {
590                tracing::error!(
591                    "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
592                    buffer.len()
593                );
594                send_error_occurred = true;
595                break; // Stop processing buffer, remaining messages preserved for next reconnection
596            }
597
598            // Only remove from buffer after successful send
599            buffer.pop_front();
600        }
601
602        if buffer.is_empty() {
603            tracing::info!(
604                "Successfully sent all {} buffered messages",
605                initial_buffer_len
606            );
607        }
608
609        send_error_occurred
610    }
611
612    fn spawn_write_task(
613        connection_state: Arc<AtomicU8>,
614        writer: MessageWriter,
615        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
616    ) -> tokio::task::JoinHandle<()> {
617        log_task_started("write");
618
619        // Interval between checking the connection mode
620        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
621
622        tokio::task::spawn(async move {
623            let mut active_writer = writer;
624            // Buffer for messages received during reconnection
625            // VecDeque for efficient pop_front() operations
626            let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
627
628            loop {
629                match ConnectionMode::from_atomic(&connection_state) {
630                    ConnectionMode::Disconnect => {
631                        // Log any buffered messages that will be lost
632                        if !reconnect_buffer.is_empty() {
633                            tracing::warn!(
634                                "Discarding {} buffered messages due to disconnect",
635                                reconnect_buffer.len()
636                            );
637                            reconnect_buffer.clear();
638                        }
639
640                        // Attempt to close the writer gracefully before exiting,
641                        // we ignore any error as the writer may already be closed.
642                        _ = tokio::time::timeout(
643                            Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
644                            active_writer.close(),
645                        )
646                        .await;
647                        break;
648                    }
649                    ConnectionMode::Closed => {
650                        // Log any buffered messages that will be lost
651                        if !reconnect_buffer.is_empty() {
652                            tracing::warn!(
653                                "Discarding {} buffered messages due to closed connection",
654                                reconnect_buffer.len()
655                            );
656                            reconnect_buffer.clear();
657                        }
658                        break;
659                    }
660                    _ => {}
661                }
662
663                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
664                    Ok(Some(msg)) => {
665                        // Re-check connection mode after receiving a message
666                        let mode = ConnectionMode::from_atomic(&connection_state);
667                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
668                            break;
669                        }
670
671                        match msg {
672                            WriterCommand::Update(new_writer, tx) => {
673                                tracing::debug!("Received new writer");
674
675                                // Delay before closing connection
676                                tokio::time::sleep(Duration::from_millis(100)).await;
677
678                                // Attempt to close the writer gracefully on update,
679                                // we ignore any error as the writer may already be closed.
680                                _ = tokio::time::timeout(
681                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
682                                    active_writer.close(),
683                                )
684                                .await;
685
686                                active_writer = new_writer;
687                                tracing::debug!("Updated writer");
688
689                                let send_error = Self::drain_reconnect_buffer(
690                                    &mut reconnect_buffer,
691                                    &mut active_writer,
692                                )
693                                .await;
694
695                                if let Err(e) = tx.send(!send_error) {
696                                    tracing::error!(
697                                        "Failed to report drain status to controller: {e:?}"
698                                    );
699                                }
700                            }
701                            WriterCommand::Send(msg) if mode.is_reconnect() => {
702                                // Buffer messages during reconnection instead of dropping them
703                                tracing::debug!(
704                                    "Buffering message during reconnection (buffer size: {})",
705                                    reconnect_buffer.len() + 1
706                                );
707                                reconnect_buffer.push_back(msg);
708                            }
709                            WriterCommand::Send(msg) => {
710                                if let Err(e) = active_writer.send(msg.clone()).await {
711                                    tracing::error!("Failed to send message: {e}");
712                                    tracing::warn!("Writer triggering reconnect");
713                                    reconnect_buffer.push_back(msg);
714                                    connection_state
715                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
716                                }
717                            }
718                        }
719                    }
720                    Ok(None) => {
721                        // Channel closed - writer task should terminate
722                        tracing::debug!("Writer channel closed, terminating writer task");
723                        break;
724                    }
725                    Err(_) => {
726                        // Timeout - just continue the loop
727                        continue;
728                    }
729                }
730            }
731
732            // Attempt to close the writer gracefully before exiting,
733            // we ignore any error as the writer may already be closed.
734            _ = tokio::time::timeout(
735                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
736                active_writer.close(),
737            )
738            .await;
739
740            log_task_stopped("write");
741        })
742    }
743
744    fn spawn_heartbeat_task(
745        connection_state: Arc<AtomicU8>,
746        heartbeat_secs: u64,
747        message: Option<String>,
748        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
749    ) -> tokio::task::JoinHandle<()> {
750        log_task_started("heartbeat");
751
752        tokio::task::spawn(async move {
753            let interval = Duration::from_secs(heartbeat_secs);
754
755            loop {
756                tokio::time::sleep(interval).await;
757
758                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
759                    ConnectionMode::Active => {
760                        let msg = match &message {
761                            Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
762                            None => WriterCommand::Send(Message::Ping(vec![].into())),
763                        };
764
765                        match writer_tx.send(msg) {
766                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
767                            Err(e) => {
768                                tracing::error!("Failed to send heartbeat to writer task: {e}");
769                            }
770                        }
771                    }
772                    ConnectionMode::Reconnect => continue,
773                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
774                }
775            }
776
777            log_task_stopped("heartbeat");
778        })
779    }
780}
781
782impl Drop for WebSocketClientInner {
783    fn drop(&mut self) {
784        // Delegate to explicit cleanup handler
785        self.clean_drop();
786    }
787}
788
789/// Cleanup on drop: aborts background tasks and clears handlers to break reference cycles.
790impl CleanDrop for WebSocketClientInner {
791    fn clean_drop(&mut self) {
792        if let Some(ref read_task) = self.read_task.take()
793            && !read_task.is_finished()
794        {
795            read_task.abort();
796            log_task_aborted("read");
797        }
798
799        if !self.write_task.is_finished() {
800            self.write_task.abort();
801            log_task_aborted("write");
802        }
803
804        if let Some(ref handle) = self.heartbeat_task.take()
805            && !handle.is_finished()
806        {
807            handle.abort();
808            log_task_aborted("heartbeat");
809        }
810
811        // Clear handlers to break potential reference cycles
812        self.config.message_handler = None;
813        self.config.ping_handler = None;
814    }
815}
816
817impl Debug for WebSocketClientInner {
818    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
819        f.debug_struct("WebSocketClientInner")
820            .field("config", &self.config)
821            .field(
822                "connection_mode",
823                &ConnectionMode::from_atomic(&self.connection_mode),
824            )
825            .field("reconnect_timeout", &self.reconnect_timeout)
826            .field("is_stream_mode", &self.is_stream_mode)
827            .finish()
828    }
829}
830
831/// WebSocket client with automatic reconnection.
832///
833/// Handles connection state, callbacks, and rate limiting.
834/// See module docs for architecture details.
835#[cfg_attr(
836    feature = "python",
837    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
838)]
839pub struct WebSocketClient {
840    pub(crate) controller_task: tokio::task::JoinHandle<()>,
841    pub(crate) connection_mode: Arc<AtomicU8>,
842    pub(crate) reconnect_timeout: Duration,
843    pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
844    pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
845}
846
847impl Debug for WebSocketClient {
848    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
849        f.debug_struct(stringify!(WebSocketClient)).finish()
850    }
851}
852
853impl WebSocketClient {
854    /// Creates a websocket client in **stream mode** that returns a [`MessageReader`].
855    ///
856    /// Returns a stream that the caller owns and reads from directly. Automatic reconnection
857    /// is **disabled** because the reader cannot be replaced internally. On disconnection, the
858    /// client transitions to CLOSED state and the caller must manually reconnect by calling
859    /// `connect_stream` again.
860    ///
861    /// Use stream mode when you need custom reconnection logic, direct control over message
862    /// reading, or fine-grained backpressure handling.
863    ///
864    /// See [`WebSocketConfig`] documentation for comparison with handler mode.
865    ///
866    /// # Errors
867    ///
868    /// Returns an error if the connection cannot be established.
869    #[allow(clippy::too_many_arguments)]
870    pub async fn connect_stream(
871        config: WebSocketConfig,
872        keyed_quotas: Vec<(String, Quota)>,
873        default_quota: Option<Quota>,
874        post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
875    ) -> Result<(MessageReader, Self), Error> {
876        install_cryptographic_provider();
877
878        // Create a single connection and split it, respecting configured headers
879        let (writer, reader) =
880            WebSocketClientInner::connect_with_server(&config.url, config.headers.clone()).await?;
881
882        // Create inner without connecting (we'll provide the writer)
883        let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
884
885        let connection_mode = inner.connection_mode.clone();
886        let reconnect_timeout = inner.reconnect_timeout;
887        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
888        let writer_tx = inner.writer_tx.clone();
889
890        let controller_task =
891            Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnect);
892
893        Ok((
894            reader,
895            Self {
896                controller_task,
897                connection_mode,
898                reconnect_timeout,
899                rate_limiter,
900                writer_tx,
901            },
902        ))
903    }
904
905    /// Creates a websocket client in **handler mode** with automatic reconnection.
906    ///
907    /// Requires `config.message_handler` to be set. The handler is called for each incoming
908    /// message on an internal task. Automatic reconnection is **enabled** with exponential
909    /// backoff. On disconnection, the client automatically attempts to reconnect and replaces
910    /// the internal reader (the handler continues working seamlessly).
911    ///
912    /// Use handler mode for simplified connection management, automatic reconnection, Python
913    /// bindings, or callback-based message handling.
914    ///
915    /// See [`WebSocketConfig`] documentation for comparison with stream mode.
916    ///
917    /// # Errors
918    ///
919    /// Returns an error if the connection cannot be established or if
920    /// `config.message_handler` is `None` (use `connect_stream` instead).
921    pub async fn connect(
922        config: WebSocketConfig,
923        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
924        keyed_quotas: Vec<(String, Quota)>,
925        default_quota: Option<Quota>,
926    ) -> Result<Self, Error> {
927        // Validate that handler mode has a message handler
928        if config.message_handler.is_none() {
929            return Err(Error::Io(std::io::Error::new(
930                std::io::ErrorKind::InvalidInput,
931                "Handler mode requires config.message_handler to be set. Use connect_stream() for stream mode without a handler.",
932            )));
933        }
934
935        tracing::debug!("Connecting");
936        let inner = WebSocketClientInner::connect_url(config).await?;
937        let connection_mode = inner.connection_mode.clone();
938        let writer_tx = inner.writer_tx.clone();
939        let reconnect_timeout = inner.reconnect_timeout;
940
941        let controller_task =
942            Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnection);
943
944        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
945
946        Ok(Self {
947            controller_task,
948            connection_mode,
949            reconnect_timeout,
950            rate_limiter,
951            writer_tx,
952        })
953    }
954
955    /// Returns the current connection mode.
956    #[must_use]
957    pub fn connection_mode(&self) -> ConnectionMode {
958        ConnectionMode::from_atomic(&self.connection_mode)
959    }
960
961    /// Returns a clone of the connection mode atomic for external state tracking.
962    ///
963    /// This allows adapter clients to track connection state across reconnections
964    /// without message-passing delays.
965    #[must_use]
966    pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
967        Arc::clone(&self.connection_mode)
968    }
969
970    /// Check if the client connection is active.
971    ///
972    /// Returns `true` if the client is connected and has not been signalled to disconnect.
973    /// The client will automatically retry connection based on its configuration.
974    #[inline]
975    #[must_use]
976    pub fn is_active(&self) -> bool {
977        self.connection_mode().is_active()
978    }
979
980    /// Check if the client is disconnected.
981    #[must_use]
982    pub fn is_disconnected(&self) -> bool {
983        self.controller_task.is_finished()
984    }
985
986    /// Check if the client is reconnecting.
987    ///
988    /// Returns `true` if the client lost connection and is attempting to reestablish it.
989    /// The client will automatically retry connection based on its configuration.
990    #[inline]
991    #[must_use]
992    pub fn is_reconnecting(&self) -> bool {
993        self.connection_mode().is_reconnect()
994    }
995
996    /// Check if the client is disconnecting.
997    ///
998    /// Returns `true` if the client is in disconnect mode.
999    #[inline]
1000    #[must_use]
1001    pub fn is_disconnecting(&self) -> bool {
1002        self.connection_mode().is_disconnect()
1003    }
1004
1005    /// Check if the client is closed.
1006    ///
1007    /// Returns `true` if the client has been explicitly disconnected or reached
1008    /// maximum reconnection attempts. In this state, the client cannot be reused
1009    /// and a new client must be created for further connections.
1010    #[inline]
1011    #[must_use]
1012    pub fn is_closed(&self) -> bool {
1013        self.connection_mode().is_closed()
1014    }
1015
1016    /// Wait for the client to become active before sending.
1017    ///
1018    /// Returns an error if the client is closed, disconnecting, or if the wait times out.
1019    async fn wait_for_active(&self) -> Result<(), SendError> {
1020        if self.is_closed() {
1021            return Err(SendError::Closed);
1022        }
1023
1024        let timeout = self.reconnect_timeout;
1025        let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
1026
1027        if !self.is_active() {
1028            tracing::debug!("Waiting for client to become ACTIVE before sending...");
1029
1030            let inner = tokio::time::timeout(timeout, async {
1031                loop {
1032                    if self.is_active() {
1033                        return Ok(());
1034                    }
1035                    if matches!(
1036                        self.connection_mode(),
1037                        ConnectionMode::Disconnect | ConnectionMode::Closed
1038                    ) {
1039                        return Err(());
1040                    }
1041                    tokio::time::sleep(check_interval).await;
1042                }
1043            })
1044            .await
1045            .map_err(|_| SendError::Timeout)?;
1046            inner.map_err(|()| SendError::Closed)?;
1047        }
1048
1049        Ok(())
1050    }
1051
1052    /// Set disconnect mode to true.
1053    ///
1054    /// Controller task will periodically check the disconnect mode
1055    /// and shutdown the client if it is alive
1056    pub async fn disconnect(&self) {
1057        tracing::debug!("Disconnecting");
1058        self.connection_mode
1059            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1060
1061        if let Ok(()) =
1062            tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1063                while !self.is_disconnected() {
1064                    tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS))
1065                        .await;
1066                }
1067
1068                if !self.controller_task.is_finished() {
1069                    self.controller_task.abort();
1070                    log_task_aborted("controller");
1071                }
1072            })
1073            .await
1074        {
1075            tracing::debug!("Controller task finished");
1076        } else {
1077            tracing::error!("Timeout waiting for controller task to finish");
1078            if !self.controller_task.is_finished() {
1079                self.controller_task.abort();
1080                log_task_aborted("controller");
1081            }
1082        }
1083    }
1084
1085    /// Sends the given text `data` to the server.
1086    ///
1087    /// # Errors
1088    ///
1089    /// Returns a websocket error if unable to send.
1090    #[allow(unused_variables)]
1091    pub async fn send_text(
1092        &self,
1093        data: String,
1094        keys: Option<Vec<String>>,
1095    ) -> Result<(), SendError> {
1096        // Check connection state before rate limiting to fail fast
1097        if self.is_closed() || self.is_disconnecting() {
1098            return Err(SendError::Closed);
1099        }
1100
1101        self.rate_limiter.await_keys_ready(keys).await;
1102        self.wait_for_active().await?;
1103
1104        tracing::trace!("Sending text: {data:?}");
1105
1106        let msg = Message::Text(data.into());
1107        self.writer_tx
1108            .send(WriterCommand::Send(msg))
1109            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1110    }
1111
1112    /// Sends a pong frame back to the server.
1113    ///
1114    /// # Errors
1115    ///
1116    /// Returns a websocket error if unable to send.
1117    pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1118        self.wait_for_active().await?;
1119
1120        tracing::trace!("Sending pong frame ({} bytes)", data.len());
1121
1122        let msg = Message::Pong(data.into());
1123        self.writer_tx
1124            .send(WriterCommand::Send(msg))
1125            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1126    }
1127
1128    /// Sends the given bytes `data` to the server.
1129    ///
1130    /// # Errors
1131    ///
1132    /// Returns a websocket error if unable to send.
1133    #[allow(unused_variables)]
1134    pub async fn send_bytes(
1135        &self,
1136        data: Vec<u8>,
1137        keys: Option<Vec<String>>,
1138    ) -> Result<(), SendError> {
1139        // Check connection state before rate limiting to fail fast
1140        if self.is_closed() || self.is_disconnecting() {
1141            return Err(SendError::Closed);
1142        }
1143
1144        self.rate_limiter.await_keys_ready(keys).await;
1145        self.wait_for_active().await?;
1146
1147        tracing::trace!("Sending bytes: {data:?}");
1148
1149        let msg = Message::Binary(data.into());
1150        self.writer_tx
1151            .send(WriterCommand::Send(msg))
1152            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1153    }
1154
1155    /// Sends a close message to the server.
1156    ///
1157    /// # Errors
1158    ///
1159    /// Returns a websocket error if unable to send.
1160    pub async fn send_close_message(&self) -> Result<(), SendError> {
1161        self.wait_for_active().await?;
1162
1163        let msg = Message::Close(None);
1164        self.writer_tx
1165            .send(WriterCommand::Send(msg))
1166            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1167    }
1168
1169    fn spawn_controller_task(
1170        mut inner: WebSocketClientInner,
1171        connection_mode: Arc<AtomicU8>,
1172        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1173    ) -> tokio::task::JoinHandle<()> {
1174        tokio::task::spawn(async move {
1175            log_task_started("controller");
1176
1177            let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1178
1179            loop {
1180                tokio::time::sleep(check_interval).await;
1181                let mut mode = ConnectionMode::from_atomic(&connection_mode);
1182
1183                if mode.is_disconnect() {
1184                    tracing::debug!("Disconnecting");
1185
1186                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1187                    if tokio::time::timeout(timeout, async {
1188                        // Delay awaiting graceful shutdown
1189                        tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1190
1191                        if let Some(task) = &inner.read_task
1192                            && !task.is_finished()
1193                        {
1194                            task.abort();
1195                            log_task_aborted("read");
1196                        }
1197
1198                        if let Some(task) = &inner.heartbeat_task
1199                            && !task.is_finished()
1200                        {
1201                            task.abort();
1202                            log_task_aborted("heartbeat");
1203                        }
1204                    })
1205                    .await
1206                    .is_err()
1207                    {
1208                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
1209                    }
1210
1211                    tracing::debug!("Closed");
1212                    break; // Controller finished
1213                }
1214
1215                if mode.is_active() && !inner.is_alive() {
1216                    if connection_mode
1217                        .compare_exchange(
1218                            ConnectionMode::Active.as_u8(),
1219                            ConnectionMode::Reconnect.as_u8(),
1220                            Ordering::SeqCst,
1221                            Ordering::SeqCst,
1222                        )
1223                        .is_ok()
1224                    {
1225                        tracing::debug!("Detected dead read task, transitioning to RECONNECT");
1226                    }
1227                    mode = ConnectionMode::from_atomic(&connection_mode);
1228                }
1229
1230                if mode.is_reconnect() {
1231                    // Check if max reconnection attempts exceeded
1232                    if let Some(max_attempts) = inner.reconnect_max_attempts
1233                        && inner.reconnection_attempt_count >= max_attempts
1234                    {
1235                        tracing::error!(
1236                            "Max reconnection attempts ({}) exceeded, transitioning to CLOSED",
1237                            max_attempts
1238                        );
1239                        connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1240                        break;
1241                    }
1242
1243                    inner.reconnection_attempt_count += 1;
1244                    tracing::debug!(
1245                        "Reconnection attempt {} of {}",
1246                        inner.reconnection_attempt_count,
1247                        inner
1248                            .reconnect_max_attempts
1249                            .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1250                    );
1251
1252                    match inner.reconnect().await {
1253                        Ok(()) => {
1254                            inner.backoff.reset();
1255                            inner.reconnection_attempt_count = 0; // Reset counter on success
1256
1257                            // Only invoke callbacks if not in disconnect state
1258                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
1259                                if let Some(ref handler) = inner.config.message_handler {
1260                                    let reconnected_msg =
1261                                        Message::Text(RECONNECTED.to_string().into());
1262                                    handler(reconnected_msg);
1263                                    tracing::debug!("Sent reconnected message to handler");
1264                                }
1265
1266                                // TODO: Retain this legacy callback for use from Python
1267                                if let Some(ref callback) = post_reconnection {
1268                                    callback();
1269                                    tracing::debug!("Called `post_reconnection` handler");
1270                                }
1271
1272                                tracing::debug!("Reconnected successfully");
1273                            } else {
1274                                tracing::debug!(
1275                                    "Skipping post_reconnection handlers due to disconnect state"
1276                                );
1277                            }
1278                        }
1279                        Err(e) => {
1280                            let duration = inner.backoff.next_duration();
1281                            tracing::warn!(
1282                                "Reconnect attempt {} failed: {e}",
1283                                inner.reconnection_attempt_count
1284                            );
1285                            if !duration.is_zero() {
1286                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
1287                            }
1288                            tokio::time::sleep(duration).await;
1289                        }
1290                    }
1291                }
1292            }
1293            inner
1294                .connection_mode
1295                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1296
1297            log_task_stopped("controller");
1298        })
1299    }
1300}
1301
1302// Abort controller task on drop to clean up background tasks
1303impl Drop for WebSocketClient {
1304    fn drop(&mut self) {
1305        if !self.controller_task.is_finished() {
1306            self.controller_task.abort();
1307            log_task_aborted("controller");
1308        }
1309    }
1310}
1311
1312#[cfg(test)]
1313#[cfg(not(feature = "turmoil"))]
1314#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
1315mod tests {
1316    use std::{num::NonZeroU32, sync::Arc};
1317
1318    use futures_util::{SinkExt, StreamExt};
1319    use tokio::{
1320        net::TcpListener,
1321        task::{self, JoinHandle},
1322    };
1323    use tokio_tungstenite::{
1324        accept_hdr_async,
1325        tungstenite::{
1326            handshake::server::{self, Callback},
1327            http::HeaderValue,
1328        },
1329    };
1330
1331    use crate::{
1332        ratelimiter::quota::Quota,
1333        websocket::{WebSocketClient, WebSocketConfig},
1334    };
1335
1336    struct TestServer {
1337        task: JoinHandle<()>,
1338        port: u16,
1339    }
1340
1341    #[derive(Debug, Clone)]
1342    struct TestCallback {
1343        key: String,
1344        value: HeaderValue,
1345    }
1346
1347    impl Callback for TestCallback {
1348        #[allow(clippy::panic_in_result_fn)]
1349        fn on_request(
1350            self,
1351            request: &server::Request,
1352            response: server::Response,
1353        ) -> Result<server::Response, server::ErrorResponse> {
1354            let _ = response;
1355            let value = request.headers().get(&self.key);
1356            assert!(value.is_some());
1357
1358            if let Some(value) = request.headers().get(&self.key) {
1359                assert_eq!(value, self.value);
1360            }
1361
1362            Ok(response)
1363        }
1364    }
1365
1366    impl TestServer {
1367        async fn setup() -> Self {
1368            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1369            let port = TcpListener::local_addr(&server).unwrap().port();
1370
1371            let header_key = "test".to_string();
1372            let header_value = "test".to_string();
1373
1374            let test_call_back = TestCallback {
1375                key: header_key,
1376                value: HeaderValue::from_str(&header_value).unwrap(),
1377            };
1378
1379            let task = task::spawn(async move {
1380                // Keep accepting connections
1381                loop {
1382                    let (conn, _) = server.accept().await.unwrap();
1383                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1384                        .await
1385                        .unwrap();
1386
1387                    task::spawn(async move {
1388                        while let Some(Ok(msg)) = websocket.next().await {
1389                            match msg {
1390                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1391                                    if txt == "close-now" =>
1392                                {
1393                                    tracing::debug!("Forcibly closing from server side");
1394                                    // This sends a close frame, then stops reading
1395                                    let _ = websocket.close(None).await;
1396                                    break;
1397                                }
1398                                // Echo text/binary frames
1399                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1400                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1401                                    if websocket.send(msg).await.is_err() {
1402                                        break;
1403                                    }
1404                                }
1405                                // If the client closes, we also break
1406                                tokio_tungstenite::tungstenite::protocol::Message::Close(
1407                                    _frame,
1408                                ) => {
1409                                    let _ = websocket.close(None).await;
1410                                    break;
1411                                }
1412                                // Ignore pings/pongs
1413                                _ => {}
1414                            }
1415                        }
1416                    });
1417                }
1418            });
1419
1420            Self { task, port }
1421        }
1422    }
1423
1424    impl Drop for TestServer {
1425        fn drop(&mut self) {
1426            self.task.abort();
1427        }
1428    }
1429
1430    async fn setup_test_client(port: u16) -> WebSocketClient {
1431        let config = WebSocketConfig {
1432            url: format!("ws://127.0.0.1:{port}"),
1433            headers: vec![("test".into(), "test".into())],
1434            message_handler: Some(Arc::new(|_| {})),
1435            heartbeat: None,
1436            heartbeat_msg: None,
1437            ping_handler: None,
1438            reconnect_timeout_ms: None,
1439            reconnect_delay_initial_ms: None,
1440            reconnect_backoff_factor: None,
1441            reconnect_delay_max_ms: None,
1442            reconnect_jitter_ms: None,
1443            reconnect_max_attempts: None,
1444        };
1445        WebSocketClient::connect(config, None, vec![], None)
1446            .await
1447            .expect("Failed to connect")
1448    }
1449
1450    #[tokio::test]
1451    async fn test_websocket_basic() {
1452        let server = TestServer::setup().await;
1453        let client = setup_test_client(server.port).await;
1454
1455        assert!(!client.is_disconnected());
1456
1457        client.disconnect().await;
1458        assert!(client.is_disconnected());
1459    }
1460
1461    #[tokio::test]
1462    async fn test_websocket_heartbeat() {
1463        let server = TestServer::setup().await;
1464        let client = setup_test_client(server.port).await;
1465
1466        // Wait ~3s => server should see multiple "ping"
1467        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1468
1469        // Cleanup
1470        client.disconnect().await;
1471        assert!(client.is_disconnected());
1472    }
1473
1474    #[tokio::test]
1475    async fn test_websocket_reconnect_exhausted() {
1476        let config = WebSocketConfig {
1477            url: "ws://127.0.0.1:9997".into(), // <-- No server
1478            headers: vec![],
1479            message_handler: Some(Arc::new(|_| {})),
1480            heartbeat: None,
1481            heartbeat_msg: None,
1482            ping_handler: None,
1483            reconnect_timeout_ms: None,
1484            reconnect_delay_initial_ms: None,
1485            reconnect_backoff_factor: None,
1486            reconnect_delay_max_ms: None,
1487            reconnect_jitter_ms: None,
1488            reconnect_max_attempts: None,
1489        };
1490        let res = WebSocketClient::connect(config, None, vec![], None).await;
1491        assert!(res.is_err(), "Should fail quickly with no server");
1492    }
1493
1494    #[tokio::test]
1495    async fn test_websocket_forced_close_reconnect() {
1496        let server = TestServer::setup().await;
1497        let client = setup_test_client(server.port).await;
1498
1499        // 1) Send normal message
1500        client.send_text("Hello".into(), None).await.unwrap();
1501
1502        // 2) Trigger forced close from server
1503        client.send_text("close-now".into(), None).await.unwrap();
1504
1505        // 3) Wait a bit => read loop sees close => reconnect
1506        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1507
1508        // Confirm not disconnected
1509        assert!(!client.is_disconnected());
1510
1511        // Cleanup
1512        client.disconnect().await;
1513        assert!(client.is_disconnected());
1514    }
1515
1516    #[tokio::test]
1517    async fn test_rate_limiter() {
1518        let server = TestServer::setup().await;
1519        let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1520
1521        let config = WebSocketConfig {
1522            url: format!("ws://127.0.0.1:{}", server.port),
1523            headers: vec![("test".into(), "test".into())],
1524            message_handler: Some(Arc::new(|_| {})),
1525            heartbeat: None,
1526            heartbeat_msg: None,
1527            ping_handler: None,
1528            reconnect_timeout_ms: None,
1529            reconnect_delay_initial_ms: None,
1530            reconnect_backoff_factor: None,
1531            reconnect_delay_max_ms: None,
1532            reconnect_jitter_ms: None,
1533            reconnect_max_attempts: None,
1534        };
1535
1536        let client = WebSocketClient::connect(config, None, vec![("default".into(), quota)], None)
1537            .await
1538            .unwrap();
1539
1540        // First 2 should succeed
1541        client.send_text("test1".into(), None).await.unwrap();
1542        client.send_text("test2".into(), None).await.unwrap();
1543
1544        // Third should error
1545        client.send_text("test3".into(), None).await.unwrap();
1546
1547        // Cleanup
1548        client.disconnect().await;
1549        assert!(client.is_disconnected());
1550    }
1551
1552    #[tokio::test]
1553    async fn test_concurrent_writers() {
1554        let server = TestServer::setup().await;
1555        let client = Arc::new(setup_test_client(server.port).await);
1556
1557        let mut handles = vec![];
1558        for i in 0..10 {
1559            let client = client.clone();
1560            handles.push(task::spawn(async move {
1561                client.send_text(format!("test{i}"), None).await.unwrap();
1562            }));
1563        }
1564
1565        for handle in handles {
1566            handle.await.unwrap();
1567        }
1568
1569        // Cleanup
1570        client.disconnect().await;
1571        assert!(client.is_disconnected());
1572    }
1573}
1574
1575#[cfg(test)]
1576#[cfg(not(feature = "turmoil"))]
1577mod rust_tests {
1578    use futures_util::StreamExt;
1579    use rstest::rstest;
1580    use tokio::{
1581        net::TcpListener,
1582        task,
1583        time::{Duration, sleep},
1584    };
1585    use tokio_tungstenite::accept_async;
1586
1587    use super::*;
1588    use crate::websocket::types::channel_message_handler;
1589
1590    #[rstest]
1591    #[tokio::test]
1592    async fn test_reconnect_then_disconnect() {
1593        // Bind an ephemeral port
1594        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1595        let port = listener.local_addr().unwrap().port();
1596
1597        // Server task: accept one ws connection then close it
1598        let server = task::spawn(async move {
1599            let (stream, _) = listener.accept().await.unwrap();
1600            let ws = accept_async(stream).await.unwrap();
1601            drop(ws);
1602            // Keep alive briefly
1603            sleep(Duration::from_secs(1)).await;
1604        });
1605
1606        // Build a channel-based message handler for incoming messages (unused here)
1607        let (handler, _rx) = channel_message_handler();
1608
1609        // Configure client with short reconnect backoff
1610        let config = WebSocketConfig {
1611            url: format!("ws://127.0.0.1:{port}"),
1612            headers: vec![],
1613            message_handler: Some(handler),
1614            heartbeat: None,
1615            heartbeat_msg: None,
1616            ping_handler: None,
1617            reconnect_timeout_ms: Some(1_000),
1618            reconnect_delay_initial_ms: Some(50),
1619            reconnect_delay_max_ms: Some(100),
1620            reconnect_backoff_factor: Some(1.0),
1621            reconnect_jitter_ms: Some(0),
1622            reconnect_max_attempts: None,
1623        };
1624
1625        // Connect the client
1626        let client = WebSocketClient::connect(config, None, vec![], None)
1627            .await
1628            .unwrap();
1629
1630        // Allow server to drop connection and client to detect
1631        sleep(Duration::from_millis(100)).await;
1632        // Now immediately disconnect the client
1633        client.disconnect().await;
1634        assert!(client.is_disconnected());
1635        server.abort();
1636    }
1637
1638    #[rstest]
1639    #[tokio::test]
1640    async fn test_reconnect_state_flips_when_reader_stops() {
1641        // Bind an ephemeral port and accept a single websocket connection which we drop.
1642        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1643        let port = listener.local_addr().unwrap().port();
1644
1645        let server = task::spawn(async move {
1646            if let Ok((stream, _)) = listener.accept().await
1647                && let Ok(ws) = accept_async(stream).await
1648            {
1649                drop(ws);
1650            }
1651            sleep(Duration::from_millis(50)).await;
1652        });
1653
1654        let (handler, _rx) = channel_message_handler();
1655
1656        let config = WebSocketConfig {
1657            url: format!("ws://127.0.0.1:{port}"),
1658            headers: vec![],
1659            message_handler: Some(handler),
1660            heartbeat: None,
1661            heartbeat_msg: None,
1662            ping_handler: None,
1663            reconnect_timeout_ms: Some(1_000),
1664            reconnect_delay_initial_ms: Some(50),
1665            reconnect_delay_max_ms: Some(100),
1666            reconnect_backoff_factor: Some(1.0),
1667            reconnect_jitter_ms: Some(0),
1668            reconnect_max_attempts: None,
1669        };
1670
1671        let client = WebSocketClient::connect(config, None, vec![], None)
1672            .await
1673            .unwrap();
1674
1675        tokio::time::timeout(Duration::from_secs(2), async {
1676            loop {
1677                if client.is_reconnecting() {
1678                    break;
1679                }
1680                tokio::time::sleep(Duration::from_millis(10)).await;
1681            }
1682        })
1683        .await
1684        .expect("client did not enter RECONNECT state");
1685
1686        client.disconnect().await;
1687        server.abort();
1688    }
1689
1690    #[rstest]
1691    #[tokio::test]
1692    async fn test_stream_mode_disables_auto_reconnect() {
1693        // Test that stream-based clients (created via connect_stream) set is_stream_mode flag
1694        // and that reconnect() transitions to CLOSED state for stream mode
1695        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1696        let port = listener.local_addr().unwrap().port();
1697
1698        let server = task::spawn(async move {
1699            if let Ok((stream, _)) = listener.accept().await
1700                && let Ok(_ws) = accept_async(stream).await
1701            {
1702                // Keep connection alive briefly
1703                sleep(Duration::from_millis(100)).await;
1704            }
1705        });
1706
1707        let config = WebSocketConfig {
1708            url: format!("ws://127.0.0.1:{port}"),
1709            headers: vec![],
1710            message_handler: None, // Stream mode - no handler
1711            heartbeat: None,
1712            heartbeat_msg: None,
1713            ping_handler: None,
1714            reconnect_timeout_ms: Some(1_000),
1715            reconnect_delay_initial_ms: Some(50),
1716            reconnect_delay_max_ms: Some(100),
1717            reconnect_backoff_factor: Some(1.0),
1718            reconnect_jitter_ms: Some(0),
1719            reconnect_max_attempts: None,
1720        };
1721
1722        // Create stream-based client
1723        let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
1724            .await
1725            .unwrap();
1726
1727        // Note: We can't easily test the reconnect behavior from the outside since
1728        // the inner client is private. The key fix is that WebSocketClientInner
1729        // now has is_stream_mode=true for connect_stream, and reconnect() will
1730        // transition to CLOSED state instead of creating a new reader that gets dropped.
1731        // This is tested implicitly by the fact that stream users won't get stuck
1732        // in an infinite reconnect loop.
1733
1734        server.abort();
1735    }
1736
1737    #[rstest]
1738    #[tokio::test]
1739    async fn test_message_handler_mode_allows_auto_reconnect() {
1740        // Test that regular clients (with message handler) can auto-reconnect
1741        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1742        let port = listener.local_addr().unwrap().port();
1743
1744        let server = task::spawn(async move {
1745            // Accept first connection and close it
1746            if let Ok((stream, _)) = listener.accept().await
1747                && let Ok(ws) = accept_async(stream).await
1748            {
1749                drop(ws);
1750            }
1751            sleep(Duration::from_millis(50)).await;
1752        });
1753
1754        let (handler, _rx) = channel_message_handler();
1755
1756        let config = WebSocketConfig {
1757            url: format!("ws://127.0.0.1:{port}"),
1758            headers: vec![],
1759            message_handler: Some(handler), // Has message handler
1760            heartbeat: None,
1761            heartbeat_msg: None,
1762            ping_handler: None,
1763            reconnect_timeout_ms: Some(1_000),
1764            reconnect_delay_initial_ms: Some(50),
1765            reconnect_delay_max_ms: Some(100),
1766            reconnect_backoff_factor: Some(1.0),
1767            reconnect_jitter_ms: Some(0),
1768            reconnect_max_attempts: None,
1769        };
1770
1771        let client = WebSocketClient::connect(config, None, vec![], None)
1772            .await
1773            .unwrap();
1774
1775        // Wait for the connection to be dropped and reconnection to be attempted
1776        tokio::time::timeout(Duration::from_secs(2), async {
1777            loop {
1778                if client.is_reconnecting() || client.is_closed() {
1779                    break;
1780                }
1781                tokio::time::sleep(Duration::from_millis(10)).await;
1782            }
1783        })
1784        .await
1785        .expect("client should attempt reconnection or close");
1786
1787        // Should either be reconnecting or closed (depending on timing)
1788        // The important thing is it's not staying active forever
1789        assert!(
1790            client.is_reconnecting() || client.is_closed(),
1791            "Client with message handler should attempt reconnection"
1792        );
1793
1794        client.disconnect().await;
1795        server.abort();
1796    }
1797
1798    #[rstest]
1799    #[tokio::test]
1800    async fn test_handler_mode_reconnect_with_new_connection() {
1801        // Test that handler mode successfully reconnects and messages continue flowing
1802        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1803        let port = listener.local_addr().unwrap().port();
1804
1805        let server = task::spawn(async move {
1806            // First connection - accept and immediately close
1807            if let Ok((stream, _)) = listener.accept().await
1808                && let Ok(ws) = accept_async(stream).await
1809            {
1810                drop(ws);
1811            }
1812
1813            // Small delay to let client detect disconnection
1814            sleep(Duration::from_millis(100)).await;
1815
1816            // Second connection - accept, send a message, then keep alive
1817            if let Ok((stream, _)) = listener.accept().await
1818                && let Ok(mut ws) = accept_async(stream).await
1819            {
1820                use futures_util::SinkExt;
1821                let _ = ws
1822                    .send(Message::Text("reconnected".to_string().into()))
1823                    .await;
1824                sleep(Duration::from_secs(1)).await;
1825            }
1826        });
1827
1828        let (handler, mut rx) = channel_message_handler();
1829
1830        let config = WebSocketConfig {
1831            url: format!("ws://127.0.0.1:{port}"),
1832            headers: vec![],
1833            message_handler: Some(handler),
1834            heartbeat: None,
1835            heartbeat_msg: None,
1836            ping_handler: None,
1837            reconnect_timeout_ms: Some(2_000),
1838            reconnect_delay_initial_ms: Some(50),
1839            reconnect_delay_max_ms: Some(200),
1840            reconnect_backoff_factor: Some(1.5),
1841            reconnect_jitter_ms: Some(10),
1842            reconnect_max_attempts: None,
1843        };
1844
1845        let client = WebSocketClient::connect(config, None, vec![], None)
1846            .await
1847            .unwrap();
1848
1849        // Wait for reconnection to happen and message to arrive
1850        let result = tokio::time::timeout(Duration::from_secs(5), async {
1851            loop {
1852                if let Ok(msg) = rx.try_recv()
1853                    && matches!(msg, Message::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
1854                {
1855                    return true;
1856                }
1857                tokio::time::sleep(Duration::from_millis(10)).await;
1858            }
1859        })
1860        .await;
1861
1862        assert!(
1863            result.is_ok(),
1864            "Should receive message after reconnection within timeout"
1865        );
1866
1867        client.disconnect().await;
1868        server.abort();
1869    }
1870
1871    #[rstest]
1872    #[tokio::test]
1873    async fn test_stream_mode_no_auto_reconnect() {
1874        // Test that stream mode does not automatically reconnect when connection is lost
1875        // The caller owns the reader and is responsible for detecting disconnection
1876        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1877        let port = listener.local_addr().unwrap().port();
1878
1879        let server = task::spawn(async move {
1880            // Accept connection and send one message, then close
1881            if let Ok((stream, _)) = listener.accept().await
1882                && let Ok(mut ws) = accept_async(stream).await
1883            {
1884                use futures_util::SinkExt;
1885                let _ = ws.send(Message::Text("hello".to_string().into())).await;
1886                sleep(Duration::from_millis(50)).await;
1887                // Connection closes when ws is dropped
1888            }
1889        });
1890
1891        let config = WebSocketConfig {
1892            url: format!("ws://127.0.0.1:{port}"),
1893            headers: vec![],
1894            message_handler: None, // Stream mode
1895            heartbeat: None,
1896            heartbeat_msg: None,
1897            ping_handler: None,
1898            reconnect_timeout_ms: Some(1_000),
1899            reconnect_delay_initial_ms: Some(50),
1900            reconnect_delay_max_ms: Some(100),
1901            reconnect_backoff_factor: Some(1.0),
1902            reconnect_jitter_ms: Some(0),
1903            reconnect_max_attempts: None,
1904        };
1905
1906        let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
1907            .await
1908            .unwrap();
1909
1910        // Initially active
1911        assert!(client.is_active(), "Client should start as active");
1912
1913        // Read the hello message
1914        let msg = reader.next().await;
1915        assert!(
1916            matches!(msg, Some(Ok(Message::Text(ref text))) if AsRef::<str>::as_ref(text) == "hello"),
1917            "Should receive initial message"
1918        );
1919
1920        // Read until connection closes (reader will return None or error)
1921        while let Some(msg) = reader.next().await {
1922            if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
1923                break;
1924            }
1925        }
1926
1927        // In stream mode, the controller cannot detect disconnection (reader is owned by caller)
1928        // The client remains ACTIVE - it's the caller's responsibility to call disconnect()
1929        sleep(Duration::from_millis(200)).await;
1930
1931        // Client should still be ACTIVE (not RECONNECTING or CLOSED)
1932        // This is correct behavior - stream mode doesn't auto-detect disconnection
1933        assert!(
1934            client.is_active() || client.is_closed(),
1935            "Stream mode client stays ACTIVE (caller owns reader) or caller disconnected"
1936        );
1937        assert!(
1938            !client.is_reconnecting(),
1939            "Stream mode client should never attempt reconnection"
1940        );
1941
1942        client.disconnect().await;
1943        server.abort();
1944    }
1945
1946    #[rstest]
1947    #[tokio::test]
1948    async fn test_send_timeout_uses_configured_reconnect_timeout() {
1949        // Test that send operations respect the configured reconnect_timeout.
1950        // When a client is stuck in RECONNECT longer than the timeout, sends should fail with Timeout.
1951        use nautilus_common::testing::wait_until_async;
1952
1953        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1954        let port = listener.local_addr().unwrap().port();
1955
1956        let server = task::spawn(async move {
1957            // Accept first connection and immediately close it
1958            if let Ok((stream, _)) = listener.accept().await
1959                && let Ok(ws) = accept_async(stream).await
1960            {
1961                drop(ws);
1962            }
1963            // Don't accept second connection - client will be stuck in RECONNECT
1964            sleep(Duration::from_secs(60)).await;
1965        });
1966
1967        let (handler, _rx) = channel_message_handler();
1968
1969        // Configure with SHORT 2s reconnect timeout
1970        let config = WebSocketConfig {
1971            url: format!("ws://127.0.0.1:{port}"),
1972            headers: vec![],
1973            message_handler: Some(handler),
1974            heartbeat: None,
1975            heartbeat_msg: None,
1976            ping_handler: None,
1977            reconnect_timeout_ms: Some(2_000), // 2s timeout
1978            reconnect_delay_initial_ms: Some(50),
1979            reconnect_delay_max_ms: Some(100),
1980            reconnect_backoff_factor: Some(1.0),
1981            reconnect_jitter_ms: Some(0),
1982            reconnect_max_attempts: None,
1983        };
1984
1985        let client = WebSocketClient::connect(config, None, vec![], None)
1986            .await
1987            .unwrap();
1988
1989        // Wait for client to enter RECONNECT state
1990        wait_until_async(
1991            || async { client.is_reconnecting() },
1992            Duration::from_secs(3),
1993        )
1994        .await;
1995
1996        // Attempt send while stuck in RECONNECT - should timeout after 2s (configured timeout)
1997        let start = std::time::Instant::now();
1998        let send_result = client.send_text("test".to_string(), None).await;
1999        let elapsed = start.elapsed();
2000
2001        assert!(
2002            send_result.is_err(),
2003            "Send should fail when client stuck in RECONNECT"
2004        );
2005        assert!(
2006            matches!(send_result, Err(crate::error::SendError::Timeout)),
2007            "Send should return Timeout error, was: {send_result:?}"
2008        );
2009        // Verify timeout respects configured value (2s), but don't check upper bound
2010        // as CI scheduler jitter can cause legitimate delays beyond the timeout
2011        assert!(
2012            elapsed >= Duration::from_millis(1800),
2013            "Send should timeout after at least 2s (configured timeout), took {elapsed:?}"
2014        );
2015
2016        client.disconnect().await;
2017        server.abort();
2018    }
2019
2020    #[rstest]
2021    #[tokio::test]
2022    async fn test_send_waits_during_reconnection() {
2023        // Test that send operations wait for reconnection to complete (up to timeout)
2024        use nautilus_common::testing::wait_until_async;
2025
2026        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2027        let port = listener.local_addr().unwrap().port();
2028
2029        let server = task::spawn(async move {
2030            // First connection - accept and immediately close
2031            if let Ok((stream, _)) = listener.accept().await
2032                && let Ok(ws) = accept_async(stream).await
2033            {
2034                drop(ws);
2035            }
2036
2037            // Wait a bit before accepting second connection
2038            sleep(Duration::from_millis(500)).await;
2039
2040            // Second connection - accept and keep alive
2041            if let Ok((stream, _)) = listener.accept().await
2042                && let Ok(mut ws) = accept_async(stream).await
2043            {
2044                // Echo messages
2045                while let Some(Ok(msg)) = ws.next().await {
2046                    if ws.send(msg).await.is_err() {
2047                        break;
2048                    }
2049                }
2050            }
2051        });
2052
2053        let (handler, _rx) = channel_message_handler();
2054
2055        let config = WebSocketConfig {
2056            url: format!("ws://127.0.0.1:{port}"),
2057            headers: vec![],
2058            message_handler: Some(handler),
2059            heartbeat: None,
2060            heartbeat_msg: None,
2061            ping_handler: None,
2062            reconnect_timeout_ms: Some(5_000), // 5s timeout - enough for reconnect
2063            reconnect_delay_initial_ms: Some(100),
2064            reconnect_delay_max_ms: Some(200),
2065            reconnect_backoff_factor: Some(1.0),
2066            reconnect_jitter_ms: Some(0),
2067            reconnect_max_attempts: None,
2068        };
2069
2070        let client = WebSocketClient::connect(config, None, vec![], None)
2071            .await
2072            .unwrap();
2073
2074        // Wait for reconnection to trigger
2075        wait_until_async(
2076            || async { client.is_reconnecting() },
2077            Duration::from_secs(2),
2078        )
2079        .await;
2080
2081        // Try to send while reconnecting - should wait and succeed after reconnect
2082        let send_result = tokio::time::timeout(
2083            Duration::from_secs(3),
2084            client.send_text("test_message".to_string(), None),
2085        )
2086        .await;
2087
2088        assert!(
2089            send_result.is_ok() && send_result.unwrap().is_ok(),
2090            "Send should succeed after waiting for reconnection"
2091        );
2092
2093        client.disconnect().await;
2094        server.abort();
2095    }
2096
2097    #[rstest]
2098    #[tokio::test]
2099    async fn test_rate_limiter_before_active_wait() {
2100        // Test that rate limiting happens BEFORE active state check.
2101        // This prevents race conditions where connection state changes during rate limit wait.
2102        // We verify this by: (1) exhausting rate limit, (2) ensuring client is RECONNECTING,
2103        // (3) sending again and confirming it waits for rate limit THEN reconnection.
2104        use std::{num::NonZeroU32, sync::Arc};
2105
2106        use nautilus_common::testing::wait_until_async;
2107
2108        use crate::ratelimiter::quota::Quota;
2109
2110        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2111        let port = listener.local_addr().unwrap().port();
2112
2113        let server = task::spawn(async move {
2114            // First connection - accept and close after receiving one message
2115            if let Ok((stream, _)) = listener.accept().await
2116                && let Ok(mut ws) = accept_async(stream).await
2117            {
2118                // Receive first message then close
2119                if let Some(Ok(_)) = ws.next().await {
2120                    drop(ws);
2121                }
2122            }
2123
2124            // Wait before accepting reconnection
2125            sleep(Duration::from_millis(500)).await;
2126
2127            // Second connection - accept and keep alive
2128            if let Ok((stream, _)) = listener.accept().await
2129                && let Ok(mut ws) = accept_async(stream).await
2130            {
2131                while let Some(Ok(msg)) = ws.next().await {
2132                    if ws.send(msg).await.is_err() {
2133                        break;
2134                    }
2135                }
2136            }
2137        });
2138
2139        let (handler, _rx) = channel_message_handler();
2140
2141        let config = WebSocketConfig {
2142            url: format!("ws://127.0.0.1:{port}"),
2143            headers: vec![],
2144            message_handler: Some(handler),
2145            heartbeat: None,
2146            heartbeat_msg: None,
2147            ping_handler: None,
2148            reconnect_timeout_ms: Some(5_000),
2149            reconnect_delay_initial_ms: Some(50),
2150            reconnect_delay_max_ms: Some(100),
2151            reconnect_backoff_factor: Some(1.0),
2152            reconnect_jitter_ms: Some(0),
2153            reconnect_max_attempts: None,
2154        };
2155
2156        // Very restrictive rate limit: 1 request per second, burst of 1
2157        let quota =
2158            Quota::per_second(NonZeroU32::new(1).unwrap()).allow_burst(NonZeroU32::new(1).unwrap());
2159
2160        let client = Arc::new(
2161            WebSocketClient::connect(config, None, vec![("test_key".to_string(), quota)], None)
2162                .await
2163                .unwrap(),
2164        );
2165
2166        // First send exhausts burst capacity and triggers connection close
2167        client
2168            .send_text("msg1".to_string(), Some(vec!["test_key".to_string()]))
2169            .await
2170            .unwrap();
2171
2172        // Wait for client to enter RECONNECT state
2173        wait_until_async(
2174            || async { client.is_reconnecting() },
2175            Duration::from_secs(2),
2176        )
2177        .await;
2178
2179        // Second send: will hit rate limit (~1s) THEN wait for reconnection (~0.5s)
2180        let start = std::time::Instant::now();
2181        let send_result = client
2182            .send_text("msg2".to_string(), Some(vec!["test_key".to_string()]))
2183            .await;
2184        let elapsed = start.elapsed();
2185
2186        // Should succeed after both rate limit AND reconnection
2187        assert!(
2188            send_result.is_ok(),
2189            "Send should succeed after rate limit + reconnection, was: {send_result:?}"
2190        );
2191        // Total wait should be at least rate limit time (~1s)
2192        // The reconnection completes while rate limiting or after
2193        // Use 850ms threshold to account for timing jitter in CI
2194        assert!(
2195            elapsed >= Duration::from_millis(850),
2196            "Should wait for rate limit (~1s), waited {elapsed:?}"
2197        );
2198
2199        client.disconnect().await;
2200        server.abort();
2201    }
2202
2203    #[rstest]
2204    #[tokio::test]
2205    async fn test_disconnect_during_reconnect_exits_cleanly() {
2206        // Test CAS race condition: disconnect called during reconnection
2207        // Should exit cleanly without spawning new tasks
2208        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2209        let port = listener.local_addr().unwrap().port();
2210
2211        let server = task::spawn(async move {
2212            // Accept first connection and immediately close
2213            if let Ok((stream, _)) = listener.accept().await
2214                && let Ok(ws) = accept_async(stream).await
2215            {
2216                drop(ws);
2217            }
2218            // Don't accept second connection - let reconnect hang
2219            sleep(Duration::from_secs(60)).await;
2220        });
2221
2222        let (handler, _rx) = channel_message_handler();
2223
2224        let config = WebSocketConfig {
2225            url: format!("ws://127.0.0.1:{port}"),
2226            headers: vec![],
2227            message_handler: Some(handler),
2228            heartbeat: None,
2229            heartbeat_msg: None,
2230            ping_handler: None,
2231            reconnect_timeout_ms: Some(2_000), // 2s timeout - shorter than disconnect timeout
2232            reconnect_delay_initial_ms: Some(100),
2233            reconnect_delay_max_ms: Some(200),
2234            reconnect_backoff_factor: Some(1.0),
2235            reconnect_jitter_ms: Some(0),
2236            reconnect_max_attempts: None,
2237        };
2238
2239        let client = WebSocketClient::connect(config, None, vec![], None)
2240            .await
2241            .unwrap();
2242
2243        // Wait for reconnection to start
2244        tokio::time::timeout(Duration::from_secs(2), async {
2245            while !client.is_reconnecting() {
2246                sleep(Duration::from_millis(10)).await;
2247            }
2248        })
2249        .await
2250        .expect("Client should enter RECONNECT state");
2251
2252        // Disconnect while reconnecting
2253        client.disconnect().await;
2254
2255        // Should be cleanly closed
2256        assert!(
2257            client.is_disconnected(),
2258            "Client should be cleanly disconnected"
2259        );
2260
2261        server.abort();
2262    }
2263
2264    #[rstest]
2265    #[tokio::test]
2266    async fn test_send_fails_fast_when_closed_before_rate_limit() {
2267        // Test that send operations check connection state BEFORE rate limiting,
2268        // preventing unnecessary delays when the connection is already closed.
2269        use std::{num::NonZeroU32, sync::Arc};
2270
2271        use nautilus_common::testing::wait_until_async;
2272
2273        use crate::ratelimiter::quota::Quota;
2274
2275        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2276        let port = listener.local_addr().unwrap().port();
2277
2278        let server = task::spawn(async move {
2279            // Accept connection and immediately close
2280            if let Ok((stream, _)) = listener.accept().await
2281                && let Ok(ws) = accept_async(stream).await
2282            {
2283                drop(ws);
2284            }
2285            sleep(Duration::from_secs(60)).await;
2286        });
2287
2288        let (handler, _rx) = channel_message_handler();
2289
2290        let config = WebSocketConfig {
2291            url: format!("ws://127.0.0.1:{port}"),
2292            headers: vec![],
2293            message_handler: Some(handler),
2294            heartbeat: None,
2295            heartbeat_msg: None,
2296            ping_handler: None,
2297            reconnect_timeout_ms: Some(5_000),
2298            reconnect_delay_initial_ms: Some(50),
2299            reconnect_delay_max_ms: Some(100),
2300            reconnect_backoff_factor: Some(1.0),
2301            reconnect_jitter_ms: Some(0),
2302            reconnect_max_attempts: None,
2303        };
2304
2305        // Very restrictive rate limit: 1 request per 10 seconds
2306        // This ensures that if we wait for rate limit, the test will timeout
2307        let quota = Quota::with_period(Duration::from_secs(10))
2308            .unwrap()
2309            .allow_burst(NonZeroU32::new(1).unwrap());
2310
2311        let client = Arc::new(
2312            WebSocketClient::connect(config, None, vec![("test_key".to_string(), quota)], None)
2313                .await
2314                .unwrap(),
2315        );
2316
2317        // Wait for disconnection
2318        wait_until_async(
2319            || async { client.is_reconnecting() || client.is_closed() },
2320            Duration::from_secs(2),
2321        )
2322        .await;
2323
2324        // Explicitly disconnect to move away from ACTIVE state
2325        client.disconnect().await;
2326        assert!(
2327            !client.is_active(),
2328            "Client should not be active after disconnect"
2329        );
2330
2331        // Attempt send - should fail IMMEDIATELY without waiting for rate limit
2332        let start = std::time::Instant::now();
2333        let result = client
2334            .send_text("test".to_string(), Some(vec!["test_key".to_string()]))
2335            .await;
2336        let elapsed = start.elapsed();
2337
2338        // Should fail with Closed error
2339        assert!(result.is_err(), "Send should fail when client is closed");
2340        assert!(
2341            matches!(result, Err(crate::error::SendError::Closed)),
2342            "Send should return Closed error, was: {result:?}"
2343        );
2344
2345        // Should fail FAST (< 100ms) without waiting for rate limit (10s)
2346        assert!(
2347            elapsed < Duration::from_millis(100),
2348            "Send should fail fast without rate limiting, took {elapsed:?}"
2349        );
2350
2351        server.abort();
2352    }
2353
2354    #[rstest]
2355    #[tokio::test]
2356    async fn test_connect_rejects_config_without_message_handler() {
2357        // Test that connect() properly rejects configs without a message handler
2358        // to prevent zombie connections that appear alive but never detect disconnections
2359
2360        let config = WebSocketConfig {
2361            url: "ws://127.0.0.1:9999".to_string(),
2362            headers: vec![],
2363            message_handler: None, // No handler provided
2364            heartbeat: None,
2365            heartbeat_msg: None,
2366            ping_handler: None,
2367            reconnect_timeout_ms: Some(1_000),
2368            reconnect_delay_initial_ms: Some(100),
2369            reconnect_delay_max_ms: Some(500),
2370            reconnect_backoff_factor: Some(1.5),
2371            reconnect_jitter_ms: Some(0),
2372            reconnect_max_attempts: None,
2373        };
2374
2375        let result = WebSocketClient::connect(config, None, vec![], None).await;
2376
2377        assert!(
2378            result.is_err(),
2379            "connect() should reject configs without message_handler"
2380        );
2381
2382        let err = result.unwrap_err();
2383        let err_msg = err.to_string();
2384        assert!(
2385            err_msg.contains("Handler mode requires config.message_handler"),
2386            "Error should mention missing message_handler, was: {err_msg}"
2387        );
2388    }
2389
2390    #[rstest]
2391    #[tokio::test]
2392    async fn test_client_without_handler_sets_stream_mode() {
2393        // Test that if a client is somehow created without a handler,
2394        // it properly sets is_stream_mode=true to prevent zombie connections
2395
2396        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2397        let port = listener.local_addr().unwrap().port();
2398
2399        let server = task::spawn(async move {
2400            // Accept and immediately close to simulate server disconnect
2401            if let Ok((stream, _)) = listener.accept().await
2402                && let Ok(ws) = accept_async(stream).await
2403            {
2404                drop(ws); // Drop connection immediately
2405            }
2406        });
2407
2408        let config = WebSocketConfig {
2409            url: format!("ws://127.0.0.1:{port}"),
2410            headers: vec![],
2411            message_handler: None, // No handler
2412            heartbeat: None,
2413            heartbeat_msg: None,
2414            ping_handler: None,
2415            reconnect_timeout_ms: Some(1_000),
2416            reconnect_delay_initial_ms: Some(100),
2417            reconnect_delay_max_ms: Some(500),
2418            reconnect_backoff_factor: Some(1.5),
2419            reconnect_jitter_ms: Some(0),
2420            reconnect_max_attempts: None,
2421        };
2422
2423        // Create client directly via connect_url to bypass validation
2424        let inner = WebSocketClientInner::connect_url(config).await.unwrap();
2425
2426        // Verify is_stream_mode is true when no handler
2427        assert!(
2428            inner.is_stream_mode,
2429            "Client without handler should have is_stream_mode=true"
2430        );
2431
2432        // Verify that when stream mode is enabled, reconnection is disabled
2433        // (documented behavior - stream mode clients close instead of reconnecting)
2434
2435        server.abort();
2436    }
2437}