nautilus_network/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance WebSocket client implementation with automatic reconnection
17//! with exponential backoff and state management.
18
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED)
21//! - Synchronized reconnection with backoff
22//! - Split read/write architecture
23//! - Python callback integration
24//!
25//! **Design**:
26//! - Single reader, multiple writer model
27//! - Read half runs in dedicated task
28//! - Write half runs in dedicated task connected with channel
29//! - Controller task manages lifecycle
30
31use std::{
32    fmt::Debug,
33    sync::{
34        Arc,
35        atomic::{AtomicU8, Ordering},
36    },
37    time::Duration,
38};
39
40use futures_util::{
41    SinkExt, StreamExt,
42    stream::{SplitSink, SplitStream},
43};
44use http::HeaderName;
45use nautilus_cryptography::providers::install_cryptographic_provider;
46#[cfg(feature = "python")]
47use pyo3::{prelude::*, types::PyBytes};
48use tokio::{
49    net::TcpStream,
50    sync::mpsc::{self, Receiver, Sender},
51};
52use tokio_tungstenite::{
53    MaybeTlsStream, WebSocketStream, connect_async,
54    tungstenite::{Error, Message, client::IntoClientRequest, http::HeaderValue},
55};
56
57use crate::{
58    backoff::ExponentialBackoff,
59    error::SendError,
60    logging::{log_task_aborted, log_task_started, log_task_stopped},
61    mode::ConnectionMode,
62    ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
63};
64
65type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
66pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
67
68/// Defines how WebSocket messages are consumed.
69#[derive(Debug, Clone)]
70pub enum Consumer {
71    /// Python-based message handler.
72    #[cfg(feature = "python")]
73    Python(Option<Arc<PyObject>>),
74    /// Rust-based message handler using a channel sender.
75    Rust(Sender<Message>),
76}
77
78impl Consumer {
79    /// Creates a Rust-based consumer with a channel for receiving messages.
80    ///
81    /// Returns a tuple containing the consumer and a receiver for messages.
82    #[must_use]
83    pub fn rust_consumer() -> (Self, Receiver<Message>) {
84        let (tx, rx) = mpsc::channel(100);
85        (Self::Rust(tx), rx)
86    }
87}
88
89#[derive(Debug, Clone)]
90#[cfg_attr(
91    feature = "python",
92    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
93)]
94pub struct WebSocketConfig {
95    /// The URL to connect to.
96    pub url: String,
97    /// The default headers.
98    pub headers: Vec<(String, String)>,
99    /// The Python function to handle incoming messages.
100    pub handler: Consumer,
101    /// The optional heartbeat interval (seconds).
102    pub heartbeat: Option<u64>,
103    /// The optional heartbeat message.
104    pub heartbeat_msg: Option<String>,
105    /// The handler for incoming pings.
106    #[cfg(feature = "python")]
107    pub ping_handler: Option<Arc<PyObject>>,
108    /// The timeout (milliseconds) for reconnection attempts.
109    pub reconnect_timeout_ms: Option<u64>,
110    /// The initial reconnection delay (milliseconds) for reconnects.
111    pub reconnect_delay_initial_ms: Option<u64>,
112    /// The maximum reconnect delay (milliseconds) for exponential backoff.
113    pub reconnect_delay_max_ms: Option<u64>,
114    /// The exponential backoff factor for reconnection delays.
115    pub reconnect_backoff_factor: Option<f64>,
116    /// The maximum jitter (milliseconds) added to reconnection delays.
117    pub reconnect_jitter_ms: Option<u64>,
118}
119
120/// Represents a command for the writer task.
121#[derive(Debug)]
122pub(crate) enum WriterCommand {
123    /// Update the writer reference with a new one after reconnection.
124    Update(MessageWriter),
125    /// Send message to the server.
126    Send(Message),
127}
128
129/// `WebSocketClient` connects to a websocket server to read and send messages.
130///
131/// The client is opinionated about how messages are read and written. It
132/// assumes that data can only have one reader but multiple writers.
133///
134/// The client splits the connection into read and write halves. It moves
135/// the read half into a tokio task which keeps receiving messages from the
136/// server and calls a handler - a Python function that takes the data
137/// as its parameter. It stores the write half in the struct wrapped
138/// with an Arc Mutex. This way the client struct can be used to write
139/// data to the server from multiple scopes/tasks.
140///
141/// The client also maintains a heartbeat if given a duration in seconds.
142/// It's preferable to set the duration slightly lower - heartbeat more
143/// frequently - than the required amount.
144struct WebSocketClientInner {
145    config: WebSocketConfig,
146    read_task: Option<tokio::task::JoinHandle<()>>,
147    write_task: tokio::task::JoinHandle<()>,
148    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
149    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
150    connection_mode: Arc<AtomicU8>,
151    reconnect_timeout: Duration,
152    backoff: ExponentialBackoff,
153}
154
155impl WebSocketClientInner {
156    /// Create an inner websocket client.
157    pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
158        install_cryptographic_provider();
159
160        #[allow(unused_variables)]
161        let WebSocketConfig {
162            url,
163            handler,
164            heartbeat,
165            headers,
166            heartbeat_msg,
167            #[cfg(feature = "python")]
168            ping_handler,
169            reconnect_timeout_ms,
170            reconnect_delay_initial_ms,
171            reconnect_delay_max_ms,
172            reconnect_backoff_factor,
173            reconnect_jitter_ms,
174        } = &config;
175        let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
176
177        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
178
179        let read_task = match &handler {
180            #[cfg(feature = "python")]
181            Consumer::Python(handler) => handler.as_ref().map(|handler| {
182                Self::spawn_python_callback_task(
183                    connection_mode.clone(),
184                    reader,
185                    handler.clone(),
186                    ping_handler.clone(),
187                )
188            }),
189            Consumer::Rust(sender) => Some(Self::spawn_rust_streaming_task(
190                connection_mode.clone(),
191                reader,
192                sender.clone(),
193            )),
194        };
195
196        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
197        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
198
199        // Optionally spawn a heartbeat task to periodically ping server
200        let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
201            Self::spawn_heartbeat_task(
202                connection_mode.clone(),
203                *heartbeat_secs,
204                heartbeat_msg.clone(),
205                writer_tx.clone(),
206            )
207        });
208
209        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
210        let backoff = ExponentialBackoff::new(
211            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
212            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
213            reconnect_backoff_factor.unwrap_or(1.5),
214            reconnect_jitter_ms.unwrap_or(100),
215            true, // immediate-first
216        )
217        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
218
219        Ok(Self {
220            config,
221            read_task,
222            write_task,
223            writer_tx,
224            heartbeat_task,
225            connection_mode,
226            reconnect_timeout,
227            backoff,
228        })
229    }
230
231    /// Connects with the server creating a tokio-tungstenite websocket stream.
232    #[inline]
233    pub async fn connect_with_server(
234        url: &str,
235        headers: Vec<(String, String)>,
236    ) -> Result<(MessageWriter, MessageReader), Error> {
237        let mut request = url.into_client_request()?;
238        let req_headers = request.headers_mut();
239
240        let mut header_names: Vec<HeaderName> = Vec::new();
241        for (key, val) in headers {
242            let header_value = HeaderValue::from_str(&val)?;
243            let header_name: HeaderName = key.parse()?;
244            header_names.push(header_name.clone());
245            req_headers.insert(header_name, header_value);
246        }
247
248        connect_async(request).await.map(|resp| resp.0.split())
249    }
250
251    /// Reconnect with server.
252    ///
253    /// Make a new connection with server. Use the new read and write halves
254    /// to update self writer and read and heartbeat tasks.
255    pub async fn reconnect(&mut self) -> Result<(), Error> {
256        tracing::debug!("Reconnecting");
257
258        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
259            tracing::debug!("Reconnect aborted due to disconnect state");
260            return Ok(());
261        }
262
263        tokio::time::timeout(self.reconnect_timeout, async {
264            // Attempt to connect; abort early if a disconnect was requested
265            let (new_writer, reader) =
266                Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
267
268            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
269                tracing::debug!("Reconnect aborted mid-flight (after connect)");
270                return Ok(());
271            }
272
273            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
274                tracing::error!("{e}");
275            }
276
277            // Delay before closing connection
278            tokio::time::sleep(Duration::from_millis(100)).await;
279
280            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
281                tracing::debug!("Reconnect aborted mid-flight (after delay)");
282                return Ok(());
283            }
284
285            if let Some(ref read_task) = self.read_task.take()
286                && !read_task.is_finished()
287            {
288                read_task.abort();
289                log_task_aborted("read");
290            }
291
292            // If a disconnect was requested during reconnect, do not proceed to reactivate
293            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
294                tracing::debug!("Reconnect aborted mid-flight (before spawn read)");
295                return Ok(());
296            }
297
298            // Mark as active only if not disconnecting
299            self.connection_mode
300                .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
301
302            self.read_task = match &self.config.handler {
303                #[cfg(feature = "python")]
304                Consumer::Python(handler) => handler.as_ref().map(|handler| {
305                    Self::spawn_python_callback_task(
306                        self.connection_mode.clone(),
307                        reader,
308                        handler.clone(),
309                        self.config.ping_handler.clone(),
310                    )
311                }),
312                Consumer::Rust(sender) => Some(Self::spawn_rust_streaming_task(
313                    self.connection_mode.clone(),
314                    reader,
315                    sender.clone(),
316                )),
317            };
318
319            tracing::debug!("Reconnect succeeded");
320            Ok(())
321        })
322        .await
323        .map_err(|_| {
324            Error::Io(std::io::Error::new(
325                std::io::ErrorKind::TimedOut,
326                format!(
327                    "reconnection timed out after {}s",
328                    self.reconnect_timeout.as_secs_f64()
329                ),
330            ))
331        })?
332    }
333
334    /// Check if the client is still connected.
335    ///
336    /// The client is connected if the read task has not finished. It is expected
337    /// that in case of any failure client or server side. The read task will be
338    /// shutdown or will receive a `Close` frame which will finish it. There
339    /// might be some delay between the connection being closed and the client
340    /// detecting.
341    #[inline]
342    #[must_use]
343    pub fn is_alive(&self) -> bool {
344        match &self.read_task {
345            Some(read_task) => !read_task.is_finished(),
346            None => true, // Stream is being used directly
347        }
348    }
349
350    fn spawn_rust_streaming_task(
351        connection_state: Arc<AtomicU8>,
352        mut reader: MessageReader,
353        sender: Sender<Message>,
354    ) -> tokio::task::JoinHandle<()> {
355        tracing::debug!("Started streaming task 'read'");
356
357        let check_interval = Duration::from_millis(10);
358
359        tokio::task::spawn(async move {
360            loop {
361                if !ConnectionMode::from_atomic(&connection_state).is_active() {
362                    break;
363                }
364
365                match tokio::time::timeout(check_interval, reader.next()).await {
366                    Ok(Some(Ok(message))) => {
367                        if let Err(e) = sender.send(message).await {
368                            tracing::error!("Failed to send message: {e}");
369                        }
370                    }
371                    Ok(Some(Err(e))) => {
372                        tracing::error!("Received error message - terminating: {e}");
373                        break;
374                    }
375                    Ok(None) => {
376                        tracing::debug!("No message received - terminating");
377                        break;
378                    }
379                    Err(_) => {
380                        // Timeout - continue loop and check connection mode
381                        continue;
382                    }
383                }
384            }
385        })
386    }
387
388    #[cfg(feature = "python")]
389    fn spawn_python_callback_task(
390        connection_state: Arc<AtomicU8>,
391        mut reader: MessageReader,
392        handler: Arc<PyObject>,
393        ping_handler: Option<Arc<PyObject>>,
394    ) -> tokio::task::JoinHandle<()> {
395        log_task_started("read");
396
397        // Interval between checking the connection mode
398        let check_interval = Duration::from_millis(10);
399
400        tokio::task::spawn(async move {
401            loop {
402                if !ConnectionMode::from_atomic(&connection_state).is_active() {
403                    break;
404                }
405
406                match tokio::time::timeout(check_interval, reader.next()).await {
407                    Ok(Some(Ok(Message::Binary(data)))) => {
408                        tracing::trace!("Received message <binary> {} bytes", data.len());
409                        if let Err(e) =
410                            Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &data),)))
411                        {
412                            tracing::error!("Error calling handler: {e}");
413                            break;
414                        }
415                        continue;
416                    }
417                    Ok(Some(Ok(Message::Text(data)))) => {
418                        tracing::trace!("Received message: {data}");
419                        if let Err(e) = Python::with_gil(|py| {
420                            handler.call1(py, (PyBytes::new(py, data.as_bytes()),))
421                        }) {
422                            tracing::error!("Error calling handler: {e}");
423                            break;
424                        }
425                        continue;
426                    }
427                    Ok(Some(Ok(Message::Ping(ping)))) => {
428                        tracing::trace!("Received ping: {ping:?}");
429                        if let Some(ref handler) = ping_handler
430                            && let Err(e) =
431                                Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &ping),)))
432                        {
433                            tracing::error!("Error calling handler: {e}");
434                            break;
435                        }
436                        continue;
437                    }
438                    Ok(Some(Ok(Message::Pong(_)))) => {
439                        tracing::trace!("Received pong");
440                    }
441                    Ok(Some(Ok(Message::Close(_)))) => {
442                        tracing::debug!("Received close message - terminating");
443                        break;
444                    }
445                    Ok(Some(Ok(_))) => (),
446                    Ok(Some(Err(e))) => {
447                        tracing::error!("Received error message - terminating: {e}");
448                        break;
449                    }
450                    // Internally tungstenite considers the connection closed when polling
451                    // for the next message in the stream returns None.
452                    Ok(None) => {
453                        tracing::debug!("No message received - terminating");
454                        break;
455                    }
456                    Err(_) => {
457                        // Timeout - continue loop and check connection mode
458                        continue;
459                    }
460                }
461            }
462        })
463    }
464
465    fn spawn_write_task(
466        connection_state: Arc<AtomicU8>,
467        writer: MessageWriter,
468        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
469    ) -> tokio::task::JoinHandle<()> {
470        log_task_started("write");
471
472        // Interval between checking the connection mode
473        let check_interval = Duration::from_millis(10);
474
475        tokio::task::spawn(async move {
476            let mut active_writer = writer;
477
478            loop {
479                match ConnectionMode::from_atomic(&connection_state) {
480                    ConnectionMode::Disconnect => {
481                        // Attempt to close the writer gracefully before exiting,
482                        // we ignore any error as the writer may already be closed.
483                        _ = active_writer.close().await;
484                        break;
485                    }
486                    ConnectionMode::Closed => break,
487                    _ => {}
488                }
489
490                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
491                    Ok(Some(msg)) => {
492                        // Re-check connection mode after receiving a message
493                        let mode = ConnectionMode::from_atomic(&connection_state);
494                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
495                            break;
496                        }
497
498                        match msg {
499                            WriterCommand::Update(new_writer) => {
500                                tracing::debug!("Received new writer");
501
502                                // Delay before closing connection
503                                tokio::time::sleep(Duration::from_millis(100)).await;
504
505                                // Attempt to close the writer gracefully on update,
506                                // we ignore any error as the writer may already be closed.
507                                _ = active_writer.close().await;
508
509                                active_writer = new_writer;
510                                tracing::debug!("Updated writer");
511                            }
512                            _ if mode.is_reconnect() => {
513                                tracing::warn!("Skipping message while reconnecting, {msg:?}");
514                                continue;
515                            }
516                            WriterCommand::Send(msg) => {
517                                if let Err(e) = active_writer.send(msg).await {
518                                    tracing::error!("Failed to send message: {e}");
519                                    // Mode is active so trigger reconnection
520                                    tracing::warn!("Writer triggering reconnect");
521                                    connection_state
522                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
523                                }
524                            }
525                        }
526                    }
527                    Ok(None) => {
528                        // Channel closed - writer task should terminate
529                        tracing::debug!("Writer channel closed, terminating writer task");
530                        break;
531                    }
532                    Err(_) => {
533                        // Timeout - just continue the loop
534                        continue;
535                    }
536                }
537            }
538
539            // Attempt to close the writer gracefully before exiting,
540            // we ignore any error as the writer may already be closed.
541            _ = active_writer.close().await;
542
543            log_task_stopped("write");
544        })
545    }
546
547    fn spawn_heartbeat_task(
548        connection_state: Arc<AtomicU8>,
549        heartbeat_secs: u64,
550        message: Option<String>,
551        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
552    ) -> tokio::task::JoinHandle<()> {
553        log_task_started("heartbeat");
554
555        tokio::task::spawn(async move {
556            let interval = Duration::from_secs(heartbeat_secs);
557
558            loop {
559                tokio::time::sleep(interval).await;
560
561                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
562                    ConnectionMode::Active => {
563                        let msg = match &message {
564                            Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
565                            None => WriterCommand::Send(Message::Ping(vec![].into())),
566                        };
567
568                        match writer_tx.send(msg) {
569                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
570                            Err(e) => {
571                                tracing::error!("Failed to send heartbeat to writer task: {e}");
572                            }
573                        }
574                    }
575                    ConnectionMode::Reconnect => continue,
576                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
577                }
578            }
579
580            log_task_stopped("heartbeat");
581        })
582    }
583}
584
585impl Drop for WebSocketClientInner {
586    fn drop(&mut self) {
587        if let Some(ref read_task) = self.read_task.take()
588            && !read_task.is_finished()
589        {
590            read_task.abort();
591            log_task_aborted("read");
592        }
593
594        if !self.write_task.is_finished() {
595            self.write_task.abort();
596            log_task_aborted("write");
597        }
598
599        if let Some(ref handle) = self.heartbeat_task.take()
600            && !handle.is_finished()
601        {
602            handle.abort();
603            log_task_aborted("heartbeat");
604        }
605    }
606}
607
608/// WebSocket client with automatic reconnection.
609///
610/// Handles connection state, Python callbacks, and rate limiting.
611/// See module docs for architecture details.
612#[cfg_attr(
613    feature = "python",
614    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
615)]
616pub struct WebSocketClient {
617    pub(crate) controller_task: tokio::task::JoinHandle<()>,
618    pub(crate) connection_mode: Arc<AtomicU8>,
619    pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
620    pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
621}
622
623impl Debug for WebSocketClient {
624    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
625        f.debug_struct(stringify!(WebSocketClient)).finish()
626    }
627}
628
629impl WebSocketClient {
630    /// Creates a websocket client that returns a stream for reading messages.
631    ///
632    /// # Errors
633    ///
634    /// Returns any error connecting to the server.
635    #[allow(clippy::too_many_arguments)]
636    pub async fn connect_stream(
637        config: WebSocketConfig,
638        keyed_quotas: Vec<(String, Quota)>,
639        default_quota: Option<Quota>,
640        post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
641    ) -> Result<(MessageReader, Self), Error> {
642        install_cryptographic_provider();
643        let (ws_stream, _) = connect_async(config.url.clone().into_client_request()?).await?;
644        let (writer, reader) = ws_stream.split();
645        let inner = WebSocketClientInner::connect_url(config).await?;
646
647        let connection_mode = inner.connection_mode.clone();
648
649        let writer_tx = inner.writer_tx.clone();
650        if let Err(e) = writer_tx.send(WriterCommand::Update(writer)) {
651            tracing::error!("{e}");
652        }
653
654        let controller_task = Self::spawn_controller_task(
655            inner,
656            connection_mode.clone(),
657            post_reconnect,
658            #[cfg(feature = "python")]
659            None, // no post_reconnection
660            #[cfg(feature = "python")]
661            None, // no post_disconnection
662        );
663
664        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
665
666        Ok((
667            reader,
668            Self {
669                controller_task,
670                connection_mode,
671                writer_tx,
672                rate_limiter,
673            },
674        ))
675    }
676
677    /// Creates a websocket client.
678    ///
679    /// Creates an inner client and controller task to reconnect or disconnect
680    /// the client. Also assumes ownership of writer from inner client.
681    ///
682    /// # Errors
683    ///
684    /// Returns any websocket error.
685    pub async fn connect(
686        config: WebSocketConfig,
687        #[cfg(feature = "python")] post_connection: Option<PyObject>,
688        #[cfg(feature = "python")] post_reconnection: Option<PyObject>,
689        #[cfg(feature = "python")] post_disconnection: Option<PyObject>,
690        keyed_quotas: Vec<(String, Quota)>,
691        default_quota: Option<Quota>,
692    ) -> Result<Self, Error> {
693        tracing::debug!("Connecting");
694        let inner = WebSocketClientInner::connect_url(config.clone()).await?;
695        let connection_mode = inner.connection_mode.clone();
696        let writer_tx = inner.writer_tx.clone();
697
698        let controller_task = Self::spawn_controller_task(
699            inner,
700            connection_mode.clone(),
701            None, // Rust handler
702            #[cfg(feature = "python")]
703            post_reconnection, // TODO: Deprecated
704            #[cfg(feature = "python")]
705            post_disconnection, // TODO: Deprecated
706        );
707
708        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
709
710        #[cfg(feature = "python")]
711        if let Some(handler) = post_connection {
712            Python::with_gil(|py| match handler.call0(py) {
713                Ok(_) => tracing::debug!("Called `post_connection` handler"),
714                Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
715            });
716        }
717
718        Ok(Self {
719            controller_task,
720            connection_mode,
721            writer_tx,
722            rate_limiter,
723        })
724    }
725
726    /// Returns the current connection mode.
727    #[must_use]
728    pub fn connection_mode(&self) -> ConnectionMode {
729        ConnectionMode::from_atomic(&self.connection_mode)
730    }
731
732    /// Check if the client connection is active.
733    ///
734    /// Returns `true` if the client is connected and has not been signalled to disconnect.
735    /// The client will automatically retry connection based on its configuration.
736    #[inline]
737    #[must_use]
738    pub fn is_active(&self) -> bool {
739        self.connection_mode().is_active()
740    }
741
742    /// Check if the client is disconnected.
743    #[must_use]
744    pub fn is_disconnected(&self) -> bool {
745        self.controller_task.is_finished()
746    }
747
748    /// Check if the client is reconnecting.
749    ///
750    /// Returns `true` if the client lost connection and is attempting to reestablish it.
751    /// The client will automatically retry connection based on its configuration.
752    #[inline]
753    #[must_use]
754    pub fn is_reconnecting(&self) -> bool {
755        self.connection_mode().is_reconnect()
756    }
757
758    /// Check if the client is disconnecting.
759    ///
760    /// Returns `true` if the client is in disconnect mode.
761    #[inline]
762    #[must_use]
763    pub fn is_disconnecting(&self) -> bool {
764        self.connection_mode().is_disconnect()
765    }
766
767    /// Check if the client is closed.
768    ///
769    /// Returns `true` if the client has been explicitly disconnected or reached
770    /// maximum reconnection attempts. In this state, the client cannot be reused
771    /// and a new client must be created for further connections.
772    #[inline]
773    #[must_use]
774    pub fn is_closed(&self) -> bool {
775        self.connection_mode().is_closed()
776    }
777
778    /// Set disconnect mode to true.
779    ///
780    /// Controller task will periodically check the disconnect mode
781    /// and shutdown the client if it is alive
782    pub async fn disconnect(&self) {
783        tracing::debug!("Disconnecting");
784        self.connection_mode
785            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
786
787        match tokio::time::timeout(Duration::from_secs(5), async {
788            while !self.is_disconnected() {
789                tokio::time::sleep(Duration::from_millis(10)).await;
790            }
791
792            if !self.controller_task.is_finished() {
793                self.controller_task.abort();
794                log_task_aborted("controller");
795            }
796        })
797        .await
798        {
799            Ok(()) => {
800                tracing::debug!("Controller task finished");
801            }
802            Err(_) => {
803                tracing::error!("Timeout waiting for controller task to finish");
804            }
805        }
806    }
807
808    /// Sends the given text `data` to the server.
809    ///
810    /// # Errors
811    ///
812    /// Returns a websocket error if unable to send.
813    #[allow(unused_variables)]
814    pub async fn send_text(
815        &self,
816        data: String,
817        keys: Option<Vec<String>>,
818    ) -> std::result::Result<(), SendError> {
819        self.rate_limiter.await_keys_ready(keys).await;
820
821        if !self.is_active() {
822            return Err(SendError::Closed);
823        }
824
825        tracing::trace!("Sending text: {data:?}");
826
827        let msg = Message::Text(data.into());
828        self.writer_tx
829            .send(WriterCommand::Send(msg))
830            .map_err(|e| SendError::BrokenPipe(e.to_string()))
831    }
832
833    /// Sends the given bytes `data` to the server.
834    ///
835    /// # Errors
836    ///
837    /// Returns a websocket error if unable to send.
838    #[allow(unused_variables)]
839    pub async fn send_bytes(
840        &self,
841        data: Vec<u8>,
842        keys: Option<Vec<String>>,
843    ) -> std::result::Result<(), SendError> {
844        self.rate_limiter.await_keys_ready(keys).await;
845
846        if !self.is_active() {
847            return Err(SendError::Closed);
848        }
849
850        tracing::trace!("Sending bytes: {data:?}");
851
852        let msg = Message::Binary(data.into());
853        self.writer_tx
854            .send(WriterCommand::Send(msg))
855            .map_err(|e| SendError::BrokenPipe(e.to_string()))
856    }
857
858    /// Sends a close message to the server.
859    ///
860    /// # Errors
861    ///
862    /// Returns a websocket error if unable to send.
863    pub async fn send_close_message(&self) -> std::result::Result<(), SendError> {
864        if !self.is_active() {
865            return Err(SendError::Closed);
866        }
867
868        let msg = Message::Close(None);
869        self.writer_tx
870            .send(WriterCommand::Send(msg))
871            .map_err(|e| SendError::BrokenPipe(e.to_string()))
872    }
873
874    fn spawn_controller_task(
875        mut inner: WebSocketClientInner,
876        connection_mode: Arc<AtomicU8>,
877        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
878        #[cfg(feature = "python")] py_post_reconnection: Option<PyObject>, // TODO: Deprecated
879        #[cfg(feature = "python")] py_post_disconnection: Option<PyObject>, // TODO: Deprecated
880    ) -> tokio::task::JoinHandle<()> {
881        tokio::task::spawn(async move {
882            log_task_started("controller");
883
884            let check_interval = Duration::from_millis(10);
885
886            loop {
887                tokio::time::sleep(check_interval).await;
888                let mode = ConnectionMode::from_atomic(&connection_mode);
889
890                if mode.is_disconnect() {
891                    tracing::debug!("Disconnecting");
892
893                    let timeout = Duration::from_secs(5);
894                    if tokio::time::timeout(timeout, async {
895                        // Delay awaiting graceful shutdown
896                        tokio::time::sleep(Duration::from_millis(100)).await;
897
898                        if let Some(task) = &inner.read_task
899                            && !task.is_finished()
900                        {
901                            task.abort();
902                            log_task_aborted("read");
903                        }
904
905                        if let Some(task) = &inner.heartbeat_task
906                            && !task.is_finished()
907                        {
908                            task.abort();
909                            log_task_aborted("heartbeat");
910                        }
911                    })
912                    .await
913                    .is_err()
914                    {
915                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
916                    }
917
918                    tracing::debug!("Closed");
919
920                    #[cfg(feature = "python")]
921                    if let Some(ref handler) = py_post_disconnection {
922                        Python::with_gil(|py| match handler.call0(py) {
923                            Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
924                            Err(e) => {
925                                tracing::error!("Error calling `post_disconnection` handler: {e}");
926                            }
927                        });
928                    }
929                    break; // Controller finished
930                }
931
932                if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
933                    match inner.reconnect().await {
934                        Ok(()) => {
935                            inner.backoff.reset();
936
937                            // Only invoke callbacks if not in disconnect state
938                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
939                                if let Some(ref callback) = post_reconnection {
940                                    callback();
941                                }
942
943                                // TODO: Python based websocket handlers deprecated (will be removed)
944                                #[cfg(feature = "python")]
945                                if let Some(ref callback) = py_post_reconnection {
946                                    Python::with_gil(|py| match callback.call0(py) {
947                                        Ok(_) => {
948                                            tracing::debug!("Called `post_reconnection` handler");
949                                        }
950                                        Err(e) => tracing::error!(
951                                            "Error calling `post_reconnection` handler: {e}"
952                                        ),
953                                    });
954                                }
955
956                                tracing::debug!("Reconnected successfully");
957                            } else {
958                                tracing::debug!(
959                                    "Skipping post_reconnection handlers due to disconnect state"
960                                );
961                            }
962                        }
963                        Err(e) => {
964                            let duration = inner.backoff.next_duration();
965                            tracing::warn!("Reconnect attempt failed: {e}");
966                            if !duration.is_zero() {
967                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
968                            }
969                            tokio::time::sleep(duration).await;
970                        }
971                    }
972                }
973            }
974            inner
975                .connection_mode
976                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
977
978            log_task_stopped("controller");
979        })
980    }
981}
982
983// Abort controller task on drop to clean up background tasks
984impl Drop for WebSocketClient {
985    fn drop(&mut self) {
986        if !self.controller_task.is_finished() {
987            self.controller_task.abort();
988            log_task_aborted("controller");
989        }
990    }
991}
992
993////////////////////////////////////////////////////////////////////////////////
994// Tests
995////////////////////////////////////////////////////////////////////////////////
996#[cfg(feature = "python")]
997#[cfg(test)]
998#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
999mod tests {
1000    use std::{num::NonZeroU32, sync::Arc};
1001
1002    use futures_util::{SinkExt, StreamExt};
1003    use tokio::{
1004        net::TcpListener,
1005        task::{self, JoinHandle},
1006    };
1007    use tokio_tungstenite::{
1008        accept_hdr_async,
1009        tungstenite::{
1010            handshake::server::{self, Callback},
1011            http::HeaderValue,
1012        },
1013    };
1014
1015    use crate::{
1016        ratelimiter::quota::Quota,
1017        websocket::{Consumer, WebSocketClient, WebSocketConfig},
1018    };
1019
1020    struct TestServer {
1021        task: JoinHandle<()>,
1022        port: u16,
1023    }
1024
1025    #[derive(Debug, Clone)]
1026    struct TestCallback {
1027        key: String,
1028        value: HeaderValue,
1029    }
1030
1031    impl Callback for TestCallback {
1032        fn on_request(
1033            self,
1034            request: &server::Request,
1035            response: server::Response,
1036        ) -> Result<server::Response, server::ErrorResponse> {
1037            let _ = response;
1038            let value = request.headers().get(&self.key);
1039            assert!(value.is_some());
1040
1041            if let Some(value) = request.headers().get(&self.key) {
1042                assert_eq!(value, self.value);
1043            }
1044
1045            Ok(response)
1046        }
1047    }
1048
1049    impl TestServer {
1050        async fn setup() -> Self {
1051            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1052            let port = TcpListener::local_addr(&server).unwrap().port();
1053
1054            let header_key = "test".to_string();
1055            let header_value = "test".to_string();
1056
1057            let test_call_back = TestCallback {
1058                key: header_key,
1059                value: HeaderValue::from_str(&header_value).unwrap(),
1060            };
1061
1062            let task = task::spawn(async move {
1063                // Keep accepting connections
1064                loop {
1065                    let (conn, _) = server.accept().await.unwrap();
1066                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1067                        .await
1068                        .unwrap();
1069
1070                    task::spawn(async move {
1071                        while let Some(Ok(msg)) = websocket.next().await {
1072                            match msg {
1073                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1074                                    if txt == "close-now" =>
1075                                {
1076                                    tracing::debug!("Forcibly closing from server side");
1077                                    // This sends a close frame, then stops reading
1078                                    let _ = websocket.close(None).await;
1079                                    break;
1080                                }
1081                                // Echo text/binary frames
1082                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1083                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1084                                    if websocket.send(msg).await.is_err() {
1085                                        break;
1086                                    }
1087                                }
1088                                // If the client closes, we also break
1089                                tokio_tungstenite::tungstenite::protocol::Message::Close(
1090                                    _frame,
1091                                ) => {
1092                                    let _ = websocket.close(None).await;
1093                                    break;
1094                                }
1095                                // Ignore pings/pongs
1096                                _ => {}
1097                            }
1098                        }
1099                    });
1100                }
1101            });
1102
1103            Self { task, port }
1104        }
1105    }
1106
1107    impl Drop for TestServer {
1108        fn drop(&mut self) {
1109            self.task.abort();
1110        }
1111    }
1112
1113    async fn setup_test_client(port: u16) -> WebSocketClient {
1114        let config = WebSocketConfig {
1115            url: format!("ws://127.0.0.1:{port}"),
1116            headers: vec![("test".into(), "test".into())],
1117            handler: Consumer::Python(None),
1118            heartbeat: None,
1119            heartbeat_msg: None,
1120            ping_handler: None,
1121            reconnect_timeout_ms: None,
1122            reconnect_delay_initial_ms: None,
1123            reconnect_backoff_factor: None,
1124            reconnect_delay_max_ms: None,
1125            reconnect_jitter_ms: None,
1126        };
1127        WebSocketClient::connect(config, None, None, None, vec![], None)
1128            .await
1129            .expect("Failed to connect")
1130    }
1131
1132    #[tokio::test]
1133    async fn test_websocket_basic() {
1134        let server = TestServer::setup().await;
1135        let client = setup_test_client(server.port).await;
1136
1137        assert!(!client.is_disconnected());
1138
1139        client.disconnect().await;
1140        assert!(client.is_disconnected());
1141    }
1142
1143    #[tokio::test]
1144    async fn test_websocket_heartbeat() {
1145        let server = TestServer::setup().await;
1146        let client = setup_test_client(server.port).await;
1147
1148        // Wait ~3s => server should see multiple "ping"
1149        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1150
1151        // Cleanup
1152        client.disconnect().await;
1153        assert!(client.is_disconnected());
1154    }
1155
1156    #[tokio::test]
1157    async fn test_websocket_reconnect_exhausted() {
1158        let config = WebSocketConfig {
1159            url: "ws://127.0.0.1:9997".into(), // <-- No server
1160            headers: vec![],
1161            handler: Consumer::Python(None),
1162            heartbeat: None,
1163            heartbeat_msg: None,
1164            ping_handler: None,
1165            reconnect_timeout_ms: None,
1166            reconnect_delay_initial_ms: None,
1167            reconnect_backoff_factor: None,
1168            reconnect_delay_max_ms: None,
1169            reconnect_jitter_ms: None,
1170        };
1171        let res = WebSocketClient::connect(config, None, None, None, vec![], None).await;
1172        assert!(res.is_err(), "Should fail quickly with no server");
1173    }
1174
1175    #[tokio::test]
1176    async fn test_websocket_forced_close_reconnect() {
1177        let server = TestServer::setup().await;
1178        let client = setup_test_client(server.port).await;
1179
1180        // 1) Send normal message
1181        client.send_text("Hello".into(), None).await.unwrap();
1182
1183        // 2) Trigger forced close from server
1184        client.send_text("close-now".into(), None).await.unwrap();
1185
1186        // 3) Wait a bit => read loop sees close => reconnect
1187        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1188
1189        // Confirm not disconnected
1190        assert!(!client.is_disconnected());
1191
1192        // Cleanup
1193        client.disconnect().await;
1194        assert!(client.is_disconnected());
1195    }
1196
1197    #[tokio::test]
1198    async fn test_rate_limiter() {
1199        let server = TestServer::setup().await;
1200        let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1201
1202        let config = WebSocketConfig {
1203            url: format!("ws://127.0.0.1:{}", server.port),
1204            headers: vec![("test".into(), "test".into())],
1205            handler: Consumer::Python(None),
1206            heartbeat: None,
1207            heartbeat_msg: None,
1208            ping_handler: None,
1209            reconnect_timeout_ms: None,
1210            reconnect_delay_initial_ms: None,
1211            reconnect_backoff_factor: None,
1212            reconnect_delay_max_ms: None,
1213            reconnect_jitter_ms: None,
1214        };
1215
1216        let client = WebSocketClient::connect(
1217            config,
1218            None,
1219            None,
1220            None,
1221            vec![("default".into(), quota)],
1222            None,
1223        )
1224        .await
1225        .unwrap();
1226
1227        // First 2 should succeed
1228        client.send_text("test1".into(), None).await.unwrap();
1229        client.send_text("test2".into(), None).await.unwrap();
1230
1231        // Third should error
1232        client.send_text("test3".into(), None).await.unwrap();
1233
1234        // Cleanup
1235        client.disconnect().await;
1236        assert!(client.is_disconnected());
1237    }
1238
1239    #[tokio::test]
1240    async fn test_concurrent_writers() {
1241        let server = TestServer::setup().await;
1242        let client = Arc::new(setup_test_client(server.port).await);
1243
1244        let mut handles = vec![];
1245        for i in 0..10 {
1246            let client = client.clone();
1247            handles.push(task::spawn(async move {
1248                client.send_text(format!("test{i}"), None).await.unwrap();
1249            }));
1250        }
1251
1252        for handle in handles {
1253            handle.await.unwrap();
1254        }
1255
1256        // Cleanup
1257        client.disconnect().await;
1258        assert!(client.is_disconnected());
1259    }
1260}
1261
1262#[cfg(test)]
1263mod rust_tests {
1264    use tokio::{
1265        net::TcpListener,
1266        task,
1267        time::{Duration, sleep},
1268    };
1269    use tokio_tungstenite::accept_async;
1270
1271    use super::*;
1272
1273    #[tokio::test]
1274    async fn test_reconnect_then_disconnect() {
1275        // Bind an ephemeral port
1276        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1277        let port = listener.local_addr().unwrap().port();
1278
1279        // Server task: accept one ws connection then close it
1280        let server = task::spawn(async move {
1281            let (stream, _) = listener.accept().await.unwrap();
1282            let ws = accept_async(stream).await.unwrap();
1283            drop(ws);
1284            // Keep alive briefly
1285            sleep(Duration::from_secs(1)).await;
1286        });
1287
1288        // Build a rust consumer for incoming messages (unused here)
1289        let (consumer, _rx) = Consumer::rust_consumer();
1290
1291        // Configure client with short reconnect backoff
1292        let config = WebSocketConfig {
1293            url: format!("ws://127.0.0.1:{port}"),
1294            headers: vec![],
1295            handler: consumer,
1296            heartbeat: None,
1297            heartbeat_msg: None,
1298            #[cfg(feature = "python")]
1299            ping_handler: None,
1300            reconnect_timeout_ms: Some(1_000),
1301            reconnect_delay_initial_ms: Some(50),
1302            reconnect_delay_max_ms: Some(100),
1303            reconnect_backoff_factor: Some(1.0),
1304            reconnect_jitter_ms: Some(0),
1305        };
1306
1307        // Connect the client
1308        let client = {
1309            #[cfg(feature = "python")]
1310            {
1311                WebSocketClient::connect(config.clone(), None, None, None, vec![], None)
1312                    .await
1313                    .unwrap()
1314            }
1315            #[cfg(not(feature = "python"))]
1316            {
1317                WebSocketClient::connect(config.clone(), vec![], None)
1318                    .await
1319                    .unwrap()
1320            }
1321        };
1322
1323        // Allow server to drop connection and client to detect
1324        sleep(Duration::from_millis(100)).await;
1325        // Now immediately disconnect the client
1326        client.disconnect().await;
1327        assert!(client.is_disconnected());
1328        server.abort();
1329    }
1330}