Skip to main content

nautilus_network/websocket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! WebSocket client implementation with automatic reconnection.
17//!
18//! This module contains the core WebSocket client implementation including:
19//! - Connection management with automatic reconnection.
20//! - Split read/write architecture with separate tasks.
21//! - Unbounded channels on latency-sensitive paths.
22//! - Heartbeat support.
23//! - Rate limiting integration.
24
25use std::{
26    collections::VecDeque,
27    fmt::Debug,
28    sync::{
29        Arc,
30        atomic::{AtomicU8, Ordering},
31    },
32    time::Duration,
33};
34
35use futures_util::{SinkExt, StreamExt};
36use http::HeaderName;
37use nautilus_core::CleanDrop;
38use nautilus_cryptography::providers::install_cryptographic_provider;
39#[cfg(feature = "turmoil")]
40use tokio_tungstenite::MaybeTlsStream;
41#[cfg(feature = "turmoil")]
42use tokio_tungstenite::client_async;
43#[cfg(not(feature = "turmoil"))]
44use tokio_tungstenite::connect_async_with_config;
45use tokio_tungstenite::tungstenite::{
46    Error, Message, client::IntoClientRequest, http::HeaderValue,
47};
48use ustr::Ustr;
49
50use super::{
51    config::WebSocketConfig,
52    consts::{
53        CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
54        GRACEFUL_SHUTDOWN_TIMEOUT_SECS, SEND_OPERATION_CHECK_INTERVAL_MS,
55    },
56    types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
57};
58#[cfg(feature = "turmoil")]
59use crate::net::TcpConnector;
60use crate::{
61    RECONNECTED,
62    backoff::ExponentialBackoff,
63    error::SendError,
64    logging::{log_task_aborted, log_task_started, log_task_stopped},
65    mode::ConnectionMode,
66    ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
67};
68
69/// `WebSocketClient` connects to a websocket server to read and send messages.
70///
71/// The client is opinionated about how messages are read and written. It
72/// assumes that data can only have one reader but multiple writers.
73///
74/// The client splits the connection into read and write halves. It moves
75/// the read half into a tokio task which keeps receiving messages from the
76/// server and calls a handler - a Python function that takes the data
77/// as its parameter. It stores the write half in the struct wrapped
78/// with an Arc Mutex. This way the client struct can be used to write
79/// data to the server from multiple scopes/tasks.
80///
81/// The client also maintains a heartbeat if given a duration in seconds.
82/// It's preferable to set the duration slightly lower - heartbeat more
83/// frequently - than the required amount.
84pub struct WebSocketClientInner {
85    config: WebSocketConfig,
86    /// The function to handle incoming messages (stored separately from config).
87    message_handler: Option<MessageHandler>,
88    /// The handler for incoming pings (stored separately from config).
89    ping_handler: Option<PingHandler>,
90    read_task: Option<tokio::task::JoinHandle<()>>,
91    write_task: tokio::task::JoinHandle<()>,
92    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
93    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
94    connection_mode: Arc<AtomicU8>,
95    reconnect_timeout: Duration,
96    backoff: ExponentialBackoff,
97    /// True if this is a stream-based client (created via `connect_stream`).
98    /// Stream-based clients disable auto-reconnect because the reader is
99    /// owned by the caller and cannot be replaced during reconnection.
100    is_stream_mode: bool,
101    /// Maximum number of reconnection attempts before giving up (None = unlimited).
102    reconnect_max_attempts: Option<u32>,
103    /// Current count of consecutive reconnection attempts.
104    reconnection_attempt_count: u32,
105}
106
107impl WebSocketClientInner {
108    /// Create an inner websocket client with an existing writer.
109    ///
110    /// This is used for stream mode where the reader is owned by the caller.
111    ///
112    /// # Errors
113    ///
114    /// Returns an error if the exponential backoff configuration is invalid.
115    pub async fn new_with_writer(
116        config: WebSocketConfig,
117        writer: MessageWriter,
118    ) -> Result<Self, Error> {
119        install_cryptographic_provider();
120
121        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
122
123        // Note: We don't spawn a read task here since the reader is handled externally
124        let read_task = None;
125
126        let backoff = ExponentialBackoff::new(
127            Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
128            Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
129            config.reconnect_backoff_factor.unwrap_or(1.5),
130            config.reconnect_jitter_ms.unwrap_or(100),
131            true, // immediate-first
132        )
133        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
134
135        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
136        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
137
138        let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
139            Some(Self::spawn_heartbeat_task(
140                connection_mode.clone(),
141                heartbeat_interval,
142                config.heartbeat_msg.clone(),
143                writer_tx.clone(),
144            ))
145        } else {
146            None
147        };
148
149        let reconnect_max_attempts = config.reconnect_max_attempts;
150        let reconnect_timeout = Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10000));
151
152        Ok(Self {
153            config,
154            message_handler: None, // Stream mode has no handler
155            ping_handler: None,
156            writer_tx,
157            connection_mode,
158            reconnect_timeout,
159            heartbeat_task,
160            read_task,
161            write_task,
162            backoff,
163            is_stream_mode: true,
164            reconnect_max_attempts,
165            reconnection_attempt_count: 0,
166        })
167    }
168
169    /// Create an inner websocket client.
170    ///
171    /// # Errors
172    ///
173    /// Returns an error if:
174    /// - The connection to the server fails.
175    /// - The exponential backoff configuration is invalid.
176    pub async fn connect_url(
177        config: WebSocketConfig,
178        message_handler: Option<MessageHandler>,
179        ping_handler: Option<PingHandler>,
180    ) -> Result<Self, Error> {
181        install_cryptographic_provider();
182
183        // Capture whether we're in stream mode before moving config
184        let is_stream_mode = message_handler.is_none();
185        let reconnect_max_attempts = config.reconnect_max_attempts;
186
187        let (writer, reader) =
188            Self::connect_with_server(&config.url, config.headers.clone()).await?;
189
190        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
191
192        let read_task = if message_handler.is_some() {
193            Some(Self::spawn_message_handler_task(
194                connection_mode.clone(),
195                reader,
196                message_handler.as_ref(),
197                ping_handler.as_ref(),
198            ))
199        } else {
200            None
201        };
202
203        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
204        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
205
206        // Optionally spawn a heartbeat task to periodically ping server
207        let heartbeat_task = config.heartbeat.map(|heartbeat_secs| {
208            Self::spawn_heartbeat_task(
209                connection_mode.clone(),
210                heartbeat_secs,
211                config.heartbeat_msg.clone(),
212                writer_tx.clone(),
213            )
214        });
215
216        let reconnect_timeout =
217            Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10_000));
218        let backoff = ExponentialBackoff::new(
219            Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
220            Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
221            config.reconnect_backoff_factor.unwrap_or(1.5),
222            config.reconnect_jitter_ms.unwrap_or(100),
223            true, // immediate-first
224        )
225        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
226
227        Ok(Self {
228            config,
229            message_handler,
230            ping_handler,
231            read_task,
232            write_task,
233            writer_tx,
234            heartbeat_task,
235            connection_mode,
236            reconnect_timeout,
237            backoff,
238            // Set stream mode when no message handler (reader not managed by client)
239            is_stream_mode,
240            reconnect_max_attempts,
241            reconnection_attempt_count: 0,
242        })
243    }
244
245    /// Connects with the server creating a tokio-tungstenite websocket stream.
246    /// Production version that uses `connect_async_with_config` convenience helper.
247    ///
248    /// # Errors
249    ///
250    /// Returns an error if:
251    /// - The URL cannot be parsed into a valid client request.
252    /// - Header values are invalid.
253    /// - The WebSocket connection fails.
254    #[inline]
255    #[cfg(not(feature = "turmoil"))]
256    pub async fn connect_with_server(
257        url: &str,
258        headers: Vec<(String, String)>,
259    ) -> Result<(MessageWriter, MessageReader), Error> {
260        let mut request = url.into_client_request()?;
261        let req_headers = request.headers_mut();
262
263        let mut header_names: Vec<HeaderName> = Vec::new();
264        for (key, val) in headers {
265            let header_value = HeaderValue::from_str(&val)?;
266            let header_name: HeaderName = key.parse()?;
267            header_names.push(header_name.clone());
268            req_headers.insert(header_name, header_value);
269        }
270
271        connect_async_with_config(request, None, true)
272            .await
273            .map(|resp| resp.0.split())
274    }
275
276    /// Connects with the server creating a tokio-tungstenite websocket stream.
277    /// Turmoil version that uses the lower-level `client_async` API with injected stream.
278    ///
279    /// # Errors
280    ///
281    /// Returns an error if:
282    /// - The URL cannot be parsed into a valid client request.
283    /// - The URL is missing a hostname.
284    /// - Header values are invalid.
285    /// - The TCP connection fails.
286    /// - TLS setup fails (for wss:// URLs).
287    /// - The WebSocket handshake fails.
288    #[inline]
289    #[cfg(feature = "turmoil")]
290    pub async fn connect_with_server(
291        url: &str,
292        headers: Vec<(String, String)>,
293    ) -> Result<(MessageWriter, MessageReader), Error> {
294        use rustls::ClientConfig;
295        use tokio_rustls::TlsConnector;
296
297        let mut request = url.into_client_request()?;
298        let req_headers = request.headers_mut();
299
300        let mut header_names: Vec<HeaderName> = Vec::new();
301        for (key, val) in headers {
302            let header_value = HeaderValue::from_str(&val)?;
303            let header_name: HeaderName = key.parse()?;
304            header_names.push(header_name.clone());
305            req_headers.insert(header_name, header_value);
306        }
307
308        let uri = request.uri();
309        let scheme = uri.scheme_str().unwrap_or("ws");
310        let host = uri.host().ok_or_else(|| {
311            Error::Url(tokio_tungstenite::tungstenite::error::UrlError::NoHostName)
312        })?;
313
314        // Determine port: use explicit port if specified, otherwise default based on scheme
315        let port = uri
316            .port_u16()
317            .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
318
319        let addr = format!("{host}:{port}");
320
321        // Use the connector to get a turmoil-compatible stream
322        let connector = crate::net::RealTcpConnector;
323        let tcp_stream = connector.connect(&addr).await?;
324        if let Err(e) = tcp_stream.set_nodelay(true) {
325            log::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
326        }
327
328        // Wrap stream appropriately based on scheme
329        let maybe_tls_stream = if scheme == "wss" {
330            // Build TLS config with webpki roots
331            let mut root_store = rustls::RootCertStore::empty();
332            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
333
334            let config = ClientConfig::builder()
335                .with_root_certificates(root_store)
336                .with_no_client_auth();
337
338            let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
339            let domain =
340                rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|e| {
341                    Error::Io(std::io::Error::new(
342                        std::io::ErrorKind::InvalidInput,
343                        format!("Invalid DNS name: {e}"),
344                    ))
345                })?;
346
347            let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
348            MaybeTlsStream::Rustls(tls_stream)
349        } else {
350            MaybeTlsStream::Plain(tcp_stream)
351        };
352
353        // Use client_async with the stream (plain or TLS)
354        client_async(request, maybe_tls_stream)
355            .await
356            .map(|resp| resp.0.split())
357    }
358
359    /// Reconnect with server.
360    ///
361    /// Make a new connection with server. Use the new read and write halves
362    /// to update self writer and read and heartbeat tasks.
363    ///
364    /// For stream-based clients (created via `connect_stream`), reconnection is disabled
365    /// because the reader is owned by the caller and cannot be replaced. Stream users
366    /// should handle disconnections by creating a new connection.
367    ///
368    /// # Errors
369    ///
370    /// Returns an error if:
371    /// - The reconnection attempt times out.
372    /// - The connection to the server fails.
373    pub async fn reconnect(&mut self) -> Result<(), Error> {
374        log::debug!("Reconnecting");
375
376        if self.is_stream_mode {
377            log::warn!(
378                "Auto-reconnect disabled for stream-based WebSocket client; \
379                stream users must manually reconnect by creating a new connection"
380            );
381            // Transition to CLOSED state to stop reconnection attempts
382            self.connection_mode
383                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
384            return Ok(());
385        }
386
387        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
388            log::debug!("Reconnect aborted due to disconnect state");
389            return Ok(());
390        }
391
392        tokio::time::timeout(self.reconnect_timeout, async {
393            // Attempt to connect; abort early if a disconnect was requested
394            let (new_writer, reader) =
395                Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
396
397            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
398                log::debug!("Reconnect aborted mid-flight (after connect)");
399                return Ok(());
400            }
401
402            // Use a oneshot channel to synchronize with the writer task.
403            // We must verify that the buffer was successfully drained before transitioning to ACTIVE
404            // to prevent silent message loss if the new connection drops immediately.
405            let (tx, rx) = tokio::sync::oneshot::channel();
406            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
407                log::error!("{e}");
408                return Err(Error::Io(std::io::Error::new(
409                    std::io::ErrorKind::BrokenPipe,
410                    format!("Failed to send update command: {e}"),
411                )));
412            }
413
414            // Wait for writer to confirm it has drained the buffer
415            match rx.await {
416                Ok(true) => log::debug!("Writer confirmed buffer drain success"),
417                Ok(false) => {
418                    log::warn!("Writer failed to drain buffer, aborting reconnect");
419                    // Return error to trigger retry logic in controller
420                    return Err(Error::Io(std::io::Error::other(
421                        "Failed to drain reconnection buffer",
422                    )));
423                }
424                Err(e) => {
425                    log::error!("Writer dropped update channel: {e}");
426                    return Err(Error::Io(std::io::Error::new(
427                        std::io::ErrorKind::BrokenPipe,
428                        "Writer task dropped response channel",
429                    )));
430                }
431            }
432
433            // Delay before closing connection
434            tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
435
436            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
437                log::debug!("Reconnect aborted mid-flight (after delay)");
438                return Ok(());
439            }
440
441            if let Some(ref read_task) = self.read_task.take()
442                && !read_task.is_finished()
443            {
444                read_task.abort();
445                log_task_aborted("read");
446            }
447
448            // Atomically transition from Reconnect to Active
449            // This prevents race condition where disconnect could be requested between check and store
450            if self
451                .connection_mode
452                .compare_exchange(
453                    ConnectionMode::Reconnect.as_u8(),
454                    ConnectionMode::Active.as_u8(),
455                    Ordering::SeqCst,
456                    Ordering::SeqCst,
457                )
458                .is_err()
459            {
460                log::debug!("Reconnect aborted (state changed during reconnect)");
461                return Ok(());
462            }
463
464            self.read_task = if self.message_handler.is_some() {
465                Some(Self::spawn_message_handler_task(
466                    self.connection_mode.clone(),
467                    reader,
468                    self.message_handler.as_ref(),
469                    self.ping_handler.as_ref(),
470                ))
471            } else {
472                None
473            };
474
475            log::debug!("Reconnect succeeded");
476            Ok(())
477        })
478        .await
479        .map_err(|_| {
480            Error::Io(std::io::Error::new(
481                std::io::ErrorKind::TimedOut,
482                format!(
483                    "reconnection timed out after {}s",
484                    self.reconnect_timeout.as_secs_f64()
485                ),
486            ))
487        })?
488    }
489
490    /// Check if the client is still connected.
491    ///
492    /// The client is connected if the read task has not finished. It is expected
493    /// that in case of any failure client or server side. The read task will be
494    /// shutdown or will receive a `Close` frame which will finish it. There
495    /// might be some delay between the connection being closed and the client
496    /// detecting.
497    #[inline]
498    #[must_use]
499    pub fn is_alive(&self) -> bool {
500        match &self.read_task {
501            Some(read_task) => !read_task.is_finished(),
502            None => true, // Stream is being used directly
503        }
504    }
505
506    fn spawn_message_handler_task(
507        connection_state: Arc<AtomicU8>,
508        mut reader: MessageReader,
509        message_handler: Option<&MessageHandler>,
510        ping_handler: Option<&PingHandler>,
511    ) -> tokio::task::JoinHandle<()> {
512        log::debug!("Started message handler task 'read'");
513
514        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
515
516        // Clone Arc handlers for the async task
517        let message_handler = message_handler.cloned();
518        let ping_handler = ping_handler.cloned();
519
520        tokio::task::spawn(async move {
521            loop {
522                if !ConnectionMode::from_atomic(&connection_state).is_active() {
523                    break;
524                }
525
526                match tokio::time::timeout(check_interval, reader.next()).await {
527                    Ok(Some(Ok(Message::Binary(data)))) => {
528                        log::trace!("Received message <binary> {} bytes", data.len());
529                        if let Some(ref handler) = message_handler {
530                            handler(Message::Binary(data));
531                        }
532                    }
533                    Ok(Some(Ok(Message::Text(data)))) => {
534                        log::trace!("Received message: {data}");
535                        if let Some(ref handler) = message_handler {
536                            handler(Message::Text(data));
537                        }
538                    }
539                    Ok(Some(Ok(Message::Ping(ping_data)))) => {
540                        log::trace!("Received ping: {ping_data:?}");
541                        if let Some(ref handler) = ping_handler {
542                            handler(ping_data.to_vec());
543                        }
544                    }
545                    Ok(Some(Ok(Message::Pong(_)))) => {
546                        log::trace!("Received pong");
547                    }
548                    Ok(Some(Ok(Message::Close(_)))) => {
549                        log::debug!("Received close message - terminating");
550                        break;
551                    }
552                    Ok(Some(Ok(_))) => (),
553                    Ok(Some(Err(e))) => {
554                        log::error!("Received error message - terminating: {e}");
555                        break;
556                    }
557                    Ok(None) => {
558                        log::debug!("No message received - terminating");
559                        break;
560                    }
561                    Err(_) => {
562                        // Timeout - continue loop and check connection mode
563                        continue;
564                    }
565                }
566            }
567        })
568    }
569
570    /// Attempts to send all buffered messages after reconnection.
571    ///
572    /// Returns `true` if a send error occurred (caller should trigger reconnection).
573    /// Messages remain in buffer if send fails, preserving them for the next reconnection attempt.
574    async fn drain_reconnect_buffer(
575        buffer: &mut VecDeque<Message>,
576        writer: &mut MessageWriter,
577    ) -> bool {
578        if buffer.is_empty() {
579            return false;
580        }
581
582        let initial_buffer_len = buffer.len();
583        log::info!("Sending {initial_buffer_len} buffered messages after reconnection");
584
585        let mut send_error_occurred = false;
586
587        while let Some(buffered_msg) = buffer.front() {
588            // Clone message before attempting send (to keep in buffer if send fails)
589            let msg_to_send = buffered_msg.clone();
590
591            if let Err(e) = writer.send(msg_to_send).await {
592                log::error!(
593                    "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
594                    buffer.len()
595                );
596                send_error_occurred = true;
597                break; // Stop processing buffer, remaining messages preserved for next reconnection
598            }
599
600            // Only remove from buffer after successful send
601            buffer.pop_front();
602        }
603
604        if buffer.is_empty() {
605            log::info!("Successfully sent all {initial_buffer_len} buffered messages");
606        }
607
608        send_error_occurred
609    }
610
611    fn spawn_write_task(
612        connection_state: Arc<AtomicU8>,
613        writer: MessageWriter,
614        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
615    ) -> tokio::task::JoinHandle<()> {
616        log_task_started("write");
617
618        // Interval between checking the connection mode
619        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
620
621        tokio::task::spawn(async move {
622            let mut active_writer = writer;
623            // Buffer for messages received during reconnection
624            // VecDeque for efficient pop_front() operations
625            let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
626
627            loop {
628                match ConnectionMode::from_atomic(&connection_state) {
629                    ConnectionMode::Disconnect => {
630                        // Log any buffered messages that will be lost
631                        if !reconnect_buffer.is_empty() {
632                            log::warn!(
633                                "Discarding {} buffered messages due to disconnect",
634                                reconnect_buffer.len()
635                            );
636                            reconnect_buffer.clear();
637                        }
638
639                        // Attempt to close the writer gracefully before exiting,
640                        // we ignore any error as the writer may already be closed.
641                        _ = tokio::time::timeout(
642                            Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
643                            active_writer.close(),
644                        )
645                        .await;
646                        break;
647                    }
648                    ConnectionMode::Closed => {
649                        // Log any buffered messages that will be lost
650                        if !reconnect_buffer.is_empty() {
651                            log::warn!(
652                                "Discarding {} buffered messages due to closed connection",
653                                reconnect_buffer.len()
654                            );
655                            reconnect_buffer.clear();
656                        }
657                        break;
658                    }
659                    _ => {}
660                }
661
662                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
663                    Ok(Some(msg)) => {
664                        // Re-check connection mode after receiving a message
665                        let mode = ConnectionMode::from_atomic(&connection_state);
666                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
667                            break;
668                        }
669
670                        match msg {
671                            WriterCommand::Update(new_writer, tx) => {
672                                log::debug!("Received new writer");
673
674                                // Delay before closing connection
675                                tokio::time::sleep(Duration::from_millis(100)).await;
676
677                                // Attempt to close the writer gracefully on update,
678                                // we ignore any error as the writer may already be closed.
679                                _ = tokio::time::timeout(
680                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
681                                    active_writer.close(),
682                                )
683                                .await;
684
685                                active_writer = new_writer;
686                                log::debug!("Updated writer");
687
688                                let send_error = Self::drain_reconnect_buffer(
689                                    &mut reconnect_buffer,
690                                    &mut active_writer,
691                                )
692                                .await;
693
694                                if let Err(e) = tx.send(!send_error) {
695                                    log::error!(
696                                        "Failed to report drain status to controller: {e:?}"
697                                    );
698                                }
699                            }
700                            WriterCommand::Send(msg) if mode.is_reconnect() => {
701                                // Buffer messages during reconnection instead of dropping them
702                                log::debug!(
703                                    "Buffering message during reconnection (buffer size: {})",
704                                    reconnect_buffer.len() + 1
705                                );
706                                reconnect_buffer.push_back(msg);
707                            }
708                            WriterCommand::Send(msg) => {
709                                if let Err(e) = active_writer.send(msg.clone()).await {
710                                    log::error!("Failed to send message: {e}");
711                                    log::warn!("Writer triggering reconnect");
712                                    reconnect_buffer.push_back(msg);
713                                    connection_state
714                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
715                                }
716                            }
717                        }
718                    }
719                    Ok(None) => {
720                        // Channel closed - writer task should terminate
721                        log::debug!("Writer channel closed, terminating writer task");
722                        break;
723                    }
724                    Err(_) => {
725                        // Timeout - just continue the loop
726                        continue;
727                    }
728                }
729            }
730
731            // Attempt to close the writer gracefully before exiting,
732            // we ignore any error as the writer may already be closed.
733            _ = tokio::time::timeout(
734                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
735                active_writer.close(),
736            )
737            .await;
738
739            log_task_stopped("write");
740        })
741    }
742
743    fn spawn_heartbeat_task(
744        connection_state: Arc<AtomicU8>,
745        heartbeat_secs: u64,
746        message: Option<String>,
747        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
748    ) -> tokio::task::JoinHandle<()> {
749        log_task_started("heartbeat");
750
751        tokio::task::spawn(async move {
752            let interval = Duration::from_secs(heartbeat_secs);
753
754            loop {
755                tokio::time::sleep(interval).await;
756
757                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
758                    ConnectionMode::Active => {
759                        let msg = match &message {
760                            Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
761                            None => WriterCommand::Send(Message::Ping(vec![].into())),
762                        };
763
764                        match writer_tx.send(msg) {
765                            Ok(()) => log::trace!("Sent heartbeat to writer task"),
766                            Err(e) => {
767                                log::error!("Failed to send heartbeat to writer task: {e}");
768                            }
769                        }
770                    }
771                    ConnectionMode::Reconnect => continue,
772                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
773                }
774            }
775
776            log_task_stopped("heartbeat");
777        })
778    }
779}
780
781impl Drop for WebSocketClientInner {
782    fn drop(&mut self) {
783        // Delegate to explicit cleanup handler
784        self.clean_drop();
785    }
786}
787
788/// Cleanup on drop: aborts background tasks and clears handlers to break reference cycles.
789impl CleanDrop for WebSocketClientInner {
790    fn clean_drop(&mut self) {
791        if let Some(ref read_task) = self.read_task.take()
792            && !read_task.is_finished()
793        {
794            read_task.abort();
795            log_task_aborted("read");
796        }
797
798        if !self.write_task.is_finished() {
799            self.write_task.abort();
800            log_task_aborted("write");
801        }
802
803        if let Some(ref handle) = self.heartbeat_task.take()
804            && !handle.is_finished()
805        {
806            handle.abort();
807            log_task_aborted("heartbeat");
808        }
809
810        // Clear handlers to break potential reference cycles
811        self.message_handler = None;
812        self.ping_handler = None;
813    }
814}
815
816impl Debug for WebSocketClientInner {
817    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
818        f.debug_struct(stringify!(WebSocketClientInner))
819            .field("config", &self.config)
820            .field(
821                "connection_mode",
822                &ConnectionMode::from_atomic(&self.connection_mode),
823            )
824            .field("reconnect_timeout", &self.reconnect_timeout)
825            .field("is_stream_mode", &self.is_stream_mode)
826            .finish()
827    }
828}
829
830/// WebSocket client with automatic reconnection.
831///
832/// Handles connection state, callbacks, and rate limiting.
833/// See module docs for architecture details.
834#[cfg_attr(
835    feature = "python",
836    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
837)]
838pub struct WebSocketClient {
839    pub(crate) controller_task: tokio::task::JoinHandle<()>,
840    pub(crate) connection_mode: Arc<AtomicU8>,
841    pub(crate) reconnect_timeout: Duration,
842    pub(crate) rate_limiter: Arc<RateLimiter<Ustr, MonotonicClock>>,
843    pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
844}
845
846impl Debug for WebSocketClient {
847    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
848        f.debug_struct(stringify!(WebSocketClient)).finish()
849    }
850}
851
852impl WebSocketClient {
853    /// Creates a websocket client in **stream mode** that returns a [`MessageReader`].
854    ///
855    /// Returns a stream that the caller owns and reads from directly. Automatic reconnection
856    /// is **disabled** because the reader cannot be replaced internally. On disconnection, the
857    /// client transitions to CLOSED state and the caller must manually reconnect by calling
858    /// `connect_stream` again.
859    ///
860    /// Use stream mode when you need custom reconnection logic, direct control over message
861    /// reading, or fine-grained backpressure handling.
862    ///
863    /// See [`WebSocketConfig`] documentation for comparison with handler mode.
864    ///
865    /// # Errors
866    ///
867    /// Returns an error if the connection cannot be established.
868    #[allow(clippy::too_many_arguments)]
869    pub async fn connect_stream(
870        config: WebSocketConfig,
871        keyed_quotas: Vec<(String, Quota)>,
872        default_quota: Option<Quota>,
873        post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
874    ) -> Result<(MessageReader, Self), Error> {
875        install_cryptographic_provider();
876
877        // Create a single connection and split it, respecting configured headers
878        let (writer, reader) =
879            WebSocketClientInner::connect_with_server(&config.url, config.headers.clone()).await?;
880
881        // Create inner without connecting (we'll provide the writer)
882        let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
883
884        let connection_mode = inner.connection_mode.clone();
885        let reconnect_timeout = inner.reconnect_timeout;
886        let keyed_quotas = keyed_quotas
887            .into_iter()
888            .map(|(key, quota)| (Ustr::from(&key), quota))
889            .collect();
890        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
891        let writer_tx = inner.writer_tx.clone();
892
893        let controller_task =
894            Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnect);
895
896        Ok((
897            reader,
898            Self {
899                controller_task,
900                connection_mode,
901                reconnect_timeout,
902                rate_limiter,
903                writer_tx,
904            },
905        ))
906    }
907
908    /// Creates a websocket client in **handler mode** with automatic reconnection.
909    ///
910    /// The handler is called for each incoming message on an internal task.
911    /// Automatic reconnection is **enabled** with exponential backoff. On disconnection,
912    /// the client automatically attempts to reconnect and replaces the internal reader
913    /// (the handler continues working seamlessly).
914    ///
915    /// Use handler mode for simplified connection management, automatic reconnection, Python
916    /// bindings, or callback-based message handling.
917    ///
918    /// See [`WebSocketConfig`] documentation for comparison with stream mode.
919    ///
920    /// # Errors
921    ///
922    /// Returns an error if:
923    /// - The connection cannot be established.
924    /// - `message_handler` is `None` (use `connect_stream` instead).
925    pub async fn connect(
926        config: WebSocketConfig,
927        message_handler: Option<MessageHandler>,
928        ping_handler: Option<PingHandler>,
929        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
930        keyed_quotas: Vec<(String, Quota)>,
931        default_quota: Option<Quota>,
932    ) -> Result<Self, Error> {
933        // Validate that handler mode has a message handler
934        if message_handler.is_none() {
935            return Err(Error::Io(std::io::Error::new(
936                std::io::ErrorKind::InvalidInput,
937                "Handler mode requires message_handler to be set. Use connect_stream() for stream mode without a handler.",
938            )));
939        }
940
941        log::debug!("Connecting");
942        let inner =
943            WebSocketClientInner::connect_url(config, message_handler, ping_handler).await?;
944        let connection_mode = inner.connection_mode.clone();
945        let writer_tx = inner.writer_tx.clone();
946        let reconnect_timeout = inner.reconnect_timeout;
947
948        let controller_task =
949            Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnection);
950
951        let keyed_quotas = keyed_quotas
952            .into_iter()
953            .map(|(key, quota)| (Ustr::from(&key), quota))
954            .collect();
955        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
956
957        Ok(Self {
958            controller_task,
959            connection_mode,
960            reconnect_timeout,
961            rate_limiter,
962            writer_tx,
963        })
964    }
965
966    /// Returns the current connection mode.
967    #[must_use]
968    pub fn connection_mode(&self) -> ConnectionMode {
969        ConnectionMode::from_atomic(&self.connection_mode)
970    }
971
972    /// Returns a clone of the connection mode atomic for external state tracking.
973    ///
974    /// This allows adapter clients to track connection state across reconnections
975    /// without message-passing delays.
976    #[must_use]
977    pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
978        Arc::clone(&self.connection_mode)
979    }
980
981    /// Check if the client connection is active.
982    ///
983    /// Returns `true` if the client is connected and has not been signalled to disconnect.
984    /// The client will automatically retry connection based on its configuration.
985    #[inline]
986    #[must_use]
987    pub fn is_active(&self) -> bool {
988        self.connection_mode().is_active()
989    }
990
991    /// Check if the client is disconnected.
992    #[must_use]
993    pub fn is_disconnected(&self) -> bool {
994        self.controller_task.is_finished()
995    }
996
997    /// Check if the client is reconnecting.
998    ///
999    /// Returns `true` if the client lost connection and is attempting to reestablish it.
1000    /// The client will automatically retry connection based on its configuration.
1001    #[inline]
1002    #[must_use]
1003    pub fn is_reconnecting(&self) -> bool {
1004        self.connection_mode().is_reconnect()
1005    }
1006
1007    /// Check if the client is disconnecting.
1008    ///
1009    /// Returns `true` if the client is in disconnect mode.
1010    #[inline]
1011    #[must_use]
1012    pub fn is_disconnecting(&self) -> bool {
1013        self.connection_mode().is_disconnect()
1014    }
1015
1016    /// Check if the client is closed.
1017    ///
1018    /// Returns `true` if the client has been explicitly disconnected or reached
1019    /// maximum reconnection attempts. In this state, the client cannot be reused
1020    /// and a new client must be created for further connections.
1021    #[inline]
1022    #[must_use]
1023    pub fn is_closed(&self) -> bool {
1024        self.connection_mode().is_closed()
1025    }
1026
1027    /// Wait for the client to become active before sending.
1028    ///
1029    /// Returns an error if the client is closed, disconnecting, or if the wait times out.
1030    async fn wait_for_active(&self) -> Result<(), SendError> {
1031        if self.is_closed() {
1032            return Err(SendError::Closed);
1033        }
1034
1035        let timeout = self.reconnect_timeout;
1036        let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
1037
1038        if !self.is_active() {
1039            log::debug!("Waiting for client to become ACTIVE before sending...");
1040
1041            let inner = tokio::time::timeout(timeout, async {
1042                loop {
1043                    if self.is_active() {
1044                        return Ok(());
1045                    }
1046                    if matches!(
1047                        self.connection_mode(),
1048                        ConnectionMode::Disconnect | ConnectionMode::Closed
1049                    ) {
1050                        return Err(());
1051                    }
1052                    tokio::time::sleep(check_interval).await;
1053                }
1054            })
1055            .await
1056            .map_err(|_| SendError::Timeout)?;
1057            inner.map_err(|()| SendError::Closed)?;
1058        }
1059
1060        Ok(())
1061    }
1062
1063    /// Set disconnect mode to true.
1064    ///
1065    /// Controller task will periodically check the disconnect mode
1066    /// and shutdown the client if it is alive
1067    pub async fn disconnect(&self) {
1068        log::debug!("Disconnecting");
1069        self.connection_mode
1070            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1071
1072        if tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1073            while !self.is_disconnected() {
1074                tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
1075            }
1076
1077            if !self.controller_task.is_finished() {
1078                self.controller_task.abort();
1079                log_task_aborted("controller");
1080            }
1081        })
1082        .await
1083            == Ok(())
1084        {
1085            log::debug!("Controller task finished");
1086        } else {
1087            log::error!("Timeout waiting for controller task to finish");
1088            if !self.controller_task.is_finished() {
1089                self.controller_task.abort();
1090                log_task_aborted("controller");
1091            }
1092        }
1093    }
1094
1095    /// Sends the given text `data` to the server.
1096    ///
1097    /// # Errors
1098    ///
1099    /// Returns a websocket error if unable to send.
1100    #[allow(unused_variables)]
1101    pub async fn send_text(&self, data: String, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1102        // Check connection state before rate limiting to fail fast
1103        if self.is_closed() || self.is_disconnecting() {
1104            return Err(SendError::Closed);
1105        }
1106
1107        self.rate_limiter.await_keys_ready(keys).await;
1108        self.wait_for_active().await?;
1109
1110        log::trace!("Sending text: {data:?}");
1111
1112        let msg = Message::Text(data.into());
1113        self.writer_tx
1114            .send(WriterCommand::Send(msg))
1115            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1116    }
1117
1118    /// Sends a pong frame back to the server.
1119    ///
1120    /// # Errors
1121    ///
1122    /// Returns a websocket error if unable to send.
1123    pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1124        self.wait_for_active().await?;
1125
1126        log::trace!("Sending pong frame ({} bytes)", data.len());
1127
1128        let msg = Message::Pong(data.into());
1129        self.writer_tx
1130            .send(WriterCommand::Send(msg))
1131            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1132    }
1133
1134    /// Sends the given bytes `data` to the server.
1135    ///
1136    /// # Errors
1137    ///
1138    /// Returns a websocket error if unable to send.
1139    #[allow(unused_variables)]
1140    pub async fn send_bytes(&self, data: Vec<u8>, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1141        // Check connection state before rate limiting to fail fast
1142        if self.is_closed() || self.is_disconnecting() {
1143            return Err(SendError::Closed);
1144        }
1145
1146        self.rate_limiter.await_keys_ready(keys).await;
1147        self.wait_for_active().await?;
1148
1149        log::trace!("Sending bytes: {data:?}");
1150
1151        let msg = Message::Binary(data.into());
1152        self.writer_tx
1153            .send(WriterCommand::Send(msg))
1154            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1155    }
1156
1157    /// Sends a close message to the server.
1158    ///
1159    /// # Errors
1160    ///
1161    /// Returns a websocket error if unable to send.
1162    pub async fn send_close_message(&self) -> Result<(), SendError> {
1163        self.wait_for_active().await?;
1164
1165        let msg = Message::Close(None);
1166        self.writer_tx
1167            .send(WriterCommand::Send(msg))
1168            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1169    }
1170
1171    fn spawn_controller_task(
1172        mut inner: WebSocketClientInner,
1173        connection_mode: Arc<AtomicU8>,
1174        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1175    ) -> tokio::task::JoinHandle<()> {
1176        tokio::task::spawn(async move {
1177            log_task_started("controller");
1178
1179            let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1180
1181            loop {
1182                tokio::time::sleep(check_interval).await;
1183                let mut mode = ConnectionMode::from_atomic(&connection_mode);
1184
1185                if mode.is_disconnect() {
1186                    log::debug!("Disconnecting");
1187
1188                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1189                    if tokio::time::timeout(timeout, async {
1190                        // Delay awaiting graceful shutdown
1191                        tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1192
1193                        if let Some(task) = &inner.read_task
1194                            && !task.is_finished()
1195                        {
1196                            task.abort();
1197                            log_task_aborted("read");
1198                        }
1199
1200                        if let Some(task) = &inner.heartbeat_task
1201                            && !task.is_finished()
1202                        {
1203                            task.abort();
1204                            log_task_aborted("heartbeat");
1205                        }
1206                    })
1207                    .await
1208                    .is_err()
1209                    {
1210                        log::error!("Shutdown timed out after {}s", timeout.as_secs());
1211                    }
1212
1213                    log::debug!("Closed");
1214                    break; // Controller finished
1215                }
1216
1217                if mode.is_closed() {
1218                    log::debug!("Connection closed");
1219                    break;
1220                }
1221
1222                if mode.is_active() && !inner.is_alive() {
1223                    if connection_mode
1224                        .compare_exchange(
1225                            ConnectionMode::Active.as_u8(),
1226                            ConnectionMode::Reconnect.as_u8(),
1227                            Ordering::SeqCst,
1228                            Ordering::SeqCst,
1229                        )
1230                        .is_ok()
1231                    {
1232                        log::debug!("Detected dead read task, transitioning to RECONNECT");
1233                    }
1234                    mode = ConnectionMode::from_atomic(&connection_mode);
1235                }
1236
1237                if mode.is_reconnect() {
1238                    // Check if max reconnection attempts exceeded
1239                    if let Some(max_attempts) = inner.reconnect_max_attempts
1240                        && inner.reconnection_attempt_count >= max_attempts
1241                    {
1242                        log::error!(
1243                            "Max reconnection attempts ({max_attempts}) exceeded, transitioning to CLOSED"
1244                        );
1245                        connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1246                        break;
1247                    }
1248
1249                    inner.reconnection_attempt_count += 1;
1250                    log::debug!(
1251                        "Reconnection attempt {} of {}",
1252                        inner.reconnection_attempt_count,
1253                        inner
1254                            .reconnect_max_attempts
1255                            .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1256                    );
1257
1258                    match inner.reconnect().await {
1259                        Ok(()) => {
1260                            inner.backoff.reset();
1261                            inner.reconnection_attempt_count = 0; // Reset counter on success
1262
1263                            // Only invoke callbacks if not in disconnect state
1264                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
1265                                if let Some(ref handler) = inner.message_handler {
1266                                    let reconnected_msg =
1267                                        Message::Text(RECONNECTED.to_string().into());
1268                                    handler(reconnected_msg);
1269                                    log::debug!("Sent reconnected message to handler");
1270                                }
1271
1272                                // TODO: Retain this legacy callback for use from Python
1273                                if let Some(ref callback) = post_reconnection {
1274                                    callback();
1275                                    log::debug!("Called `post_reconnection` handler");
1276                                }
1277
1278                                log::debug!("Reconnected successfully");
1279                            } else {
1280                                log::debug!(
1281                                    "Skipping post_reconnection handlers due to disconnect state"
1282                                );
1283                            }
1284                        }
1285                        Err(e) => {
1286                            let duration = inner.backoff.next_duration();
1287                            log::warn!(
1288                                "Reconnect attempt {} failed: {e}",
1289                                inner.reconnection_attempt_count
1290                            );
1291                            if !duration.is_zero() {
1292                                log::warn!("Backing off for {}s...", duration.as_secs_f64());
1293                            }
1294                            tokio::time::sleep(duration).await;
1295                        }
1296                    }
1297                }
1298            }
1299            inner
1300                .connection_mode
1301                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1302
1303            log_task_stopped("controller");
1304        })
1305    }
1306}
1307
1308// Abort controller task on drop to clean up background tasks
1309impl Drop for WebSocketClient {
1310    fn drop(&mut self) {
1311        if !self.controller_task.is_finished() {
1312            self.controller_task.abort();
1313            log_task_aborted("controller");
1314        }
1315    }
1316}
1317
1318#[cfg(test)]
1319#[cfg(not(feature = "turmoil"))]
1320#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
1321mod tests {
1322    use std::{num::NonZeroU32, sync::Arc};
1323
1324    use futures_util::{SinkExt, StreamExt};
1325    use tokio::{
1326        net::TcpListener,
1327        task::{self, JoinHandle},
1328    };
1329    use tokio_tungstenite::{
1330        accept_hdr_async,
1331        tungstenite::{
1332            handshake::server::{self, Callback},
1333            http::HeaderValue,
1334        },
1335    };
1336
1337    use crate::{
1338        ratelimiter::quota::Quota,
1339        websocket::{WebSocketClient, WebSocketConfig},
1340    };
1341
1342    struct TestServer {
1343        task: JoinHandle<()>,
1344        port: u16,
1345    }
1346
1347    #[derive(Debug, Clone)]
1348    struct TestCallback {
1349        key: String,
1350        value: HeaderValue,
1351    }
1352
1353    impl Callback for TestCallback {
1354        #[allow(clippy::panic_in_result_fn)]
1355        fn on_request(
1356            self,
1357            request: &server::Request,
1358            response: server::Response,
1359        ) -> Result<server::Response, server::ErrorResponse> {
1360            let _ = response;
1361            let value = request.headers().get(&self.key);
1362            assert!(value.is_some());
1363
1364            if let Some(value) = request.headers().get(&self.key) {
1365                assert_eq!(value, self.value);
1366            }
1367
1368            Ok(response)
1369        }
1370    }
1371
1372    impl TestServer {
1373        async fn setup() -> Self {
1374            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1375            let port = TcpListener::local_addr(&server).unwrap().port();
1376
1377            let header_key = "test".to_string();
1378            let header_value = "test".to_string();
1379
1380            let test_call_back = TestCallback {
1381                key: header_key,
1382                value: HeaderValue::from_str(&header_value).unwrap(),
1383            };
1384
1385            let task = task::spawn(async move {
1386                // Keep accepting connections
1387                loop {
1388                    let (conn, _) = server.accept().await.unwrap();
1389                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1390                        .await
1391                        .unwrap();
1392
1393                    task::spawn(async move {
1394                        while let Some(Ok(msg)) = websocket.next().await {
1395                            match msg {
1396                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1397                                    if txt == "close-now" =>
1398                                {
1399                                    log::debug!("Forcibly closing from server side");
1400                                    // This sends a close frame, then stops reading
1401                                    let _ = websocket.close(None).await;
1402                                    break;
1403                                }
1404                                // Echo text/binary frames
1405                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1406                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1407                                    if websocket.send(msg).await.is_err() {
1408                                        break;
1409                                    }
1410                                }
1411                                // If the client closes, we also break
1412                                tokio_tungstenite::tungstenite::protocol::Message::Close(
1413                                    _frame,
1414                                ) => {
1415                                    let _ = websocket.close(None).await;
1416                                    break;
1417                                }
1418                                // Ignore pings/pongs
1419                                _ => {}
1420                            }
1421                        }
1422                    });
1423                }
1424            });
1425
1426            Self { task, port }
1427        }
1428    }
1429
1430    impl Drop for TestServer {
1431        fn drop(&mut self) {
1432            self.task.abort();
1433        }
1434    }
1435
1436    async fn setup_test_client(port: u16) -> WebSocketClient {
1437        let config = WebSocketConfig {
1438            url: format!("ws://127.0.0.1:{port}"),
1439            headers: vec![("test".into(), "test".into())],
1440            heartbeat: None,
1441            heartbeat_msg: None,
1442            reconnect_timeout_ms: None,
1443            reconnect_delay_initial_ms: None,
1444            reconnect_backoff_factor: None,
1445            reconnect_delay_max_ms: None,
1446            reconnect_jitter_ms: None,
1447            reconnect_max_attempts: None,
1448        };
1449        WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1450            .await
1451            .expect("Failed to connect")
1452    }
1453
1454    #[tokio::test]
1455    async fn test_websocket_basic() {
1456        let server = TestServer::setup().await;
1457        let client = setup_test_client(server.port).await;
1458
1459        assert!(!client.is_disconnected());
1460
1461        client.disconnect().await;
1462        assert!(client.is_disconnected());
1463    }
1464
1465    #[tokio::test]
1466    async fn test_websocket_heartbeat() {
1467        let server = TestServer::setup().await;
1468        let client = setup_test_client(server.port).await;
1469
1470        // Wait ~3s => server should see multiple "ping"
1471        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1472
1473        // Cleanup
1474        client.disconnect().await;
1475        assert!(client.is_disconnected());
1476    }
1477
1478    #[tokio::test]
1479    async fn test_websocket_reconnect_exhausted() {
1480        let config = WebSocketConfig {
1481            url: "ws://127.0.0.1:9997".into(), // <-- No server
1482            headers: vec![],
1483            heartbeat: None,
1484            heartbeat_msg: None,
1485            reconnect_timeout_ms: None,
1486            reconnect_delay_initial_ms: None,
1487            reconnect_backoff_factor: None,
1488            reconnect_delay_max_ms: None,
1489            reconnect_jitter_ms: None,
1490            reconnect_max_attempts: None,
1491        };
1492        let res =
1493            WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1494                .await;
1495        assert!(res.is_err(), "Should fail quickly with no server");
1496    }
1497
1498    #[tokio::test]
1499    async fn test_websocket_forced_close_reconnect() {
1500        let server = TestServer::setup().await;
1501        let client = setup_test_client(server.port).await;
1502
1503        // 1) Send normal message
1504        client.send_text("Hello".into(), None).await.unwrap();
1505
1506        // 2) Trigger forced close from server
1507        client.send_text("close-now".into(), None).await.unwrap();
1508
1509        // 3) Wait a bit => read loop sees close => reconnect
1510        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1511
1512        // Confirm not disconnected
1513        assert!(!client.is_disconnected());
1514
1515        // Cleanup
1516        client.disconnect().await;
1517        assert!(client.is_disconnected());
1518    }
1519
1520    #[tokio::test]
1521    async fn test_rate_limiter() {
1522        let server = TestServer::setup().await;
1523        let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1524
1525        let config = WebSocketConfig {
1526            url: format!("ws://127.0.0.1:{}", server.port),
1527            headers: vec![("test".into(), "test".into())],
1528            heartbeat: None,
1529            heartbeat_msg: None,
1530            reconnect_timeout_ms: None,
1531            reconnect_delay_initial_ms: None,
1532            reconnect_backoff_factor: None,
1533            reconnect_delay_max_ms: None,
1534            reconnect_jitter_ms: None,
1535            reconnect_max_attempts: None,
1536        };
1537
1538        let client = WebSocketClient::connect(
1539            config,
1540            Some(Arc::new(|_| {})),
1541            None,
1542            None,
1543            vec![("default".into(), quota)],
1544            None,
1545        )
1546        .await
1547        .unwrap();
1548
1549        // First 2 should succeed
1550        client.send_text("test1".into(), None).await.unwrap();
1551        client.send_text("test2".into(), None).await.unwrap();
1552
1553        // Third should error
1554        client.send_text("test3".into(), None).await.unwrap();
1555
1556        // Cleanup
1557        client.disconnect().await;
1558        assert!(client.is_disconnected());
1559    }
1560
1561    #[tokio::test]
1562    async fn test_concurrent_writers() {
1563        let server = TestServer::setup().await;
1564        let client = Arc::new(setup_test_client(server.port).await);
1565
1566        let mut handles = vec![];
1567        for i in 0..10 {
1568            let client = client.clone();
1569            handles.push(task::spawn(async move {
1570                client.send_text(format!("test{i}"), None).await.unwrap();
1571            }));
1572        }
1573
1574        for handle in handles {
1575            handle.await.unwrap();
1576        }
1577
1578        // Cleanup
1579        client.disconnect().await;
1580        assert!(client.is_disconnected());
1581    }
1582}
1583
1584#[cfg(test)]
1585#[cfg(not(feature = "turmoil"))]
1586mod rust_tests {
1587    use futures_util::StreamExt;
1588    use rstest::rstest;
1589    use tokio::{
1590        net::TcpListener,
1591        task,
1592        time::{Duration, sleep},
1593    };
1594    use tokio_tungstenite::accept_async;
1595
1596    use super::*;
1597    use crate::websocket::types::channel_message_handler;
1598
1599    #[rstest]
1600    #[tokio::test]
1601    async fn test_reconnect_then_disconnect() {
1602        // Bind an ephemeral port
1603        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1604        let port = listener.local_addr().unwrap().port();
1605
1606        // Server task: accept one ws connection then close it
1607        let server = task::spawn(async move {
1608            let (stream, _) = listener.accept().await.unwrap();
1609            let ws = accept_async(stream).await.unwrap();
1610            drop(ws);
1611            // Keep alive briefly
1612            sleep(Duration::from_secs(1)).await;
1613        });
1614
1615        // Build a channel-based message handler for incoming messages (unused here)
1616        let (handler, _rx) = channel_message_handler();
1617
1618        // Configure client with short reconnect backoff
1619        let config = WebSocketConfig {
1620            url: format!("ws://127.0.0.1:{port}"),
1621            headers: vec![],
1622            heartbeat: None,
1623            heartbeat_msg: None,
1624            reconnect_timeout_ms: Some(1_000),
1625            reconnect_delay_initial_ms: Some(50),
1626            reconnect_delay_max_ms: Some(100),
1627            reconnect_backoff_factor: Some(1.0),
1628            reconnect_jitter_ms: Some(0),
1629            reconnect_max_attempts: None,
1630        };
1631
1632        // Connect the client
1633        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1634            .await
1635            .unwrap();
1636
1637        // Allow server to drop connection and client to detect
1638        sleep(Duration::from_millis(100)).await;
1639        // Now immediately disconnect the client
1640        client.disconnect().await;
1641        assert!(client.is_disconnected());
1642        server.abort();
1643    }
1644
1645    #[rstest]
1646    #[tokio::test]
1647    async fn test_reconnect_state_flips_when_reader_stops() {
1648        // Bind an ephemeral port and accept a single websocket connection which we drop.
1649        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1650        let port = listener.local_addr().unwrap().port();
1651
1652        let server = task::spawn(async move {
1653            if let Ok((stream, _)) = listener.accept().await
1654                && let Ok(ws) = accept_async(stream).await
1655            {
1656                drop(ws);
1657            }
1658            sleep(Duration::from_millis(50)).await;
1659        });
1660
1661        let (handler, _rx) = channel_message_handler();
1662
1663        let config = WebSocketConfig {
1664            url: format!("ws://127.0.0.1:{port}"),
1665            headers: vec![],
1666            heartbeat: None,
1667            heartbeat_msg: None,
1668            reconnect_timeout_ms: Some(1_000),
1669            reconnect_delay_initial_ms: Some(50),
1670            reconnect_delay_max_ms: Some(100),
1671            reconnect_backoff_factor: Some(1.0),
1672            reconnect_jitter_ms: Some(0),
1673            reconnect_max_attempts: None,
1674        };
1675
1676        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1677            .await
1678            .unwrap();
1679
1680        tokio::time::timeout(Duration::from_secs(2), async {
1681            loop {
1682                if client.is_reconnecting() {
1683                    break;
1684                }
1685                tokio::time::sleep(Duration::from_millis(10)).await;
1686            }
1687        })
1688        .await
1689        .expect("client did not enter RECONNECT state");
1690
1691        client.disconnect().await;
1692        server.abort();
1693    }
1694
1695    #[rstest]
1696    #[tokio::test]
1697    async fn test_stream_mode_disables_auto_reconnect() {
1698        // Test that stream-based clients (created via connect_stream) set is_stream_mode flag
1699        // and that reconnect() transitions to CLOSED state for stream mode
1700        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1701        let port = listener.local_addr().unwrap().port();
1702
1703        let server = task::spawn(async move {
1704            if let Ok((stream, _)) = listener.accept().await
1705                && let Ok(_ws) = accept_async(stream).await
1706            {
1707                // Keep connection alive briefly
1708                sleep(Duration::from_millis(100)).await;
1709            }
1710        });
1711
1712        let config = WebSocketConfig {
1713            url: format!("ws://127.0.0.1:{port}"),
1714            headers: vec![],
1715            heartbeat: None,
1716            heartbeat_msg: None,
1717            reconnect_timeout_ms: Some(1_000),
1718            reconnect_delay_initial_ms: Some(50),
1719            reconnect_delay_max_ms: Some(100),
1720            reconnect_backoff_factor: Some(1.0),
1721            reconnect_jitter_ms: Some(0),
1722            reconnect_max_attempts: None,
1723        };
1724
1725        let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
1726            .await
1727            .unwrap();
1728
1729        // Note: We can't easily test the reconnect behavior from the outside since
1730        // the inner client is private. The key fix is that WebSocketClientInner
1731        // now has is_stream_mode=true for connect_stream, and reconnect() will
1732        // transition to CLOSED state instead of creating a new reader that gets dropped.
1733        // This is tested implicitly by the fact that stream users won't get stuck
1734        // in an infinite reconnect loop.
1735
1736        server.abort();
1737    }
1738
1739    #[rstest]
1740    #[tokio::test]
1741    async fn test_message_handler_mode_allows_auto_reconnect() {
1742        // Test that regular clients (with message handler) can auto-reconnect
1743        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1744        let port = listener.local_addr().unwrap().port();
1745
1746        let server = task::spawn(async move {
1747            // Accept first connection and close it
1748            if let Ok((stream, _)) = listener.accept().await
1749                && let Ok(ws) = accept_async(stream).await
1750            {
1751                drop(ws);
1752            }
1753            sleep(Duration::from_millis(50)).await;
1754        });
1755
1756        let (handler, _rx) = channel_message_handler();
1757
1758        let config = WebSocketConfig {
1759            url: format!("ws://127.0.0.1:{port}"),
1760            headers: vec![],
1761            heartbeat: None,
1762            heartbeat_msg: None,
1763            reconnect_timeout_ms: Some(1_000),
1764            reconnect_delay_initial_ms: Some(50),
1765            reconnect_delay_max_ms: Some(100),
1766            reconnect_backoff_factor: Some(1.0),
1767            reconnect_jitter_ms: Some(0),
1768            reconnect_max_attempts: None,
1769        };
1770
1771        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1772            .await
1773            .unwrap();
1774
1775        // Wait for the connection to be dropped and reconnection to be attempted
1776        tokio::time::timeout(Duration::from_secs(2), async {
1777            loop {
1778                if client.is_reconnecting() || client.is_closed() {
1779                    break;
1780                }
1781                tokio::time::sleep(Duration::from_millis(10)).await;
1782            }
1783        })
1784        .await
1785        .expect("client should attempt reconnection or close");
1786
1787        // Should either be reconnecting or closed (depending on timing)
1788        // The important thing is it's not staying active forever
1789        assert!(
1790            client.is_reconnecting() || client.is_closed(),
1791            "Client with message handler should attempt reconnection"
1792        );
1793
1794        client.disconnect().await;
1795        server.abort();
1796    }
1797
1798    #[rstest]
1799    #[tokio::test]
1800    async fn test_handler_mode_reconnect_with_new_connection() {
1801        // Test that handler mode successfully reconnects and messages continue flowing
1802        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1803        let port = listener.local_addr().unwrap().port();
1804
1805        let server = task::spawn(async move {
1806            // First connection - accept and immediately close
1807            if let Ok((stream, _)) = listener.accept().await
1808                && let Ok(ws) = accept_async(stream).await
1809            {
1810                drop(ws);
1811            }
1812
1813            // Small delay to let client detect disconnection
1814            sleep(Duration::from_millis(100)).await;
1815
1816            // Second connection - accept, send a message, then keep alive
1817            if let Ok((stream, _)) = listener.accept().await
1818                && let Ok(mut ws) = accept_async(stream).await
1819            {
1820                use futures_util::SinkExt;
1821                let _ = ws
1822                    .send(Message::Text("reconnected".to_string().into()))
1823                    .await;
1824                sleep(Duration::from_secs(1)).await;
1825            }
1826        });
1827
1828        let (handler, mut rx) = channel_message_handler();
1829
1830        let config = WebSocketConfig {
1831            url: format!("ws://127.0.0.1:{port}"),
1832            headers: vec![],
1833            heartbeat: None,
1834            heartbeat_msg: None,
1835            reconnect_timeout_ms: Some(2_000),
1836            reconnect_delay_initial_ms: Some(50),
1837            reconnect_delay_max_ms: Some(200),
1838            reconnect_backoff_factor: Some(1.5),
1839            reconnect_jitter_ms: Some(10),
1840            reconnect_max_attempts: None,
1841        };
1842
1843        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1844            .await
1845            .unwrap();
1846
1847        // Wait for reconnection to happen and message to arrive
1848        let result = tokio::time::timeout(Duration::from_secs(5), async {
1849            loop {
1850                if let Ok(msg) = rx.try_recv()
1851                    && matches!(msg, Message::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
1852                {
1853                    return true;
1854                }
1855                tokio::time::sleep(Duration::from_millis(10)).await;
1856            }
1857        })
1858        .await;
1859
1860        assert!(
1861            result.is_ok(),
1862            "Should receive message after reconnection within timeout"
1863        );
1864
1865        client.disconnect().await;
1866        server.abort();
1867    }
1868
1869    #[rstest]
1870    #[tokio::test]
1871    async fn test_stream_mode_no_auto_reconnect() {
1872        // Test that stream mode does not automatically reconnect when connection is lost
1873        // The caller owns the reader and is responsible for detecting disconnection
1874        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1875        let port = listener.local_addr().unwrap().port();
1876
1877        let server = task::spawn(async move {
1878            // Accept connection and send one message, then close
1879            if let Ok((stream, _)) = listener.accept().await
1880                && let Ok(mut ws) = accept_async(stream).await
1881            {
1882                use futures_util::SinkExt;
1883                let _ = ws.send(Message::Text("hello".to_string().into())).await;
1884                sleep(Duration::from_millis(50)).await;
1885                // Connection closes when ws is dropped
1886            }
1887        });
1888
1889        let config = WebSocketConfig {
1890            url: format!("ws://127.0.0.1:{port}"),
1891            headers: vec![],
1892            heartbeat: None,
1893            heartbeat_msg: None,
1894            reconnect_timeout_ms: Some(1_000),
1895            reconnect_delay_initial_ms: Some(50),
1896            reconnect_delay_max_ms: Some(100),
1897            reconnect_backoff_factor: Some(1.0),
1898            reconnect_jitter_ms: Some(0),
1899            reconnect_max_attempts: None,
1900        };
1901
1902        let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
1903            .await
1904            .unwrap();
1905
1906        // Initially active
1907        assert!(client.is_active(), "Client should start as active");
1908
1909        // Read the hello message
1910        let msg = reader.next().await;
1911        assert!(
1912            matches!(msg, Some(Ok(Message::Text(ref text))) if AsRef::<str>::as_ref(text) == "hello"),
1913            "Should receive initial message"
1914        );
1915
1916        // Read until connection closes (reader will return None or error)
1917        while let Some(msg) = reader.next().await {
1918            if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
1919                break;
1920            }
1921        }
1922
1923        // In stream mode, the controller cannot detect disconnection (reader is owned by caller)
1924        // The client remains ACTIVE - it's the caller's responsibility to call disconnect()
1925        sleep(Duration::from_millis(200)).await;
1926
1927        // Client should still be ACTIVE (not RECONNECTING or CLOSED)
1928        // This is correct behavior - stream mode doesn't auto-detect disconnection
1929        assert!(
1930            client.is_active() || client.is_closed(),
1931            "Stream mode client stays ACTIVE (caller owns reader) or caller disconnected"
1932        );
1933        assert!(
1934            !client.is_reconnecting(),
1935            "Stream mode client should never attempt reconnection"
1936        );
1937
1938        client.disconnect().await;
1939        server.abort();
1940    }
1941
1942    #[rstest]
1943    #[tokio::test]
1944    async fn test_send_timeout_uses_configured_reconnect_timeout() {
1945        // Test that send operations respect the configured reconnect_timeout.
1946        // When a client is stuck in RECONNECT longer than the timeout, sends should fail with Timeout.
1947        use nautilus_common::testing::wait_until_async;
1948
1949        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1950        let port = listener.local_addr().unwrap().port();
1951
1952        let server = task::spawn(async move {
1953            // Accept first connection and immediately close it
1954            if let Ok((stream, _)) = listener.accept().await
1955                && let Ok(ws) = accept_async(stream).await
1956            {
1957                drop(ws);
1958            }
1959            // Don't accept second connection - client will be stuck in RECONNECT
1960            sleep(Duration::from_secs(60)).await;
1961        });
1962
1963        let (handler, _rx) = channel_message_handler();
1964
1965        // Configure with SHORT 2s reconnect timeout
1966        let config = WebSocketConfig {
1967            url: format!("ws://127.0.0.1:{port}"),
1968            headers: vec![],
1969            heartbeat: None,
1970            heartbeat_msg: None,
1971            reconnect_timeout_ms: Some(2_000), // 2s timeout
1972            reconnect_delay_initial_ms: Some(50),
1973            reconnect_delay_max_ms: Some(100),
1974            reconnect_backoff_factor: Some(1.0),
1975            reconnect_jitter_ms: Some(0),
1976            reconnect_max_attempts: None,
1977        };
1978
1979        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1980            .await
1981            .unwrap();
1982
1983        // Wait for client to enter RECONNECT state
1984        wait_until_async(
1985            || async { client.is_reconnecting() },
1986            Duration::from_secs(3),
1987        )
1988        .await;
1989
1990        // Attempt send while stuck in RECONNECT - should timeout after 2s (configured timeout)
1991        let start = std::time::Instant::now();
1992        let send_result = client.send_text("test".to_string(), None).await;
1993        let elapsed = start.elapsed();
1994
1995        assert!(
1996            send_result.is_err(),
1997            "Send should fail when client stuck in RECONNECT"
1998        );
1999        assert!(
2000            matches!(send_result, Err(crate::error::SendError::Timeout)),
2001            "Send should return Timeout error, was: {send_result:?}"
2002        );
2003        // Verify timeout respects configured value (2s), but don't check upper bound
2004        // as CI scheduler jitter can cause legitimate delays beyond the timeout
2005        assert!(
2006            elapsed >= Duration::from_millis(1800),
2007            "Send should timeout after at least 2s (configured timeout), took {elapsed:?}"
2008        );
2009
2010        client.disconnect().await;
2011        server.abort();
2012    }
2013
2014    #[rstest]
2015    #[tokio::test]
2016    async fn test_send_waits_during_reconnection() {
2017        // Test that send operations wait for reconnection to complete (up to timeout)
2018        use nautilus_common::testing::wait_until_async;
2019
2020        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2021        let port = listener.local_addr().unwrap().port();
2022
2023        let server = task::spawn(async move {
2024            // First connection - accept and immediately close
2025            if let Ok((stream, _)) = listener.accept().await
2026                && let Ok(ws) = accept_async(stream).await
2027            {
2028                drop(ws);
2029            }
2030
2031            // Wait a bit before accepting second connection
2032            sleep(Duration::from_millis(500)).await;
2033
2034            // Second connection - accept and keep alive
2035            if let Ok((stream, _)) = listener.accept().await
2036                && let Ok(mut ws) = accept_async(stream).await
2037            {
2038                // Echo messages
2039                while let Some(Ok(msg)) = ws.next().await {
2040                    if ws.send(msg).await.is_err() {
2041                        break;
2042                    }
2043                }
2044            }
2045        });
2046
2047        let (handler, _rx) = channel_message_handler();
2048
2049        let config = WebSocketConfig {
2050            url: format!("ws://127.0.0.1:{port}"),
2051            headers: vec![],
2052            heartbeat: None,
2053            heartbeat_msg: None,
2054            reconnect_timeout_ms: Some(5_000), // 5s timeout - enough for reconnect
2055            reconnect_delay_initial_ms: Some(100),
2056            reconnect_delay_max_ms: Some(200),
2057            reconnect_backoff_factor: Some(1.0),
2058            reconnect_jitter_ms: Some(0),
2059            reconnect_max_attempts: None,
2060        };
2061
2062        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2063            .await
2064            .unwrap();
2065
2066        // Wait for reconnection to trigger
2067        wait_until_async(
2068            || async { client.is_reconnecting() },
2069            Duration::from_secs(2),
2070        )
2071        .await;
2072
2073        // Try to send while reconnecting - should wait and succeed after reconnect
2074        let send_result = tokio::time::timeout(
2075            Duration::from_secs(3),
2076            client.send_text("test_message".to_string(), None),
2077        )
2078        .await;
2079
2080        assert!(
2081            send_result.is_ok() && send_result.unwrap().is_ok(),
2082            "Send should succeed after waiting for reconnection"
2083        );
2084
2085        client.disconnect().await;
2086        server.abort();
2087    }
2088
2089    #[rstest]
2090    #[tokio::test]
2091    async fn test_rate_limiter_before_active_wait() {
2092        // Test that rate limiting happens BEFORE active state check.
2093        // This prevents race conditions where connection state changes during rate limit wait.
2094        // We verify this by: (1) exhausting rate limit, (2) ensuring client is RECONNECTING,
2095        // (3) sending again and confirming it waits for rate limit THEN reconnection.
2096        use std::{num::NonZeroU32, sync::Arc};
2097
2098        use nautilus_common::testing::wait_until_async;
2099
2100        use crate::ratelimiter::quota::Quota;
2101
2102        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2103        let port = listener.local_addr().unwrap().port();
2104
2105        let server = task::spawn(async move {
2106            // First connection - accept and close after receiving one message
2107            if let Ok((stream, _)) = listener.accept().await
2108                && let Ok(mut ws) = accept_async(stream).await
2109            {
2110                // Receive first message then close
2111                if let Some(Ok(_)) = ws.next().await {
2112                    drop(ws);
2113                }
2114            }
2115
2116            // Wait before accepting reconnection
2117            sleep(Duration::from_millis(500)).await;
2118
2119            // Second connection - accept and keep alive
2120            if let Ok((stream, _)) = listener.accept().await
2121                && let Ok(mut ws) = accept_async(stream).await
2122            {
2123                while let Some(Ok(msg)) = ws.next().await {
2124                    if ws.send(msg).await.is_err() {
2125                        break;
2126                    }
2127                }
2128            }
2129        });
2130
2131        let (handler, _rx) = channel_message_handler();
2132
2133        let config = WebSocketConfig {
2134            url: format!("ws://127.0.0.1:{port}"),
2135            headers: vec![],
2136            heartbeat: None,
2137            heartbeat_msg: None,
2138            reconnect_timeout_ms: Some(5_000),
2139            reconnect_delay_initial_ms: Some(50),
2140            reconnect_delay_max_ms: Some(100),
2141            reconnect_backoff_factor: Some(1.0),
2142            reconnect_jitter_ms: Some(0),
2143            reconnect_max_attempts: None,
2144        };
2145
2146        // Very restrictive rate limit: 1 request per second, burst of 1
2147        let quota =
2148            Quota::per_second(NonZeroU32::new(1).unwrap()).allow_burst(NonZeroU32::new(1).unwrap());
2149
2150        let client = Arc::new(
2151            WebSocketClient::connect(
2152                config,
2153                Some(handler),
2154                None,
2155                None,
2156                vec![("test_key".to_string(), quota)],
2157                None,
2158            )
2159            .await
2160            .unwrap(),
2161        );
2162
2163        // First send exhausts burst capacity and triggers connection close
2164        let test_key: [Ustr; 1] = [Ustr::from("test_key")];
2165        client
2166            .send_text("msg1".to_string(), Some(test_key.as_slice()))
2167            .await
2168            .unwrap();
2169
2170        // Wait for client to enter RECONNECT state
2171        wait_until_async(
2172            || async { client.is_reconnecting() },
2173            Duration::from_secs(2),
2174        )
2175        .await;
2176
2177        // Second send: will hit rate limit (~1s) THEN wait for reconnection (~0.5s)
2178        let start = std::time::Instant::now();
2179        let send_result = client
2180            .send_text("msg2".to_string(), Some(test_key.as_slice()))
2181            .await;
2182        let elapsed = start.elapsed();
2183
2184        // Should succeed after both rate limit AND reconnection
2185        assert!(
2186            send_result.is_ok(),
2187            "Send should succeed after rate limit + reconnection, was: {send_result:?}"
2188        );
2189        // Total wait should be at least rate limit time (~1s)
2190        // The reconnection completes while rate limiting or after
2191        // Use 850ms threshold to account for timing jitter in CI
2192        assert!(
2193            elapsed >= Duration::from_millis(850),
2194            "Should wait for rate limit (~1s), waited {elapsed:?}"
2195        );
2196
2197        client.disconnect().await;
2198        server.abort();
2199    }
2200
2201    #[rstest]
2202    #[tokio::test]
2203    async fn test_disconnect_during_reconnect_exits_cleanly() {
2204        // Test CAS race condition: disconnect called during reconnection
2205        // Should exit cleanly without spawning new tasks
2206        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2207        let port = listener.local_addr().unwrap().port();
2208
2209        let server = task::spawn(async move {
2210            // Accept first connection and immediately close
2211            if let Ok((stream, _)) = listener.accept().await
2212                && let Ok(ws) = accept_async(stream).await
2213            {
2214                drop(ws);
2215            }
2216            // Don't accept second connection - let reconnect hang
2217            sleep(Duration::from_secs(60)).await;
2218        });
2219
2220        let (handler, _rx) = channel_message_handler();
2221
2222        let config = WebSocketConfig {
2223            url: format!("ws://127.0.0.1:{port}"),
2224            headers: vec![],
2225            heartbeat: None,
2226            heartbeat_msg: None,
2227            reconnect_timeout_ms: Some(2_000), // 2s timeout - shorter than disconnect timeout
2228            reconnect_delay_initial_ms: Some(100),
2229            reconnect_delay_max_ms: Some(200),
2230            reconnect_backoff_factor: Some(1.0),
2231            reconnect_jitter_ms: Some(0),
2232            reconnect_max_attempts: None,
2233        };
2234
2235        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2236            .await
2237            .unwrap();
2238
2239        // Wait for reconnection to start
2240        tokio::time::timeout(Duration::from_secs(2), async {
2241            while !client.is_reconnecting() {
2242                sleep(Duration::from_millis(10)).await;
2243            }
2244        })
2245        .await
2246        .expect("Client should enter RECONNECT state");
2247
2248        // Disconnect while reconnecting
2249        client.disconnect().await;
2250
2251        // Should be cleanly closed
2252        assert!(
2253            client.is_disconnected(),
2254            "Client should be cleanly disconnected"
2255        );
2256
2257        server.abort();
2258    }
2259
2260    #[rstest]
2261    #[tokio::test]
2262    async fn test_send_fails_fast_when_closed_before_rate_limit() {
2263        // Test that send operations check connection state BEFORE rate limiting,
2264        // preventing unnecessary delays when the connection is already closed.
2265        use std::{num::NonZeroU32, sync::Arc};
2266
2267        use nautilus_common::testing::wait_until_async;
2268
2269        use crate::ratelimiter::quota::Quota;
2270
2271        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2272        let port = listener.local_addr().unwrap().port();
2273
2274        let server = task::spawn(async move {
2275            // Accept connection and immediately close
2276            if let Ok((stream, _)) = listener.accept().await
2277                && let Ok(ws) = accept_async(stream).await
2278            {
2279                drop(ws);
2280            }
2281            sleep(Duration::from_secs(60)).await;
2282        });
2283
2284        let (handler, _rx) = channel_message_handler();
2285
2286        let config = WebSocketConfig {
2287            url: format!("ws://127.0.0.1:{port}"),
2288            headers: vec![],
2289            heartbeat: None,
2290            heartbeat_msg: None,
2291            reconnect_timeout_ms: Some(5_000),
2292            reconnect_delay_initial_ms: Some(50),
2293            reconnect_delay_max_ms: Some(100),
2294            reconnect_backoff_factor: Some(1.0),
2295            reconnect_jitter_ms: Some(0),
2296            reconnect_max_attempts: None,
2297        };
2298
2299        // Very restrictive rate limit: 1 request per 10 seconds
2300        // This ensures that if we wait for rate limit, the test will timeout
2301        let quota = Quota::with_period(Duration::from_secs(10))
2302            .unwrap()
2303            .allow_burst(NonZeroU32::new(1).unwrap());
2304
2305        let client = Arc::new(
2306            WebSocketClient::connect(
2307                config,
2308                Some(handler),
2309                None,
2310                None,
2311                vec![("test_key".to_string(), quota)],
2312                None,
2313            )
2314            .await
2315            .unwrap(),
2316        );
2317
2318        // Wait for disconnection
2319        wait_until_async(
2320            || async { client.is_reconnecting() || client.is_closed() },
2321            Duration::from_secs(2),
2322        )
2323        .await;
2324
2325        // Explicitly disconnect to move away from ACTIVE state
2326        client.disconnect().await;
2327        assert!(
2328            !client.is_active(),
2329            "Client should not be active after disconnect"
2330        );
2331
2332        // Attempt send - should fail IMMEDIATELY without waiting for rate limit
2333        let start = std::time::Instant::now();
2334        let test_key: [Ustr; 1] = [Ustr::from("test_key")];
2335        let result = client
2336            .send_text("test".to_string(), Some(test_key.as_slice()))
2337            .await;
2338        let elapsed = start.elapsed();
2339
2340        // Should fail with Closed error
2341        assert!(result.is_err(), "Send should fail when client is closed");
2342        assert!(
2343            matches!(result, Err(crate::error::SendError::Closed)),
2344            "Send should return Closed error, was: {result:?}"
2345        );
2346
2347        // Should fail FAST (< 100ms) without waiting for rate limit (10s)
2348        assert!(
2349            elapsed < Duration::from_millis(100),
2350            "Send should fail fast without rate limiting, took {elapsed:?}"
2351        );
2352
2353        server.abort();
2354    }
2355
2356    #[rstest]
2357    #[tokio::test]
2358    async fn test_connect_rejects_none_message_handler() {
2359        // Test that connect() properly rejects None message_handler
2360        // to prevent zombie connections that appear alive but never detect disconnections
2361
2362        let config = WebSocketConfig {
2363            url: "ws://127.0.0.1:9999".to_string(),
2364            headers: vec![],
2365            heartbeat: None,
2366            heartbeat_msg: None,
2367            reconnect_timeout_ms: Some(1_000),
2368            reconnect_delay_initial_ms: Some(100),
2369            reconnect_delay_max_ms: Some(500),
2370            reconnect_backoff_factor: Some(1.5),
2371            reconnect_jitter_ms: Some(0),
2372            reconnect_max_attempts: None,
2373        };
2374
2375        // Pass None for message_handler - should be rejected
2376        let result = WebSocketClient::connect(config, None, None, None, vec![], None).await;
2377
2378        assert!(
2379            result.is_err(),
2380            "connect() should reject None message_handler"
2381        );
2382
2383        let err = result.unwrap_err();
2384        let err_msg = err.to_string();
2385        assert!(
2386            err_msg.contains("Handler mode requires message_handler"),
2387            "Error should mention missing message_handler, was: {err_msg}"
2388        );
2389    }
2390
2391    #[rstest]
2392    #[tokio::test]
2393    async fn test_client_without_handler_sets_stream_mode() {
2394        // Test that if a client is created without a handler via connect_url,
2395        // it properly sets is_stream_mode=true to prevent zombie connections
2396
2397        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2398        let port = listener.local_addr().unwrap().port();
2399
2400        let server = task::spawn(async move {
2401            // Accept and immediately close to simulate server disconnect
2402            if let Ok((stream, _)) = listener.accept().await
2403                && let Ok(ws) = accept_async(stream).await
2404            {
2405                drop(ws); // Drop connection immediately
2406            }
2407        });
2408
2409        let config = WebSocketConfig {
2410            url: format!("ws://127.0.0.1:{port}"),
2411            headers: vec![],
2412            heartbeat: None,
2413            heartbeat_msg: None,
2414            reconnect_timeout_ms: Some(1_000),
2415            reconnect_delay_initial_ms: Some(100),
2416            reconnect_delay_max_ms: Some(500),
2417            reconnect_backoff_factor: Some(1.5),
2418            reconnect_jitter_ms: Some(0),
2419            reconnect_max_attempts: None,
2420        };
2421
2422        // Create client directly via connect_url with no handler (stream mode)
2423        let inner = WebSocketClientInner::connect_url(config, None, None)
2424            .await
2425            .unwrap();
2426
2427        // Verify is_stream_mode is true when no handler
2428        assert!(
2429            inner.is_stream_mode,
2430            "Client without handler should have is_stream_mode=true"
2431        );
2432
2433        // Verify that when stream mode is enabled, reconnection is disabled
2434        // (documented behavior - stream mode clients close instead of reconnecting)
2435
2436        server.abort();
2437    }
2438}