nautilus_network/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance WebSocket client implementation with automatic reconnection
17//! with exponential backoff and state management.
18
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED)
21//! - Synchronized reconnection with backoff
22//! - Split read/write architecture
23//! - Python callback integration
24//!
25//! **Design**:
26//! - Single reader, multiple writer model
27//! - Read half runs in dedicated task
28//! - Write half runs in dedicated task connected with channel
29//! - Controller task manages lifecycle
30
31use std::{
32    sync::{
33        Arc,
34        atomic::{AtomicU8, Ordering},
35    },
36    time::Duration,
37};
38
39use futures_util::{
40    SinkExt, StreamExt,
41    stream::{SplitSink, SplitStream},
42};
43use http::HeaderName;
44use nautilus_cryptography::providers::install_cryptographic_provider;
45use pyo3::{prelude::*, types::PyBytes};
46use tokio::{
47    net::TcpStream,
48    sync::mpsc::{self, Receiver, Sender},
49};
50use tokio_tungstenite::{
51    MaybeTlsStream, WebSocketStream, connect_async,
52    tungstenite::{Error, Message, client::IntoClientRequest, http::HeaderValue},
53};
54
55use crate::{
56    backoff::ExponentialBackoff,
57    mode::ConnectionMode,
58    ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
59};
60type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
61pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
62
63#[derive(Debug, Clone)]
64pub enum Consumer {
65    Python(Option<Arc<PyObject>>),
66    Rust(Sender<Message>),
67}
68
69impl Consumer {
70    #[must_use]
71    pub fn rust_consumer() -> (Self, Receiver<Message>) {
72        let (tx, rx) = mpsc::channel(100);
73        (Self::Rust(tx), rx)
74    }
75}
76
77#[derive(Debug, Clone)]
78#[cfg_attr(
79    feature = "python",
80    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
81)]
82pub struct WebSocketConfig {
83    /// The URL to connect to.
84    pub url: String,
85    /// The default headers.
86    pub headers: Vec<(String, String)>,
87    /// The Python function to handle incoming messages.
88    pub handler: Consumer,
89    /// The optional heartbeat interval (seconds).
90    pub heartbeat: Option<u64>,
91    /// The optional heartbeat message.
92    pub heartbeat_msg: Option<String>,
93    /// The handler for incoming pings.
94    pub ping_handler: Option<Arc<PyObject>>,
95    /// The timeout (milliseconds) for reconnection attempts.
96    pub reconnect_timeout_ms: Option<u64>,
97    /// The initial reconnection delay (milliseconds) for reconnects.
98    pub reconnect_delay_initial_ms: Option<u64>,
99    /// The maximum reconnect delay (milliseconds) for exponential backoff.
100    pub reconnect_delay_max_ms: Option<u64>,
101    /// The exponential backoff factor for reconnection delays.
102    pub reconnect_backoff_factor: Option<f64>,
103    /// The maximum jitter (milliseconds) added to reconnection delays.
104    pub reconnect_jitter_ms: Option<u64>,
105}
106
107/// Represents a command for the writer task.
108#[derive(Debug)]
109pub(crate) enum WriterCommand {
110    /// Update the writer reference with a new one after reconnection.
111    Update(MessageWriter),
112    /// Send message to the server.
113    Send(Message),
114}
115
116/// `WebSocketClient` connects to a websocket server to read and send messages.
117///
118/// The client is opinionated about how messages are read and written. It
119/// assumes that data can only have one reader but multiple writers.
120///
121/// The client splits the connection into read and write halves. It moves
122/// the read half into a tokio task which keeps receiving messages from the
123/// server and calls a handler - a Python function that takes the data
124/// as its parameter. It stores the write half in the struct wrapped
125/// with an Arc Mutex. This way the client struct can be used to write
126/// data to the server from multiple scopes/tasks.
127///
128/// The client also maintains a heartbeat if given a duration in seconds.
129/// It's preferable to set the duration slightly lower - heartbeat more
130/// frequently - than the required amount.
131struct WebSocketClientInner {
132    config: WebSocketConfig,
133    read_task: Option<tokio::task::JoinHandle<()>>,
134    write_task: tokio::task::JoinHandle<()>,
135    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
136    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
137    connection_mode: Arc<AtomicU8>,
138    reconnect_timeout: Duration,
139    backoff: ExponentialBackoff,
140}
141
142impl WebSocketClientInner {
143    /// Create an inner websocket client.
144    pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
145        install_cryptographic_provider();
146
147        #[allow(unused_variables)]
148        let WebSocketConfig {
149            url,
150            handler,
151            heartbeat,
152            headers,
153            heartbeat_msg,
154            ping_handler,
155            reconnect_timeout_ms,
156            reconnect_delay_initial_ms,
157            reconnect_delay_max_ms,
158            reconnect_backoff_factor,
159            reconnect_jitter_ms,
160        } = &config;
161        let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
162
163        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
164
165        let read_task = match &handler {
166            Consumer::Python(handler) => handler.as_ref().map(|handler| {
167                Self::spawn_python_callback_task(
168                    connection_mode.clone(),
169                    reader,
170                    handler.clone(),
171                    ping_handler.clone(),
172                )
173            }),
174            Consumer::Rust(sender) => Some(Self::spawn_rust_streaming_task(
175                connection_mode.clone(),
176                reader,
177                sender.clone(),
178            )),
179        };
180
181        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
182        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
183
184        // Optionally spawn a heartbeat task to periodically ping server
185        let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
186            Self::spawn_heartbeat_task(
187                connection_mode.clone(),
188                *heartbeat_secs,
189                heartbeat_msg.clone(),
190                writer_tx.clone(),
191            )
192        });
193
194        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
195        let backoff = ExponentialBackoff::new(
196            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
197            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
198            reconnect_backoff_factor.unwrap_or(1.5),
199            reconnect_jitter_ms.unwrap_or(100),
200            true, // immediate-first
201        );
202
203        Ok(Self {
204            config,
205            read_task,
206            write_task,
207            writer_tx,
208            heartbeat_task,
209            connection_mode,
210            reconnect_timeout,
211            backoff,
212        })
213    }
214
215    /// Connects with the server creating a tokio-tungstenite websocket stream.
216    #[inline]
217    pub async fn connect_with_server(
218        url: &str,
219        headers: Vec<(String, String)>,
220    ) -> Result<(MessageWriter, MessageReader), Error> {
221        let mut request = url.into_client_request()?;
222        let req_headers = request.headers_mut();
223
224        let mut header_names: Vec<HeaderName> = Vec::new();
225        for (key, val) in headers {
226            let header_value = HeaderValue::from_str(&val)?;
227            let header_name: HeaderName = key.parse()?;
228            header_names.push(header_name.clone());
229            req_headers.insert(header_name, header_value);
230        }
231
232        connect_async(request).await.map(|resp| resp.0.split())
233    }
234
235    /// Reconnect with server.
236    ///
237    /// Make a new connection with server. Use the new read and write halves
238    /// to update self writer and read and heartbeat tasks.
239    pub async fn reconnect(&mut self) -> Result<(), Error> {
240        tracing::debug!("Reconnecting");
241
242        tokio::time::timeout(self.reconnect_timeout, async {
243            let (new_writer, reader) =
244                Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
245
246            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
247                tracing::error!("{e}");
248            }
249
250            // Delay before closing connection
251            tokio::time::sleep(Duration::from_millis(100)).await;
252
253            if let Some(ref read_task) = self.read_task.take() {
254                if !read_task.is_finished() {
255                    read_task.abort();
256                    tracing::debug!("Aborted task 'read'");
257                }
258            }
259
260            self.connection_mode
261                .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
262
263            self.read_task = match &self.config.handler {
264                Consumer::Python(handler) => handler.as_ref().map(|handler| {
265                    Self::spawn_python_callback_task(
266                        self.connection_mode.clone(),
267                        reader,
268                        handler.clone(),
269                        self.config.ping_handler.clone(),
270                    )
271                }),
272                Consumer::Rust(sender) => Some(Self::spawn_rust_streaming_task(
273                    self.connection_mode.clone(),
274                    reader,
275                    sender.clone(),
276                )),
277            };
278
279            tracing::debug!("Reconnect succeeded");
280            Ok(())
281        })
282        .await
283        .map_err(|_| {
284            Error::Io(std::io::Error::new(
285                std::io::ErrorKind::TimedOut,
286                format!(
287                    "reconnection timed out after {}s",
288                    self.reconnect_timeout.as_secs_f64()
289                ),
290            ))
291        })?
292    }
293
294    /// Check if the client is still connected.
295    ///
296    /// The client is connected if the read task has not finished. It is expected
297    /// that in case of any failure client or server side. The read task will be
298    /// shutdown or will receive a `Close` frame which will finish it. There
299    /// might be some delay between the connection being closed and the client
300    /// detecting.
301    #[inline]
302    #[must_use]
303    pub fn is_alive(&self) -> bool {
304        match &self.read_task {
305            Some(read_task) => !read_task.is_finished(),
306            None => true, // Stream is being used directly
307        }
308    }
309
310    fn spawn_rust_streaming_task(
311        connection_state: Arc<AtomicU8>,
312        mut reader: MessageReader,
313        sender: Sender<Message>,
314    ) -> tokio::task::JoinHandle<()> {
315        tracing::debug!("Started streaming task 'read'");
316
317        let check_interval = Duration::from_millis(10);
318
319        tokio::task::spawn(async move {
320            loop {
321                if !ConnectionMode::from_atomic(&connection_state).is_active() {
322                    break;
323                }
324
325                match tokio::time::timeout(check_interval, reader.next()).await {
326                    Ok(Some(Ok(message))) => {
327                        if let Err(e) = sender.send(message).await {
328                            tracing::error!("Failed to send message: {e}");
329                        }
330                    }
331                    Ok(Some(Err(e))) => {
332                        tracing::error!("Received error message - terminating: {e}");
333                        break;
334                    }
335                    Ok(None) => {
336                        tracing::debug!("No message received - terminating");
337                        break;
338                    }
339                    Err(_) => {
340                        // Timeout - continue loop and check connection mode
341                        continue;
342                    }
343                }
344            }
345        })
346    }
347
348    fn spawn_python_callback_task(
349        connection_state: Arc<AtomicU8>,
350        mut reader: MessageReader,
351        handler: Arc<PyObject>,
352        ping_handler: Option<Arc<PyObject>>,
353    ) -> tokio::task::JoinHandle<()> {
354        tracing::debug!("Started task 'read'");
355
356        // Interval between checking the connection mode
357        let check_interval = Duration::from_millis(10);
358
359        tokio::task::spawn(async move {
360            loop {
361                if !ConnectionMode::from_atomic(&connection_state).is_active() {
362                    break;
363                }
364
365                match tokio::time::timeout(check_interval, reader.next()).await {
366                    Ok(Some(Ok(Message::Binary(data)))) => {
367                        tracing::trace!("Received message <binary> {} bytes", data.len());
368                        if let Err(e) =
369                            Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &data),)))
370                        {
371                            tracing::error!("Error calling handler: {e}");
372                            break;
373                        }
374                        continue;
375                    }
376                    Ok(Some(Ok(Message::Text(data)))) => {
377                        tracing::trace!("Received message: {data}");
378                        if let Err(e) = Python::with_gil(|py| {
379                            handler.call1(py, (PyBytes::new(py, data.as_bytes()),))
380                        }) {
381                            tracing::error!("Error calling handler: {e}");
382                            break;
383                        }
384                        continue;
385                    }
386                    Ok(Some(Ok(Message::Ping(ping)))) => {
387                        tracing::trace!("Received ping: {ping:?}");
388                        if let Some(ref handler) = ping_handler {
389                            if let Err(e) =
390                                Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &ping),)))
391                            {
392                                tracing::error!("Error calling handler: {e}");
393                                break;
394                            }
395                        }
396                        continue;
397                    }
398                    Ok(Some(Ok(Message::Pong(_)))) => {
399                        tracing::trace!("Received pong");
400                    }
401                    Ok(Some(Ok(Message::Close(_)))) => {
402                        tracing::debug!("Received close message - terminating");
403                        break;
404                    }
405                    Ok(Some(Ok(_))) => (),
406                    Ok(Some(Err(e))) => {
407                        tracing::error!("Received error message - terminating: {e}");
408                        break;
409                    }
410                    // Internally tungstenite considers the connection closed when polling
411                    // for the next message in the stream returns None.
412                    Ok(None) => {
413                        tracing::debug!("No message received - terminating");
414                        break;
415                    }
416                    Err(_) => {
417                        // Timeout - continue loop and check connection mode
418                        continue;
419                    }
420                }
421            }
422        })
423    }
424
425    fn spawn_write_task(
426        connection_state: Arc<AtomicU8>,
427        writer: MessageWriter,
428        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
429    ) -> tokio::task::JoinHandle<()> {
430        tracing::debug!("Started task 'write'");
431
432        // Interval between checking the connection mode
433        let check_interval = Duration::from_millis(10);
434
435        tokio::task::spawn(async move {
436            let mut active_writer = writer;
437
438            loop {
439                match ConnectionMode::from_atomic(&connection_state) {
440                    ConnectionMode::Disconnect => {
441                        // Attempt to close the writer gracefully before exiting,
442                        // we ignore any error as the writer may already be closed.
443                        _ = active_writer.close().await;
444                        break;
445                    }
446                    ConnectionMode::Closed => break,
447                    _ => {}
448                }
449
450                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
451                    Ok(Some(msg)) => {
452                        // Re-check connection mode after receiving a message
453                        let mode = ConnectionMode::from_atomic(&connection_state);
454                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
455                            break;
456                        }
457
458                        match msg {
459                            WriterCommand::Update(new_writer) => {
460                                tracing::debug!("Received new writer");
461
462                                // Delay before closing connection
463                                tokio::time::sleep(Duration::from_millis(100)).await;
464
465                                // Attempt to close the writer gracefully on update,
466                                // we ignore any error as the writer may already be closed.
467                                _ = active_writer.close().await;
468
469                                active_writer = new_writer;
470                                tracing::debug!("Updated writer");
471                            }
472                            _ if mode.is_reconnect() => {
473                                tracing::warn!("Skipping message while reconnecting, {msg:?}");
474                                continue;
475                            }
476                            WriterCommand::Send(msg) => {
477                                if let Err(e) = active_writer.send(msg).await {
478                                    tracing::error!("Failed to send message: {e}");
479                                    // Mode is active so trigger reconnection
480                                    tracing::warn!("Writer triggering reconnect");
481                                    connection_state
482                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
483                                }
484                            }
485                        }
486                    }
487                    Ok(None) => {
488                        // Channel closed - writer task should terminate
489                        tracing::debug!("Writer channel closed, terminating writer task");
490                        break;
491                    }
492                    Err(_) => {
493                        // Timeout - just continue the loop
494                        continue;
495                    }
496                }
497            }
498
499            // Attempt to close the writer gracefully before exiting,
500            // we ignore any error as the writer may already be closed.
501            _ = active_writer.close().await;
502
503            tracing::debug!("Completed task 'write'");
504        })
505    }
506
507    fn spawn_heartbeat_task(
508        connection_state: Arc<AtomicU8>,
509        heartbeat_secs: u64,
510        message: Option<String>,
511        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
512    ) -> tokio::task::JoinHandle<()> {
513        tracing::debug!("Started task 'heartbeat'");
514
515        tokio::task::spawn(async move {
516            let interval = Duration::from_secs(heartbeat_secs);
517
518            loop {
519                tokio::time::sleep(interval).await;
520
521                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
522                    ConnectionMode::Active => {
523                        let msg = match &message {
524                            Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
525                            None => WriterCommand::Send(Message::Ping(vec![].into())),
526                        };
527
528                        match writer_tx.send(msg) {
529                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
530                            Err(e) => {
531                                tracing::error!("Failed to send heartbeat to writer task: {e}");
532                            }
533                        }
534                    }
535                    ConnectionMode::Reconnect => continue,
536                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
537                }
538            }
539
540            tracing::debug!("Completed task 'heartbeat'");
541        })
542    }
543}
544
545impl Drop for WebSocketClientInner {
546    fn drop(&mut self) {
547        if let Some(ref read_task) = self.read_task.take() {
548            if !read_task.is_finished() {
549                read_task.abort();
550                tracing::debug!("Aborted task 'read'");
551            }
552        }
553
554        if !self.write_task.is_finished() {
555            self.write_task.abort();
556            tracing::debug!("Aborted task 'write'");
557        }
558
559        if let Some(ref handle) = self.heartbeat_task.take() {
560            if !handle.is_finished() {
561                handle.abort();
562                tracing::debug!("Aborted task 'heartbeat'");
563            }
564        }
565    }
566}
567
568/// WebSocket client with automatic reconnection.
569///
570/// Handles connection state, Python callbacks, and rate limiting.
571/// See module docs for architecture details.
572#[cfg_attr(
573    feature = "python",
574    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
575)]
576pub struct WebSocketClient {
577    pub(crate) controller_task: tokio::task::JoinHandle<()>,
578    pub(crate) connection_mode: Arc<AtomicU8>,
579    pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
580    pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
581}
582
583impl WebSocketClient {
584    /// Creates a websocket client that returns a stream for reading messages.
585    ///
586    /// # Errors
587    ///
588    /// Returns any error connecting to the server.
589    #[allow(clippy::too_many_arguments)]
590    pub async fn connect_stream(
591        config: WebSocketConfig,
592        keyed_quotas: Vec<(String, Quota)>,
593        default_quota: Option<Quota>,
594        post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
595    ) -> Result<(MessageReader, Self), Error> {
596        install_cryptographic_provider();
597        let (ws_stream, _) = connect_async(config.url.clone().into_client_request()?).await?;
598        let (writer, reader) = ws_stream.split();
599        let inner = WebSocketClientInner::connect_url(config).await?;
600
601        let connection_mode = inner.connection_mode.clone();
602        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
603        let writer_tx = inner.writer_tx.clone();
604        if let Err(e) = writer_tx.send(WriterCommand::Update(writer)) {
605            tracing::error!("{e}");
606        }
607
608        let controller_task = Self::spawn_controller_task(
609            inner,
610            connection_mode.clone(),
611            post_reconnect,
612            None, // no post_reconnection
613            None, // no post_disconnection
614        );
615
616        Ok((
617            reader,
618            Self {
619                controller_task,
620                connection_mode,
621                writer_tx,
622                rate_limiter,
623            },
624        ))
625    }
626
627    /// Creates a websocket client.
628    ///
629    /// Creates an inner client and controller task to reconnect or disconnect
630    /// the client. Also assumes ownership of writer from inner client.
631    ///
632    /// # Errors
633    ///
634    /// Returns any websocket error.
635    pub async fn connect(
636        config: WebSocketConfig,
637        post_connection: Option<PyObject>,
638        post_reconnection: Option<PyObject>,
639        post_disconnection: Option<PyObject>,
640        keyed_quotas: Vec<(String, Quota)>,
641        default_quota: Option<Quota>,
642    ) -> Result<Self, Error> {
643        tracing::debug!("Connecting");
644        let inner = WebSocketClientInner::connect_url(config.clone()).await?;
645        let connection_mode = inner.connection_mode.clone();
646        let writer_tx = inner.writer_tx.clone();
647
648        let controller_task = Self::spawn_controller_task(
649            inner,
650            connection_mode.clone(),
651            None,               // Rust handler
652            post_reconnection,  // TODO: Deprecated
653            post_disconnection, // TODO: Deprecated
654        );
655        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
656
657        if let Some(handler) = post_connection {
658            Python::with_gil(|py| match handler.call0(py) {
659                Ok(_) => tracing::debug!("Called `post_connection` handler"),
660                Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
661            });
662        }
663
664        Ok(Self {
665            controller_task,
666            connection_mode,
667            writer_tx,
668            rate_limiter,
669        })
670    }
671
672    /// Returns the current connection mode.
673    #[must_use]
674    pub fn connection_mode(&self) -> ConnectionMode {
675        ConnectionMode::from_atomic(&self.connection_mode)
676    }
677
678    /// Check if the client connection is active.
679    ///
680    /// Returns `true` if the client is connected and has not been signalled to disconnect.
681    /// The client will automatically retry connection based on its configuration.
682    #[inline]
683    #[must_use]
684    pub fn is_active(&self) -> bool {
685        self.connection_mode().is_active()
686    }
687
688    /// Check if the client is disconnected.
689    #[must_use]
690    pub fn is_disconnected(&self) -> bool {
691        self.controller_task.is_finished()
692    }
693
694    /// Check if the client is reconnecting.
695    ///
696    /// Returns `true` if the client lost connection and is attempting to reestablish it.
697    /// The client will automatically retry connection based on its configuration.
698    #[inline]
699    #[must_use]
700    pub fn is_reconnecting(&self) -> bool {
701        self.connection_mode().is_reconnect()
702    }
703
704    /// Check if the client is disconnecting.
705    ///
706    /// Returns `true` if the client is in disconnect mode.
707    #[inline]
708    #[must_use]
709    pub fn is_disconnecting(&self) -> bool {
710        self.connection_mode().is_disconnect()
711    }
712
713    /// Check if the client is closed.
714    ///
715    /// Returns `true` if the client has been explicitly disconnected or reached
716    /// maximum reconnection attempts. In this state, the client cannot be reused
717    /// and a new client must be created for further connections.
718    #[inline]
719    #[must_use]
720    pub fn is_closed(&self) -> bool {
721        self.connection_mode().is_closed()
722    }
723
724    /// Set disconnect mode to true.
725    ///
726    /// Controller task will periodically check the disconnect mode
727    /// and shutdown the client if it is alive
728    pub async fn disconnect(&self) {
729        tracing::debug!("Disconnecting");
730        self.connection_mode
731            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
732
733        match tokio::time::timeout(Duration::from_secs(5), async {
734            while !self.is_disconnected() {
735                tokio::time::sleep(Duration::from_millis(10)).await;
736            }
737
738            if !self.controller_task.is_finished() {
739                self.controller_task.abort();
740                tracing::debug!("Aborted task 'controller'");
741            }
742        })
743        .await
744        {
745            Ok(()) => {
746                tracing::debug!("Controller task finished");
747            }
748            Err(_) => {
749                tracing::error!("Timeout waiting for controller task to finish");
750            }
751        }
752    }
753
754    /// Sends the given text `data` to the server.
755    pub async fn send_text(&self, data: String, keys: Option<Vec<String>>) {
756        self.rate_limiter.await_keys_ready(keys).await;
757
758        if !self.is_active() {
759            tracing::error!("Cannot send data - connection not active");
760            return;
761        }
762
763        tracing::trace!("Sending text: {data:?}");
764
765        let msg = Message::Text(data.into());
766        if let Err(e) = self.writer_tx.send(WriterCommand::Send(msg)) {
767            tracing::error!("Error sending message: {e}");
768        }
769    }
770
771    /// Sends the given bytes `data` to the server.
772    pub async fn send_bytes(&self, data: Vec<u8>, keys: Option<Vec<String>>) {
773        self.rate_limiter.await_keys_ready(keys).await;
774
775        if !self.is_active() {
776            tracing::error!("Cannot send data - connection not active");
777            return;
778        }
779
780        tracing::trace!("Sending bytes: {data:?}");
781
782        let msg = Message::Binary(data.into());
783        if let Err(e) = self.writer_tx.send(WriterCommand::Send(msg)) {
784            tracing::error!("Error sending message: {e}");
785        }
786    }
787
788    /// Sends a close message to the server.
789    pub async fn send_close_message(&self) {
790        if !self.is_active() {
791            tracing::error!("Cannot send close message - connection not active");
792            return;
793        }
794
795        let msg = Message::Close(None);
796        if let Err(e) = self.writer_tx.send(WriterCommand::Send(msg)) {
797            tracing::error!("Error sending close message: {e}");
798        }
799    }
800
801    fn spawn_controller_task(
802        mut inner: WebSocketClientInner,
803        connection_mode: Arc<AtomicU8>,
804        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
805        py_post_reconnection: Option<PyObject>, // TODO: Deprecated
806        py_post_disconnection: Option<PyObject>, // TODO: Deprecated
807    ) -> tokio::task::JoinHandle<()> {
808        tokio::task::spawn(async move {
809            tracing::debug!("Started task 'controller'");
810
811            let check_interval = Duration::from_millis(10);
812
813            loop {
814                tokio::time::sleep(check_interval).await;
815                let mode = ConnectionMode::from_atomic(&connection_mode);
816
817                if mode.is_disconnect() {
818                    tracing::debug!("Disconnecting");
819
820                    let timeout = Duration::from_secs(5);
821                    if tokio::time::timeout(timeout, async {
822                        // Delay awaiting graceful shutdown
823                        tokio::time::sleep(Duration::from_millis(100)).await;
824
825                        if let Some(task) = &inner.read_task {
826                            if !task.is_finished() {
827                                task.abort();
828                                tracing::debug!("Aborted task 'read'");
829                            }
830                        }
831
832                        if let Some(task) = &inner.heartbeat_task {
833                            if !task.is_finished() {
834                                task.abort();
835                                tracing::debug!("Aborted task 'heartbeat'");
836                            }
837                        }
838                    })
839                    .await
840                    .is_err()
841                    {
842                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
843                    }
844
845                    tracing::debug!("Closed");
846
847                    if let Some(ref handler) = py_post_disconnection {
848                        Python::with_gil(|py| match handler.call0(py) {
849                            Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
850                            Err(e) => {
851                                tracing::error!("Error calling `post_disconnection` handler: {e}");
852                            }
853                        });
854                    }
855                    break; // Controller finished
856                }
857
858                if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
859                    match inner.reconnect().await {
860                        Ok(()) => {
861                            tracing::debug!("Reconnected successfully");
862                            inner.backoff.reset();
863
864                            if let Some(ref callback) = post_reconnection {
865                                callback();
866                            }
867
868                            // TODO: Python based websocket handlers deprecated (will be removed)
869                            if let Some(ref handler) = py_post_reconnection {
870                                Python::with_gil(|py| match handler.call0(py) {
871                                    Ok(_) => {
872                                        tracing::debug!("Called `post_reconnection` handler");
873                                    }
874                                    Err(e) => tracing::error!(
875                                        "Error calling `post_reconnection` handler: {e}"
876                                    ),
877                                });
878                            }
879                        }
880                        Err(e) => {
881                            let duration = inner.backoff.next_duration();
882                            tracing::warn!("Reconnect attempt failed: {e}");
883                            if !duration.is_zero() {
884                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
885                            }
886                            tokio::time::sleep(duration).await;
887                        }
888                    }
889                }
890            }
891            inner
892                .connection_mode
893                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
894
895            tracing::debug!("Completed task 'controller'");
896        })
897    }
898}
899
900////////////////////////////////////////////////////////////////////////////////
901// Tests
902////////////////////////////////////////////////////////////////////////////////
903#[cfg(test)]
904#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
905mod tests {
906    use std::{num::NonZeroU32, sync::Arc};
907
908    use futures_util::{SinkExt, StreamExt};
909    use tokio::{
910        net::TcpListener,
911        task::{self, JoinHandle},
912    };
913    use tokio_tungstenite::{
914        accept_hdr_async,
915        tungstenite::{
916            handshake::server::{self, Callback},
917            http::HeaderValue,
918        },
919    };
920
921    use crate::{
922        ratelimiter::quota::Quota,
923        websocket::{Consumer, WebSocketClient, WebSocketConfig},
924    };
925
926    struct TestServer {
927        task: JoinHandle<()>,
928        port: u16,
929    }
930
931    #[derive(Debug, Clone)]
932    struct TestCallback {
933        key: String,
934        value: HeaderValue,
935    }
936
937    impl Callback for TestCallback {
938        fn on_request(
939            self,
940            request: &server::Request,
941            response: server::Response,
942        ) -> Result<server::Response, server::ErrorResponse> {
943            let _ = response;
944            let value = request.headers().get(&self.key);
945            assert!(value.is_some());
946
947            if let Some(value) = request.headers().get(&self.key) {
948                assert_eq!(value, self.value);
949            }
950
951            Ok(response)
952        }
953    }
954
955    impl TestServer {
956        async fn setup() -> Self {
957            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
958            let port = TcpListener::local_addr(&server).unwrap().port();
959
960            let header_key = "test".to_string();
961            let header_value = "test".to_string();
962
963            let test_call_back = TestCallback {
964                key: header_key,
965                value: HeaderValue::from_str(&header_value).unwrap(),
966            };
967
968            let task = task::spawn(async move {
969                // Keep accepting connections
970                loop {
971                    let (conn, _) = server.accept().await.unwrap();
972                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
973                        .await
974                        .unwrap();
975
976                    task::spawn(async move {
977                        while let Some(Ok(msg)) = websocket.next().await {
978                            match msg {
979                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
980                                    if txt == "close-now" =>
981                                {
982                                    tracing::debug!("Forcibly closing from server side");
983                                    // This sends a close frame, then stops reading
984                                    let _ = websocket.close(None).await;
985                                    break;
986                                }
987                                // Echo text/binary frames
988                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
989                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
990                                    if websocket.send(msg).await.is_err() {
991                                        break;
992                                    }
993                                }
994                                // If the client closes, we also break
995                                tokio_tungstenite::tungstenite::protocol::Message::Close(
996                                    _frame,
997                                ) => {
998                                    let _ = websocket.close(None).await;
999                                    break;
1000                                }
1001                                // Ignore pings/pongs
1002                                _ => {}
1003                            }
1004                        }
1005                    });
1006                }
1007            });
1008
1009            Self { task, port }
1010        }
1011    }
1012
1013    impl Drop for TestServer {
1014        fn drop(&mut self) {
1015            self.task.abort();
1016        }
1017    }
1018
1019    async fn setup_test_client(port: u16) -> WebSocketClient {
1020        let config = WebSocketConfig {
1021            url: format!("ws://127.0.0.1:{port}"),
1022            headers: vec![("test".into(), "test".into())],
1023            handler: Consumer::Python(None),
1024            heartbeat: None,
1025            heartbeat_msg: None,
1026            ping_handler: None,
1027            reconnect_timeout_ms: None,
1028            reconnect_delay_initial_ms: None,
1029            reconnect_backoff_factor: None,
1030            reconnect_delay_max_ms: None,
1031            reconnect_jitter_ms: None,
1032        };
1033        WebSocketClient::connect(config, None, None, None, vec![], None)
1034            .await
1035            .expect("Failed to connect")
1036    }
1037
1038    #[tokio::test]
1039    async fn test_websocket_basic() {
1040        let server = TestServer::setup().await;
1041        let client = setup_test_client(server.port).await;
1042
1043        assert!(!client.is_disconnected());
1044
1045        client.disconnect().await;
1046        assert!(client.is_disconnected());
1047    }
1048
1049    #[tokio::test]
1050    async fn test_websocket_heartbeat() {
1051        let server = TestServer::setup().await;
1052        let client = setup_test_client(server.port).await;
1053
1054        // Wait ~3s => server should see multiple "ping"
1055        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1056
1057        // Cleanup
1058        client.disconnect().await;
1059        assert!(client.is_disconnected());
1060    }
1061
1062    #[tokio::test]
1063    async fn test_websocket_reconnect_exhausted() {
1064        let config = WebSocketConfig {
1065            url: "ws://127.0.0.1:9997".into(), // <-- No server
1066            headers: vec![],
1067            handler: Consumer::Python(None),
1068            heartbeat: None,
1069            heartbeat_msg: None,
1070            ping_handler: None,
1071            reconnect_timeout_ms: None,
1072            reconnect_delay_initial_ms: None,
1073            reconnect_backoff_factor: None,
1074            reconnect_delay_max_ms: None,
1075            reconnect_jitter_ms: None,
1076        };
1077        let res = WebSocketClient::connect(config, None, None, None, vec![], None).await;
1078        assert!(res.is_err(), "Should fail quickly with no server");
1079    }
1080
1081    #[tokio::test]
1082    async fn test_websocket_forced_close_reconnect() {
1083        let server = TestServer::setup().await;
1084        let client = setup_test_client(server.port).await;
1085
1086        // 1) Send normal message
1087        client.send_text("Hello".into(), None).await;
1088
1089        // 2) Trigger forced close from server
1090        client.send_text("close-now".into(), None).await;
1091
1092        // 3) Wait a bit => read loop sees close => reconnect
1093        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1094
1095        // Confirm not disconnected
1096        assert!(!client.is_disconnected());
1097
1098        // Cleanup
1099        client.disconnect().await;
1100        assert!(client.is_disconnected());
1101    }
1102
1103    #[tokio::test]
1104    async fn test_rate_limiter() {
1105        let server = TestServer::setup().await;
1106        let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1107
1108        let config = WebSocketConfig {
1109            url: format!("ws://127.0.0.1:{}", server.port),
1110            headers: vec![("test".into(), "test".into())],
1111            handler: Consumer::Python(None),
1112            heartbeat: None,
1113            heartbeat_msg: None,
1114            ping_handler: None,
1115            reconnect_timeout_ms: None,
1116            reconnect_delay_initial_ms: None,
1117            reconnect_backoff_factor: None,
1118            reconnect_delay_max_ms: None,
1119            reconnect_jitter_ms: None,
1120        };
1121
1122        let client = WebSocketClient::connect(
1123            config,
1124            None,
1125            None,
1126            None,
1127            vec![("default".into(), quota)],
1128            None,
1129        )
1130        .await
1131        .unwrap();
1132
1133        // First 2 should succeed
1134        client.send_text("test1".into(), None).await;
1135        client.send_text("test2".into(), None).await;
1136
1137        // Third should error
1138        client.send_text("test3".into(), None).await;
1139
1140        // Cleanup
1141        client.disconnect().await;
1142        assert!(client.is_disconnected());
1143    }
1144
1145    #[tokio::test]
1146    async fn test_concurrent_writers() {
1147        let server = TestServer::setup().await;
1148        let client = Arc::new(setup_test_client(server.port).await);
1149
1150        let mut handles = vec![];
1151        for i in 0..10 {
1152            let client = client.clone();
1153            handles.push(task::spawn(async move {
1154                client.send_text(format!("test{i}"), None).await;
1155            }));
1156        }
1157
1158        for handle in handles {
1159            handle.await.unwrap();
1160        }
1161
1162        // Cleanup
1163        client.disconnect().await;
1164        assert!(client.is_disconnected());
1165    }
1166}