nautilus_network/
socket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance raw TCP client implementation with TLS capability, automatic reconnection
17//! with exponential backoff and state management.
18
19use std::{
20    path::Path,
21    sync::{
22        Arc,
23        atomic::{AtomicU8, Ordering},
24    },
25    time::Duration,
26};
27
28use nautilus_cryptography::providers::install_cryptographic_provider;
29use pyo3::prelude::*;
30use tokio::{
31    io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
32    net::TcpStream,
33    sync::Mutex,
34};
35use tokio_tungstenite::{
36    MaybeTlsStream,
37    tungstenite::{Error, client::IntoClientRequest, stream::Mode},
38};
39
40use crate::{
41    backoff::ExponentialBackoff,
42    mode::ConnectionMode,
43    tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
44};
45
46type TcpWriter = WriteHalf<MaybeTlsStream<TcpStream>>;
47type SharedTcpWriter = Arc<Mutex<WriteHalf<MaybeTlsStream<TcpStream>>>>;
48type TcpReader = ReadHalf<MaybeTlsStream<TcpStream>>;
49
50/// Configuration for TCP socket connection.
51#[derive(Debug, Clone)]
52#[cfg_attr(
53    feature = "python",
54    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
55)]
56pub struct SocketConfig {
57    /// The URL to connect to.
58    pub url: String,
59    /// The connection mode {Plain, TLS}.
60    pub mode: Mode,
61    /// The sequence of bytes which separates lines.
62    pub suffix: Vec<u8>,
63    /// The Python function to handle incoming messages.
64    pub handler: Arc<PyObject>,
65    /// The optional heartbeat with period and beat message.
66    pub heartbeat: Option<(u64, Vec<u8>)>,
67    /// The timeout (milliseconds) for reconnection attempts.
68    pub reconnect_timeout_ms: Option<u64>,
69    /// The initial reconnection delay (milliseconds) for reconnects.
70    pub reconnect_delay_initial_ms: Option<u64>,
71    /// The maximum reconnect delay (milliseconds) for exponential backoff.
72    pub reconnect_delay_max_ms: Option<u64>,
73    /// The exponential backoff factor for reconnection delays.
74    pub reconnect_backoff_factor: Option<f64>,
75    /// The maximum jitter (milliseconds) added to reconnection delays.
76    pub reconnect_jitter_ms: Option<u64>,
77    /// The path to the certificates directory.
78    pub certs_dir: Option<String>,
79}
80
81/// Creates a TcpStream with the server.
82///
83/// The stream can be encrypted with TLS or Plain. The stream is split into
84/// read and write ends:
85/// - The read end is passed to the task that keeps receiving
86///   messages from the server and passing them to a handler.
87/// - The write end is wrapped in an `Arc<Mutex>` and used to send messages
88///   or heart beats.
89///
90/// The heartbeat is optional and can be configured with an interval and data to
91/// send.
92///
93/// The client uses a suffix to separate messages on the byte stream. It is
94/// appended to all sent messages and heartbeats. It is also used to split
95/// the received byte stream.
96#[cfg_attr(
97    feature = "python",
98    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
99)]
100struct SocketClientInner {
101    config: SocketConfig,
102    connector: Option<Connector>,
103    read_task: Arc<tokio::task::JoinHandle<()>>,
104    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
105    writer: SharedTcpWriter,
106    connection_mode: Arc<AtomicU8>,
107    reconnect_timeout: Duration,
108    backoff: ExponentialBackoff,
109}
110
111impl SocketClientInner {
112    pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
113        install_cryptographic_provider();
114
115        let SocketConfig {
116            url,
117            mode,
118            heartbeat,
119            suffix,
120            handler,
121            reconnect_timeout_ms,
122            reconnect_delay_initial_ms,
123            reconnect_delay_max_ms,
124            reconnect_backoff_factor,
125            reconnect_jitter_ms,
126            certs_dir,
127        } = &config;
128        let connector = if let Some(dir) = certs_dir {
129            let config = create_tls_config_from_certs_dir(Path::new(dir))?;
130            Some(Connector::Rustls(Arc::new(config)))
131        } else {
132            None
133        };
134
135        let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector.clone()).await?;
136        let writer = Arc::new(Mutex::new(writer));
137
138        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
139
140        let handler = Python::with_gil(|py| handler.clone_ref(py));
141        let read_task = Arc::new(Self::spawn_read_task(reader, handler, suffix.clone()));
142
143        // Optionally spawn a heartbeat task to periodically ping server
144        let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
145            Self::spawn_heartbeat_task(
146                connection_mode.clone(),
147                heartbeat.clone(),
148                writer.clone(),
149                suffix.clone(),
150            )
151        });
152
153        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
154        let backoff = ExponentialBackoff::new(
155            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
156            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
157            reconnect_backoff_factor.unwrap_or(1.5),
158            reconnect_jitter_ms.unwrap_or(100),
159            true, // immediate-first
160        );
161
162        Ok(Self {
163            config,
164            connector,
165            read_task,
166            heartbeat_task,
167            writer,
168            connection_mode,
169            reconnect_timeout,
170            backoff,
171        })
172    }
173
174    pub async fn tls_connect_with_server(
175        url: &str,
176        mode: Mode,
177        connector: Option<Connector>,
178    ) -> Result<(TcpReader, TcpWriter), Error> {
179        tracing::debug!("Connecting to server");
180        let stream = TcpStream::connect(url).await?;
181        tracing::debug!("Making TLS connection");
182        let request = url.into_client_request()?;
183        tcp_tls(&request, mode, stream, connector)
184            .await
185            .map(tokio::io::split)
186    }
187
188    /// Reconnect with server.
189    ///
190    /// Make a new connection with server. Use the new read and write halves
191    /// to update the shared writer and the read and heartbeat tasks.
192    async fn reconnect(&mut self) -> Result<(), Error> {
193        tracing::debug!("Reconnecting");
194
195        tokio::time::timeout(self.reconnect_timeout, async {
196            // Clean up existing tasks
197            shutdown(
198                self.read_task.clone(),
199                self.heartbeat_task.take(),
200                self.writer.clone(),
201            )
202            .await;
203
204            let SocketConfig {
205                url,
206                mode,
207                heartbeat,
208                suffix,
209                handler,
210                reconnect_timeout_ms: _,
211                reconnect_delay_initial_ms: _,
212                reconnect_backoff_factor: _,
213                reconnect_delay_max_ms: _,
214                reconnect_jitter_ms: _,
215                certs_dir: _,
216            } = &self.config;
217            // Create a fresh connection
218            let connector = self.connector.clone();
219            let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
220            let writer = Arc::new(Mutex::new(writer));
221            self.writer = writer.clone();
222
223            // Spawn new read task
224            let handler_for_read = Python::with_gil(|py| handler.clone_ref(py));
225            self.read_task = Arc::new(Self::spawn_read_task(
226                reader,
227                handler_for_read,
228                suffix.clone(),
229            ));
230
231            // Optionally spawn new heartbeat task
232            self.heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
233                Self::spawn_heartbeat_task(
234                    self.connection_mode.clone(),
235                    heartbeat.clone(),
236                    writer.clone(),
237                    suffix.clone(),
238                )
239            });
240
241            self.connection_mode
242                .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
243
244            tracing::debug!("Reconnect succeeded");
245            Ok(())
246        })
247        .await
248        .map_err(|_| {
249            Error::Io(std::io::Error::new(
250                std::io::ErrorKind::TimedOut,
251                format!(
252                    "reconnection timed out after {}s",
253                    self.reconnect_timeout.as_secs_f64()
254                ),
255            ))
256        })?
257    }
258
259    /// Check if the client is still alive.
260    ///
261    /// The client is connected if the read task has not finished. It is expected
262    /// that in case of any failure client or server side. The read task will be
263    /// shutdown. There might be some delay between the connection being closed
264    /// and the client detecting it.
265    #[inline]
266    #[must_use]
267    pub fn is_alive(&self) -> bool {
268        !self.read_task.is_finished()
269    }
270
271    #[must_use]
272    fn spawn_read_task(
273        mut reader: TcpReader,
274        handler: PyObject,
275        suffix: Vec<u8>,
276    ) -> tokio::task::JoinHandle<()> {
277        tracing::debug!("Started task 'read'");
278
279        tokio::task::spawn(async move {
280            let mut buf = Vec::new();
281
282            loop {
283                match reader.read_buf(&mut buf).await {
284                    // Connection has been terminated or vector buffer is complete
285                    Ok(0) => {
286                        tracing::debug!("Connection closed by server");
287                        break;
288                    }
289                    Err(e) => {
290                        tracing::debug!("Connection ended: {e}");
291                        break;
292                    }
293                    // Received bytes of data
294                    Ok(bytes) => {
295                        tracing::trace!("Received <binary> {bytes} bytes");
296
297                        // While received data has a line break
298                        // drain it and pass it to the handler
299                        while let Some((i, _)) = &buf
300                            .windows(suffix.len())
301                            .enumerate()
302                            .find(|(_, pair)| pair.eq(&suffix))
303                        {
304                            let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
305                            data.truncate(data.len() - suffix.len());
306
307                            if let Err(e) =
308                                Python::with_gil(|py| handler.call1(py, (data.as_slice(),)))
309                            {
310                                tracing::error!("Call to handler failed: {e}");
311                                break;
312                            }
313                        }
314                    }
315                };
316            }
317
318            tracing::debug!("Completed task 'read'");
319        })
320    }
321
322    fn spawn_heartbeat_task(
323        connection_state: Arc<AtomicU8>,
324        heartbeat: (u64, Vec<u8>),
325        writer: SharedTcpWriter,
326        suffix: Vec<u8>,
327    ) -> tokio::task::JoinHandle<()> {
328        tracing::debug!("Started task 'heartbeat'");
329        let (interval_secs, mut message) = heartbeat;
330
331        tokio::task::spawn(async move {
332            let interval = Duration::from_secs(interval_secs);
333            message.extend(suffix);
334
335            loop {
336                tokio::time::sleep(interval).await;
337
338                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
339                    ConnectionMode::Active => {
340                        let mut guard = writer.lock().await;
341                        match guard.write_all(&message).await {
342                            Ok(()) => tracing::trace!("Sent heartbeat"),
343                            Err(e) => tracing::error!("Failed to send heartbeat: {e}"),
344                        }
345                    }
346                    ConnectionMode::Reconnect => continue,
347                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
348                }
349            }
350
351            tracing::debug!("Completed task 'heartbeat'");
352        })
353    }
354}
355
356/// Shutdown socket connection.
357///
358/// The client must be explicitly shutdown before dropping otherwise
359/// the connection might still be alive for some time before terminating.
360/// Closing the connection is an async call which cannot be done by the
361/// drop method so it must be done explicitly.
362async fn shutdown(
363    read_task: Arc<tokio::task::JoinHandle<()>>,
364    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
365    writer: SharedTcpWriter,
366) {
367    tracing::debug!("Shutting down inner client");
368
369    let timeout = Duration::from_secs(5);
370    if tokio::time::timeout(timeout, async {
371        // Final close of writer
372        let mut writer = writer.lock().await;
373        if let Err(e) = writer.shutdown().await {
374            tracing::error!("Error on shutdown: {e}");
375        }
376        drop(writer);
377
378        tokio::time::sleep(Duration::from_millis(100)).await;
379
380        // Abort tasks
381        if !read_task.is_finished() {
382            read_task.abort();
383            tracing::debug!("Aborted read task");
384        }
385        if let Some(task) = heartbeat_task {
386            if !task.is_finished() {
387                task.abort();
388                tracing::debug!("Aborted heartbeat task");
389            }
390        }
391    })
392    .await
393    .is_err()
394    {
395        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
396    }
397
398    tracing::debug!("Closed");
399}
400
401impl Drop for SocketClientInner {
402    fn drop(&mut self) {
403        if !self.read_task.is_finished() {
404            self.read_task.abort();
405        }
406
407        // Cancel heart beat task
408        if let Some(ref handle) = self.heartbeat_task.take() {
409            if !handle.is_finished() {
410                handle.abort();
411            }
412        }
413    }
414}
415
416#[cfg_attr(
417    feature = "python",
418    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
419)]
420pub struct SocketClient {
421    pub(crate) writer: SharedTcpWriter,
422    pub(crate) controller_task: tokio::task::JoinHandle<()>,
423    pub(crate) connection_mode: Arc<AtomicU8>,
424    pub(crate) suffix: Vec<u8>,
425}
426
427impl SocketClient {
428    /// Connect to the server.
429    ///
430    /// # Errors
431    ///
432    /// Returns any error connecting to the server.
433    pub async fn connect(
434        config: SocketConfig,
435        post_connection: Option<PyObject>,
436        post_reconnection: Option<PyObject>,
437        post_disconnection: Option<PyObject>,
438    ) -> anyhow::Result<Self> {
439        let suffix = config.suffix.clone();
440        let inner = SocketClientInner::connect_url(config).await?;
441        let writer = inner.writer.clone();
442        let connection_mode = inner.connection_mode.clone();
443
444        let controller_task = Self::spawn_controller_task(
445            inner,
446            connection_mode.clone(),
447            post_reconnection,
448            post_disconnection,
449        );
450
451        if let Some(handler) = post_connection {
452            Python::with_gil(|py| match handler.call0(py) {
453                Ok(_) => tracing::debug!("Called `post_connection` handler"),
454                Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
455            });
456        }
457
458        Ok(Self {
459            writer,
460            controller_task,
461            connection_mode,
462            suffix,
463        })
464    }
465
466    /// Returns the current connection mode.
467    #[must_use]
468    pub fn connection_mode(&self) -> ConnectionMode {
469        ConnectionMode::from_atomic(&self.connection_mode)
470    }
471
472    /// Check if the client connection is active.
473    ///
474    /// Returns `true` if the client is connected and has not been signalled to disconnect.
475    /// The client will automatically retry connection based on its configuration.
476    #[inline]
477    #[must_use]
478    pub fn is_active(&self) -> bool {
479        self.connection_mode().is_active()
480    }
481
482    /// Check if the client is reconnecting.
483    ///
484    /// Returns `true` if the client lost connection and is attempting to reestablish it.
485    /// The client will automatically retry connection based on its configuration.
486    #[inline]
487    #[must_use]
488    pub fn is_reconnecting(&self) -> bool {
489        self.connection_mode().is_reconnect()
490    }
491
492    /// Check if the client is disconnecting.
493    ///
494    /// Returns `true` if the client is in disconnect mode.
495    #[inline]
496    #[must_use]
497    pub fn is_disconnecting(&self) -> bool {
498        self.connection_mode().is_disconnect()
499    }
500
501    /// Check if the client is closed.
502    ///
503    /// Returns `true` if the client has been explicitly disconnected or reached
504    /// maximum reconnection attempts. In this state, the client cannot be reused
505    /// and a new client must be created for further connections.
506    #[inline]
507    #[must_use]
508    pub fn is_closed(&self) -> bool {
509        self.connection_mode().is_closed()
510    }
511
512    /// Close the client.
513    ///
514    /// Controller task will periodically check the disconnect mode
515    /// and shutdown the client if it is not alive.
516    pub async fn close(&self) {
517        self.connection_mode
518            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
519
520        match tokio::time::timeout(Duration::from_secs(5), async {
521            while !self.is_closed() {
522                tokio::time::sleep(Duration::from_millis(10)).await;
523            }
524
525            if !self.controller_task.is_finished() {
526                self.controller_task.abort();
527                tracing::debug!("Aborted controller task");
528            }
529        })
530        .await
531        {
532            Ok(()) => {
533                tracing::debug!("Controller task finished");
534            }
535            Err(_) => {
536                tracing::error!("Timeout waiting for controller task to finish");
537            }
538        }
539    }
540
541    /// Sends a message of the given `data`.
542    ///
543    /// # Errors
544    ///
545    /// Returns any I/O error.
546    pub async fn send_bytes(&self, data: &[u8]) -> Result<(), std::io::Error> {
547        if self.is_closed() {
548            return Err(std::io::Error::new(
549                std::io::ErrorKind::NotConnected,
550                "Not connected",
551            ));
552        }
553
554        let timeout = Duration::from_secs(2);
555        let check_interval = Duration::from_millis(1);
556
557        if !self.is_active() {
558            tracing::debug!("Waiting for client to become ACTIVE before sending (2s)...");
559            match tokio::time::timeout(timeout, async {
560                while !self.is_active() {
561                    if matches!(
562                        self.connection_mode(),
563                        ConnectionMode::Disconnect | ConnectionMode::Closed
564                    ) {
565                        return Err("Client disconnected waiting to send");
566                    }
567
568                    tokio::time::sleep(check_interval).await;
569                }
570
571                Ok(())
572            })
573            .await
574            {
575                Ok(Ok(())) => tracing::debug!("Client now active"),
576                Ok(Err(e)) => {
577                    tracing::error!("Cannot send data ({}): {e}", String::from_utf8_lossy(data));
578                    return Ok(());
579                }
580                Err(_) => {
581                    tracing::error!(
582                        "Cannot send data ({}): timeout waiting to become ACTIVE",
583                        String::from_utf8_lossy(data)
584                    );
585                    return Ok(());
586                }
587            }
588        }
589
590        let mut writer = self.writer.lock().await;
591        writer.write_all(data).await?;
592        writer.write_all(&self.suffix).await
593    }
594
595    fn spawn_controller_task(
596        mut inner: SocketClientInner,
597        connection_mode: Arc<AtomicU8>,
598        post_reconnection: Option<PyObject>,
599        post_disconnection: Option<PyObject>,
600    ) -> tokio::task::JoinHandle<()> {
601        tokio::task::spawn(async move {
602            tracing::debug!("Started task 'controller'");
603
604            let check_interval = Duration::from_millis(10);
605
606            loop {
607                tokio::time::sleep(check_interval).await;
608                let mode = ConnectionMode::from_atomic(&connection_mode);
609
610                if mode.is_disconnect() {
611                    tracing::debug!("Disconnecting");
612                    shutdown(
613                        inner.read_task.clone(),
614                        inner.heartbeat_task.take(),
615                        inner.writer.clone(),
616                    )
617                    .await;
618
619                    if let Some(ref handler) = post_disconnection {
620                        Python::with_gil(|py| match handler.call0(py) {
621                            Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
622                            Err(e) => {
623                                tracing::error!("Error calling `post_disconnection` handler: {e}");
624                            }
625                        });
626                    }
627                    break; // Controller finished
628                }
629
630                if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
631                    match inner.reconnect().await {
632                        Ok(()) => {
633                            tracing::debug!("Reconnected successfully");
634                            inner.backoff.reset();
635
636                            if let Some(ref handler) = post_reconnection {
637                                Python::with_gil(|py| match handler.call0(py) {
638                                    Ok(_) => tracing::debug!("Called `post_reconnection` handler"),
639                                    Err(e) => tracing::error!(
640                                        "Error calling `post_reconnection` handler: {e}"
641                                    ),
642                                });
643                            }
644                        }
645                        Err(e) => {
646                            let duration = inner.backoff.next_duration();
647                            tracing::warn!("Reconnect attempt failed: {e}",);
648                            if !duration.is_zero() {
649                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
650                            }
651                            tokio::time::sleep(duration).await;
652                        }
653                    }
654                }
655            }
656            inner
657                .connection_mode
658                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
659        })
660    }
661}
662
663////////////////////////////////////////////////////////////////////////////////
664// Tests
665////////////////////////////////////////////////////////////////////////////////
666#[cfg(test)]
667#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
668mod tests {
669    use std::{ffi::CString, net::TcpListener};
670
671    use nautilus_common::testing::wait_until_async;
672    use nautilus_core::python::IntoPyObjectNautilusExt;
673    use pyo3::prepare_freethreaded_python;
674    use tokio::{
675        io::{AsyncReadExt, AsyncWriteExt},
676        net::TcpStream,
677        task,
678        time::{Duration, sleep},
679    };
680
681    use super::*;
682
683    fn create_handler() -> PyObject {
684        let code_raw = r"
685class Counter:
686    def __init__(self):
687        self.count = 0
688        self.check = False
689
690    def handler(self, bytes):
691        msg = bytes.decode()
692        if msg == 'ping':
693            self.count += 1
694        elif msg == 'heartbeat message':
695            self.check = True
696
697    def get_check(self):
698        return self.check
699
700    def get_count(self):
701        return self.count
702
703counter = Counter()
704";
705        let code = CString::new(code_raw).unwrap();
706        let filename = CString::new("test".to_string()).unwrap();
707        let module = CString::new("test".to_string()).unwrap();
708        Python::with_gil(|py| {
709            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
710            let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
711
712            counter
713                .getattr(py, "handler")
714                .unwrap()
715                .into_py_any_unwrap(py)
716        })
717    }
718
719    fn bind_test_server() -> (u16, TcpListener) {
720        let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind ephemeral port");
721        let port = listener.local_addr().unwrap().port();
722        (port, listener)
723    }
724
725    async fn run_echo_server(mut socket: TcpStream) {
726        let mut buf = Vec::new();
727        loop {
728            match socket.read_buf(&mut buf).await {
729                Ok(0) => {
730                    break;
731                }
732                Ok(_n) => {
733                    while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
734                        let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
735                        // Remove trailing \r\n
736                        line.truncate(line.len() - 2);
737
738                        if line == b"close" {
739                            let _ = socket.shutdown().await;
740                            return;
741                        }
742
743                        let mut echo_data = line;
744                        echo_data.extend_from_slice(b"\r\n");
745                        if socket.write_all(&echo_data).await.is_err() {
746                            break;
747                        }
748                    }
749                }
750                Err(e) => {
751                    eprintln!("Server read error: {e}");
752                    break;
753                }
754            }
755        }
756    }
757
758    #[tokio::test]
759    async fn test_basic_send_receive() {
760        prepare_freethreaded_python();
761
762        let (port, listener) = bind_test_server();
763        let server_task = task::spawn(async move {
764            let (socket, _) = tokio::net::TcpListener::from_std(listener)
765                .unwrap()
766                .accept()
767                .await
768                .unwrap();
769            run_echo_server(socket).await;
770        });
771
772        let config = SocketConfig {
773            url: format!("127.0.0.1:{port}"),
774            mode: Mode::Plain,
775            suffix: b"\r\n".to_vec(),
776            handler: Arc::new(create_handler()),
777            heartbeat: None,
778            reconnect_timeout_ms: None,
779            reconnect_delay_initial_ms: None,
780            reconnect_backoff_factor: None,
781            reconnect_delay_max_ms: None,
782            reconnect_jitter_ms: None,
783            certs_dir: None,
784        };
785
786        let client = SocketClient::connect(config, None, None, None)
787            .await
788            .expect("Client connect failed unexpectedly");
789
790        client.send_bytes(b"Hello").await.unwrap();
791        client.send_bytes(b"World").await.unwrap();
792
793        // Wait a bit for the server to echo them back
794        sleep(Duration::from_millis(100)).await;
795
796        client.send_bytes(b"close").await.unwrap();
797        server_task.await.unwrap();
798        assert!(!client.is_closed());
799    }
800
801    #[tokio::test]
802    async fn test_reconnect_fail_exhausted() {
803        prepare_freethreaded_python();
804
805        let (port, listener) = bind_test_server();
806        drop(listener); // We drop it immediately -> no server is listening
807
808        let config = SocketConfig {
809            url: format!("127.0.0.1:{port}"),
810            mode: Mode::Plain,
811            suffix: b"\r\n".to_vec(),
812            handler: Arc::new(create_handler()),
813            heartbeat: None,
814            reconnect_timeout_ms: None,
815            reconnect_delay_initial_ms: None,
816            reconnect_backoff_factor: None,
817            reconnect_delay_max_ms: None,
818            reconnect_jitter_ms: None,
819            certs_dir: None,
820        };
821
822        let client_res = SocketClient::connect(config, None, None, None).await;
823        assert!(
824            client_res.is_err(),
825            "Should fail quickly with no server listening"
826        );
827    }
828
829    #[tokio::test]
830    async fn test_user_disconnect() {
831        prepare_freethreaded_python();
832
833        let (port, listener) = bind_test_server();
834        let server_task = task::spawn(async move {
835            let (socket, _) = tokio::net::TcpListener::from_std(listener)
836                .unwrap()
837                .accept()
838                .await
839                .unwrap();
840            let mut buf = [0u8; 1024];
841            let _ = socket.try_read(&mut buf);
842
843            loop {
844                sleep(Duration::from_secs(1)).await;
845            }
846        });
847
848        let config = SocketConfig {
849            url: format!("127.0.0.1:{port}"),
850            mode: Mode::Plain,
851            suffix: b"\r\n".to_vec(),
852            handler: Arc::new(create_handler()),
853            heartbeat: None,
854            reconnect_timeout_ms: None,
855            reconnect_delay_initial_ms: None,
856            reconnect_backoff_factor: None,
857            reconnect_delay_max_ms: None,
858            reconnect_jitter_ms: None,
859            certs_dir: None,
860        };
861
862        let client = SocketClient::connect(config, None, None, None)
863            .await
864            .unwrap();
865
866        client.close().await;
867        assert!(client.is_closed());
868        server_task.abort();
869    }
870
871    #[tokio::test]
872    async fn test_heartbeat() {
873        prepare_freethreaded_python();
874
875        let (port, listener) = bind_test_server();
876        let received = Arc::new(Mutex::new(Vec::new()));
877        let received2 = received.clone();
878
879        let server_task = task::spawn(async move {
880            let (socket, _) = tokio::net::TcpListener::from_std(listener)
881                .unwrap()
882                .accept()
883                .await
884                .unwrap();
885
886            let mut buf = Vec::new();
887            loop {
888                match socket.try_read_buf(&mut buf) {
889                    Ok(0) => break,
890                    Ok(_) => {
891                        while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
892                            let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
893                            line.truncate(line.len() - 2);
894                            received2.lock().await.push(line);
895                        }
896                    }
897                    Err(_) => {
898                        tokio::time::sleep(Duration::from_millis(10)).await;
899                    }
900                }
901            }
902        });
903
904        // Heartbeat every 1 second
905        let heartbeat = Some((1, b"ping".to_vec()));
906
907        let config = SocketConfig {
908            url: format!("127.0.0.1:{port}"),
909            mode: Mode::Plain,
910            suffix: b"\r\n".to_vec(),
911            handler: Arc::new(create_handler()),
912            heartbeat,
913            reconnect_timeout_ms: None,
914            reconnect_delay_initial_ms: None,
915            reconnect_backoff_factor: None,
916            reconnect_delay_max_ms: None,
917            reconnect_jitter_ms: None,
918            certs_dir: None,
919        };
920
921        let client = SocketClient::connect(config, None, None, None)
922            .await
923            .unwrap();
924
925        // Wait ~3 seconds to collect some heartbeats
926        sleep(Duration::from_secs(3)).await;
927
928        {
929            let lock = received.lock().await;
930            let pings = lock
931                .iter()
932                .filter(|line| line == &&b"ping".to_vec())
933                .count();
934            assert!(
935                pings >= 2,
936                "Expected at least 2 heartbeat pings; got {pings}"
937            );
938        }
939
940        client.close().await;
941        server_task.abort();
942    }
943
944    #[tokio::test]
945    async fn test_python_handler_error() {
946        prepare_freethreaded_python();
947
948        let (port, listener) = bind_test_server();
949        let server_task = task::spawn(async move {
950            let (socket, _) = tokio::net::TcpListener::from_std(listener)
951                .unwrap()
952                .accept()
953                .await
954                .unwrap();
955            run_echo_server(socket).await;
956        });
957
958        let code_raw = r#"
959def handler(bytes_data):
960    txt = bytes_data.decode()
961    if "ERR" in txt:
962        raise ValueError("Simulated error in handler")
963    return
964"#;
965        let code = CString::new(code_raw).unwrap();
966        let filename = CString::new("test".to_string()).unwrap();
967        let module = CString::new("test".to_string()).unwrap();
968
969        let handler = Python::with_gil(|py| {
970            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
971            let func = pymod.getattr("handler").unwrap();
972            Arc::new(func.into_py_any_unwrap(py))
973        });
974
975        let config = SocketConfig {
976            url: format!("127.0.0.1:{port}"),
977            mode: Mode::Plain,
978            suffix: b"\r\n".to_vec(),
979            handler,
980            heartbeat: None,
981            reconnect_timeout_ms: None,
982            reconnect_delay_initial_ms: None,
983            reconnect_backoff_factor: None,
984            reconnect_delay_max_ms: None,
985            reconnect_jitter_ms: None,
986            certs_dir: None,
987        };
988
989        let client = SocketClient::connect(config, None, None, None)
990            .await
991            .expect("Client connect failed unexpectedly");
992
993        client.send_bytes(b"hello").await.unwrap();
994        sleep(Duration::from_millis(100)).await;
995
996        client.send_bytes(b"ERR").await.unwrap();
997        sleep(Duration::from_secs(1)).await;
998
999        assert!(client.is_active());
1000
1001        client.close().await;
1002
1003        assert!(client.is_closed());
1004        server_task.abort();
1005    }
1006
1007    #[tokio::test]
1008    async fn test_reconnect_success() {
1009        prepare_freethreaded_python();
1010
1011        let (port, listener) = bind_test_server();
1012        let listener = tokio::net::TcpListener::from_std(listener).unwrap();
1013
1014        // Spawn a server task that:
1015        // 1. Accepts the first connection and then drops it after a short delay (simulate disconnect)
1016        // 2. Waits a bit and then accepts a new connection and runs the echo server
1017        let server_task = task::spawn(async move {
1018            // Accept first connection
1019            let (mut socket, _) = listener.accept().await.expect("First accept failed");
1020
1021            // Wait briefly and then force-close the connection
1022            sleep(Duration::from_millis(500)).await;
1023            let _ = socket.shutdown().await;
1024
1025            // Wait for the client's reconnect attempt
1026            sleep(Duration::from_millis(500)).await;
1027
1028            // Run the echo server on the new connection
1029            let (socket, _) = listener.accept().await.expect("Second accept failed");
1030            run_echo_server(socket).await;
1031        });
1032
1033        let config = SocketConfig {
1034            url: format!("127.0.0.1:{port}"),
1035            mode: Mode::Plain,
1036            suffix: b"\r\n".to_vec(),
1037            handler: Arc::new(create_handler()),
1038            heartbeat: None,
1039            reconnect_timeout_ms: Some(5_000),
1040            reconnect_delay_initial_ms: Some(500),
1041            reconnect_delay_max_ms: Some(5_000),
1042            reconnect_backoff_factor: Some(2.0),
1043            reconnect_jitter_ms: Some(50),
1044            certs_dir: None,
1045        };
1046
1047        let client = SocketClient::connect(config, None, None, None)
1048            .await
1049            .expect("Client connect failed unexpectedly");
1050
1051        // Initially, the client should be active
1052        assert!(client.is_active(), "Client should start as active");
1053
1054        // Wait until the client loses connection (i.e. not active),
1055        // then wait until it reconnects (active again).
1056        wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1057
1058        client
1059            .send_bytes(b"TestReconnect")
1060            .await
1061            .expect("Send failed");
1062
1063        client.close().await;
1064        server_task.abort();
1065    }
1066}