nautilus_network/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance WebSocket client implementation with automatic reconnection
17//! with exponential backoff and state management.
18
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED)
21//! - Synchronized reconnection with backoff
22//! - Split read/write architecture
23//! - Python callback integration
24//!
25//! **Design**:
26//! - Single reader, multiple writer model
27//! - Read half runs in dedicated task
28//! - Write half runs in dedicated task connected with channel
29//! - Controller task manages lifecycle
30
31use std::{
32    fmt::Debug,
33    sync::{
34        Arc,
35        atomic::{AtomicU8, Ordering},
36    },
37    time::Duration,
38};
39
40use futures_util::{
41    SinkExt, StreamExt,
42    stream::{SplitSink, SplitStream},
43};
44use http::HeaderName;
45use nautilus_core::CleanDrop;
46use nautilus_cryptography::providers::install_cryptographic_provider;
47use tokio::net::TcpStream;
48use tokio_tungstenite::{
49    MaybeTlsStream, WebSocketStream, connect_async,
50    tungstenite::{Error, Message, client::IntoClientRequest, http::HeaderValue},
51};
52
53use crate::{
54    RECONNECTED,
55    backoff::ExponentialBackoff,
56    error::SendError,
57    logging::{log_task_aborted, log_task_started, log_task_stopped},
58    mode::ConnectionMode,
59    ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
60};
61
62// Connection timing constants
63const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
64const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
65const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
66
67type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
68pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
69
70/// Function type for handling WebSocket messages.
71pub type MessageHandler = Arc<dyn Fn(Message) + Send + Sync>;
72
73/// Function type for handling WebSocket ping messages.
74pub type PingHandler = Arc<dyn Fn(Vec<u8>) + Send + Sync>;
75
76/// Creates a channel-based message handler.
77///
78/// Returns a tuple containing the message handler and a receiver for messages.
79#[must_use]
80pub fn channel_message_handler() -> (
81    MessageHandler,
82    tokio::sync::mpsc::UnboundedReceiver<Message>,
83) {
84    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
85    let handler = Arc::new(move |msg: Message| {
86        if let Err(e) = tx.send(msg) {
87            tracing::error!("Failed to send message to channel: {e}");
88        }
89    });
90    (handler, rx)
91}
92
93#[cfg_attr(
94    feature = "python",
95    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
96)]
97pub struct WebSocketConfig {
98    /// The URL to connect to.
99    pub url: String,
100    /// The default headers.
101    pub headers: Vec<(String, String)>,
102    /// The function to handle incoming messages.
103    pub message_handler: Option<MessageHandler>,
104    /// The optional heartbeat interval (seconds).
105    pub heartbeat: Option<u64>,
106    /// The optional heartbeat message.
107    pub heartbeat_msg: Option<String>,
108    /// The handler for incoming pings.
109    pub ping_handler: Option<PingHandler>,
110    /// The timeout (milliseconds) for reconnection attempts.
111    pub reconnect_timeout_ms: Option<u64>,
112    /// The initial reconnection delay (milliseconds) for reconnects.
113    pub reconnect_delay_initial_ms: Option<u64>,
114    /// The maximum reconnect delay (milliseconds) for exponential backoff.
115    pub reconnect_delay_max_ms: Option<u64>,
116    /// The exponential backoff factor for reconnection delays.
117    pub reconnect_backoff_factor: Option<f64>,
118    /// The maximum jitter (milliseconds) added to reconnection delays.
119    pub reconnect_jitter_ms: Option<u64>,
120}
121
122impl Debug for WebSocketConfig {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        f.debug_struct("WebSocketConfig")
125            .field("url", &self.url)
126            .field("headers", &self.headers)
127            .field(
128                "message_handler",
129                &self.message_handler.as_ref().map(|_| "<function>"),
130            )
131            .field("heartbeat", &self.heartbeat)
132            .field("heartbeat_msg", &self.heartbeat_msg)
133            .field(
134                "ping_handler",
135                &self.ping_handler.as_ref().map(|_| "<function>"),
136            )
137            .field("reconnect_timeout_ms", &self.reconnect_timeout_ms)
138            .field(
139                "reconnect_delay_initial_ms",
140                &self.reconnect_delay_initial_ms,
141            )
142            .field("reconnect_delay_max_ms", &self.reconnect_delay_max_ms)
143            .field("reconnect_backoff_factor", &self.reconnect_backoff_factor)
144            .field("reconnect_jitter_ms", &self.reconnect_jitter_ms)
145            .finish()
146    }
147}
148
149impl Clone for WebSocketConfig {
150    fn clone(&self) -> Self {
151        Self {
152            url: self.url.clone(),
153            headers: self.headers.clone(),
154            message_handler: self.message_handler.clone(),
155            heartbeat: self.heartbeat,
156            heartbeat_msg: self.heartbeat_msg.clone(),
157            ping_handler: self.ping_handler.clone(),
158            reconnect_timeout_ms: self.reconnect_timeout_ms,
159            reconnect_delay_initial_ms: self.reconnect_delay_initial_ms,
160            reconnect_delay_max_ms: self.reconnect_delay_max_ms,
161            reconnect_backoff_factor: self.reconnect_backoff_factor,
162            reconnect_jitter_ms: self.reconnect_jitter_ms,
163        }
164    }
165}
166
167/// Represents a command for the writer task.
168#[derive(Debug)]
169pub(crate) enum WriterCommand {
170    /// Update the writer reference with a new one after reconnection.
171    Update(MessageWriter),
172    /// Send message to the server.
173    Send(Message),
174}
175
176/// `WebSocketClient` connects to a websocket server to read and send messages.
177///
178/// The client is opinionated about how messages are read and written. It
179/// assumes that data can only have one reader but multiple writers.
180///
181/// The client splits the connection into read and write halves. It moves
182/// the read half into a tokio task which keeps receiving messages from the
183/// server and calls a handler - a Python function that takes the data
184/// as its parameter. It stores the write half in the struct wrapped
185/// with an Arc Mutex. This way the client struct can be used to write
186/// data to the server from multiple scopes/tasks.
187///
188/// The client also maintains a heartbeat if given a duration in seconds.
189/// It's preferable to set the duration slightly lower - heartbeat more
190/// frequently - than the required amount.
191struct WebSocketClientInner {
192    config: WebSocketConfig,
193    read_task: Option<tokio::task::JoinHandle<()>>,
194    write_task: tokio::task::JoinHandle<()>,
195    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
196    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
197    connection_mode: Arc<AtomicU8>,
198    reconnect_timeout: Duration,
199    backoff: ExponentialBackoff,
200}
201
202impl WebSocketClientInner {
203    /// Create an inner websocket client with an existing writer.
204    pub async fn new_with_writer(
205        config: WebSocketConfig,
206        writer: MessageWriter,
207    ) -> Result<Self, Error> {
208        install_cryptographic_provider();
209
210        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
211
212        // Note: We don't spawn a read task here since the reader is handled externally
213        let read_task = None;
214
215        let backoff = ExponentialBackoff::new(
216            Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
217            Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
218            config.reconnect_backoff_factor.unwrap_or(1.5),
219            config.reconnect_jitter_ms.unwrap_or(100),
220            true, // immediate-first
221        )
222        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
223
224        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
225        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
226
227        let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
228            Some(Self::spawn_heartbeat_task(
229                connection_mode.clone(),
230                heartbeat_interval,
231                config.heartbeat_msg.clone(),
232                writer_tx.clone(),
233            ))
234        } else {
235            None
236        };
237
238        Ok(Self {
239            config: config.clone(),
240            writer_tx,
241            connection_mode,
242            reconnect_timeout: Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10000)),
243            heartbeat_task,
244            read_task,
245            write_task,
246            backoff,
247        })
248    }
249
250    /// Create an inner websocket client.
251    pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
252        install_cryptographic_provider();
253
254        let WebSocketConfig {
255            url,
256            message_handler,
257            heartbeat,
258            headers,
259            heartbeat_msg,
260            ping_handler,
261            reconnect_timeout_ms,
262            reconnect_delay_initial_ms,
263            reconnect_delay_max_ms,
264            reconnect_backoff_factor,
265            reconnect_jitter_ms,
266        } = &config;
267        let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
268
269        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
270
271        let read_task = if message_handler.is_some() {
272            Some(Self::spawn_message_handler_task(
273                connection_mode.clone(),
274                reader,
275                message_handler.as_ref(),
276                ping_handler.as_ref(),
277            ))
278        } else {
279            None
280        };
281
282        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
283        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
284
285        // Optionally spawn a heartbeat task to periodically ping server
286        let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
287            Self::spawn_heartbeat_task(
288                connection_mode.clone(),
289                *heartbeat_secs,
290                heartbeat_msg.clone(),
291                writer_tx.clone(),
292            )
293        });
294
295        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
296        let backoff = ExponentialBackoff::new(
297            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
298            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
299            reconnect_backoff_factor.unwrap_or(1.5),
300            reconnect_jitter_ms.unwrap_or(100),
301            true, // immediate-first
302        )
303        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
304
305        Ok(Self {
306            config,
307            read_task,
308            write_task,
309            writer_tx,
310            heartbeat_task,
311            connection_mode,
312            reconnect_timeout,
313            backoff,
314        })
315    }
316
317    /// Connects with the server creating a tokio-tungstenite websocket stream.
318    #[inline]
319    pub async fn connect_with_server(
320        url: &str,
321        headers: Vec<(String, String)>,
322    ) -> Result<(MessageWriter, MessageReader), Error> {
323        let mut request = url.into_client_request()?;
324        let req_headers = request.headers_mut();
325
326        let mut header_names: Vec<HeaderName> = Vec::new();
327        for (key, val) in headers {
328            let header_value = HeaderValue::from_str(&val)?;
329            let header_name: HeaderName = key.parse()?;
330            header_names.push(header_name.clone());
331            req_headers.insert(header_name, header_value);
332        }
333
334        connect_async(request).await.map(|resp| resp.0.split())
335    }
336
337    /// Reconnect with server.
338    ///
339    /// Make a new connection with server. Use the new read and write halves
340    /// to update self writer and read and heartbeat tasks.
341    pub async fn reconnect(&mut self) -> Result<(), Error> {
342        tracing::debug!("Reconnecting");
343
344        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
345            tracing::debug!("Reconnect aborted due to disconnect state");
346            return Ok(());
347        }
348
349        tokio::time::timeout(self.reconnect_timeout, async {
350            // Attempt to connect; abort early if a disconnect was requested
351            let (new_writer, reader) =
352                Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
353
354            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
355                tracing::debug!("Reconnect aborted mid-flight (after connect)");
356                return Ok(());
357            }
358
359            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
360                tracing::error!("{e}");
361            }
362
363            // Delay before closing connection
364            tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
365
366            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
367                tracing::debug!("Reconnect aborted mid-flight (after delay)");
368                return Ok(());
369            }
370
371            if let Some(ref read_task) = self.read_task.take()
372                && !read_task.is_finished()
373            {
374                read_task.abort();
375                log_task_aborted("read");
376            }
377
378            // If a disconnect was requested during reconnect, do not proceed to reactivate
379            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
380                tracing::debug!("Reconnect aborted mid-flight (before spawn read)");
381                return Ok(());
382            }
383
384            // Mark as active only if not disconnecting
385            self.connection_mode
386                .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
387
388            self.read_task = if self.config.message_handler.is_some() {
389                Some(Self::spawn_message_handler_task(
390                    self.connection_mode.clone(),
391                    reader,
392                    self.config.message_handler.as_ref(),
393                    self.config.ping_handler.as_ref(),
394                ))
395            } else {
396                None
397            };
398
399            tracing::debug!("Reconnect succeeded");
400            Ok(())
401        })
402        .await
403        .map_err(|_| {
404            Error::Io(std::io::Error::new(
405                std::io::ErrorKind::TimedOut,
406                format!(
407                    "reconnection timed out after {}s",
408                    self.reconnect_timeout.as_secs_f64()
409                ),
410            ))
411        })?
412    }
413
414    /// Check if the client is still connected.
415    ///
416    /// The client is connected if the read task has not finished. It is expected
417    /// that in case of any failure client or server side. The read task will be
418    /// shutdown or will receive a `Close` frame which will finish it. There
419    /// might be some delay between the connection being closed and the client
420    /// detecting.
421    #[inline]
422    #[must_use]
423    pub fn is_alive(&self) -> bool {
424        match &self.read_task {
425            Some(read_task) => !read_task.is_finished(),
426            None => true, // Stream is being used directly
427        }
428    }
429
430    fn spawn_message_handler_task(
431        connection_state: Arc<AtomicU8>,
432        mut reader: MessageReader,
433        message_handler: Option<&MessageHandler>,
434        ping_handler: Option<&PingHandler>,
435    ) -> tokio::task::JoinHandle<()> {
436        tracing::debug!("Started message handler task 'read'");
437
438        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
439
440        // Clone Arc handlers for the async task
441        let message_handler = message_handler.cloned();
442        let ping_handler = ping_handler.cloned();
443
444        tokio::task::spawn(async move {
445            loop {
446                if !ConnectionMode::from_atomic(&connection_state).is_active() {
447                    break;
448                }
449
450                match tokio::time::timeout(check_interval, reader.next()).await {
451                    Ok(Some(Ok(Message::Binary(data)))) => {
452                        tracing::trace!("Received message <binary> {} bytes", data.len());
453                        if let Some(ref handler) = message_handler {
454                            handler(Message::Binary(data));
455                        }
456                    }
457                    Ok(Some(Ok(Message::Text(data)))) => {
458                        tracing::trace!("Received message: {data}");
459                        if let Some(ref handler) = message_handler {
460                            handler(Message::Text(data));
461                        }
462                    }
463                    Ok(Some(Ok(Message::Ping(ping_data)))) => {
464                        tracing::trace!("Received ping: {ping_data:?}");
465                        if let Some(ref handler) = ping_handler {
466                            handler(ping_data.to_vec());
467                        }
468                    }
469                    Ok(Some(Ok(Message::Pong(_)))) => {
470                        tracing::trace!("Received pong");
471                    }
472                    Ok(Some(Ok(Message::Close(_)))) => {
473                        tracing::debug!("Received close message - terminating");
474                        break;
475                    }
476                    Ok(Some(Ok(_))) => (),
477                    Ok(Some(Err(e))) => {
478                        tracing::error!("Received error message - terminating: {e}");
479                        break;
480                    }
481                    Ok(None) => {
482                        tracing::debug!("No message received - terminating");
483                        break;
484                    }
485                    Err(_) => {
486                        // Timeout - continue loop and check connection mode
487                        continue;
488                    }
489                }
490            }
491        })
492    }
493
494    fn spawn_write_task(
495        connection_state: Arc<AtomicU8>,
496        writer: MessageWriter,
497        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
498    ) -> tokio::task::JoinHandle<()> {
499        log_task_started("write");
500
501        // Interval between checking the connection mode
502        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
503
504        tokio::task::spawn(async move {
505            let mut active_writer = writer;
506
507            loop {
508                match ConnectionMode::from_atomic(&connection_state) {
509                    ConnectionMode::Disconnect => {
510                        // Attempt to close the writer gracefully before exiting,
511                        // we ignore any error as the writer may already be closed.
512                        _ = tokio::time::timeout(
513                            Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
514                            active_writer.close(),
515                        )
516                        .await;
517                        break;
518                    }
519                    ConnectionMode::Closed => break,
520                    _ => {}
521                }
522
523                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
524                    Ok(Some(msg)) => {
525                        // Re-check connection mode after receiving a message
526                        let mode = ConnectionMode::from_atomic(&connection_state);
527                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
528                            break;
529                        }
530
531                        match msg {
532                            WriterCommand::Update(new_writer) => {
533                                tracing::debug!("Received new writer");
534
535                                // Delay before closing connection
536                                tokio::time::sleep(Duration::from_millis(100)).await;
537
538                                // Attempt to close the writer gracefully on update,
539                                // we ignore any error as the writer may already be closed.
540                                _ = tokio::time::timeout(
541                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
542                                    active_writer.close(),
543                                )
544                                .await;
545
546                                active_writer = new_writer;
547                                tracing::debug!("Updated writer");
548                            }
549                            _ if mode.is_reconnect() => {
550                                tracing::warn!("Skipping message while reconnecting, {msg:?}");
551                                continue;
552                            }
553                            WriterCommand::Send(msg) => {
554                                if let Err(e) = active_writer.send(msg).await {
555                                    tracing::error!("Failed to send message: {e}");
556                                    // Mode is active so trigger reconnection
557                                    tracing::warn!("Writer triggering reconnect");
558                                    connection_state
559                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
560                                }
561                            }
562                        }
563                    }
564                    Ok(None) => {
565                        // Channel closed - writer task should terminate
566                        tracing::debug!("Writer channel closed, terminating writer task");
567                        break;
568                    }
569                    Err(_) => {
570                        // Timeout - just continue the loop
571                        continue;
572                    }
573                }
574            }
575
576            // Attempt to close the writer gracefully before exiting,
577            // we ignore any error as the writer may already be closed.
578            _ = tokio::time::timeout(
579                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
580                active_writer.close(),
581            )
582            .await;
583
584            log_task_stopped("write");
585        })
586    }
587
588    fn spawn_heartbeat_task(
589        connection_state: Arc<AtomicU8>,
590        heartbeat_secs: u64,
591        message: Option<String>,
592        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
593    ) -> tokio::task::JoinHandle<()> {
594        log_task_started("heartbeat");
595
596        tokio::task::spawn(async move {
597            let interval = Duration::from_secs(heartbeat_secs);
598
599            loop {
600                tokio::time::sleep(interval).await;
601
602                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
603                    ConnectionMode::Active => {
604                        let msg = match &message {
605                            Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
606                            None => WriterCommand::Send(Message::Ping(vec![].into())),
607                        };
608
609                        match writer_tx.send(msg) {
610                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
611                            Err(e) => {
612                                tracing::error!("Failed to send heartbeat to writer task: {e}");
613                            }
614                        }
615                    }
616                    ConnectionMode::Reconnect => continue,
617                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
618                }
619            }
620
621            log_task_stopped("heartbeat");
622        })
623    }
624}
625
626impl Drop for WebSocketClientInner {
627    fn drop(&mut self) {
628        // Delegate to explicit cleanup handler
629        self.clean_drop();
630    }
631}
632
633impl CleanDrop for WebSocketClientInner {
634    fn clean_drop(&mut self) {
635        if let Some(ref read_task) = self.read_task.take()
636            && !read_task.is_finished()
637        {
638            read_task.abort();
639            log_task_aborted("read");
640        }
641
642        if !self.write_task.is_finished() {
643            self.write_task.abort();
644            log_task_aborted("write");
645        }
646
647        if let Some(ref handle) = self.heartbeat_task.take()
648            && !handle.is_finished()
649        {
650            handle.abort();
651            log_task_aborted("heartbeat");
652        }
653
654        // Clear handlers to break potential reference cycles
655        self.config.message_handler = None;
656        self.config.ping_handler = None;
657    }
658}
659
660/// WebSocket client with automatic reconnection.
661///
662/// Handles connection state, callbacks, and rate limiting.
663/// See module docs for architecture details.
664#[cfg_attr(
665    feature = "python",
666    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
667)]
668pub struct WebSocketClient {
669    pub(crate) controller_task: tokio::task::JoinHandle<()>,
670    pub(crate) connection_mode: Arc<AtomicU8>,
671    pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
672    pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
673}
674
675impl Debug for WebSocketClient {
676    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
677        f.debug_struct(stringify!(WebSocketClient)).finish()
678    }
679}
680
681impl WebSocketClient {
682    /// Creates a websocket client that returns a stream for reading messages.
683    ///
684    /// # Errors
685    ///
686    /// Returns any error connecting to the server.
687    #[allow(clippy::too_many_arguments)]
688    pub async fn connect_stream(
689        config: WebSocketConfig,
690        keyed_quotas: Vec<(String, Quota)>,
691        default_quota: Option<Quota>,
692        post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
693    ) -> Result<(MessageReader, Self), Error> {
694        install_cryptographic_provider();
695
696        // Create a single connection and split it
697        let (ws_stream, _) = connect_async(config.url.clone().into_client_request()?).await?;
698        let (writer, reader) = ws_stream.split();
699
700        // Create inner without connecting (we'll provide the writer)
701        let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
702
703        let connection_mode = inner.connection_mode.clone();
704        let writer_tx = inner.writer_tx.clone();
705
706        let controller_task =
707            Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnect);
708
709        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
710
711        Ok((
712            reader,
713            Self {
714                controller_task,
715                connection_mode,
716                writer_tx,
717                rate_limiter,
718            },
719        ))
720    }
721
722    /// Creates a websocket client.
723    ///
724    /// Creates an inner client and controller task to reconnect or disconnect
725    /// the client. Also assumes ownership of writer from inner client.
726    ///
727    /// # Errors
728    ///
729    /// Returns any websocket error.
730    pub async fn connect(
731        config: WebSocketConfig,
732        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
733        keyed_quotas: Vec<(String, Quota)>,
734        default_quota: Option<Quota>,
735    ) -> Result<Self, Error> {
736        tracing::debug!("Connecting");
737        let inner = WebSocketClientInner::connect_url(config).await?;
738        let connection_mode = inner.connection_mode.clone();
739        let writer_tx = inner.writer_tx.clone();
740
741        let controller_task =
742            Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnection);
743
744        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
745
746        Ok(Self {
747            controller_task,
748            connection_mode,
749            writer_tx,
750            rate_limiter,
751        })
752    }
753
754    /// Returns the current connection mode.
755    #[must_use]
756    pub fn connection_mode(&self) -> ConnectionMode {
757        ConnectionMode::from_atomic(&self.connection_mode)
758    }
759
760    /// Check if the client connection is active.
761    ///
762    /// Returns `true` if the client is connected and has not been signalled to disconnect.
763    /// The client will automatically retry connection based on its configuration.
764    #[inline]
765    #[must_use]
766    pub fn is_active(&self) -> bool {
767        self.connection_mode().is_active()
768    }
769
770    /// Check if the client is disconnected.
771    #[must_use]
772    pub fn is_disconnected(&self) -> bool {
773        self.controller_task.is_finished()
774    }
775
776    /// Check if the client is reconnecting.
777    ///
778    /// Returns `true` if the client lost connection and is attempting to reestablish it.
779    /// The client will automatically retry connection based on its configuration.
780    #[inline]
781    #[must_use]
782    pub fn is_reconnecting(&self) -> bool {
783        self.connection_mode().is_reconnect()
784    }
785
786    /// Check if the client is disconnecting.
787    ///
788    /// Returns `true` if the client is in disconnect mode.
789    #[inline]
790    #[must_use]
791    pub fn is_disconnecting(&self) -> bool {
792        self.connection_mode().is_disconnect()
793    }
794
795    /// Check if the client is closed.
796    ///
797    /// Returns `true` if the client has been explicitly disconnected or reached
798    /// maximum reconnection attempts. In this state, the client cannot be reused
799    /// and a new client must be created for further connections.
800    #[inline]
801    #[must_use]
802    pub fn is_closed(&self) -> bool {
803        self.connection_mode().is_closed()
804    }
805
806    /// Set disconnect mode to true.
807    ///
808    /// Controller task will periodically check the disconnect mode
809    /// and shutdown the client if it is alive
810    pub async fn disconnect(&self) {
811        tracing::debug!("Disconnecting");
812        self.connection_mode
813            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
814
815        match tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
816            while !self.is_disconnected() {
817                tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
818            }
819
820            if !self.controller_task.is_finished() {
821                self.controller_task.abort();
822                log_task_aborted("controller");
823            }
824        })
825        .await
826        {
827            Ok(()) => {
828                tracing::debug!("Controller task finished");
829            }
830            Err(_) => {
831                tracing::error!("Timeout waiting for controller task to finish");
832            }
833        }
834    }
835
836    /// Sends the given text `data` to the server.
837    ///
838    /// # Errors
839    ///
840    /// Returns a websocket error if unable to send.
841    #[allow(unused_variables)]
842    pub async fn send_text(
843        &self,
844        data: String,
845        keys: Option<Vec<String>>,
846    ) -> std::result::Result<(), SendError> {
847        self.rate_limiter.await_keys_ready(keys).await;
848
849        if !self.is_active() {
850            return Err(SendError::Closed);
851        }
852
853        tracing::trace!("Sending text: {data:?}");
854
855        let msg = Message::Text(data.into());
856        self.writer_tx
857            .send(WriterCommand::Send(msg))
858            .map_err(|e| SendError::BrokenPipe(e.to_string()))
859    }
860
861    /// Sends the given bytes `data` to the server.
862    ///
863    /// # Errors
864    ///
865    /// Returns a websocket error if unable to send.
866    #[allow(unused_variables)]
867    pub async fn send_bytes(
868        &self,
869        data: Vec<u8>,
870        keys: Option<Vec<String>>,
871    ) -> std::result::Result<(), SendError> {
872        self.rate_limiter.await_keys_ready(keys).await;
873
874        if !self.is_active() {
875            return Err(SendError::Closed);
876        }
877
878        tracing::trace!("Sending bytes: {data:?}");
879
880        let msg = Message::Binary(data.into());
881        self.writer_tx
882            .send(WriterCommand::Send(msg))
883            .map_err(|e| SendError::BrokenPipe(e.to_string()))
884    }
885
886    /// Sends a close message to the server.
887    ///
888    /// # Errors
889    ///
890    /// Returns a websocket error if unable to send.
891    pub async fn send_close_message(&self) -> std::result::Result<(), SendError> {
892        if !self.is_active() {
893            return Err(SendError::Closed);
894        }
895
896        let msg = Message::Close(None);
897        self.writer_tx
898            .send(WriterCommand::Send(msg))
899            .map_err(|e| SendError::BrokenPipe(e.to_string()))
900    }
901
902    fn spawn_controller_task(
903        mut inner: WebSocketClientInner,
904        connection_mode: Arc<AtomicU8>,
905        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
906    ) -> tokio::task::JoinHandle<()> {
907        tokio::task::spawn(async move {
908            log_task_started("controller");
909
910            let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
911
912            loop {
913                tokio::time::sleep(check_interval).await;
914                let mode = ConnectionMode::from_atomic(&connection_mode);
915
916                if mode.is_disconnect() {
917                    tracing::debug!("Disconnecting");
918
919                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
920                    if tokio::time::timeout(timeout, async {
921                        // Delay awaiting graceful shutdown
922                        tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
923
924                        if let Some(task) = &inner.read_task
925                            && !task.is_finished()
926                        {
927                            task.abort();
928                            log_task_aborted("read");
929                        }
930
931                        if let Some(task) = &inner.heartbeat_task
932                            && !task.is_finished()
933                        {
934                            task.abort();
935                            log_task_aborted("heartbeat");
936                        }
937                    })
938                    .await
939                    .is_err()
940                    {
941                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
942                    }
943
944                    tracing::debug!("Closed");
945                    break; // Controller finished
946                }
947
948                if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
949                    match inner.reconnect().await {
950                        Ok(()) => {
951                            inner.backoff.reset();
952
953                            // Only invoke callbacks if not in disconnect state
954                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
955                                if let Some(ref handler) = inner.config.message_handler {
956                                    let reconnected_msg =
957                                        Message::Text(RECONNECTED.to_string().into());
958                                    handler(reconnected_msg);
959                                    tracing::debug!("Sent reconnected message to handler");
960                                }
961
962                                // TODO: Retain this legacy callback for use from Python
963                                if let Some(ref callback) = post_reconnection {
964                                    callback();
965                                    tracing::debug!("Called `post_reconnection` handler");
966                                }
967
968                                tracing::debug!("Reconnected successfully");
969                            } else {
970                                tracing::debug!(
971                                    "Skipping post_reconnection handlers due to disconnect state"
972                                );
973                            }
974                        }
975                        Err(e) => {
976                            let duration = inner.backoff.next_duration();
977                            tracing::warn!("Reconnect attempt failed: {e}");
978                            if !duration.is_zero() {
979                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
980                            }
981                            tokio::time::sleep(duration).await;
982                        }
983                    }
984                }
985            }
986            inner
987                .connection_mode
988                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
989
990            log_task_stopped("controller");
991        })
992    }
993}
994
995// Abort controller task on drop to clean up background tasks
996impl Drop for WebSocketClient {
997    fn drop(&mut self) {
998        if !self.controller_task.is_finished() {
999            self.controller_task.abort();
1000            log_task_aborted("controller");
1001        }
1002    }
1003}
1004
1005////////////////////////////////////////////////////////////////////////////////
1006// Tests
1007////////////////////////////////////////////////////////////////////////////////
1008
1009#[cfg(test)]
1010#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
1011mod tests {
1012    use std::{num::NonZeroU32, sync::Arc};
1013
1014    use futures_util::{SinkExt, StreamExt};
1015    use tokio::{
1016        net::TcpListener,
1017        task::{self, JoinHandle},
1018    };
1019    use tokio_tungstenite::{
1020        accept_hdr_async,
1021        tungstenite::{
1022            handshake::server::{self, Callback},
1023            http::HeaderValue,
1024        },
1025    };
1026
1027    use crate::{
1028        ratelimiter::quota::Quota,
1029        websocket::{WebSocketClient, WebSocketConfig},
1030    };
1031
1032    struct TestServer {
1033        task: JoinHandle<()>,
1034        port: u16,
1035    }
1036
1037    #[derive(Debug, Clone)]
1038    struct TestCallback {
1039        key: String,
1040        value: HeaderValue,
1041    }
1042
1043    impl Callback for TestCallback {
1044        fn on_request(
1045            self,
1046            request: &server::Request,
1047            response: server::Response,
1048        ) -> Result<server::Response, server::ErrorResponse> {
1049            let _ = response;
1050            let value = request.headers().get(&self.key);
1051            assert!(value.is_some());
1052
1053            if let Some(value) = request.headers().get(&self.key) {
1054                assert_eq!(value, self.value);
1055            }
1056
1057            Ok(response)
1058        }
1059    }
1060
1061    impl TestServer {
1062        async fn setup() -> Self {
1063            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1064            let port = TcpListener::local_addr(&server).unwrap().port();
1065
1066            let header_key = "test".to_string();
1067            let header_value = "test".to_string();
1068
1069            let test_call_back = TestCallback {
1070                key: header_key,
1071                value: HeaderValue::from_str(&header_value).unwrap(),
1072            };
1073
1074            let task = task::spawn(async move {
1075                // Keep accepting connections
1076                loop {
1077                    let (conn, _) = server.accept().await.unwrap();
1078                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1079                        .await
1080                        .unwrap();
1081
1082                    task::spawn(async move {
1083                        while let Some(Ok(msg)) = websocket.next().await {
1084                            match msg {
1085                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1086                                    if txt == "close-now" =>
1087                                {
1088                                    tracing::debug!("Forcibly closing from server side");
1089                                    // This sends a close frame, then stops reading
1090                                    let _ = websocket.close(None).await;
1091                                    break;
1092                                }
1093                                // Echo text/binary frames
1094                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1095                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1096                                    if websocket.send(msg).await.is_err() {
1097                                        break;
1098                                    }
1099                                }
1100                                // If the client closes, we also break
1101                                tokio_tungstenite::tungstenite::protocol::Message::Close(
1102                                    _frame,
1103                                ) => {
1104                                    let _ = websocket.close(None).await;
1105                                    break;
1106                                }
1107                                // Ignore pings/pongs
1108                                _ => {}
1109                            }
1110                        }
1111                    });
1112                }
1113            });
1114
1115            Self { task, port }
1116        }
1117    }
1118
1119    impl Drop for TestServer {
1120        fn drop(&mut self) {
1121            self.task.abort();
1122        }
1123    }
1124
1125    async fn setup_test_client(port: u16) -> WebSocketClient {
1126        let config = WebSocketConfig {
1127            url: format!("ws://127.0.0.1:{port}"),
1128            headers: vec![("test".into(), "test".into())],
1129            message_handler: None,
1130            heartbeat: None,
1131            heartbeat_msg: None,
1132            ping_handler: None,
1133            reconnect_timeout_ms: None,
1134            reconnect_delay_initial_ms: None,
1135            reconnect_backoff_factor: None,
1136            reconnect_delay_max_ms: None,
1137            reconnect_jitter_ms: None,
1138        };
1139        WebSocketClient::connect(config, None, vec![], None)
1140            .await
1141            .expect("Failed to connect")
1142    }
1143
1144    #[tokio::test]
1145    async fn test_websocket_basic() {
1146        let server = TestServer::setup().await;
1147        let client = setup_test_client(server.port).await;
1148
1149        assert!(!client.is_disconnected());
1150
1151        client.disconnect().await;
1152        assert!(client.is_disconnected());
1153    }
1154
1155    #[tokio::test]
1156    async fn test_websocket_heartbeat() {
1157        let server = TestServer::setup().await;
1158        let client = setup_test_client(server.port).await;
1159
1160        // Wait ~3s => server should see multiple "ping"
1161        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1162
1163        // Cleanup
1164        client.disconnect().await;
1165        assert!(client.is_disconnected());
1166    }
1167
1168    #[tokio::test]
1169    async fn test_websocket_reconnect_exhausted() {
1170        let config = WebSocketConfig {
1171            url: "ws://127.0.0.1:9997".into(), // <-- No server
1172            headers: vec![],
1173            message_handler: None,
1174            heartbeat: None,
1175            heartbeat_msg: None,
1176            ping_handler: None,
1177            reconnect_timeout_ms: None,
1178            reconnect_delay_initial_ms: None,
1179            reconnect_backoff_factor: None,
1180            reconnect_delay_max_ms: None,
1181            reconnect_jitter_ms: None,
1182        };
1183        let res = WebSocketClient::connect(config, None, vec![], None).await;
1184        assert!(res.is_err(), "Should fail quickly with no server");
1185    }
1186
1187    #[tokio::test]
1188    async fn test_websocket_forced_close_reconnect() {
1189        let server = TestServer::setup().await;
1190        let client = setup_test_client(server.port).await;
1191
1192        // 1) Send normal message
1193        client.send_text("Hello".into(), None).await.unwrap();
1194
1195        // 2) Trigger forced close from server
1196        client.send_text("close-now".into(), None).await.unwrap();
1197
1198        // 3) Wait a bit => read loop sees close => reconnect
1199        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1200
1201        // Confirm not disconnected
1202        assert!(!client.is_disconnected());
1203
1204        // Cleanup
1205        client.disconnect().await;
1206        assert!(client.is_disconnected());
1207    }
1208
1209    #[tokio::test]
1210    async fn test_rate_limiter() {
1211        let server = TestServer::setup().await;
1212        let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1213
1214        let config = WebSocketConfig {
1215            url: format!("ws://127.0.0.1:{}", server.port),
1216            headers: vec![("test".into(), "test".into())],
1217            message_handler: None,
1218            heartbeat: None,
1219            heartbeat_msg: None,
1220            ping_handler: None,
1221            reconnect_timeout_ms: None,
1222            reconnect_delay_initial_ms: None,
1223            reconnect_backoff_factor: None,
1224            reconnect_delay_max_ms: None,
1225            reconnect_jitter_ms: None,
1226        };
1227
1228        let client = WebSocketClient::connect(config, None, vec![("default".into(), quota)], None)
1229            .await
1230            .unwrap();
1231
1232        // First 2 should succeed
1233        client.send_text("test1".into(), None).await.unwrap();
1234        client.send_text("test2".into(), None).await.unwrap();
1235
1236        // Third should error
1237        client.send_text("test3".into(), None).await.unwrap();
1238
1239        // Cleanup
1240        client.disconnect().await;
1241        assert!(client.is_disconnected());
1242    }
1243
1244    #[tokio::test]
1245    async fn test_concurrent_writers() {
1246        let server = TestServer::setup().await;
1247        let client = Arc::new(setup_test_client(server.port).await);
1248
1249        let mut handles = vec![];
1250        for i in 0..10 {
1251            let client = client.clone();
1252            handles.push(task::spawn(async move {
1253                client.send_text(format!("test{i}"), None).await.unwrap();
1254            }));
1255        }
1256
1257        for handle in handles {
1258            handle.await.unwrap();
1259        }
1260
1261        // Cleanup
1262        client.disconnect().await;
1263        assert!(client.is_disconnected());
1264    }
1265}
1266
1267#[cfg(test)]
1268mod rust_tests {
1269    use tokio::{
1270        net::TcpListener,
1271        task,
1272        time::{Duration, sleep},
1273    };
1274    use tokio_tungstenite::accept_async;
1275
1276    use super::*;
1277
1278    #[tokio::test]
1279    async fn test_reconnect_then_disconnect() {
1280        // Bind an ephemeral port
1281        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1282        let port = listener.local_addr().unwrap().port();
1283
1284        // Server task: accept one ws connection then close it
1285        let server = task::spawn(async move {
1286            let (stream, _) = listener.accept().await.unwrap();
1287            let ws = accept_async(stream).await.unwrap();
1288            drop(ws);
1289            // Keep alive briefly
1290            sleep(Duration::from_secs(1)).await;
1291        });
1292
1293        // Build a channel-based message handler for incoming messages (unused here)
1294        let (handler, _rx) = channel_message_handler();
1295
1296        // Configure client with short reconnect backoff
1297        let config = WebSocketConfig {
1298            url: format!("ws://127.0.0.1:{port}"),
1299            headers: vec![],
1300            message_handler: Some(handler),
1301            heartbeat: None,
1302            heartbeat_msg: None,
1303            ping_handler: None,
1304            reconnect_timeout_ms: Some(1_000),
1305            reconnect_delay_initial_ms: Some(50),
1306            reconnect_delay_max_ms: Some(100),
1307            reconnect_backoff_factor: Some(1.0),
1308            reconnect_jitter_ms: Some(0),
1309        };
1310
1311        // Connect the client
1312        let client = WebSocketClient::connect(config, None, vec![], None)
1313            .await
1314            .unwrap();
1315
1316        // Allow server to drop connection and client to detect
1317        sleep(Duration::from_millis(100)).await;
1318        // Now immediately disconnect the client
1319        client.disconnect().await;
1320        assert!(client.is_disconnected());
1321        server.abort();
1322    }
1323}