nautilus_network/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance WebSocket client implementation with automatic reconnection
17//! with exponential backoff and state management.
18
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED)
21//! - Synchronized reconnection with backoff
22//! - Clean shutdown sequence
23//! - Split read/write architecture
24//! - Python callback integration
25//!
26//! **Design**:
27//! - Single reader, multiple writer model
28//! - Read half runs in dedicated task
29//! - Write half protected by `Arc<Mutex>`
30//! - Controller task manages lifecycle
31
32use std::{
33    sync::{
34        atomic::{AtomicU8, Ordering},
35        Arc,
36    },
37    time::Duration,
38};
39
40use futures_util::{
41    stream::{SplitSink, SplitStream},
42    SinkExt, StreamExt,
43};
44use http::HeaderName;
45use nautilus_cryptography::providers::install_cryptographic_provider;
46use pyo3::{prelude::*, types::PyBytes};
47use tokio::{net::TcpStream, sync::Mutex};
48use tokio_tungstenite::{
49    connect_async,
50    tungstenite::{client::IntoClientRequest, http::HeaderValue, Error, Message},
51    MaybeTlsStream, WebSocketStream,
52};
53
54use crate::{
55    backoff::ExponentialBackoff,
56    mode::ConnectionMode,
57    ratelimiter::{clock::MonotonicClock, quota::Quota, RateLimiter},
58};
59type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
60type SharedMessageWriter =
61    Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>;
62pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
63
64#[derive(Debug, Clone)]
65#[cfg_attr(
66    feature = "python",
67    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
68)]
69pub struct WebSocketConfig {
70    /// The URL to connect to.
71    pub url: String,
72    /// The default headers.
73    pub headers: Vec<(String, String)>,
74    /// The Python function to handle incoming messages.
75    pub handler: Option<Arc<PyObject>>,
76    /// The optional heartbeat interval (seconds).
77    pub heartbeat: Option<u64>,
78    /// The optional heartbeat message.
79    pub heartbeat_msg: Option<String>,
80    /// The handler for incoming pings.
81    pub ping_handler: Option<Arc<PyObject>>,
82    /// The timeout (milliseconds) for reconnection attempts.
83    pub reconnect_timeout_ms: Option<u64>,
84    /// The initial reconnection delay (milliseconds) for reconnects.
85    pub reconnect_delay_initial_ms: Option<u64>,
86    /// The maximum reconnect delay (milliseconds) for exponential backoff.
87    pub reconnect_delay_max_ms: Option<u64>,
88    /// The exponential backoff factor for reconnection delays.
89    pub reconnect_backoff_factor: Option<f64>,
90    /// The maximum jitter (milliseconds) added to reconnection delays.
91    pub reconnect_jitter_ms: Option<u64>,
92}
93
94/// `WebSocketClient` connects to a websocket server to read and send messages.
95///
96/// The client is opinionated about how messages are read and written. It
97/// assumes that data can only have one reader but multiple writers.
98///
99/// The client splits the connection into read and write halves. It moves
100/// the read half into a tokio task which keeps receiving messages from the
101/// server and calls a handler - a Python function that takes the data
102/// as its parameter. It stores the write half in the struct wrapped
103/// with an Arc Mutex. This way the client struct can be used to write
104/// data to the server from multiple scopes/tasks.
105///
106/// The client also maintains a heartbeat if given a duration in seconds.
107/// It's preferable to set the duration slightly lower - heartbeat more
108/// frequently - than the required amount.
109struct WebSocketClientInner {
110    config: WebSocketConfig,
111    read_task: Option<tokio::task::JoinHandle<()>>,
112    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
113    writer: SharedMessageWriter,
114    connection_mode: Arc<AtomicU8>,
115    reconnect_timeout: Duration,
116    backoff: ExponentialBackoff,
117}
118
119impl WebSocketClientInner {
120    /// Create an inner websocket client.
121    pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
122        install_cryptographic_provider();
123
124        #[allow(unused_variables)]
125        let WebSocketConfig {
126            url,
127            handler,
128            heartbeat,
129            headers,
130            heartbeat_msg,
131            ping_handler,
132            reconnect_timeout_ms,
133            reconnect_delay_initial_ms,
134            reconnect_delay_max_ms,
135            reconnect_backoff_factor,
136            reconnect_jitter_ms,
137        } = &config;
138        let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
139        let writer = Arc::new(Mutex::new(writer));
140
141        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
142
143        // Only spawn read task if handler is provided
144        let read_task = handler
145            .as_ref()
146            .map(|handler| Self::spawn_read_task(reader, handler.clone(), ping_handler.clone()));
147
148        // Optionally spawn a heartbeat task to periodically ping server
149        let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
150            Self::spawn_heartbeat_task(
151                connection_mode.clone(),
152                *heartbeat_secs,
153                heartbeat_msg.clone(),
154                writer.clone(),
155            )
156        });
157
158        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
159        let backoff = ExponentialBackoff::new(
160            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
161            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
162            reconnect_backoff_factor.unwrap_or(1.5),
163            reconnect_jitter_ms.unwrap_or(100),
164            true, // immediate-first
165        );
166
167        Ok(Self {
168            config,
169            read_task,
170            heartbeat_task,
171            writer,
172            connection_mode,
173            reconnect_timeout,
174            backoff,
175        })
176    }
177
178    /// Connects with the server creating a tokio-tungstenite websocket stream.
179    #[inline]
180    pub async fn connect_with_server(
181        url: &str,
182        headers: Vec<(String, String)>,
183    ) -> Result<(MessageWriter, MessageReader), Error> {
184        let mut request = url.into_client_request()?;
185        let req_headers = request.headers_mut();
186
187        let mut header_names: Vec<HeaderName> = Vec::new();
188        for (key, val) in headers {
189            let header_value = HeaderValue::from_str(&val)?;
190            let header_name: HeaderName = key.parse()?;
191            header_names.push(header_name.clone());
192            req_headers.insert(header_name, header_value);
193        }
194
195        connect_async(request).await.map(|resp| resp.0.split())
196    }
197
198    /// Reconnect with server.
199    ///
200    /// Make a new connection with server. Use the new read and write halves
201    /// to update self writer and read and heartbeat tasks.
202    pub async fn reconnect(&mut self) -> Result<(), Error> {
203        tracing::debug!("Reconnecting");
204
205        tokio::time::timeout(self.reconnect_timeout, async {
206            shutdown(
207                self.read_task.take(),
208                self.heartbeat_task.take(),
209                self.writer.clone(),
210            )
211            .await;
212
213            let (new_writer, reader) =
214                Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
215
216            {
217                let mut guard = self.writer.lock().await;
218                *guard = new_writer;
219                drop(guard);
220            }
221
222            // Spawn new read task
223            if let Some(ref handler) = self.config.handler {
224                self.read_task = Some(Self::spawn_read_task(
225                    reader,
226                    handler.clone(),
227                    self.config.ping_handler.clone(),
228                ));
229            }
230
231            // Optionally spawn new heartbeat task
232            self.heartbeat_task = self.config.heartbeat.as_ref().map(|heartbeat_secs| {
233                Self::spawn_heartbeat_task(
234                    self.connection_mode.clone(),
235                    *heartbeat_secs,
236                    self.config.heartbeat_msg.clone(),
237                    self.writer.clone(),
238                )
239            });
240
241            self.connection_mode
242                .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
243
244            tracing::debug!("Reconnect succeeded");
245            Ok(())
246        })
247        .await
248        .map_err(|_| {
249            Error::Io(std::io::Error::new(
250                std::io::ErrorKind::TimedOut,
251                format!(
252                    "reconnection timed out after {}s",
253                    self.reconnect_timeout.as_secs_f64()
254                ),
255            ))
256        })?
257    }
258
259    /// Check if the client is still connected.
260    ///
261    /// The client is connected if the read task has not finished. It is expected
262    /// that in case of any failure client or server side. The read task will be
263    /// shutdown or will receive a `Close` frame which will finish it. There
264    /// might be some delay between the connection being closed and the client
265    /// detecting.
266    #[inline]
267    #[must_use]
268    pub fn is_alive(&self) -> bool {
269        match &self.read_task {
270            Some(read_task) => !read_task.is_finished(),
271            None => true, // Stream is being used directly
272        }
273    }
274
275    fn spawn_read_task(
276        mut reader: MessageReader,
277        handler: Arc<PyObject>,
278        ping_handler: Option<Arc<PyObject>>,
279    ) -> tokio::task::JoinHandle<()> {
280        tracing::debug!("Started task 'read'");
281
282        tokio::task::spawn(async move {
283            loop {
284                match reader.next().await {
285                    Some(Ok(Message::Binary(data))) => {
286                        tracing::trace!("Received message <binary> {} bytes", data.len());
287                        if let Err(e) =
288                            Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &data),)))
289                        {
290                            tracing::error!("Error calling handler: {e}");
291                            break;
292                        }
293                        continue;
294                    }
295                    Some(Ok(Message::Text(data))) => {
296                        tracing::trace!("Received message: {data}");
297                        if let Err(e) = Python::with_gil(|py| {
298                            handler.call1(py, (PyBytes::new(py, data.as_bytes()),))
299                        }) {
300                            tracing::error!("Error calling handler: {e}");
301                            break;
302                        }
303                        continue;
304                    }
305                    Some(Ok(Message::Ping(ping))) => {
306                        tracing::trace!("Received ping: {ping:?}",);
307                        if let Some(ref handler) = ping_handler {
308                            if let Err(e) =
309                                Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &ping),)))
310                            {
311                                tracing::error!("Error calling handler: {e}");
312                                break;
313                            }
314                        }
315                        continue;
316                    }
317                    Some(Ok(Message::Pong(_))) => {
318                        tracing::trace!("Received pong");
319                    }
320                    Some(Ok(Message::Close(_))) => {
321                        tracing::debug!("Received close message - terminating");
322                        break;
323                    }
324                    Some(Ok(_)) => (),
325                    Some(Err(e)) => {
326                        tracing::error!("Received error message - terminating: {e}");
327                        break;
328                    }
329                    // Internally tungstenite considers the connection closed when polling
330                    // for the next message in the stream returns None.
331                    None => {
332                        tracing::debug!("No message received - terminating");
333                        break;
334                    }
335                }
336            }
337        })
338    }
339
340    fn spawn_heartbeat_task(
341        connection_state: Arc<AtomicU8>,
342        heartbeat_secs: u64,
343        message: Option<String>,
344        writer: SharedMessageWriter,
345    ) -> tokio::task::JoinHandle<()> {
346        tracing::debug!("Started task 'heartbeat'");
347
348        tokio::task::spawn(async move {
349            let interval = Duration::from_secs(heartbeat_secs);
350            loop {
351                tokio::time::sleep(interval).await;
352
353                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
354                    ConnectionMode::Active => {
355                        let mut guard = writer.lock().await;
356                        let guard_send_response = match message.clone() {
357                            Some(msg) => guard.send(Message::Text(msg.into())).await,
358                            None => guard.send(Message::Ping(vec![].into())).await,
359                        };
360
361                        match guard_send_response {
362                            Ok(()) => tracing::trace!("Sent ping"),
363                            Err(e) => tracing::error!("Error sending ping: {e}"),
364                        }
365                    }
366                    ConnectionMode::Reconnect => continue,
367                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
368                }
369            }
370
371            tracing::debug!("Completed task 'heartbeat'");
372        })
373    }
374}
375
376/// Shutdown websocket connection.
377///
378/// Performs orderly WebSocket shutdown:
379/// 1. Sends close frame to server
380/// 2. Waits briefly for frame delivery
381/// 3. Aborts read/heartbeat tasks
382/// 4. Closes underlying connection
383///
384/// This sequence ensures proper protocol compliance and clean resource cleanup.
385async fn shutdown(
386    read_task: Option<tokio::task::JoinHandle<()>>,
387    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
388    writer: SharedMessageWriter,
389) {
390    tracing::debug!("Closing");
391
392    let timeout = Duration::from_secs(5);
393    if tokio::time::timeout(timeout, async {
394        // Send close frame first
395        let mut write_half = writer.lock().await;
396        if let Err(e) = write_half.send(Message::Close(None)).await {
397            // Close frame errors during shutdown are expected
398            tracing::debug!("Error sending close frame: {e}");
399        }
400        drop(write_half);
401
402        tokio::time::sleep(Duration::from_millis(100)).await;
403
404        // Abort tasks
405        if let Some(task) = read_task {
406            if !task.is_finished() {
407                task.abort();
408                tracing::debug!("Aborted read task");
409            }
410        }
411        if let Some(task) = heartbeat_task {
412            if !task.is_finished() {
413                task.abort();
414                tracing::debug!("Aborted heartbeat task");
415            }
416        }
417
418        // Final close of writer
419        let mut write_half = writer.lock().await;
420        if let Err(e) = write_half.close().await {
421            tracing::error!("Error closing writer: {e}");
422        }
423    })
424    .await
425    .is_err()
426    {
427        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
428    }
429
430    tracing::debug!("Closed");
431}
432
433impl Drop for WebSocketClientInner {
434    fn drop(&mut self) {
435        if let Some(ref read_task) = self.read_task.take() {
436            if !read_task.is_finished() {
437                read_task.abort();
438            }
439        }
440
441        // Cancel heart beat task
442        if let Some(ref handle) = self.heartbeat_task.take() {
443            if !handle.is_finished() {
444                handle.abort();
445            }
446        }
447    }
448}
449
450/// WebSocket client with automatic reconnection.
451///
452/// Handles connection state, Python callbacks, and rate limiting.
453/// See module docs for architecture details.
454#[cfg_attr(
455    feature = "python",
456    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
457)]
458pub struct WebSocketClient {
459    pub(crate) writer: SharedMessageWriter,
460    pub(crate) controller_task: tokio::task::JoinHandle<()>,
461    pub(crate) connection_mode: Arc<AtomicU8>,
462    pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
463}
464
465impl WebSocketClient {
466    /// Creates a websocket client that returns a stream for reading messages.
467    #[allow(clippy::too_many_arguments)]
468    pub async fn connect_stream(
469        config: WebSocketConfig,
470        keyed_quotas: Vec<(String, Quota)>,
471        default_quota: Option<Quota>,
472    ) -> Result<(MessageReader, Self), Error> {
473        let (ws_stream, _) = connect_async(config.url.clone().into_client_request()?).await?;
474        let (writer, reader) = ws_stream.split();
475        let writer = Arc::new(Mutex::new(writer));
476
477        let inner = WebSocketClientInner::connect_url(config).await?;
478        let connection_mode = inner.connection_mode.clone();
479        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
480
481        let controller_task = Self::spawn_controller_task(
482            inner,
483            connection_mode.clone(),
484            None, // no post_reconnection
485            None, // no post_disconnection
486        );
487
488        Ok((
489            reader,
490            Self {
491                writer: writer.clone(),
492                controller_task,
493                connection_mode,
494                rate_limiter,
495            },
496        ))
497    }
498
499    /// Creates a websocket client.
500    ///
501    /// Creates an inner client and controller task to reconnect or disconnect
502    /// the client. Also assumes ownership of writer from inner client.
503    pub async fn connect(
504        config: WebSocketConfig,
505        post_connection: Option<PyObject>,
506        post_reconnection: Option<PyObject>,
507        post_disconnection: Option<PyObject>,
508        keyed_quotas: Vec<(String, Quota)>,
509        default_quota: Option<Quota>,
510    ) -> Result<Self, Error> {
511        tracing::debug!("Connecting");
512        let inner = WebSocketClientInner::connect_url(config.clone()).await?;
513        let writer = inner.writer.clone();
514        let connection_mode = inner.connection_mode.clone();
515
516        let controller_task = Self::spawn_controller_task(
517            inner,
518            connection_mode.clone(),
519            post_reconnection,
520            post_disconnection,
521        );
522        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
523
524        if let Some(handler) = post_connection {
525            Python::with_gil(|py| match handler.call0(py) {
526                Ok(_) => tracing::debug!("Called `post_connection` handler"),
527                Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
528            });
529        };
530
531        Ok(Self {
532            writer,
533            controller_task,
534            connection_mode: connection_mode.clone(),
535            rate_limiter,
536        })
537    }
538
539    /// Returns the current connection mode.
540    pub fn connection_mode(&self) -> ConnectionMode {
541        ConnectionMode::from_atomic(&self.connection_mode)
542    }
543
544    /// Check if the client connection is active.
545    ///
546    /// Returns `true` if the client is connected and has not been signalled to disconnect.
547    /// The client will automatically retry connection based on its configuration.
548    #[inline]
549    #[must_use]
550    pub fn is_active(&self) -> bool {
551        self.connection_mode().is_active()
552    }
553
554    /// Check if the client is disconnected.
555    #[must_use]
556    pub fn is_disconnected(&self) -> bool {
557        self.controller_task.is_finished()
558    }
559
560    /// Check if the client is reconnecting.
561    ///
562    /// Returns `true` if the client lost connection and is attempting to reestablish it.
563    /// The client will automatically retry connection based on its configuration.
564    #[inline]
565    #[must_use]
566    pub fn is_reconnecting(&self) -> bool {
567        self.connection_mode().is_reconnect()
568    }
569
570    /// Check if the client is disconnecting.
571    ///
572    /// Returns `true` if the client is in disconnect mode.
573    #[inline]
574    #[must_use]
575    pub fn is_disconnecting(&self) -> bool {
576        self.connection_mode().is_disconnect()
577    }
578
579    /// Check if the client is closed.
580    ///
581    /// Returns `true` if the client has been explicitly disconnected or reached
582    /// maximum reconnection attempts. In this state, the client cannot be reused
583    /// and a new client must be created for further connections.
584    #[inline]
585    #[must_use]
586    pub fn is_closed(&self) -> bool {
587        self.connection_mode().is_closed()
588    }
589
590    /// Set disconnect mode to true.
591    ///
592    /// Controller task will periodically check the disconnect mode
593    /// and shutdown the client if it is alive
594    pub async fn disconnect(&self) {
595        tracing::debug!("Disconnecting");
596        self.connection_mode
597            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
598
599        match tokio::time::timeout(Duration::from_secs(5), async {
600            while !self.is_disconnected() {
601                tokio::time::sleep(Duration::from_millis(10)).await;
602            }
603
604            if !self.controller_task.is_finished() {
605                self.controller_task.abort();
606                tracing::debug!("Aborted controller task");
607            }
608        })
609        .await
610        {
611            Ok(()) => {
612                tracing::debug!("Controller task finished");
613            }
614            Err(_) => {
615                tracing::error!("Timeout waiting for controller task to finish");
616            }
617        }
618    }
619
620    pub async fn send_text(&self, data: String, keys: Option<Vec<String>>) -> Result<(), Error> {
621        self.rate_limiter.await_keys_ready(keys).await;
622        tracing::trace!("Sending text: {data:?}");
623        let mut guard = self.writer.lock().await;
624        guard.send(Message::Text(data.into())).await
625    }
626
627    pub async fn send_bytes(&self, data: Vec<u8>, keys: Option<Vec<String>>) -> Result<(), Error> {
628        self.rate_limiter.await_keys_ready(keys).await;
629        tracing::trace!("Sending bytes: {data:?}");
630        let mut guard = self.writer.lock().await;
631        guard.send(Message::Binary(data.into())).await
632    }
633
634    pub async fn send_close_message(&self) {
635        let mut guard = self.writer.lock().await;
636        match guard.send(Message::Close(None)).await {
637            Ok(()) => tracing::debug!("Sent close message"),
638            Err(e) => tracing::error!("Error sending close message: {e}"),
639        }
640    }
641
642    fn spawn_controller_task(
643        mut inner: WebSocketClientInner,
644        connection_mode: Arc<AtomicU8>,
645        post_reconnection: Option<PyObject>,
646        post_disconnection: Option<PyObject>,
647    ) -> tokio::task::JoinHandle<()> {
648        tokio::task::spawn(async move {
649            tracing::debug!("Starting task 'controller'");
650
651            let check_interval = Duration::from_millis(10);
652
653            loop {
654                tokio::time::sleep(check_interval).await;
655                let mode = ConnectionMode::from_atomic(&connection_mode);
656
657                if mode.is_disconnect() {
658                    tracing::debug!("Disconnecting");
659                    shutdown(
660                        inner.read_task.take(),
661                        inner.heartbeat_task.take(),
662                        inner.writer.clone(),
663                    )
664                    .await;
665
666                    if let Some(ref handler) = post_disconnection {
667                        Python::with_gil(|py| match handler.call0(py) {
668                            Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
669                            Err(e) => {
670                                tracing::error!("Error calling `post_disconnection` handler: {e}")
671                            }
672                        });
673                    }
674                    break; // Controller finished
675                }
676
677                if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
678                    match inner.reconnect().await {
679                        Ok(()) => {
680                            tracing::debug!("Reconnected successfully");
681                            inner.backoff.reset();
682
683                            if let Some(ref handler) = post_reconnection {
684                                Python::with_gil(|py| match handler.call0(py) {
685                                    Ok(_) => tracing::debug!("Called `post_reconnection` handler"),
686                                    Err(e) => tracing::error!(
687                                        "Error calling `post_reconnection` handler: {e}"
688                                    ),
689                                });
690                            }
691                        }
692                        Err(e) => {
693                            let duration = inner.backoff.next_duration();
694                            tracing::warn!("Reconnect attempt failed: {e}");
695                            if !duration.is_zero() {
696                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
697                            }
698                            tokio::time::sleep(duration).await;
699                        }
700                    }
701                }
702            }
703            inner
704                .connection_mode
705                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
706        })
707    }
708}
709
710////////////////////////////////////////////////////////////////////////////////
711// Tests
712////////////////////////////////////////////////////////////////////////////////
713#[cfg(test)]
714#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
715mod tests {
716    use std::{num::NonZeroU32, sync::Arc};
717
718    use futures_util::{SinkExt, StreamExt};
719    use tokio::{
720        net::TcpListener,
721        task::{self, JoinHandle},
722    };
723    use tokio_tungstenite::{
724        accept_hdr_async,
725        tungstenite::{
726            handshake::server::{self, Callback},
727            http::HeaderValue,
728        },
729    };
730
731    use crate::{
732        ratelimiter::quota::Quota,
733        websocket::{WebSocketClient, WebSocketConfig},
734    };
735
736    struct TestServer {
737        task: JoinHandle<()>,
738        port: u16,
739    }
740
741    #[derive(Debug, Clone)]
742    struct TestCallback {
743        key: String,
744        value: HeaderValue,
745    }
746
747    impl Callback for TestCallback {
748        fn on_request(
749            self,
750            request: &server::Request,
751            response: server::Response,
752        ) -> Result<server::Response, server::ErrorResponse> {
753            let _ = response;
754            let value = request.headers().get(&self.key);
755            assert!(value.is_some());
756
757            if let Some(value) = request.headers().get(&self.key) {
758                assert_eq!(value, self.value);
759            }
760
761            Ok(response)
762        }
763    }
764
765    impl TestServer {
766        async fn setup() -> Self {
767            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
768            let port = TcpListener::local_addr(&server).unwrap().port();
769
770            let header_key = "test".to_string();
771            let header_value = "test".to_string();
772
773            let test_call_back = TestCallback {
774                key: header_key,
775                value: HeaderValue::from_str(&header_value).unwrap(),
776            };
777
778            let task = task::spawn(async move {
779                // Keep accepting connections
780                loop {
781                    let (conn, _) = server.accept().await.unwrap();
782                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
783                        .await
784                        .unwrap();
785
786                    task::spawn(async move {
787                        while let Some(Ok(msg)) = websocket.next().await {
788                            match msg {
789                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
790                                    if txt == "close-now" =>
791                                {
792                                    tracing::debug!("Forcibly closing from server side");
793                                    // This sends a close frame, then stops reading
794                                    let _ = websocket.close(None).await;
795                                    break;
796                                }
797                                // Echo text/binary frames
798                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
799                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
800                                    if websocket.send(msg).await.is_err() {
801                                        break;
802                                    }
803                                }
804                                // If the client closes, we also break
805                                tokio_tungstenite::tungstenite::protocol::Message::Close(
806                                    _frame,
807                                ) => {
808                                    let _ = websocket.close(None).await;
809                                    break;
810                                }
811                                // Ignore pings/pongs
812                                _ => {}
813                            }
814                        }
815                    });
816                }
817            });
818
819            Self { task, port }
820        }
821    }
822
823    impl Drop for TestServer {
824        fn drop(&mut self) {
825            self.task.abort();
826        }
827    }
828
829    async fn setup_test_client(port: u16) -> WebSocketClient {
830        let config = WebSocketConfig {
831            url: format!("ws://127.0.0.1:{port}"),
832            headers: vec![("test".into(), "test".into())],
833            handler: None,
834            heartbeat: None,
835            heartbeat_msg: None,
836            ping_handler: None,
837            reconnect_timeout_ms: None,
838            reconnect_delay_initial_ms: None,
839            reconnect_backoff_factor: None,
840            reconnect_delay_max_ms: None,
841            reconnect_jitter_ms: None,
842        };
843        WebSocketClient::connect(config, None, None, None, vec![], None)
844            .await
845            .expect("Failed to connect")
846    }
847
848    #[tokio::test]
849    async fn test_websocket_basic() {
850        let server = TestServer::setup().await;
851        let client = setup_test_client(server.port).await;
852
853        assert!(!client.is_disconnected());
854
855        client.disconnect().await;
856        assert!(client.is_disconnected());
857    }
858
859    #[tokio::test]
860    async fn test_websocket_heartbeat() {
861        let server = TestServer::setup().await;
862        let client = setup_test_client(server.port).await;
863
864        // Wait ~3s => server should see multiple "ping"
865        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
866
867        // Cleanup
868        client.disconnect().await;
869        assert!(client.is_disconnected());
870    }
871
872    #[tokio::test]
873    async fn test_websocket_reconnect_exhausted() {
874        let config = WebSocketConfig {
875            url: "ws://127.0.0.1:9997".into(), // <-- No server
876            headers: vec![],
877            handler: None,
878            heartbeat: None,
879            heartbeat_msg: None,
880            ping_handler: None,
881            reconnect_timeout_ms: None,
882            reconnect_delay_initial_ms: None,
883            reconnect_backoff_factor: None,
884            reconnect_delay_max_ms: None,
885            reconnect_jitter_ms: None,
886        };
887        let res = WebSocketClient::connect(config, None, None, None, vec![], None).await;
888        assert!(res.is_err(), "Should fail quickly with no server");
889    }
890
891    #[tokio::test]
892    async fn test_websocket_forced_close_reconnect() {
893        let server = TestServer::setup().await;
894        let client = setup_test_client(server.port).await;
895
896        // 1) Send normal message
897        client.send_text("Hello".into(), None).await.unwrap();
898
899        // 2) Trigger forced close from server
900        client.send_text("close-now".into(), None).await.unwrap();
901
902        // 3) Wait a bit => read loop sees close => reconnect
903        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
904
905        // Confirm not disconnected
906        assert!(!client.is_disconnected());
907
908        // Cleanup
909        client.disconnect().await;
910        assert!(client.is_disconnected());
911    }
912
913    #[tokio::test]
914    async fn test_rate_limiter() {
915        let server = TestServer::setup().await;
916        let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
917
918        let config = WebSocketConfig {
919            url: format!("ws://127.0.0.1:{}", server.port),
920            headers: vec![("test".into(), "test".into())],
921            handler: None,
922            heartbeat: None,
923            heartbeat_msg: None,
924            ping_handler: None,
925            reconnect_timeout_ms: None,
926            reconnect_delay_initial_ms: None,
927            reconnect_backoff_factor: None,
928            reconnect_delay_max_ms: None,
929            reconnect_jitter_ms: None,
930        };
931
932        let client = WebSocketClient::connect(
933            config,
934            None,
935            None,
936            None,
937            vec![("default".into(), quota)],
938            None,
939        )
940        .await
941        .unwrap();
942
943        // First 2 should succeed
944        client.send_text("test1".into(), None).await.unwrap();
945        client.send_text("test2".into(), None).await.unwrap();
946
947        // Third should error
948        client.send_text("test3".into(), None).await.unwrap();
949
950        // Cleanup
951        client.disconnect().await;
952        assert!(client.is_disconnected());
953    }
954
955    #[tokio::test]
956    async fn test_concurrent_writers() {
957        let server = TestServer::setup().await;
958        let client = Arc::new(setup_test_client(server.port).await);
959
960        let mut handles = vec![];
961        for i in 0..10 {
962            let client = client.clone();
963            handles.push(task::spawn(async move {
964                client.send_text(format!("test{i}"), None).await.unwrap();
965            }));
966        }
967
968        for handle in handles {
969            handle.await.unwrap();
970        }
971
972        // Cleanup
973        client.disconnect().await;
974        assert!(client.is_disconnected());
975    }
976}