nautilus_network/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance WebSocket client implementation with automatic reconnection
17//! with exponential backoff and state management.
18
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED)
21//! - Synchronized reconnection with backoff
22//! - Clean shutdown sequence
23//! - Split read/write architecture
24//! - Python callback integration
25//!
26//! **Design**:
27//! - Single reader, multiple writer model
28//! - Read half runs in dedicated task
29//! - Write half protected by `Arc<Mutex>`
30//! - Controller task manages lifecycle
31
32use std::{
33    sync::{
34        Arc,
35        atomic::{AtomicU8, Ordering},
36    },
37    time::Duration,
38};
39
40use futures_util::{
41    SinkExt, StreamExt,
42    stream::{SplitSink, SplitStream},
43};
44use http::HeaderName;
45use nautilus_cryptography::providers::install_cryptographic_provider;
46use pyo3::{prelude::*, types::PyBytes};
47use tokio::{net::TcpStream, sync::Mutex};
48use tokio_tungstenite::{
49    MaybeTlsStream, WebSocketStream, connect_async,
50    tungstenite::{Error, Message, client::IntoClientRequest, http::HeaderValue},
51};
52
53use crate::{
54    backoff::ExponentialBackoff,
55    mode::ConnectionMode,
56    ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
57};
58type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
59type SharedMessageWriter =
60    Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>;
61pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
62
63#[derive(Debug, Clone)]
64#[cfg_attr(
65    feature = "python",
66    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
67)]
68pub struct WebSocketConfig {
69    /// The URL to connect to.
70    pub url: String,
71    /// The default headers.
72    pub headers: Vec<(String, String)>,
73    /// The Python function to handle incoming messages.
74    pub handler: Option<Arc<PyObject>>,
75    /// The optional heartbeat interval (seconds).
76    pub heartbeat: Option<u64>,
77    /// The optional heartbeat message.
78    pub heartbeat_msg: Option<String>,
79    /// The handler for incoming pings.
80    pub ping_handler: Option<Arc<PyObject>>,
81    /// The timeout (milliseconds) for reconnection attempts.
82    pub reconnect_timeout_ms: Option<u64>,
83    /// The initial reconnection delay (milliseconds) for reconnects.
84    pub reconnect_delay_initial_ms: Option<u64>,
85    /// The maximum reconnect delay (milliseconds) for exponential backoff.
86    pub reconnect_delay_max_ms: Option<u64>,
87    /// The exponential backoff factor for reconnection delays.
88    pub reconnect_backoff_factor: Option<f64>,
89    /// The maximum jitter (milliseconds) added to reconnection delays.
90    pub reconnect_jitter_ms: Option<u64>,
91}
92
93/// `WebSocketClient` connects to a websocket server to read and send messages.
94///
95/// The client is opinionated about how messages are read and written. It
96/// assumes that data can only have one reader but multiple writers.
97///
98/// The client splits the connection into read and write halves. It moves
99/// the read half into a tokio task which keeps receiving messages from the
100/// server and calls a handler - a Python function that takes the data
101/// as its parameter. It stores the write half in the struct wrapped
102/// with an Arc Mutex. This way the client struct can be used to write
103/// data to the server from multiple scopes/tasks.
104///
105/// The client also maintains a heartbeat if given a duration in seconds.
106/// It's preferable to set the duration slightly lower - heartbeat more
107/// frequently - than the required amount.
108struct WebSocketClientInner {
109    config: WebSocketConfig,
110    read_task: Option<tokio::task::JoinHandle<()>>,
111    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
112    writer: SharedMessageWriter,
113    connection_mode: Arc<AtomicU8>,
114    reconnect_timeout: Duration,
115    backoff: ExponentialBackoff,
116}
117
118impl WebSocketClientInner {
119    /// Create an inner websocket client.
120    pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
121        install_cryptographic_provider();
122
123        #[allow(unused_variables)]
124        let WebSocketConfig {
125            url,
126            handler,
127            heartbeat,
128            headers,
129            heartbeat_msg,
130            ping_handler,
131            reconnect_timeout_ms,
132            reconnect_delay_initial_ms,
133            reconnect_delay_max_ms,
134            reconnect_backoff_factor,
135            reconnect_jitter_ms,
136        } = &config;
137        let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
138        let writer = Arc::new(Mutex::new(writer));
139
140        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
141
142        // Only spawn read task if handler is provided
143        let read_task = handler
144            .as_ref()
145            .map(|handler| Self::spawn_read_task(reader, handler.clone(), ping_handler.clone()));
146
147        // Optionally spawn a heartbeat task to periodically ping server
148        let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
149            Self::spawn_heartbeat_task(
150                connection_mode.clone(),
151                *heartbeat_secs,
152                heartbeat_msg.clone(),
153                writer.clone(),
154            )
155        });
156
157        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
158        let backoff = ExponentialBackoff::new(
159            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
160            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
161            reconnect_backoff_factor.unwrap_or(1.5),
162            reconnect_jitter_ms.unwrap_or(100),
163            true, // immediate-first
164        );
165
166        Ok(Self {
167            config,
168            read_task,
169            heartbeat_task,
170            writer,
171            connection_mode,
172            reconnect_timeout,
173            backoff,
174        })
175    }
176
177    /// Connects with the server creating a tokio-tungstenite websocket stream.
178    #[inline]
179    pub async fn connect_with_server(
180        url: &str,
181        headers: Vec<(String, String)>,
182    ) -> Result<(MessageWriter, MessageReader), Error> {
183        let mut request = url.into_client_request()?;
184        let req_headers = request.headers_mut();
185
186        let mut header_names: Vec<HeaderName> = Vec::new();
187        for (key, val) in headers {
188            let header_value = HeaderValue::from_str(&val)?;
189            let header_name: HeaderName = key.parse()?;
190            header_names.push(header_name.clone());
191            req_headers.insert(header_name, header_value);
192        }
193
194        connect_async(request).await.map(|resp| resp.0.split())
195    }
196
197    /// Reconnect with server.
198    ///
199    /// Make a new connection with server. Use the new read and write halves
200    /// to update self writer and read and heartbeat tasks.
201    pub async fn reconnect(&mut self) -> Result<(), Error> {
202        tracing::debug!("Reconnecting");
203
204        tokio::time::timeout(self.reconnect_timeout, async {
205            shutdown(
206                self.read_task.take(),
207                self.heartbeat_task.take(),
208                self.writer.clone(),
209            )
210            .await;
211
212            let (new_writer, reader) =
213                Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
214
215            {
216                let mut guard = self.writer.lock().await;
217                *guard = new_writer;
218                drop(guard);
219            }
220
221            // Spawn new read task
222            if let Some(ref handler) = self.config.handler {
223                self.read_task = Some(Self::spawn_read_task(
224                    reader,
225                    handler.clone(),
226                    self.config.ping_handler.clone(),
227                ));
228            }
229
230            // Optionally spawn new heartbeat task
231            self.heartbeat_task = self.config.heartbeat.as_ref().map(|heartbeat_secs| {
232                Self::spawn_heartbeat_task(
233                    self.connection_mode.clone(),
234                    *heartbeat_secs,
235                    self.config.heartbeat_msg.clone(),
236                    self.writer.clone(),
237                )
238            });
239
240            self.connection_mode
241                .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
242
243            tracing::debug!("Reconnect succeeded");
244            Ok(())
245        })
246        .await
247        .map_err(|_| {
248            Error::Io(std::io::Error::new(
249                std::io::ErrorKind::TimedOut,
250                format!(
251                    "reconnection timed out after {}s",
252                    self.reconnect_timeout.as_secs_f64()
253                ),
254            ))
255        })?
256    }
257
258    /// Check if the client is still connected.
259    ///
260    /// The client is connected if the read task has not finished. It is expected
261    /// that in case of any failure client or server side. The read task will be
262    /// shutdown or will receive a `Close` frame which will finish it. There
263    /// might be some delay between the connection being closed and the client
264    /// detecting.
265    #[inline]
266    #[must_use]
267    pub fn is_alive(&self) -> bool {
268        match &self.read_task {
269            Some(read_task) => !read_task.is_finished(),
270            None => true, // Stream is being used directly
271        }
272    }
273
274    fn spawn_read_task(
275        mut reader: MessageReader,
276        handler: Arc<PyObject>,
277        ping_handler: Option<Arc<PyObject>>,
278    ) -> tokio::task::JoinHandle<()> {
279        tracing::debug!("Started task 'read'");
280
281        tokio::task::spawn(async move {
282            loop {
283                match reader.next().await {
284                    Some(Ok(Message::Binary(data))) => {
285                        tracing::trace!("Received message <binary> {} bytes", data.len());
286                        if let Err(e) =
287                            Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &data),)))
288                        {
289                            tracing::error!("Error calling handler: {e}");
290                            break;
291                        }
292                        continue;
293                    }
294                    Some(Ok(Message::Text(data))) => {
295                        tracing::trace!("Received message: {data}");
296                        if let Err(e) = Python::with_gil(|py| {
297                            handler.call1(py, (PyBytes::new(py, data.as_bytes()),))
298                        }) {
299                            tracing::error!("Error calling handler: {e}");
300                            break;
301                        }
302                        continue;
303                    }
304                    Some(Ok(Message::Ping(ping))) => {
305                        tracing::trace!("Received ping: {ping:?}",);
306                        if let Some(ref handler) = ping_handler {
307                            if let Err(e) =
308                                Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &ping),)))
309                            {
310                                tracing::error!("Error calling handler: {e}");
311                                break;
312                            }
313                        }
314                        continue;
315                    }
316                    Some(Ok(Message::Pong(_))) => {
317                        tracing::trace!("Received pong");
318                    }
319                    Some(Ok(Message::Close(_))) => {
320                        tracing::debug!("Received close message - terminating");
321                        break;
322                    }
323                    Some(Ok(_)) => (),
324                    Some(Err(e)) => {
325                        tracing::error!("Received error message - terminating: {e}");
326                        break;
327                    }
328                    // Internally tungstenite considers the connection closed when polling
329                    // for the next message in the stream returns None.
330                    None => {
331                        tracing::debug!("No message received - terminating");
332                        break;
333                    }
334                }
335            }
336        })
337    }
338
339    fn spawn_heartbeat_task(
340        connection_state: Arc<AtomicU8>,
341        heartbeat_secs: u64,
342        message: Option<String>,
343        writer: SharedMessageWriter,
344    ) -> tokio::task::JoinHandle<()> {
345        tracing::debug!("Started task 'heartbeat'");
346
347        tokio::task::spawn(async move {
348            let interval = Duration::from_secs(heartbeat_secs);
349            loop {
350                tokio::time::sleep(interval).await;
351
352                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
353                    ConnectionMode::Active => {
354                        let mut guard = writer.lock().await;
355                        let guard_send_response = match message.clone() {
356                            Some(msg) => guard.send(Message::Text(msg.into())).await,
357                            None => guard.send(Message::Ping(vec![].into())).await,
358                        };
359
360                        match guard_send_response {
361                            Ok(()) => tracing::trace!("Sent ping"),
362                            Err(e) => tracing::error!("Error sending ping: {e}"),
363                        }
364                    }
365                    ConnectionMode::Reconnect => continue,
366                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
367                }
368            }
369
370            tracing::debug!("Completed task 'heartbeat'");
371        })
372    }
373}
374
375/// Shutdown websocket connection.
376///
377/// Performs orderly WebSocket shutdown:
378/// 1. Sends close frame to server
379/// 2. Waits briefly for frame delivery
380/// 3. Aborts read/heartbeat tasks
381/// 4. Closes underlying connection
382///
383/// This sequence ensures proper protocol compliance and clean resource cleanup.
384async fn shutdown(
385    read_task: Option<tokio::task::JoinHandle<()>>,
386    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
387    writer: SharedMessageWriter,
388) {
389    tracing::debug!("Closing");
390
391    let timeout = Duration::from_secs(5);
392    if tokio::time::timeout(timeout, async {
393        // Send close frame first
394        let mut write_half = writer.lock().await;
395        if let Err(e) = write_half.send(Message::Close(None)).await {
396            // Close frame errors during shutdown are expected
397            tracing::debug!("Error sending close frame: {e}");
398        }
399        drop(write_half);
400
401        tokio::time::sleep(Duration::from_millis(100)).await;
402
403        // Abort tasks
404        if let Some(task) = read_task {
405            if !task.is_finished() {
406                task.abort();
407                tracing::debug!("Aborted read task");
408            }
409        }
410        if let Some(task) = heartbeat_task {
411            if !task.is_finished() {
412                task.abort();
413                tracing::debug!("Aborted heartbeat task");
414            }
415        }
416
417        // Final close of writer
418        let mut write_half = writer.lock().await;
419        if let Err(e) = write_half.close().await {
420            tracing::error!("Error closing writer: {e}");
421        }
422    })
423    .await
424    .is_err()
425    {
426        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
427    }
428
429    tracing::debug!("Closed");
430}
431
432impl Drop for WebSocketClientInner {
433    fn drop(&mut self) {
434        if let Some(ref read_task) = self.read_task.take() {
435            if !read_task.is_finished() {
436                read_task.abort();
437            }
438        }
439
440        // Cancel heart beat task
441        if let Some(ref handle) = self.heartbeat_task.take() {
442            if !handle.is_finished() {
443                handle.abort();
444            }
445        }
446    }
447}
448
449/// WebSocket client with automatic reconnection.
450///
451/// Handles connection state, Python callbacks, and rate limiting.
452/// See module docs for architecture details.
453#[cfg_attr(
454    feature = "python",
455    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
456)]
457pub struct WebSocketClient {
458    pub(crate) writer: SharedMessageWriter,
459    pub(crate) controller_task: tokio::task::JoinHandle<()>,
460    pub(crate) connection_mode: Arc<AtomicU8>,
461    pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
462}
463
464impl WebSocketClient {
465    /// Creates a websocket client that returns a stream for reading messages.
466    ///
467    /// # Errors
468    ///
469    /// Returns any error connecting to the server.
470    #[allow(clippy::too_many_arguments)]
471    pub async fn connect_stream(
472        config: WebSocketConfig,
473        keyed_quotas: Vec<(String, Quota)>,
474        default_quota: Option<Quota>,
475    ) -> Result<(MessageReader, Self), Error> {
476        let (ws_stream, _) = connect_async(config.url.clone().into_client_request()?).await?;
477        let (writer, reader) = ws_stream.split();
478        let writer = Arc::new(Mutex::new(writer));
479
480        let inner = WebSocketClientInner::connect_url(config).await?;
481        let connection_mode = inner.connection_mode.clone();
482        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
483
484        let controller_task = Self::spawn_controller_task(
485            inner,
486            connection_mode.clone(),
487            None, // no post_reconnection
488            None, // no post_disconnection
489        );
490
491        Ok((
492            reader,
493            Self {
494                writer: writer.clone(),
495                controller_task,
496                connection_mode,
497                rate_limiter,
498            },
499        ))
500    }
501
502    /// Creates a websocket client.
503    ///
504    /// Creates an inner client and controller task to reconnect or disconnect
505    /// the client. Also assumes ownership of writer from inner client.
506    ///
507    /// # Errors
508    ///
509    /// Returns any websocket error.
510    pub async fn connect(
511        config: WebSocketConfig,
512        post_connection: Option<PyObject>,
513        post_reconnection: Option<PyObject>,
514        post_disconnection: Option<PyObject>,
515        keyed_quotas: Vec<(String, Quota)>,
516        default_quota: Option<Quota>,
517    ) -> Result<Self, Error> {
518        tracing::debug!("Connecting");
519        let inner = WebSocketClientInner::connect_url(config.clone()).await?;
520        let writer = inner.writer.clone();
521        let connection_mode = inner.connection_mode.clone();
522
523        let controller_task = Self::spawn_controller_task(
524            inner,
525            connection_mode.clone(),
526            post_reconnection,
527            post_disconnection,
528        );
529        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
530
531        if let Some(handler) = post_connection {
532            Python::with_gil(|py| match handler.call0(py) {
533                Ok(_) => tracing::debug!("Called `post_connection` handler"),
534                Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
535            });
536        };
537
538        Ok(Self {
539            writer,
540            controller_task,
541            connection_mode,
542            rate_limiter,
543        })
544    }
545
546    /// Returns the current connection mode.
547    #[must_use]
548    pub fn connection_mode(&self) -> ConnectionMode {
549        ConnectionMode::from_atomic(&self.connection_mode)
550    }
551
552    /// Check if the client connection is active.
553    ///
554    /// Returns `true` if the client is connected and has not been signalled to disconnect.
555    /// The client will automatically retry connection based on its configuration.
556    #[inline]
557    #[must_use]
558    pub fn is_active(&self) -> bool {
559        self.connection_mode().is_active()
560    }
561
562    /// Check if the client is disconnected.
563    #[must_use]
564    pub fn is_disconnected(&self) -> bool {
565        self.controller_task.is_finished()
566    }
567
568    /// Check if the client is reconnecting.
569    ///
570    /// Returns `true` if the client lost connection and is attempting to reestablish it.
571    /// The client will automatically retry connection based on its configuration.
572    #[inline]
573    #[must_use]
574    pub fn is_reconnecting(&self) -> bool {
575        self.connection_mode().is_reconnect()
576    }
577
578    /// Check if the client is disconnecting.
579    ///
580    /// Returns `true` if the client is in disconnect mode.
581    #[inline]
582    #[must_use]
583    pub fn is_disconnecting(&self) -> bool {
584        self.connection_mode().is_disconnect()
585    }
586
587    /// Check if the client is closed.
588    ///
589    /// Returns `true` if the client has been explicitly disconnected or reached
590    /// maximum reconnection attempts. In this state, the client cannot be reused
591    /// and a new client must be created for further connections.
592    #[inline]
593    #[must_use]
594    pub fn is_closed(&self) -> bool {
595        self.connection_mode().is_closed()
596    }
597
598    /// Set disconnect mode to true.
599    ///
600    /// Controller task will periodically check the disconnect mode
601    /// and shutdown the client if it is alive
602    pub async fn disconnect(&self) {
603        tracing::debug!("Disconnecting");
604        self.connection_mode
605            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
606
607        match tokio::time::timeout(Duration::from_secs(5), async {
608            while !self.is_disconnected() {
609                tokio::time::sleep(Duration::from_millis(10)).await;
610            }
611
612            if !self.controller_task.is_finished() {
613                self.controller_task.abort();
614                tracing::debug!("Aborted controller task");
615            }
616        })
617        .await
618        {
619            Ok(()) => {
620                tracing::debug!("Controller task finished");
621            }
622            Err(_) => {
623                tracing::error!("Timeout waiting for controller task to finish");
624            }
625        }
626    }
627
628    /// Sends the given text `data` to the server.
629    pub async fn send_text(&self, data: String, keys: Option<Vec<String>>) {
630        self.rate_limiter.await_keys_ready(keys).await;
631        tracing::trace!("Sending text: {data:?}");
632        let mut guard = self.writer.lock().await;
633        if let Err(e) = guard.send(Message::Text(data.into())).await {
634            tracing::error!("Error sending message: {e}");
635        }
636    }
637
638    /// Sends the given bytes `data` to the server.
639    pub async fn send_bytes(&self, data: Vec<u8>, keys: Option<Vec<String>>) {
640        self.rate_limiter.await_keys_ready(keys).await;
641        tracing::trace!("Sending bytes: {data:?}");
642        let mut guard = self.writer.lock().await;
643        if let Err(e) = guard.send(Message::Binary(data.into())).await {
644            tracing::error!("Error sending message: {e}");
645        }
646    }
647
648    /// Sends a close message to the server.
649    pub async fn send_close_message(&self) {
650        let mut guard = self.writer.lock().await;
651        match guard.send(Message::Close(None)).await {
652            Ok(()) => tracing::debug!("Sent close message"),
653            Err(e) => tracing::error!("Error sending close message: {e}"),
654        }
655    }
656
657    fn spawn_controller_task(
658        mut inner: WebSocketClientInner,
659        connection_mode: Arc<AtomicU8>,
660        post_reconnection: Option<PyObject>,
661        post_disconnection: Option<PyObject>,
662    ) -> tokio::task::JoinHandle<()> {
663        tokio::task::spawn(async move {
664            tracing::debug!("Started task 'controller'");
665
666            let check_interval = Duration::from_millis(10);
667
668            loop {
669                tokio::time::sleep(check_interval).await;
670                let mode = ConnectionMode::from_atomic(&connection_mode);
671
672                if mode.is_disconnect() {
673                    tracing::debug!("Disconnecting");
674                    shutdown(
675                        inner.read_task.take(),
676                        inner.heartbeat_task.take(),
677                        inner.writer.clone(),
678                    )
679                    .await;
680
681                    if let Some(ref handler) = post_disconnection {
682                        Python::with_gil(|py| match handler.call0(py) {
683                            Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
684                            Err(e) => {
685                                tracing::error!("Error calling `post_disconnection` handler: {e}");
686                            }
687                        });
688                    }
689                    break; // Controller finished
690                }
691
692                if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
693                    match inner.reconnect().await {
694                        Ok(()) => {
695                            tracing::debug!("Reconnected successfully");
696                            inner.backoff.reset();
697
698                            if let Some(ref handler) = post_reconnection {
699                                Python::with_gil(|py| match handler.call0(py) {
700                                    Ok(_) => tracing::debug!("Called `post_reconnection` handler"),
701                                    Err(e) => tracing::error!(
702                                        "Error calling `post_reconnection` handler: {e}"
703                                    ),
704                                });
705                            }
706                        }
707                        Err(e) => {
708                            let duration = inner.backoff.next_duration();
709                            tracing::warn!("Reconnect attempt failed: {e}");
710                            if !duration.is_zero() {
711                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
712                            }
713                            tokio::time::sleep(duration).await;
714                        }
715                    }
716                }
717            }
718            inner
719                .connection_mode
720                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
721        })
722    }
723}
724
725////////////////////////////////////////////////////////////////////////////////
726// Tests
727////////////////////////////////////////////////////////////////////////////////
728#[cfg(test)]
729#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
730mod tests {
731    use std::{num::NonZeroU32, sync::Arc};
732
733    use futures_util::{SinkExt, StreamExt};
734    use tokio::{
735        net::TcpListener,
736        task::{self, JoinHandle},
737    };
738    use tokio_tungstenite::{
739        accept_hdr_async,
740        tungstenite::{
741            handshake::server::{self, Callback},
742            http::HeaderValue,
743        },
744    };
745
746    use crate::{
747        ratelimiter::quota::Quota,
748        websocket::{WebSocketClient, WebSocketConfig},
749    };
750
751    struct TestServer {
752        task: JoinHandle<()>,
753        port: u16,
754    }
755
756    #[derive(Debug, Clone)]
757    struct TestCallback {
758        key: String,
759        value: HeaderValue,
760    }
761
762    impl Callback for TestCallback {
763        fn on_request(
764            self,
765            request: &server::Request,
766            response: server::Response,
767        ) -> Result<server::Response, server::ErrorResponse> {
768            let _ = response;
769            let value = request.headers().get(&self.key);
770            assert!(value.is_some());
771
772            if let Some(value) = request.headers().get(&self.key) {
773                assert_eq!(value, self.value);
774            }
775
776            Ok(response)
777        }
778    }
779
780    impl TestServer {
781        async fn setup() -> Self {
782            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
783            let port = TcpListener::local_addr(&server).unwrap().port();
784
785            let header_key = "test".to_string();
786            let header_value = "test".to_string();
787
788            let test_call_back = TestCallback {
789                key: header_key,
790                value: HeaderValue::from_str(&header_value).unwrap(),
791            };
792
793            let task = task::spawn(async move {
794                // Keep accepting connections
795                loop {
796                    let (conn, _) = server.accept().await.unwrap();
797                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
798                        .await
799                        .unwrap();
800
801                    task::spawn(async move {
802                        while let Some(Ok(msg)) = websocket.next().await {
803                            match msg {
804                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
805                                    if txt == "close-now" =>
806                                {
807                                    tracing::debug!("Forcibly closing from server side");
808                                    // This sends a close frame, then stops reading
809                                    let _ = websocket.close(None).await;
810                                    break;
811                                }
812                                // Echo text/binary frames
813                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
814                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
815                                    if websocket.send(msg).await.is_err() {
816                                        break;
817                                    }
818                                }
819                                // If the client closes, we also break
820                                tokio_tungstenite::tungstenite::protocol::Message::Close(
821                                    _frame,
822                                ) => {
823                                    let _ = websocket.close(None).await;
824                                    break;
825                                }
826                                // Ignore pings/pongs
827                                _ => {}
828                            }
829                        }
830                    });
831                }
832            });
833
834            Self { task, port }
835        }
836    }
837
838    impl Drop for TestServer {
839        fn drop(&mut self) {
840            self.task.abort();
841        }
842    }
843
844    async fn setup_test_client(port: u16) -> WebSocketClient {
845        let config = WebSocketConfig {
846            url: format!("ws://127.0.0.1:{port}"),
847            headers: vec![("test".into(), "test".into())],
848            handler: None,
849            heartbeat: None,
850            heartbeat_msg: None,
851            ping_handler: None,
852            reconnect_timeout_ms: None,
853            reconnect_delay_initial_ms: None,
854            reconnect_backoff_factor: None,
855            reconnect_delay_max_ms: None,
856            reconnect_jitter_ms: None,
857        };
858        WebSocketClient::connect(config, None, None, None, vec![], None)
859            .await
860            .expect("Failed to connect")
861    }
862
863    #[tokio::test]
864    async fn test_websocket_basic() {
865        let server = TestServer::setup().await;
866        let client = setup_test_client(server.port).await;
867
868        assert!(!client.is_disconnected());
869
870        client.disconnect().await;
871        assert!(client.is_disconnected());
872    }
873
874    #[tokio::test]
875    async fn test_websocket_heartbeat() {
876        let server = TestServer::setup().await;
877        let client = setup_test_client(server.port).await;
878
879        // Wait ~3s => server should see multiple "ping"
880        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
881
882        // Cleanup
883        client.disconnect().await;
884        assert!(client.is_disconnected());
885    }
886
887    #[tokio::test]
888    async fn test_websocket_reconnect_exhausted() {
889        let config = WebSocketConfig {
890            url: "ws://127.0.0.1:9997".into(), // <-- No server
891            headers: vec![],
892            handler: None,
893            heartbeat: None,
894            heartbeat_msg: None,
895            ping_handler: None,
896            reconnect_timeout_ms: None,
897            reconnect_delay_initial_ms: None,
898            reconnect_backoff_factor: None,
899            reconnect_delay_max_ms: None,
900            reconnect_jitter_ms: None,
901        };
902        let res = WebSocketClient::connect(config, None, None, None, vec![], None).await;
903        assert!(res.is_err(), "Should fail quickly with no server");
904    }
905
906    #[tokio::test]
907    async fn test_websocket_forced_close_reconnect() {
908        let server = TestServer::setup().await;
909        let client = setup_test_client(server.port).await;
910
911        // 1) Send normal message
912        client.send_text("Hello".into(), None).await;
913
914        // 2) Trigger forced close from server
915        client.send_text("close-now".into(), None).await;
916
917        // 3) Wait a bit => read loop sees close => reconnect
918        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
919
920        // Confirm not disconnected
921        assert!(!client.is_disconnected());
922
923        // Cleanup
924        client.disconnect().await;
925        assert!(client.is_disconnected());
926    }
927
928    #[tokio::test]
929    async fn test_rate_limiter() {
930        let server = TestServer::setup().await;
931        let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
932
933        let config = WebSocketConfig {
934            url: format!("ws://127.0.0.1:{}", server.port),
935            headers: vec![("test".into(), "test".into())],
936            handler: None,
937            heartbeat: None,
938            heartbeat_msg: None,
939            ping_handler: None,
940            reconnect_timeout_ms: None,
941            reconnect_delay_initial_ms: None,
942            reconnect_backoff_factor: None,
943            reconnect_delay_max_ms: None,
944            reconnect_jitter_ms: None,
945        };
946
947        let client = WebSocketClient::connect(
948            config,
949            None,
950            None,
951            None,
952            vec![("default".into(), quota)],
953            None,
954        )
955        .await
956        .unwrap();
957
958        // First 2 should succeed
959        client.send_text("test1".into(), None).await;
960        client.send_text("test2".into(), None).await;
961
962        // Third should error
963        client.send_text("test3".into(), None).await;
964
965        // Cleanup
966        client.disconnect().await;
967        assert!(client.is_disconnected());
968    }
969
970    #[tokio::test]
971    async fn test_concurrent_writers() {
972        let server = TestServer::setup().await;
973        let client = Arc::new(setup_test_client(server.port).await);
974
975        let mut handles = vec![];
976        for i in 0..10 {
977            let client = client.clone();
978            handles.push(task::spawn(async move {
979                client.send_text(format!("test{i}"), None).await;
980            }));
981        }
982
983        for handle in handles {
984            handle.await.unwrap();
985        }
986
987        // Cleanup
988        client.disconnect().await;
989        assert!(client.is_disconnected());
990    }
991}