nautilus_network/websocket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! WebSocket client implementation with automatic reconnection.
17//!
18//! This module contains the core WebSocket client implementation including:
19//! - Connection management with automatic reconnection.
20//! - Split read/write architecture with separate tasks.
21//! - Unbounded channels on latency-sensitive paths.
22//! - Heartbeat support.
23//! - Rate limiting integration.
24
25use std::{
26    collections::VecDeque,
27    fmt::Debug,
28    sync::{
29        Arc,
30        atomic::{AtomicU8, Ordering},
31    },
32    time::Duration,
33};
34
35use futures_util::{SinkExt, StreamExt};
36use http::HeaderName;
37use nautilus_core::CleanDrop;
38use nautilus_cryptography::providers::install_cryptographic_provider;
39#[cfg(feature = "turmoil")]
40use tokio_tungstenite::MaybeTlsStream;
41#[cfg(feature = "turmoil")]
42use tokio_tungstenite::client_async;
43#[cfg(not(feature = "turmoil"))]
44use tokio_tungstenite::connect_async_with_config;
45use tokio_tungstenite::tungstenite::{
46    Error, Message, client::IntoClientRequest, http::HeaderValue,
47};
48
49use super::{
50    config::WebSocketConfig,
51    consts::{
52        CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
53        GRACEFUL_SHUTDOWN_TIMEOUT_SECS, SEND_OPERATION_CHECK_INTERVAL_MS,
54    },
55    types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
56};
57#[cfg(feature = "turmoil")]
58use crate::net::TcpConnector;
59use crate::{
60    RECONNECTED,
61    backoff::ExponentialBackoff,
62    error::SendError,
63    logging::{log_task_aborted, log_task_started, log_task_stopped},
64    mode::ConnectionMode,
65    ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
66};
67
68/// `WebSocketClient` connects to a websocket server to read and send messages.
69///
70/// The client is opinionated about how messages are read and written. It
71/// assumes that data can only have one reader but multiple writers.
72///
73/// The client splits the connection into read and write halves. It moves
74/// the read half into a tokio task which keeps receiving messages from the
75/// server and calls a handler - a Python function that takes the data
76/// as its parameter. It stores the write half in the struct wrapped
77/// with an Arc Mutex. This way the client struct can be used to write
78/// data to the server from multiple scopes/tasks.
79///
80/// The client also maintains a heartbeat if given a duration in seconds.
81/// It's preferable to set the duration slightly lower - heartbeat more
82/// frequently - than the required amount.
83pub struct WebSocketClientInner {
84    config: WebSocketConfig,
85    /// The function to handle incoming messages (stored separately from config).
86    message_handler: Option<MessageHandler>,
87    /// The handler for incoming pings (stored separately from config).
88    ping_handler: Option<PingHandler>,
89    read_task: Option<tokio::task::JoinHandle<()>>,
90    write_task: tokio::task::JoinHandle<()>,
91    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
92    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
93    connection_mode: Arc<AtomicU8>,
94    reconnect_timeout: Duration,
95    backoff: ExponentialBackoff,
96    /// True if this is a stream-based client (created via `connect_stream`).
97    /// Stream-based clients disable auto-reconnect because the reader is
98    /// owned by the caller and cannot be replaced during reconnection.
99    is_stream_mode: bool,
100    /// Maximum number of reconnection attempts before giving up (None = unlimited).
101    reconnect_max_attempts: Option<u32>,
102    /// Current count of consecutive reconnection attempts.
103    reconnection_attempt_count: u32,
104}
105
106impl WebSocketClientInner {
107    /// Create an inner websocket client with an existing writer.
108    ///
109    /// This is used for stream mode where the reader is owned by the caller.
110    ///
111    /// # Errors
112    ///
113    /// Returns an error if the exponential backoff configuration is invalid.
114    pub async fn new_with_writer(
115        config: WebSocketConfig,
116        writer: MessageWriter,
117    ) -> Result<Self, Error> {
118        install_cryptographic_provider();
119
120        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
121
122        // Note: We don't spawn a read task here since the reader is handled externally
123        let read_task = None;
124
125        let backoff = ExponentialBackoff::new(
126            Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
127            Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
128            config.reconnect_backoff_factor.unwrap_or(1.5),
129            config.reconnect_jitter_ms.unwrap_or(100),
130            true, // immediate-first
131        )
132        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
133
134        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
135        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
136
137        let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
138            Some(Self::spawn_heartbeat_task(
139                connection_mode.clone(),
140                heartbeat_interval,
141                config.heartbeat_msg.clone(),
142                writer_tx.clone(),
143            ))
144        } else {
145            None
146        };
147
148        let reconnect_max_attempts = config.reconnect_max_attempts;
149        let reconnect_timeout = Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10000));
150
151        Ok(Self {
152            config,
153            message_handler: None, // Stream mode has no handler
154            ping_handler: None,
155            writer_tx,
156            connection_mode,
157            reconnect_timeout,
158            heartbeat_task,
159            read_task,
160            write_task,
161            backoff,
162            is_stream_mode: true,
163            reconnect_max_attempts,
164            reconnection_attempt_count: 0,
165        })
166    }
167
168    /// Create an inner websocket client.
169    ///
170    /// # Errors
171    ///
172    /// Returns an error if:
173    /// - The connection to the server fails.
174    /// - The exponential backoff configuration is invalid.
175    pub async fn connect_url(
176        config: WebSocketConfig,
177        message_handler: Option<MessageHandler>,
178        ping_handler: Option<PingHandler>,
179    ) -> Result<Self, Error> {
180        install_cryptographic_provider();
181
182        // Capture whether we're in stream mode before moving config
183        let is_stream_mode = message_handler.is_none();
184        let reconnect_max_attempts = config.reconnect_max_attempts;
185
186        let (writer, reader) =
187            Self::connect_with_server(&config.url, config.headers.clone()).await?;
188
189        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
190
191        let read_task = if message_handler.is_some() {
192            Some(Self::spawn_message_handler_task(
193                connection_mode.clone(),
194                reader,
195                message_handler.as_ref(),
196                ping_handler.as_ref(),
197            ))
198        } else {
199            None
200        };
201
202        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
203        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
204
205        // Optionally spawn a heartbeat task to periodically ping server
206        let heartbeat_task = config.heartbeat.map(|heartbeat_secs| {
207            Self::spawn_heartbeat_task(
208                connection_mode.clone(),
209                heartbeat_secs,
210                config.heartbeat_msg.clone(),
211                writer_tx.clone(),
212            )
213        });
214
215        let reconnect_timeout =
216            Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10_000));
217        let backoff = ExponentialBackoff::new(
218            Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
219            Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
220            config.reconnect_backoff_factor.unwrap_or(1.5),
221            config.reconnect_jitter_ms.unwrap_or(100),
222            true, // immediate-first
223        )
224        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
225
226        Ok(Self {
227            config,
228            message_handler,
229            ping_handler,
230            read_task,
231            write_task,
232            writer_tx,
233            heartbeat_task,
234            connection_mode,
235            reconnect_timeout,
236            backoff,
237            // Set stream mode when no message handler (reader not managed by client)
238            is_stream_mode,
239            reconnect_max_attempts,
240            reconnection_attempt_count: 0,
241        })
242    }
243
244    /// Connects with the server creating a tokio-tungstenite websocket stream.
245    /// Production version that uses `connect_async_with_config` convenience helper.
246    ///
247    /// # Errors
248    ///
249    /// Returns an error if:
250    /// - The URL cannot be parsed into a valid client request.
251    /// - Header values are invalid.
252    /// - The WebSocket connection fails.
253    #[inline]
254    #[cfg(not(feature = "turmoil"))]
255    pub async fn connect_with_server(
256        url: &str,
257        headers: Vec<(String, String)>,
258    ) -> Result<(MessageWriter, MessageReader), Error> {
259        let mut request = url.into_client_request()?;
260        let req_headers = request.headers_mut();
261
262        let mut header_names: Vec<HeaderName> = Vec::new();
263        for (key, val) in headers {
264            let header_value = HeaderValue::from_str(&val)?;
265            let header_name: HeaderName = key.parse()?;
266            header_names.push(header_name.clone());
267            req_headers.insert(header_name, header_value);
268        }
269
270        connect_async_with_config(request, None, true)
271            .await
272            .map(|resp| resp.0.split())
273    }
274
275    /// Connects with the server creating a tokio-tungstenite websocket stream.
276    /// Turmoil version that uses the lower-level `client_async` API with injected stream.
277    ///
278    /// # Errors
279    ///
280    /// Returns an error if:
281    /// - The URL cannot be parsed into a valid client request.
282    /// - The URL is missing a hostname.
283    /// - Header values are invalid.
284    /// - The TCP connection fails.
285    /// - TLS setup fails (for wss:// URLs).
286    /// - The WebSocket handshake fails.
287    #[inline]
288    #[cfg(feature = "turmoil")]
289    pub async fn connect_with_server(
290        url: &str,
291        headers: Vec<(String, String)>,
292    ) -> Result<(MessageWriter, MessageReader), Error> {
293        use rustls::ClientConfig;
294        use tokio_rustls::TlsConnector;
295
296        let mut request = url.into_client_request()?;
297        let req_headers = request.headers_mut();
298
299        let mut header_names: Vec<HeaderName> = Vec::new();
300        for (key, val) in headers {
301            let header_value = HeaderValue::from_str(&val)?;
302            let header_name: HeaderName = key.parse()?;
303            header_names.push(header_name.clone());
304            req_headers.insert(header_name, header_value);
305        }
306
307        let uri = request.uri();
308        let scheme = uri.scheme_str().unwrap_or("ws");
309        let host = uri.host().ok_or_else(|| {
310            Error::Url(tokio_tungstenite::tungstenite::error::UrlError::NoHostName)
311        })?;
312
313        // Determine port: use explicit port if specified, otherwise default based on scheme
314        let port = uri
315            .port_u16()
316            .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
317
318        let addr = format!("{host}:{port}");
319
320        // Use the connector to get a turmoil-compatible stream
321        let connector = crate::net::RealTcpConnector;
322        let tcp_stream = connector.connect(&addr).await?;
323        if let Err(e) = tcp_stream.set_nodelay(true) {
324            log::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
325        }
326
327        // Wrap stream appropriately based on scheme
328        let maybe_tls_stream = if scheme == "wss" {
329            // Build TLS config with webpki roots
330            let mut root_store = rustls::RootCertStore::empty();
331            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
332
333            let config = ClientConfig::builder()
334                .with_root_certificates(root_store)
335                .with_no_client_auth();
336
337            let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
338            let domain =
339                rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|e| {
340                    Error::Io(std::io::Error::new(
341                        std::io::ErrorKind::InvalidInput,
342                        format!("Invalid DNS name: {e}"),
343                    ))
344                })?;
345
346            let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
347            MaybeTlsStream::Rustls(tls_stream)
348        } else {
349            MaybeTlsStream::Plain(tcp_stream)
350        };
351
352        // Use client_async with the stream (plain or TLS)
353        client_async(request, maybe_tls_stream)
354            .await
355            .map(|resp| resp.0.split())
356    }
357
358    /// Reconnect with server.
359    ///
360    /// Make a new connection with server. Use the new read and write halves
361    /// to update self writer and read and heartbeat tasks.
362    ///
363    /// For stream-based clients (created via `connect_stream`), reconnection is disabled
364    /// because the reader is owned by the caller and cannot be replaced. Stream users
365    /// should handle disconnections by creating a new connection.
366    ///
367    /// # Errors
368    ///
369    /// Returns an error if:
370    /// - The reconnection attempt times out.
371    /// - The connection to the server fails.
372    pub async fn reconnect(&mut self) -> Result<(), Error> {
373        log::debug!("Reconnecting");
374
375        if self.is_stream_mode {
376            log::warn!(
377                "Auto-reconnect disabled for stream-based WebSocket client; \
378                stream users must manually reconnect by creating a new connection"
379            );
380            // Transition to CLOSED state to stop reconnection attempts
381            self.connection_mode
382                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
383            return Ok(());
384        }
385
386        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
387            log::debug!("Reconnect aborted due to disconnect state");
388            return Ok(());
389        }
390
391        tokio::time::timeout(self.reconnect_timeout, async {
392            // Attempt to connect; abort early if a disconnect was requested
393            let (new_writer, reader) =
394                Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
395
396            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
397                log::debug!("Reconnect aborted mid-flight (after connect)");
398                return Ok(());
399            }
400
401            // Use a oneshot channel to synchronize with the writer task.
402            // We must verify that the buffer was successfully drained before transitioning to ACTIVE
403            // to prevent silent message loss if the new connection drops immediately.
404            let (tx, rx) = tokio::sync::oneshot::channel();
405            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
406                log::error!("{e}");
407                return Err(Error::Io(std::io::Error::new(
408                    std::io::ErrorKind::BrokenPipe,
409                    format!("Failed to send update command: {e}"),
410                )));
411            }
412
413            // Wait for writer to confirm it has drained the buffer
414            match rx.await {
415                Ok(true) => log::debug!("Writer confirmed buffer drain success"),
416                Ok(false) => {
417                    log::warn!("Writer failed to drain buffer, aborting reconnect");
418                    // Return error to trigger retry logic in controller
419                    return Err(Error::Io(std::io::Error::other(
420                        "Failed to drain reconnection buffer",
421                    )));
422                }
423                Err(e) => {
424                    log::error!("Writer dropped update channel: {e}");
425                    return Err(Error::Io(std::io::Error::new(
426                        std::io::ErrorKind::BrokenPipe,
427                        "Writer task dropped response channel",
428                    )));
429                }
430            }
431
432            // Delay before closing connection
433            tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
434
435            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
436                log::debug!("Reconnect aborted mid-flight (after delay)");
437                return Ok(());
438            }
439
440            if let Some(ref read_task) = self.read_task.take()
441                && !read_task.is_finished()
442            {
443                read_task.abort();
444                log_task_aborted("read");
445            }
446
447            // Atomically transition from Reconnect to Active
448            // This prevents race condition where disconnect could be requested between check and store
449            if self
450                .connection_mode
451                .compare_exchange(
452                    ConnectionMode::Reconnect.as_u8(),
453                    ConnectionMode::Active.as_u8(),
454                    Ordering::SeqCst,
455                    Ordering::SeqCst,
456                )
457                .is_err()
458            {
459                log::debug!("Reconnect aborted (state changed during reconnect)");
460                return Ok(());
461            }
462
463            self.read_task = if self.message_handler.is_some() {
464                Some(Self::spawn_message_handler_task(
465                    self.connection_mode.clone(),
466                    reader,
467                    self.message_handler.as_ref(),
468                    self.ping_handler.as_ref(),
469                ))
470            } else {
471                None
472            };
473
474            log::debug!("Reconnect succeeded");
475            Ok(())
476        })
477        .await
478        .map_err(|_| {
479            Error::Io(std::io::Error::new(
480                std::io::ErrorKind::TimedOut,
481                format!(
482                    "reconnection timed out after {}s",
483                    self.reconnect_timeout.as_secs_f64()
484                ),
485            ))
486        })?
487    }
488
489    /// Check if the client is still connected.
490    ///
491    /// The client is connected if the read task has not finished. It is expected
492    /// that in case of any failure client or server side. The read task will be
493    /// shutdown or will receive a `Close` frame which will finish it. There
494    /// might be some delay between the connection being closed and the client
495    /// detecting.
496    #[inline]
497    #[must_use]
498    pub fn is_alive(&self) -> bool {
499        match &self.read_task {
500            Some(read_task) => !read_task.is_finished(),
501            None => true, // Stream is being used directly
502        }
503    }
504
505    fn spawn_message_handler_task(
506        connection_state: Arc<AtomicU8>,
507        mut reader: MessageReader,
508        message_handler: Option<&MessageHandler>,
509        ping_handler: Option<&PingHandler>,
510    ) -> tokio::task::JoinHandle<()> {
511        log::debug!("Started message handler task 'read'");
512
513        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
514
515        // Clone Arc handlers for the async task
516        let message_handler = message_handler.cloned();
517        let ping_handler = ping_handler.cloned();
518
519        tokio::task::spawn(async move {
520            loop {
521                if !ConnectionMode::from_atomic(&connection_state).is_active() {
522                    break;
523                }
524
525                match tokio::time::timeout(check_interval, reader.next()).await {
526                    Ok(Some(Ok(Message::Binary(data)))) => {
527                        log::trace!("Received message <binary> {} bytes", data.len());
528                        if let Some(ref handler) = message_handler {
529                            handler(Message::Binary(data));
530                        }
531                    }
532                    Ok(Some(Ok(Message::Text(data)))) => {
533                        log::trace!("Received message: {data}");
534                        if let Some(ref handler) = message_handler {
535                            handler(Message::Text(data));
536                        }
537                    }
538                    Ok(Some(Ok(Message::Ping(ping_data)))) => {
539                        log::trace!("Received ping: {ping_data:?}");
540                        if let Some(ref handler) = ping_handler {
541                            handler(ping_data.to_vec());
542                        }
543                    }
544                    Ok(Some(Ok(Message::Pong(_)))) => {
545                        log::trace!("Received pong");
546                    }
547                    Ok(Some(Ok(Message::Close(_)))) => {
548                        log::debug!("Received close message - terminating");
549                        break;
550                    }
551                    Ok(Some(Ok(_))) => (),
552                    Ok(Some(Err(e))) => {
553                        log::error!("Received error message - terminating: {e}");
554                        break;
555                    }
556                    Ok(None) => {
557                        log::debug!("No message received - terminating");
558                        break;
559                    }
560                    Err(_) => {
561                        // Timeout - continue loop and check connection mode
562                        continue;
563                    }
564                }
565            }
566        })
567    }
568
569    /// Attempts to send all buffered messages after reconnection.
570    ///
571    /// Returns `true` if a send error occurred (caller should trigger reconnection).
572    /// Messages remain in buffer if send fails, preserving them for the next reconnection attempt.
573    async fn drain_reconnect_buffer(
574        buffer: &mut VecDeque<Message>,
575        writer: &mut MessageWriter,
576    ) -> bool {
577        if buffer.is_empty() {
578            return false;
579        }
580
581        let initial_buffer_len = buffer.len();
582        log::info!("Sending {initial_buffer_len} buffered messages after reconnection");
583
584        let mut send_error_occurred = false;
585
586        while let Some(buffered_msg) = buffer.front() {
587            // Clone message before attempting send (to keep in buffer if send fails)
588            let msg_to_send = buffered_msg.clone();
589
590            if let Err(e) = writer.send(msg_to_send).await {
591                log::error!(
592                    "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
593                    buffer.len()
594                );
595                send_error_occurred = true;
596                break; // Stop processing buffer, remaining messages preserved for next reconnection
597            }
598
599            // Only remove from buffer after successful send
600            buffer.pop_front();
601        }
602
603        if buffer.is_empty() {
604            log::info!("Successfully sent all {initial_buffer_len} buffered messages");
605        }
606
607        send_error_occurred
608    }
609
610    fn spawn_write_task(
611        connection_state: Arc<AtomicU8>,
612        writer: MessageWriter,
613        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
614    ) -> tokio::task::JoinHandle<()> {
615        log_task_started("write");
616
617        // Interval between checking the connection mode
618        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
619
620        tokio::task::spawn(async move {
621            let mut active_writer = writer;
622            // Buffer for messages received during reconnection
623            // VecDeque for efficient pop_front() operations
624            let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
625
626            loop {
627                match ConnectionMode::from_atomic(&connection_state) {
628                    ConnectionMode::Disconnect => {
629                        // Log any buffered messages that will be lost
630                        if !reconnect_buffer.is_empty() {
631                            log::warn!(
632                                "Discarding {} buffered messages due to disconnect",
633                                reconnect_buffer.len()
634                            );
635                            reconnect_buffer.clear();
636                        }
637
638                        // Attempt to close the writer gracefully before exiting,
639                        // we ignore any error as the writer may already be closed.
640                        _ = tokio::time::timeout(
641                            Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
642                            active_writer.close(),
643                        )
644                        .await;
645                        break;
646                    }
647                    ConnectionMode::Closed => {
648                        // Log any buffered messages that will be lost
649                        if !reconnect_buffer.is_empty() {
650                            log::warn!(
651                                "Discarding {} buffered messages due to closed connection",
652                                reconnect_buffer.len()
653                            );
654                            reconnect_buffer.clear();
655                        }
656                        break;
657                    }
658                    _ => {}
659                }
660
661                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
662                    Ok(Some(msg)) => {
663                        // Re-check connection mode after receiving a message
664                        let mode = ConnectionMode::from_atomic(&connection_state);
665                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
666                            break;
667                        }
668
669                        match msg {
670                            WriterCommand::Update(new_writer, tx) => {
671                                log::debug!("Received new writer");
672
673                                // Delay before closing connection
674                                tokio::time::sleep(Duration::from_millis(100)).await;
675
676                                // Attempt to close the writer gracefully on update,
677                                // we ignore any error as the writer may already be closed.
678                                _ = tokio::time::timeout(
679                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
680                                    active_writer.close(),
681                                )
682                                .await;
683
684                                active_writer = new_writer;
685                                log::debug!("Updated writer");
686
687                                let send_error = Self::drain_reconnect_buffer(
688                                    &mut reconnect_buffer,
689                                    &mut active_writer,
690                                )
691                                .await;
692
693                                if let Err(e) = tx.send(!send_error) {
694                                    log::error!(
695                                        "Failed to report drain status to controller: {e:?}"
696                                    );
697                                }
698                            }
699                            WriterCommand::Send(msg) if mode.is_reconnect() => {
700                                // Buffer messages during reconnection instead of dropping them
701                                log::debug!(
702                                    "Buffering message during reconnection (buffer size: {})",
703                                    reconnect_buffer.len() + 1
704                                );
705                                reconnect_buffer.push_back(msg);
706                            }
707                            WriterCommand::Send(msg) => {
708                                if let Err(e) = active_writer.send(msg.clone()).await {
709                                    log::error!("Failed to send message: {e}");
710                                    log::warn!("Writer triggering reconnect");
711                                    reconnect_buffer.push_back(msg);
712                                    connection_state
713                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
714                                }
715                            }
716                        }
717                    }
718                    Ok(None) => {
719                        // Channel closed - writer task should terminate
720                        log::debug!("Writer channel closed, terminating writer task");
721                        break;
722                    }
723                    Err(_) => {
724                        // Timeout - just continue the loop
725                        continue;
726                    }
727                }
728            }
729
730            // Attempt to close the writer gracefully before exiting,
731            // we ignore any error as the writer may already be closed.
732            _ = tokio::time::timeout(
733                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
734                active_writer.close(),
735            )
736            .await;
737
738            log_task_stopped("write");
739        })
740    }
741
742    fn spawn_heartbeat_task(
743        connection_state: Arc<AtomicU8>,
744        heartbeat_secs: u64,
745        message: Option<String>,
746        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
747    ) -> tokio::task::JoinHandle<()> {
748        log_task_started("heartbeat");
749
750        tokio::task::spawn(async move {
751            let interval = Duration::from_secs(heartbeat_secs);
752
753            loop {
754                tokio::time::sleep(interval).await;
755
756                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
757                    ConnectionMode::Active => {
758                        let msg = match &message {
759                            Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
760                            None => WriterCommand::Send(Message::Ping(vec![].into())),
761                        };
762
763                        match writer_tx.send(msg) {
764                            Ok(()) => log::trace!("Sent heartbeat to writer task"),
765                            Err(e) => {
766                                log::error!("Failed to send heartbeat to writer task: {e}");
767                            }
768                        }
769                    }
770                    ConnectionMode::Reconnect => continue,
771                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
772                }
773            }
774
775            log_task_stopped("heartbeat");
776        })
777    }
778}
779
780impl Drop for WebSocketClientInner {
781    fn drop(&mut self) {
782        // Delegate to explicit cleanup handler
783        self.clean_drop();
784    }
785}
786
787/// Cleanup on drop: aborts background tasks and clears handlers to break reference cycles.
788impl CleanDrop for WebSocketClientInner {
789    fn clean_drop(&mut self) {
790        if let Some(ref read_task) = self.read_task.take()
791            && !read_task.is_finished()
792        {
793            read_task.abort();
794            log_task_aborted("read");
795        }
796
797        if !self.write_task.is_finished() {
798            self.write_task.abort();
799            log_task_aborted("write");
800        }
801
802        if let Some(ref handle) = self.heartbeat_task.take()
803            && !handle.is_finished()
804        {
805            handle.abort();
806            log_task_aborted("heartbeat");
807        }
808
809        // Clear handlers to break potential reference cycles
810        self.message_handler = None;
811        self.ping_handler = None;
812    }
813}
814
815impl Debug for WebSocketClientInner {
816    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
817        f.debug_struct(stringify!(WebSocketClientInner))
818            .field("config", &self.config)
819            .field(
820                "connection_mode",
821                &ConnectionMode::from_atomic(&self.connection_mode),
822            )
823            .field("reconnect_timeout", &self.reconnect_timeout)
824            .field("is_stream_mode", &self.is_stream_mode)
825            .finish()
826    }
827}
828
829/// WebSocket client with automatic reconnection.
830///
831/// Handles connection state, callbacks, and rate limiting.
832/// See module docs for architecture details.
833#[cfg_attr(
834    feature = "python",
835    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
836)]
837pub struct WebSocketClient {
838    pub(crate) controller_task: tokio::task::JoinHandle<()>,
839    pub(crate) connection_mode: Arc<AtomicU8>,
840    pub(crate) reconnect_timeout: Duration,
841    pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
842    pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
843}
844
845impl Debug for WebSocketClient {
846    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
847        f.debug_struct(stringify!(WebSocketClient)).finish()
848    }
849}
850
851impl WebSocketClient {
852    /// Creates a websocket client in **stream mode** that returns a [`MessageReader`].
853    ///
854    /// Returns a stream that the caller owns and reads from directly. Automatic reconnection
855    /// is **disabled** because the reader cannot be replaced internally. On disconnection, the
856    /// client transitions to CLOSED state and the caller must manually reconnect by calling
857    /// `connect_stream` again.
858    ///
859    /// Use stream mode when you need custom reconnection logic, direct control over message
860    /// reading, or fine-grained backpressure handling.
861    ///
862    /// See [`WebSocketConfig`] documentation for comparison with handler mode.
863    ///
864    /// # Errors
865    ///
866    /// Returns an error if the connection cannot be established.
867    #[allow(clippy::too_many_arguments)]
868    pub async fn connect_stream(
869        config: WebSocketConfig,
870        keyed_quotas: Vec<(String, Quota)>,
871        default_quota: Option<Quota>,
872        post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
873    ) -> Result<(MessageReader, Self), Error> {
874        install_cryptographic_provider();
875
876        // Create a single connection and split it, respecting configured headers
877        let (writer, reader) =
878            WebSocketClientInner::connect_with_server(&config.url, config.headers.clone()).await?;
879
880        // Create inner without connecting (we'll provide the writer)
881        let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
882
883        let connection_mode = inner.connection_mode.clone();
884        let reconnect_timeout = inner.reconnect_timeout;
885        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
886        let writer_tx = inner.writer_tx.clone();
887
888        let controller_task =
889            Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnect);
890
891        Ok((
892            reader,
893            Self {
894                controller_task,
895                connection_mode,
896                reconnect_timeout,
897                rate_limiter,
898                writer_tx,
899            },
900        ))
901    }
902
903    /// Creates a websocket client in **handler mode** with automatic reconnection.
904    ///
905    /// The handler is called for each incoming message on an internal task.
906    /// Automatic reconnection is **enabled** with exponential backoff. On disconnection,
907    /// the client automatically attempts to reconnect and replaces the internal reader
908    /// (the handler continues working seamlessly).
909    ///
910    /// Use handler mode for simplified connection management, automatic reconnection, Python
911    /// bindings, or callback-based message handling.
912    ///
913    /// See [`WebSocketConfig`] documentation for comparison with stream mode.
914    ///
915    /// # Errors
916    ///
917    /// Returns an error if:
918    /// - The connection cannot be established.
919    /// - `message_handler` is `None` (use `connect_stream` instead).
920    pub async fn connect(
921        config: WebSocketConfig,
922        message_handler: Option<MessageHandler>,
923        ping_handler: Option<PingHandler>,
924        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
925        keyed_quotas: Vec<(String, Quota)>,
926        default_quota: Option<Quota>,
927    ) -> Result<Self, Error> {
928        // Validate that handler mode has a message handler
929        if message_handler.is_none() {
930            return Err(Error::Io(std::io::Error::new(
931                std::io::ErrorKind::InvalidInput,
932                "Handler mode requires message_handler to be set. Use connect_stream() for stream mode without a handler.",
933            )));
934        }
935
936        log::debug!("Connecting");
937        let inner =
938            WebSocketClientInner::connect_url(config, message_handler, ping_handler).await?;
939        let connection_mode = inner.connection_mode.clone();
940        let writer_tx = inner.writer_tx.clone();
941        let reconnect_timeout = inner.reconnect_timeout;
942
943        let controller_task =
944            Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnection);
945
946        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
947
948        Ok(Self {
949            controller_task,
950            connection_mode,
951            reconnect_timeout,
952            rate_limiter,
953            writer_tx,
954        })
955    }
956
957    /// Returns the current connection mode.
958    #[must_use]
959    pub fn connection_mode(&self) -> ConnectionMode {
960        ConnectionMode::from_atomic(&self.connection_mode)
961    }
962
963    /// Returns a clone of the connection mode atomic for external state tracking.
964    ///
965    /// This allows adapter clients to track connection state across reconnections
966    /// without message-passing delays.
967    #[must_use]
968    pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
969        Arc::clone(&self.connection_mode)
970    }
971
972    /// Check if the client connection is active.
973    ///
974    /// Returns `true` if the client is connected and has not been signalled to disconnect.
975    /// The client will automatically retry connection based on its configuration.
976    #[inline]
977    #[must_use]
978    pub fn is_active(&self) -> bool {
979        self.connection_mode().is_active()
980    }
981
982    /// Check if the client is disconnected.
983    #[must_use]
984    pub fn is_disconnected(&self) -> bool {
985        self.controller_task.is_finished()
986    }
987
988    /// Check if the client is reconnecting.
989    ///
990    /// Returns `true` if the client lost connection and is attempting to reestablish it.
991    /// The client will automatically retry connection based on its configuration.
992    #[inline]
993    #[must_use]
994    pub fn is_reconnecting(&self) -> bool {
995        self.connection_mode().is_reconnect()
996    }
997
998    /// Check if the client is disconnecting.
999    ///
1000    /// Returns `true` if the client is in disconnect mode.
1001    #[inline]
1002    #[must_use]
1003    pub fn is_disconnecting(&self) -> bool {
1004        self.connection_mode().is_disconnect()
1005    }
1006
1007    /// Check if the client is closed.
1008    ///
1009    /// Returns `true` if the client has been explicitly disconnected or reached
1010    /// maximum reconnection attempts. In this state, the client cannot be reused
1011    /// and a new client must be created for further connections.
1012    #[inline]
1013    #[must_use]
1014    pub fn is_closed(&self) -> bool {
1015        self.connection_mode().is_closed()
1016    }
1017
1018    /// Wait for the client to become active before sending.
1019    ///
1020    /// Returns an error if the client is closed, disconnecting, or if the wait times out.
1021    async fn wait_for_active(&self) -> Result<(), SendError> {
1022        if self.is_closed() {
1023            return Err(SendError::Closed);
1024        }
1025
1026        let timeout = self.reconnect_timeout;
1027        let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
1028
1029        if !self.is_active() {
1030            log::debug!("Waiting for client to become ACTIVE before sending...");
1031
1032            let inner = tokio::time::timeout(timeout, async {
1033                loop {
1034                    if self.is_active() {
1035                        return Ok(());
1036                    }
1037                    if matches!(
1038                        self.connection_mode(),
1039                        ConnectionMode::Disconnect | ConnectionMode::Closed
1040                    ) {
1041                        return Err(());
1042                    }
1043                    tokio::time::sleep(check_interval).await;
1044                }
1045            })
1046            .await
1047            .map_err(|_| SendError::Timeout)?;
1048            inner.map_err(|()| SendError::Closed)?;
1049        }
1050
1051        Ok(())
1052    }
1053
1054    /// Set disconnect mode to true.
1055    ///
1056    /// Controller task will periodically check the disconnect mode
1057    /// and shutdown the client if it is alive
1058    pub async fn disconnect(&self) {
1059        log::debug!("Disconnecting");
1060        self.connection_mode
1061            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1062
1063        if tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1064            while !self.is_disconnected() {
1065                tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
1066            }
1067
1068            if !self.controller_task.is_finished() {
1069                self.controller_task.abort();
1070                log_task_aborted("controller");
1071            }
1072        })
1073        .await
1074            == Ok(())
1075        {
1076            log::debug!("Controller task finished");
1077        } else {
1078            log::error!("Timeout waiting for controller task to finish");
1079            if !self.controller_task.is_finished() {
1080                self.controller_task.abort();
1081                log_task_aborted("controller");
1082            }
1083        }
1084    }
1085
1086    /// Sends the given text `data` to the server.
1087    ///
1088    /// # Errors
1089    ///
1090    /// Returns a websocket error if unable to send.
1091    #[allow(unused_variables)]
1092    pub async fn send_text(
1093        &self,
1094        data: String,
1095        keys: Option<Vec<String>>,
1096    ) -> Result<(), SendError> {
1097        // Check connection state before rate limiting to fail fast
1098        if self.is_closed() || self.is_disconnecting() {
1099            return Err(SendError::Closed);
1100        }
1101
1102        self.rate_limiter.await_keys_ready(keys).await;
1103        self.wait_for_active().await?;
1104
1105        log::trace!("Sending text: {data:?}");
1106
1107        let msg = Message::Text(data.into());
1108        self.writer_tx
1109            .send(WriterCommand::Send(msg))
1110            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1111    }
1112
1113    /// Sends a pong frame back to the server.
1114    ///
1115    /// # Errors
1116    ///
1117    /// Returns a websocket error if unable to send.
1118    pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1119        self.wait_for_active().await?;
1120
1121        log::trace!("Sending pong frame ({} bytes)", data.len());
1122
1123        let msg = Message::Pong(data.into());
1124        self.writer_tx
1125            .send(WriterCommand::Send(msg))
1126            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1127    }
1128
1129    /// Sends the given bytes `data` to the server.
1130    ///
1131    /// # Errors
1132    ///
1133    /// Returns a websocket error if unable to send.
1134    #[allow(unused_variables)]
1135    pub async fn send_bytes(
1136        &self,
1137        data: Vec<u8>,
1138        keys: Option<Vec<String>>,
1139    ) -> Result<(), SendError> {
1140        // Check connection state before rate limiting to fail fast
1141        if self.is_closed() || self.is_disconnecting() {
1142            return Err(SendError::Closed);
1143        }
1144
1145        self.rate_limiter.await_keys_ready(keys).await;
1146        self.wait_for_active().await?;
1147
1148        log::trace!("Sending bytes: {data:?}");
1149
1150        let msg = Message::Binary(data.into());
1151        self.writer_tx
1152            .send(WriterCommand::Send(msg))
1153            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1154    }
1155
1156    /// Sends a close message to the server.
1157    ///
1158    /// # Errors
1159    ///
1160    /// Returns a websocket error if unable to send.
1161    pub async fn send_close_message(&self) -> Result<(), SendError> {
1162        self.wait_for_active().await?;
1163
1164        let msg = Message::Close(None);
1165        self.writer_tx
1166            .send(WriterCommand::Send(msg))
1167            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1168    }
1169
1170    fn spawn_controller_task(
1171        mut inner: WebSocketClientInner,
1172        connection_mode: Arc<AtomicU8>,
1173        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1174    ) -> tokio::task::JoinHandle<()> {
1175        tokio::task::spawn(async move {
1176            log_task_started("controller");
1177
1178            let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1179
1180            loop {
1181                tokio::time::sleep(check_interval).await;
1182                let mut mode = ConnectionMode::from_atomic(&connection_mode);
1183
1184                if mode.is_disconnect() {
1185                    log::debug!("Disconnecting");
1186
1187                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1188                    if tokio::time::timeout(timeout, async {
1189                        // Delay awaiting graceful shutdown
1190                        tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1191
1192                        if let Some(task) = &inner.read_task
1193                            && !task.is_finished()
1194                        {
1195                            task.abort();
1196                            log_task_aborted("read");
1197                        }
1198
1199                        if let Some(task) = &inner.heartbeat_task
1200                            && !task.is_finished()
1201                        {
1202                            task.abort();
1203                            log_task_aborted("heartbeat");
1204                        }
1205                    })
1206                    .await
1207                    .is_err()
1208                    {
1209                        log::error!("Shutdown timed out after {}s", timeout.as_secs());
1210                    }
1211
1212                    log::debug!("Closed");
1213                    break; // Controller finished
1214                }
1215
1216                if mode.is_closed() {
1217                    log::debug!("Connection closed");
1218                    break;
1219                }
1220
1221                if mode.is_active() && !inner.is_alive() {
1222                    if connection_mode
1223                        .compare_exchange(
1224                            ConnectionMode::Active.as_u8(),
1225                            ConnectionMode::Reconnect.as_u8(),
1226                            Ordering::SeqCst,
1227                            Ordering::SeqCst,
1228                        )
1229                        .is_ok()
1230                    {
1231                        log::debug!("Detected dead read task, transitioning to RECONNECT");
1232                    }
1233                    mode = ConnectionMode::from_atomic(&connection_mode);
1234                }
1235
1236                if mode.is_reconnect() {
1237                    // Check if max reconnection attempts exceeded
1238                    if let Some(max_attempts) = inner.reconnect_max_attempts
1239                        && inner.reconnection_attempt_count >= max_attempts
1240                    {
1241                        log::error!(
1242                            "Max reconnection attempts ({max_attempts}) exceeded, transitioning to CLOSED"
1243                        );
1244                        connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1245                        break;
1246                    }
1247
1248                    inner.reconnection_attempt_count += 1;
1249                    log::debug!(
1250                        "Reconnection attempt {} of {}",
1251                        inner.reconnection_attempt_count,
1252                        inner
1253                            .reconnect_max_attempts
1254                            .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1255                    );
1256
1257                    match inner.reconnect().await {
1258                        Ok(()) => {
1259                            inner.backoff.reset();
1260                            inner.reconnection_attempt_count = 0; // Reset counter on success
1261
1262                            // Only invoke callbacks if not in disconnect state
1263                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
1264                                if let Some(ref handler) = inner.message_handler {
1265                                    let reconnected_msg =
1266                                        Message::Text(RECONNECTED.to_string().into());
1267                                    handler(reconnected_msg);
1268                                    log::debug!("Sent reconnected message to handler");
1269                                }
1270
1271                                // TODO: Retain this legacy callback for use from Python
1272                                if let Some(ref callback) = post_reconnection {
1273                                    callback();
1274                                    log::debug!("Called `post_reconnection` handler");
1275                                }
1276
1277                                log::debug!("Reconnected successfully");
1278                            } else {
1279                                log::debug!(
1280                                    "Skipping post_reconnection handlers due to disconnect state"
1281                                );
1282                            }
1283                        }
1284                        Err(e) => {
1285                            let duration = inner.backoff.next_duration();
1286                            log::warn!(
1287                                "Reconnect attempt {} failed: {e}",
1288                                inner.reconnection_attempt_count
1289                            );
1290                            if !duration.is_zero() {
1291                                log::warn!("Backing off for {}s...", duration.as_secs_f64());
1292                            }
1293                            tokio::time::sleep(duration).await;
1294                        }
1295                    }
1296                }
1297            }
1298            inner
1299                .connection_mode
1300                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1301
1302            log_task_stopped("controller");
1303        })
1304    }
1305}
1306
1307// Abort controller task on drop to clean up background tasks
1308impl Drop for WebSocketClient {
1309    fn drop(&mut self) {
1310        if !self.controller_task.is_finished() {
1311            self.controller_task.abort();
1312            log_task_aborted("controller");
1313        }
1314    }
1315}
1316
1317#[cfg(test)]
1318#[cfg(not(feature = "turmoil"))]
1319#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
1320mod tests {
1321    use std::{num::NonZeroU32, sync::Arc};
1322
1323    use futures_util::{SinkExt, StreamExt};
1324    use tokio::{
1325        net::TcpListener,
1326        task::{self, JoinHandle},
1327    };
1328    use tokio_tungstenite::{
1329        accept_hdr_async,
1330        tungstenite::{
1331            handshake::server::{self, Callback},
1332            http::HeaderValue,
1333        },
1334    };
1335
1336    use crate::{
1337        ratelimiter::quota::Quota,
1338        websocket::{WebSocketClient, WebSocketConfig},
1339    };
1340
1341    struct TestServer {
1342        task: JoinHandle<()>,
1343        port: u16,
1344    }
1345
1346    #[derive(Debug, Clone)]
1347    struct TestCallback {
1348        key: String,
1349        value: HeaderValue,
1350    }
1351
1352    impl Callback for TestCallback {
1353        #[allow(clippy::panic_in_result_fn)]
1354        fn on_request(
1355            self,
1356            request: &server::Request,
1357            response: server::Response,
1358        ) -> Result<server::Response, server::ErrorResponse> {
1359            let _ = response;
1360            let value = request.headers().get(&self.key);
1361            assert!(value.is_some());
1362
1363            if let Some(value) = request.headers().get(&self.key) {
1364                assert_eq!(value, self.value);
1365            }
1366
1367            Ok(response)
1368        }
1369    }
1370
1371    impl TestServer {
1372        async fn setup() -> Self {
1373            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1374            let port = TcpListener::local_addr(&server).unwrap().port();
1375
1376            let header_key = "test".to_string();
1377            let header_value = "test".to_string();
1378
1379            let test_call_back = TestCallback {
1380                key: header_key,
1381                value: HeaderValue::from_str(&header_value).unwrap(),
1382            };
1383
1384            let task = task::spawn(async move {
1385                // Keep accepting connections
1386                loop {
1387                    let (conn, _) = server.accept().await.unwrap();
1388                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1389                        .await
1390                        .unwrap();
1391
1392                    task::spawn(async move {
1393                        while let Some(Ok(msg)) = websocket.next().await {
1394                            match msg {
1395                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1396                                    if txt == "close-now" =>
1397                                {
1398                                    log::debug!("Forcibly closing from server side");
1399                                    // This sends a close frame, then stops reading
1400                                    let _ = websocket.close(None).await;
1401                                    break;
1402                                }
1403                                // Echo text/binary frames
1404                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1405                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1406                                    if websocket.send(msg).await.is_err() {
1407                                        break;
1408                                    }
1409                                }
1410                                // If the client closes, we also break
1411                                tokio_tungstenite::tungstenite::protocol::Message::Close(
1412                                    _frame,
1413                                ) => {
1414                                    let _ = websocket.close(None).await;
1415                                    break;
1416                                }
1417                                // Ignore pings/pongs
1418                                _ => {}
1419                            }
1420                        }
1421                    });
1422                }
1423            });
1424
1425            Self { task, port }
1426        }
1427    }
1428
1429    impl Drop for TestServer {
1430        fn drop(&mut self) {
1431            self.task.abort();
1432        }
1433    }
1434
1435    async fn setup_test_client(port: u16) -> WebSocketClient {
1436        let config = WebSocketConfig {
1437            url: format!("ws://127.0.0.1:{port}"),
1438            headers: vec![("test".into(), "test".into())],
1439            heartbeat: None,
1440            heartbeat_msg: None,
1441            reconnect_timeout_ms: None,
1442            reconnect_delay_initial_ms: None,
1443            reconnect_backoff_factor: None,
1444            reconnect_delay_max_ms: None,
1445            reconnect_jitter_ms: None,
1446            reconnect_max_attempts: None,
1447        };
1448        WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1449            .await
1450            .expect("Failed to connect")
1451    }
1452
1453    #[tokio::test]
1454    async fn test_websocket_basic() {
1455        let server = TestServer::setup().await;
1456        let client = setup_test_client(server.port).await;
1457
1458        assert!(!client.is_disconnected());
1459
1460        client.disconnect().await;
1461        assert!(client.is_disconnected());
1462    }
1463
1464    #[tokio::test]
1465    async fn test_websocket_heartbeat() {
1466        let server = TestServer::setup().await;
1467        let client = setup_test_client(server.port).await;
1468
1469        // Wait ~3s => server should see multiple "ping"
1470        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1471
1472        // Cleanup
1473        client.disconnect().await;
1474        assert!(client.is_disconnected());
1475    }
1476
1477    #[tokio::test]
1478    async fn test_websocket_reconnect_exhausted() {
1479        let config = WebSocketConfig {
1480            url: "ws://127.0.0.1:9997".into(), // <-- No server
1481            headers: vec![],
1482            heartbeat: None,
1483            heartbeat_msg: None,
1484            reconnect_timeout_ms: None,
1485            reconnect_delay_initial_ms: None,
1486            reconnect_backoff_factor: None,
1487            reconnect_delay_max_ms: None,
1488            reconnect_jitter_ms: None,
1489            reconnect_max_attempts: None,
1490        };
1491        let res =
1492            WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1493                .await;
1494        assert!(res.is_err(), "Should fail quickly with no server");
1495    }
1496
1497    #[tokio::test]
1498    async fn test_websocket_forced_close_reconnect() {
1499        let server = TestServer::setup().await;
1500        let client = setup_test_client(server.port).await;
1501
1502        // 1) Send normal message
1503        client.send_text("Hello".into(), None).await.unwrap();
1504
1505        // 2) Trigger forced close from server
1506        client.send_text("close-now".into(), None).await.unwrap();
1507
1508        // 3) Wait a bit => read loop sees close => reconnect
1509        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1510
1511        // Confirm not disconnected
1512        assert!(!client.is_disconnected());
1513
1514        // Cleanup
1515        client.disconnect().await;
1516        assert!(client.is_disconnected());
1517    }
1518
1519    #[tokio::test]
1520    async fn test_rate_limiter() {
1521        let server = TestServer::setup().await;
1522        let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1523
1524        let config = WebSocketConfig {
1525            url: format!("ws://127.0.0.1:{}", server.port),
1526            headers: vec![("test".into(), "test".into())],
1527            heartbeat: None,
1528            heartbeat_msg: None,
1529            reconnect_timeout_ms: None,
1530            reconnect_delay_initial_ms: None,
1531            reconnect_backoff_factor: None,
1532            reconnect_delay_max_ms: None,
1533            reconnect_jitter_ms: None,
1534            reconnect_max_attempts: None,
1535        };
1536
1537        let client = WebSocketClient::connect(
1538            config,
1539            Some(Arc::new(|_| {})),
1540            None,
1541            None,
1542            vec![("default".into(), quota)],
1543            None,
1544        )
1545        .await
1546        .unwrap();
1547
1548        // First 2 should succeed
1549        client.send_text("test1".into(), None).await.unwrap();
1550        client.send_text("test2".into(), None).await.unwrap();
1551
1552        // Third should error
1553        client.send_text("test3".into(), None).await.unwrap();
1554
1555        // Cleanup
1556        client.disconnect().await;
1557        assert!(client.is_disconnected());
1558    }
1559
1560    #[tokio::test]
1561    async fn test_concurrent_writers() {
1562        let server = TestServer::setup().await;
1563        let client = Arc::new(setup_test_client(server.port).await);
1564
1565        let mut handles = vec![];
1566        for i in 0..10 {
1567            let client = client.clone();
1568            handles.push(task::spawn(async move {
1569                client.send_text(format!("test{i}"), None).await.unwrap();
1570            }));
1571        }
1572
1573        for handle in handles {
1574            handle.await.unwrap();
1575        }
1576
1577        // Cleanup
1578        client.disconnect().await;
1579        assert!(client.is_disconnected());
1580    }
1581}
1582
1583#[cfg(test)]
1584#[cfg(not(feature = "turmoil"))]
1585mod rust_tests {
1586    use futures_util::StreamExt;
1587    use rstest::rstest;
1588    use tokio::{
1589        net::TcpListener,
1590        task,
1591        time::{Duration, sleep},
1592    };
1593    use tokio_tungstenite::accept_async;
1594
1595    use super::*;
1596    use crate::websocket::types::channel_message_handler;
1597
1598    #[rstest]
1599    #[tokio::test]
1600    async fn test_reconnect_then_disconnect() {
1601        // Bind an ephemeral port
1602        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1603        let port = listener.local_addr().unwrap().port();
1604
1605        // Server task: accept one ws connection then close it
1606        let server = task::spawn(async move {
1607            let (stream, _) = listener.accept().await.unwrap();
1608            let ws = accept_async(stream).await.unwrap();
1609            drop(ws);
1610            // Keep alive briefly
1611            sleep(Duration::from_secs(1)).await;
1612        });
1613
1614        // Build a channel-based message handler for incoming messages (unused here)
1615        let (handler, _rx) = channel_message_handler();
1616
1617        // Configure client with short reconnect backoff
1618        let config = WebSocketConfig {
1619            url: format!("ws://127.0.0.1:{port}"),
1620            headers: vec![],
1621            heartbeat: None,
1622            heartbeat_msg: None,
1623            reconnect_timeout_ms: Some(1_000),
1624            reconnect_delay_initial_ms: Some(50),
1625            reconnect_delay_max_ms: Some(100),
1626            reconnect_backoff_factor: Some(1.0),
1627            reconnect_jitter_ms: Some(0),
1628            reconnect_max_attempts: None,
1629        };
1630
1631        // Connect the client
1632        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1633            .await
1634            .unwrap();
1635
1636        // Allow server to drop connection and client to detect
1637        sleep(Duration::from_millis(100)).await;
1638        // Now immediately disconnect the client
1639        client.disconnect().await;
1640        assert!(client.is_disconnected());
1641        server.abort();
1642    }
1643
1644    #[rstest]
1645    #[tokio::test]
1646    async fn test_reconnect_state_flips_when_reader_stops() {
1647        // Bind an ephemeral port and accept a single websocket connection which we drop.
1648        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1649        let port = listener.local_addr().unwrap().port();
1650
1651        let server = task::spawn(async move {
1652            if let Ok((stream, _)) = listener.accept().await
1653                && let Ok(ws) = accept_async(stream).await
1654            {
1655                drop(ws);
1656            }
1657            sleep(Duration::from_millis(50)).await;
1658        });
1659
1660        let (handler, _rx) = channel_message_handler();
1661
1662        let config = WebSocketConfig {
1663            url: format!("ws://127.0.0.1:{port}"),
1664            headers: vec![],
1665            heartbeat: None,
1666            heartbeat_msg: None,
1667            reconnect_timeout_ms: Some(1_000),
1668            reconnect_delay_initial_ms: Some(50),
1669            reconnect_delay_max_ms: Some(100),
1670            reconnect_backoff_factor: Some(1.0),
1671            reconnect_jitter_ms: Some(0),
1672            reconnect_max_attempts: None,
1673        };
1674
1675        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1676            .await
1677            .unwrap();
1678
1679        tokio::time::timeout(Duration::from_secs(2), async {
1680            loop {
1681                if client.is_reconnecting() {
1682                    break;
1683                }
1684                tokio::time::sleep(Duration::from_millis(10)).await;
1685            }
1686        })
1687        .await
1688        .expect("client did not enter RECONNECT state");
1689
1690        client.disconnect().await;
1691        server.abort();
1692    }
1693
1694    #[rstest]
1695    #[tokio::test]
1696    async fn test_stream_mode_disables_auto_reconnect() {
1697        // Test that stream-based clients (created via connect_stream) set is_stream_mode flag
1698        // and that reconnect() transitions to CLOSED state for stream mode
1699        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1700        let port = listener.local_addr().unwrap().port();
1701
1702        let server = task::spawn(async move {
1703            if let Ok((stream, _)) = listener.accept().await
1704                && let Ok(_ws) = accept_async(stream).await
1705            {
1706                // Keep connection alive briefly
1707                sleep(Duration::from_millis(100)).await;
1708            }
1709        });
1710
1711        let config = WebSocketConfig {
1712            url: format!("ws://127.0.0.1:{port}"),
1713            headers: vec![],
1714            heartbeat: None,
1715            heartbeat_msg: None,
1716            reconnect_timeout_ms: Some(1_000),
1717            reconnect_delay_initial_ms: Some(50),
1718            reconnect_delay_max_ms: Some(100),
1719            reconnect_backoff_factor: Some(1.0),
1720            reconnect_jitter_ms: Some(0),
1721            reconnect_max_attempts: None,
1722        };
1723
1724        let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
1725            .await
1726            .unwrap();
1727
1728        // Note: We can't easily test the reconnect behavior from the outside since
1729        // the inner client is private. The key fix is that WebSocketClientInner
1730        // now has is_stream_mode=true for connect_stream, and reconnect() will
1731        // transition to CLOSED state instead of creating a new reader that gets dropped.
1732        // This is tested implicitly by the fact that stream users won't get stuck
1733        // in an infinite reconnect loop.
1734
1735        server.abort();
1736    }
1737
1738    #[rstest]
1739    #[tokio::test]
1740    async fn test_message_handler_mode_allows_auto_reconnect() {
1741        // Test that regular clients (with message handler) can auto-reconnect
1742        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1743        let port = listener.local_addr().unwrap().port();
1744
1745        let server = task::spawn(async move {
1746            // Accept first connection and close it
1747            if let Ok((stream, _)) = listener.accept().await
1748                && let Ok(ws) = accept_async(stream).await
1749            {
1750                drop(ws);
1751            }
1752            sleep(Duration::from_millis(50)).await;
1753        });
1754
1755        let (handler, _rx) = channel_message_handler();
1756
1757        let config = WebSocketConfig {
1758            url: format!("ws://127.0.0.1:{port}"),
1759            headers: vec![],
1760            heartbeat: None,
1761            heartbeat_msg: None,
1762            reconnect_timeout_ms: Some(1_000),
1763            reconnect_delay_initial_ms: Some(50),
1764            reconnect_delay_max_ms: Some(100),
1765            reconnect_backoff_factor: Some(1.0),
1766            reconnect_jitter_ms: Some(0),
1767            reconnect_max_attempts: None,
1768        };
1769
1770        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1771            .await
1772            .unwrap();
1773
1774        // Wait for the connection to be dropped and reconnection to be attempted
1775        tokio::time::timeout(Duration::from_secs(2), async {
1776            loop {
1777                if client.is_reconnecting() || client.is_closed() {
1778                    break;
1779                }
1780                tokio::time::sleep(Duration::from_millis(10)).await;
1781            }
1782        })
1783        .await
1784        .expect("client should attempt reconnection or close");
1785
1786        // Should either be reconnecting or closed (depending on timing)
1787        // The important thing is it's not staying active forever
1788        assert!(
1789            client.is_reconnecting() || client.is_closed(),
1790            "Client with message handler should attempt reconnection"
1791        );
1792
1793        client.disconnect().await;
1794        server.abort();
1795    }
1796
1797    #[rstest]
1798    #[tokio::test]
1799    async fn test_handler_mode_reconnect_with_new_connection() {
1800        // Test that handler mode successfully reconnects and messages continue flowing
1801        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1802        let port = listener.local_addr().unwrap().port();
1803
1804        let server = task::spawn(async move {
1805            // First connection - accept and immediately close
1806            if let Ok((stream, _)) = listener.accept().await
1807                && let Ok(ws) = accept_async(stream).await
1808            {
1809                drop(ws);
1810            }
1811
1812            // Small delay to let client detect disconnection
1813            sleep(Duration::from_millis(100)).await;
1814
1815            // Second connection - accept, send a message, then keep alive
1816            if let Ok((stream, _)) = listener.accept().await
1817                && let Ok(mut ws) = accept_async(stream).await
1818            {
1819                use futures_util::SinkExt;
1820                let _ = ws
1821                    .send(Message::Text("reconnected".to_string().into()))
1822                    .await;
1823                sleep(Duration::from_secs(1)).await;
1824            }
1825        });
1826
1827        let (handler, mut rx) = channel_message_handler();
1828
1829        let config = WebSocketConfig {
1830            url: format!("ws://127.0.0.1:{port}"),
1831            headers: vec![],
1832            heartbeat: None,
1833            heartbeat_msg: None,
1834            reconnect_timeout_ms: Some(2_000),
1835            reconnect_delay_initial_ms: Some(50),
1836            reconnect_delay_max_ms: Some(200),
1837            reconnect_backoff_factor: Some(1.5),
1838            reconnect_jitter_ms: Some(10),
1839            reconnect_max_attempts: None,
1840        };
1841
1842        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1843            .await
1844            .unwrap();
1845
1846        // Wait for reconnection to happen and message to arrive
1847        let result = tokio::time::timeout(Duration::from_secs(5), async {
1848            loop {
1849                if let Ok(msg) = rx.try_recv()
1850                    && matches!(msg, Message::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
1851                {
1852                    return true;
1853                }
1854                tokio::time::sleep(Duration::from_millis(10)).await;
1855            }
1856        })
1857        .await;
1858
1859        assert!(
1860            result.is_ok(),
1861            "Should receive message after reconnection within timeout"
1862        );
1863
1864        client.disconnect().await;
1865        server.abort();
1866    }
1867
1868    #[rstest]
1869    #[tokio::test]
1870    async fn test_stream_mode_no_auto_reconnect() {
1871        // Test that stream mode does not automatically reconnect when connection is lost
1872        // The caller owns the reader and is responsible for detecting disconnection
1873        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1874        let port = listener.local_addr().unwrap().port();
1875
1876        let server = task::spawn(async move {
1877            // Accept connection and send one message, then close
1878            if let Ok((stream, _)) = listener.accept().await
1879                && let Ok(mut ws) = accept_async(stream).await
1880            {
1881                use futures_util::SinkExt;
1882                let _ = ws.send(Message::Text("hello".to_string().into())).await;
1883                sleep(Duration::from_millis(50)).await;
1884                // Connection closes when ws is dropped
1885            }
1886        });
1887
1888        let config = WebSocketConfig {
1889            url: format!("ws://127.0.0.1:{port}"),
1890            headers: vec![],
1891            heartbeat: None,
1892            heartbeat_msg: None,
1893            reconnect_timeout_ms: Some(1_000),
1894            reconnect_delay_initial_ms: Some(50),
1895            reconnect_delay_max_ms: Some(100),
1896            reconnect_backoff_factor: Some(1.0),
1897            reconnect_jitter_ms: Some(0),
1898            reconnect_max_attempts: None,
1899        };
1900
1901        let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
1902            .await
1903            .unwrap();
1904
1905        // Initially active
1906        assert!(client.is_active(), "Client should start as active");
1907
1908        // Read the hello message
1909        let msg = reader.next().await;
1910        assert!(
1911            matches!(msg, Some(Ok(Message::Text(ref text))) if AsRef::<str>::as_ref(text) == "hello"),
1912            "Should receive initial message"
1913        );
1914
1915        // Read until connection closes (reader will return None or error)
1916        while let Some(msg) = reader.next().await {
1917            if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
1918                break;
1919            }
1920        }
1921
1922        // In stream mode, the controller cannot detect disconnection (reader is owned by caller)
1923        // The client remains ACTIVE - it's the caller's responsibility to call disconnect()
1924        sleep(Duration::from_millis(200)).await;
1925
1926        // Client should still be ACTIVE (not RECONNECTING or CLOSED)
1927        // This is correct behavior - stream mode doesn't auto-detect disconnection
1928        assert!(
1929            client.is_active() || client.is_closed(),
1930            "Stream mode client stays ACTIVE (caller owns reader) or caller disconnected"
1931        );
1932        assert!(
1933            !client.is_reconnecting(),
1934            "Stream mode client should never attempt reconnection"
1935        );
1936
1937        client.disconnect().await;
1938        server.abort();
1939    }
1940
1941    #[rstest]
1942    #[tokio::test]
1943    async fn test_send_timeout_uses_configured_reconnect_timeout() {
1944        // Test that send operations respect the configured reconnect_timeout.
1945        // When a client is stuck in RECONNECT longer than the timeout, sends should fail with Timeout.
1946        use nautilus_common::testing::wait_until_async;
1947
1948        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1949        let port = listener.local_addr().unwrap().port();
1950
1951        let server = task::spawn(async move {
1952            // Accept first connection and immediately close it
1953            if let Ok((stream, _)) = listener.accept().await
1954                && let Ok(ws) = accept_async(stream).await
1955            {
1956                drop(ws);
1957            }
1958            // Don't accept second connection - client will be stuck in RECONNECT
1959            sleep(Duration::from_secs(60)).await;
1960        });
1961
1962        let (handler, _rx) = channel_message_handler();
1963
1964        // Configure with SHORT 2s reconnect timeout
1965        let config = WebSocketConfig {
1966            url: format!("ws://127.0.0.1:{port}"),
1967            headers: vec![],
1968            heartbeat: None,
1969            heartbeat_msg: None,
1970            reconnect_timeout_ms: Some(2_000), // 2s timeout
1971            reconnect_delay_initial_ms: Some(50),
1972            reconnect_delay_max_ms: Some(100),
1973            reconnect_backoff_factor: Some(1.0),
1974            reconnect_jitter_ms: Some(0),
1975            reconnect_max_attempts: None,
1976        };
1977
1978        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1979            .await
1980            .unwrap();
1981
1982        // Wait for client to enter RECONNECT state
1983        wait_until_async(
1984            || async { client.is_reconnecting() },
1985            Duration::from_secs(3),
1986        )
1987        .await;
1988
1989        // Attempt send while stuck in RECONNECT - should timeout after 2s (configured timeout)
1990        let start = std::time::Instant::now();
1991        let send_result = client.send_text("test".to_string(), None).await;
1992        let elapsed = start.elapsed();
1993
1994        assert!(
1995            send_result.is_err(),
1996            "Send should fail when client stuck in RECONNECT"
1997        );
1998        assert!(
1999            matches!(send_result, Err(crate::error::SendError::Timeout)),
2000            "Send should return Timeout error, was: {send_result:?}"
2001        );
2002        // Verify timeout respects configured value (2s), but don't check upper bound
2003        // as CI scheduler jitter can cause legitimate delays beyond the timeout
2004        assert!(
2005            elapsed >= Duration::from_millis(1800),
2006            "Send should timeout after at least 2s (configured timeout), took {elapsed:?}"
2007        );
2008
2009        client.disconnect().await;
2010        server.abort();
2011    }
2012
2013    #[rstest]
2014    #[tokio::test]
2015    async fn test_send_waits_during_reconnection() {
2016        // Test that send operations wait for reconnection to complete (up to timeout)
2017        use nautilus_common::testing::wait_until_async;
2018
2019        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2020        let port = listener.local_addr().unwrap().port();
2021
2022        let server = task::spawn(async move {
2023            // First connection - accept and immediately close
2024            if let Ok((stream, _)) = listener.accept().await
2025                && let Ok(ws) = accept_async(stream).await
2026            {
2027                drop(ws);
2028            }
2029
2030            // Wait a bit before accepting second connection
2031            sleep(Duration::from_millis(500)).await;
2032
2033            // Second connection - accept and keep alive
2034            if let Ok((stream, _)) = listener.accept().await
2035                && let Ok(mut ws) = accept_async(stream).await
2036            {
2037                // Echo messages
2038                while let Some(Ok(msg)) = ws.next().await {
2039                    if ws.send(msg).await.is_err() {
2040                        break;
2041                    }
2042                }
2043            }
2044        });
2045
2046        let (handler, _rx) = channel_message_handler();
2047
2048        let config = WebSocketConfig {
2049            url: format!("ws://127.0.0.1:{port}"),
2050            headers: vec![],
2051            heartbeat: None,
2052            heartbeat_msg: None,
2053            reconnect_timeout_ms: Some(5_000), // 5s timeout - enough for reconnect
2054            reconnect_delay_initial_ms: Some(100),
2055            reconnect_delay_max_ms: Some(200),
2056            reconnect_backoff_factor: Some(1.0),
2057            reconnect_jitter_ms: Some(0),
2058            reconnect_max_attempts: None,
2059        };
2060
2061        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2062            .await
2063            .unwrap();
2064
2065        // Wait for reconnection to trigger
2066        wait_until_async(
2067            || async { client.is_reconnecting() },
2068            Duration::from_secs(2),
2069        )
2070        .await;
2071
2072        // Try to send while reconnecting - should wait and succeed after reconnect
2073        let send_result = tokio::time::timeout(
2074            Duration::from_secs(3),
2075            client.send_text("test_message".to_string(), None),
2076        )
2077        .await;
2078
2079        assert!(
2080            send_result.is_ok() && send_result.unwrap().is_ok(),
2081            "Send should succeed after waiting for reconnection"
2082        );
2083
2084        client.disconnect().await;
2085        server.abort();
2086    }
2087
2088    #[rstest]
2089    #[tokio::test]
2090    async fn test_rate_limiter_before_active_wait() {
2091        // Test that rate limiting happens BEFORE active state check.
2092        // This prevents race conditions where connection state changes during rate limit wait.
2093        // We verify this by: (1) exhausting rate limit, (2) ensuring client is RECONNECTING,
2094        // (3) sending again and confirming it waits for rate limit THEN reconnection.
2095        use std::{num::NonZeroU32, sync::Arc};
2096
2097        use nautilus_common::testing::wait_until_async;
2098
2099        use crate::ratelimiter::quota::Quota;
2100
2101        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2102        let port = listener.local_addr().unwrap().port();
2103
2104        let server = task::spawn(async move {
2105            // First connection - accept and close after receiving one message
2106            if let Ok((stream, _)) = listener.accept().await
2107                && let Ok(mut ws) = accept_async(stream).await
2108            {
2109                // Receive first message then close
2110                if let Some(Ok(_)) = ws.next().await {
2111                    drop(ws);
2112                }
2113            }
2114
2115            // Wait before accepting reconnection
2116            sleep(Duration::from_millis(500)).await;
2117
2118            // Second connection - accept and keep alive
2119            if let Ok((stream, _)) = listener.accept().await
2120                && let Ok(mut ws) = accept_async(stream).await
2121            {
2122                while let Some(Ok(msg)) = ws.next().await {
2123                    if ws.send(msg).await.is_err() {
2124                        break;
2125                    }
2126                }
2127            }
2128        });
2129
2130        let (handler, _rx) = channel_message_handler();
2131
2132        let config = WebSocketConfig {
2133            url: format!("ws://127.0.0.1:{port}"),
2134            headers: vec![],
2135            heartbeat: None,
2136            heartbeat_msg: None,
2137            reconnect_timeout_ms: Some(5_000),
2138            reconnect_delay_initial_ms: Some(50),
2139            reconnect_delay_max_ms: Some(100),
2140            reconnect_backoff_factor: Some(1.0),
2141            reconnect_jitter_ms: Some(0),
2142            reconnect_max_attempts: None,
2143        };
2144
2145        // Very restrictive rate limit: 1 request per second, burst of 1
2146        let quota =
2147            Quota::per_second(NonZeroU32::new(1).unwrap()).allow_burst(NonZeroU32::new(1).unwrap());
2148
2149        let client = Arc::new(
2150            WebSocketClient::connect(
2151                config,
2152                Some(handler),
2153                None,
2154                None,
2155                vec![("test_key".to_string(), quota)],
2156                None,
2157            )
2158            .await
2159            .unwrap(),
2160        );
2161
2162        // First send exhausts burst capacity and triggers connection close
2163        client
2164            .send_text("msg1".to_string(), Some(vec!["test_key".to_string()]))
2165            .await
2166            .unwrap();
2167
2168        // Wait for client to enter RECONNECT state
2169        wait_until_async(
2170            || async { client.is_reconnecting() },
2171            Duration::from_secs(2),
2172        )
2173        .await;
2174
2175        // Second send: will hit rate limit (~1s) THEN wait for reconnection (~0.5s)
2176        let start = std::time::Instant::now();
2177        let send_result = client
2178            .send_text("msg2".to_string(), Some(vec!["test_key".to_string()]))
2179            .await;
2180        let elapsed = start.elapsed();
2181
2182        // Should succeed after both rate limit AND reconnection
2183        assert!(
2184            send_result.is_ok(),
2185            "Send should succeed after rate limit + reconnection, was: {send_result:?}"
2186        );
2187        // Total wait should be at least rate limit time (~1s)
2188        // The reconnection completes while rate limiting or after
2189        // Use 850ms threshold to account for timing jitter in CI
2190        assert!(
2191            elapsed >= Duration::from_millis(850),
2192            "Should wait for rate limit (~1s), waited {elapsed:?}"
2193        );
2194
2195        client.disconnect().await;
2196        server.abort();
2197    }
2198
2199    #[rstest]
2200    #[tokio::test]
2201    async fn test_disconnect_during_reconnect_exits_cleanly() {
2202        // Test CAS race condition: disconnect called during reconnection
2203        // Should exit cleanly without spawning new tasks
2204        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2205        let port = listener.local_addr().unwrap().port();
2206
2207        let server = task::spawn(async move {
2208            // Accept first connection and immediately close
2209            if let Ok((stream, _)) = listener.accept().await
2210                && let Ok(ws) = accept_async(stream).await
2211            {
2212                drop(ws);
2213            }
2214            // Don't accept second connection - let reconnect hang
2215            sleep(Duration::from_secs(60)).await;
2216        });
2217
2218        let (handler, _rx) = channel_message_handler();
2219
2220        let config = WebSocketConfig {
2221            url: format!("ws://127.0.0.1:{port}"),
2222            headers: vec![],
2223            heartbeat: None,
2224            heartbeat_msg: None,
2225            reconnect_timeout_ms: Some(2_000), // 2s timeout - shorter than disconnect timeout
2226            reconnect_delay_initial_ms: Some(100),
2227            reconnect_delay_max_ms: Some(200),
2228            reconnect_backoff_factor: Some(1.0),
2229            reconnect_jitter_ms: Some(0),
2230            reconnect_max_attempts: None,
2231        };
2232
2233        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2234            .await
2235            .unwrap();
2236
2237        // Wait for reconnection to start
2238        tokio::time::timeout(Duration::from_secs(2), async {
2239            while !client.is_reconnecting() {
2240                sleep(Duration::from_millis(10)).await;
2241            }
2242        })
2243        .await
2244        .expect("Client should enter RECONNECT state");
2245
2246        // Disconnect while reconnecting
2247        client.disconnect().await;
2248
2249        // Should be cleanly closed
2250        assert!(
2251            client.is_disconnected(),
2252            "Client should be cleanly disconnected"
2253        );
2254
2255        server.abort();
2256    }
2257
2258    #[rstest]
2259    #[tokio::test]
2260    async fn test_send_fails_fast_when_closed_before_rate_limit() {
2261        // Test that send operations check connection state BEFORE rate limiting,
2262        // preventing unnecessary delays when the connection is already closed.
2263        use std::{num::NonZeroU32, sync::Arc};
2264
2265        use nautilus_common::testing::wait_until_async;
2266
2267        use crate::ratelimiter::quota::Quota;
2268
2269        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2270        let port = listener.local_addr().unwrap().port();
2271
2272        let server = task::spawn(async move {
2273            // Accept connection and immediately close
2274            if let Ok((stream, _)) = listener.accept().await
2275                && let Ok(ws) = accept_async(stream).await
2276            {
2277                drop(ws);
2278            }
2279            sleep(Duration::from_secs(60)).await;
2280        });
2281
2282        let (handler, _rx) = channel_message_handler();
2283
2284        let config = WebSocketConfig {
2285            url: format!("ws://127.0.0.1:{port}"),
2286            headers: vec![],
2287            heartbeat: None,
2288            heartbeat_msg: None,
2289            reconnect_timeout_ms: Some(5_000),
2290            reconnect_delay_initial_ms: Some(50),
2291            reconnect_delay_max_ms: Some(100),
2292            reconnect_backoff_factor: Some(1.0),
2293            reconnect_jitter_ms: Some(0),
2294            reconnect_max_attempts: None,
2295        };
2296
2297        // Very restrictive rate limit: 1 request per 10 seconds
2298        // This ensures that if we wait for rate limit, the test will timeout
2299        let quota = Quota::with_period(Duration::from_secs(10))
2300            .unwrap()
2301            .allow_burst(NonZeroU32::new(1).unwrap());
2302
2303        let client = Arc::new(
2304            WebSocketClient::connect(
2305                config,
2306                Some(handler),
2307                None,
2308                None,
2309                vec![("test_key".to_string(), quota)],
2310                None,
2311            )
2312            .await
2313            .unwrap(),
2314        );
2315
2316        // Wait for disconnection
2317        wait_until_async(
2318            || async { client.is_reconnecting() || client.is_closed() },
2319            Duration::from_secs(2),
2320        )
2321        .await;
2322
2323        // Explicitly disconnect to move away from ACTIVE state
2324        client.disconnect().await;
2325        assert!(
2326            !client.is_active(),
2327            "Client should not be active after disconnect"
2328        );
2329
2330        // Attempt send - should fail IMMEDIATELY without waiting for rate limit
2331        let start = std::time::Instant::now();
2332        let result = client
2333            .send_text("test".to_string(), Some(vec!["test_key".to_string()]))
2334            .await;
2335        let elapsed = start.elapsed();
2336
2337        // Should fail with Closed error
2338        assert!(result.is_err(), "Send should fail when client is closed");
2339        assert!(
2340            matches!(result, Err(crate::error::SendError::Closed)),
2341            "Send should return Closed error, was: {result:?}"
2342        );
2343
2344        // Should fail FAST (< 100ms) without waiting for rate limit (10s)
2345        assert!(
2346            elapsed < Duration::from_millis(100),
2347            "Send should fail fast without rate limiting, took {elapsed:?}"
2348        );
2349
2350        server.abort();
2351    }
2352
2353    #[rstest]
2354    #[tokio::test]
2355    async fn test_connect_rejects_none_message_handler() {
2356        // Test that connect() properly rejects None message_handler
2357        // to prevent zombie connections that appear alive but never detect disconnections
2358
2359        let config = WebSocketConfig {
2360            url: "ws://127.0.0.1:9999".to_string(),
2361            headers: vec![],
2362            heartbeat: None,
2363            heartbeat_msg: None,
2364            reconnect_timeout_ms: Some(1_000),
2365            reconnect_delay_initial_ms: Some(100),
2366            reconnect_delay_max_ms: Some(500),
2367            reconnect_backoff_factor: Some(1.5),
2368            reconnect_jitter_ms: Some(0),
2369            reconnect_max_attempts: None,
2370        };
2371
2372        // Pass None for message_handler - should be rejected
2373        let result = WebSocketClient::connect(config, None, None, None, vec![], None).await;
2374
2375        assert!(
2376            result.is_err(),
2377            "connect() should reject None message_handler"
2378        );
2379
2380        let err = result.unwrap_err();
2381        let err_msg = err.to_string();
2382        assert!(
2383            err_msg.contains("Handler mode requires message_handler"),
2384            "Error should mention missing message_handler, was: {err_msg}"
2385        );
2386    }
2387
2388    #[rstest]
2389    #[tokio::test]
2390    async fn test_client_without_handler_sets_stream_mode() {
2391        // Test that if a client is created without a handler via connect_url,
2392        // it properly sets is_stream_mode=true to prevent zombie connections
2393
2394        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2395        let port = listener.local_addr().unwrap().port();
2396
2397        let server = task::spawn(async move {
2398            // Accept and immediately close to simulate server disconnect
2399            if let Ok((stream, _)) = listener.accept().await
2400                && let Ok(ws) = accept_async(stream).await
2401            {
2402                drop(ws); // Drop connection immediately
2403            }
2404        });
2405
2406        let config = WebSocketConfig {
2407            url: format!("ws://127.0.0.1:{port}"),
2408            headers: vec![],
2409            heartbeat: None,
2410            heartbeat_msg: None,
2411            reconnect_timeout_ms: Some(1_000),
2412            reconnect_delay_initial_ms: Some(100),
2413            reconnect_delay_max_ms: Some(500),
2414            reconnect_backoff_factor: Some(1.5),
2415            reconnect_jitter_ms: Some(0),
2416            reconnect_max_attempts: None,
2417        };
2418
2419        // Create client directly via connect_url with no handler (stream mode)
2420        let inner = WebSocketClientInner::connect_url(config, None, None)
2421            .await
2422            .unwrap();
2423
2424        // Verify is_stream_mode is true when no handler
2425        assert!(
2426            inner.is_stream_mode,
2427            "Client without handler should have is_stream_mode=true"
2428        );
2429
2430        // Verify that when stream mode is enabled, reconnection is disabled
2431        // (documented behavior - stream mode clients close instead of reconnecting)
2432
2433        server.abort();
2434    }
2435}