nautilus_network/
socket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance raw TCP client implementation with TLS capability, automatic reconnection
17//! with exponential backoff and state management.
18
19use std::{
20    path::Path,
21    sync::{
22        atomic::{AtomicU8, Ordering},
23        Arc,
24    },
25    time::Duration,
26};
27
28use nautilus_cryptography::providers::install_cryptographic_provider;
29use pyo3::prelude::*;
30use tokio::{
31    io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
32    net::TcpStream,
33    sync::Mutex,
34};
35use tokio_tungstenite::{
36    tungstenite::{client::IntoClientRequest, stream::Mode, Error},
37    MaybeTlsStream,
38};
39
40use crate::{
41    backoff::ExponentialBackoff,
42    mode::ConnectionMode,
43    tls::{create_tls_config_from_certs_dir, tcp_tls, Connector},
44};
45
46type TcpWriter = WriteHalf<MaybeTlsStream<TcpStream>>;
47type SharedTcpWriter = Arc<Mutex<WriteHalf<MaybeTlsStream<TcpStream>>>>;
48type TcpReader = ReadHalf<MaybeTlsStream<TcpStream>>;
49
50/// Configuration for TCP socket connection.
51#[derive(Debug, Clone)]
52#[cfg_attr(
53    feature = "python",
54    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
55)]
56pub struct SocketConfig {
57    /// The URL to connect to.
58    pub url: String,
59    /// The connection mode {Plain, TLS}.
60    pub mode: Mode,
61    /// The sequence of bytes which separates lines.
62    pub suffix: Vec<u8>,
63    /// The Python function to handle incoming messages.
64    pub handler: Arc<PyObject>,
65    /// The optional heartbeat with period and beat message.
66    pub heartbeat: Option<(u64, Vec<u8>)>,
67    /// The timeout (milliseconds) for reconnection attempts.
68    pub reconnect_timeout_ms: Option<u64>,
69    /// The initial reconnection delay (milliseconds) for reconnects.
70    pub reconnect_delay_initial_ms: Option<u64>,
71    /// The maximum reconnect delay (milliseconds) for exponential backoff.
72    pub reconnect_delay_max_ms: Option<u64>,
73    /// The exponential backoff factor for reconnection delays.
74    pub reconnect_backoff_factor: Option<f64>,
75    /// The maximum jitter (milliseconds) added to reconnection delays.
76    pub reconnect_jitter_ms: Option<u64>,
77    /// The path to the certificates directory.
78    pub certs_dir: Option<String>,
79}
80
81/// Creates a TcpStream with the server.
82///
83/// The stream can be encrypted with TLS or Plain. The stream is split into
84/// read and write ends:
85/// - The read end is passed to the task that keeps receiving
86///   messages from the server and passing them to a handler.
87/// - The write end is wrapped in an `Arc<Mutex>` and used to send messages
88///   or heart beats.
89///
90/// The heartbeat is optional and can be configured with an interval and data to
91/// send.
92///
93/// The client uses a suffix to separate messages on the byte stream. It is
94/// appended to all sent messages and heartbeats. It is also used to split
95/// the received byte stream.
96#[cfg_attr(
97    feature = "python",
98    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
99)]
100struct SocketClientInner {
101    config: SocketConfig,
102    connector: Option<Connector>,
103    read_task: Arc<tokio::task::JoinHandle<()>>,
104    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
105    writer: SharedTcpWriter,
106    connection_mode: Arc<AtomicU8>,
107    reconnect_timeout: Duration,
108    backoff: ExponentialBackoff,
109}
110
111impl SocketClientInner {
112    pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
113        install_cryptographic_provider();
114
115        let SocketConfig {
116            url,
117            mode,
118            heartbeat,
119            suffix,
120            handler,
121            reconnect_timeout_ms,
122            reconnect_delay_initial_ms,
123            reconnect_delay_max_ms,
124            reconnect_backoff_factor,
125            reconnect_jitter_ms,
126            certs_dir,
127        } = &config;
128        let connector = if let Some(dir) = certs_dir {
129            let config = create_tls_config_from_certs_dir(Path::new(dir))?;
130            Some(Connector::Rustls(Arc::new(config)))
131        } else {
132            None
133        };
134
135        let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector.clone()).await?;
136        let writer = Arc::new(Mutex::new(writer));
137
138        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
139
140        let handler = Python::with_gil(|py| handler.clone_ref(py));
141        let read_task = Arc::new(Self::spawn_read_task(reader, handler, suffix.clone()));
142
143        // Optionally spawn a heartbeat task to periodically ping server
144        let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
145            Self::spawn_heartbeat_task(
146                connection_mode.clone(),
147                heartbeat.clone(),
148                writer.clone(),
149                suffix.clone(),
150            )
151        });
152
153        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
154        let backoff = ExponentialBackoff::new(
155            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
156            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
157            reconnect_backoff_factor.unwrap_or(1.5),
158            reconnect_jitter_ms.unwrap_or(100),
159            true, // immediate-first
160        );
161
162        Ok(Self {
163            config,
164            connector,
165            read_task,
166            heartbeat_task,
167            writer,
168            connection_mode,
169            reconnect_timeout,
170            backoff,
171        })
172    }
173
174    pub async fn tls_connect_with_server(
175        url: &str,
176        mode: Mode,
177        connector: Option<Connector>,
178    ) -> Result<(TcpReader, TcpWriter), Error> {
179        tracing::debug!("Connecting to server");
180        let stream = TcpStream::connect(url).await?;
181        tracing::debug!("Making TLS connection");
182        let request = url.into_client_request()?;
183        tcp_tls(&request, mode, stream, connector)
184            .await
185            .map(tokio::io::split)
186    }
187
188    /// Reconnect with server.
189    ///
190    /// Make a new connection with server. Use the new read and write halves
191    /// to update the shared writer and the read and heartbeat tasks.
192    async fn reconnect(&mut self) -> Result<(), Error> {
193        tracing::debug!("Reconnecting");
194
195        tokio::time::timeout(self.reconnect_timeout, async {
196            // Clean up existing tasks
197            shutdown(
198                self.read_task.clone(),
199                self.heartbeat_task.take(),
200                self.writer.clone(),
201            )
202            .await;
203
204            let SocketConfig {
205                url,
206                mode,
207                heartbeat,
208                suffix,
209                handler,
210                reconnect_timeout_ms: _,
211                reconnect_delay_initial_ms: _,
212                reconnect_backoff_factor: _,
213                reconnect_delay_max_ms: _,
214                reconnect_jitter_ms: _,
215                certs_dir: _,
216            } = &self.config;
217            // Create a fresh connection
218            let connector = self.connector.clone();
219            let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
220            let writer = Arc::new(Mutex::new(writer));
221            self.writer = writer.clone();
222
223            // Spawn new read task
224            let handler_for_read = Python::with_gil(|py| handler.clone_ref(py));
225            self.read_task = Arc::new(Self::spawn_read_task(
226                reader,
227                handler_for_read,
228                suffix.clone(),
229            ));
230
231            // Optionally spawn new heartbeat task
232            self.heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
233                Self::spawn_heartbeat_task(
234                    self.connection_mode.clone(),
235                    heartbeat.clone(),
236                    writer.clone(),
237                    suffix.clone(),
238                )
239            });
240
241            self.connection_mode
242                .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
243
244            tracing::debug!("Reconnect succeeded");
245            Ok(())
246        })
247        .await
248        .map_err(|_| {
249            Error::Io(std::io::Error::new(
250                std::io::ErrorKind::TimedOut,
251                format!(
252                    "reconnection timed out after {}s",
253                    self.reconnect_timeout.as_secs_f64()
254                ),
255            ))
256        })?
257    }
258
259    /// Check if the client is still alive.
260    ///
261    /// The client is connected if the read task has not finished. It is expected
262    /// that in case of any failure client or server side. The read task will be
263    /// shutdown. There might be some delay between the connection being closed
264    /// and the client detecting it.
265    #[inline]
266    #[must_use]
267    pub fn is_alive(&self) -> bool {
268        !self.read_task.is_finished()
269    }
270
271    #[must_use]
272    fn spawn_read_task(
273        mut reader: TcpReader,
274        handler: PyObject,
275        suffix: Vec<u8>,
276    ) -> tokio::task::JoinHandle<()> {
277        tracing::debug!("Started task 'read'");
278
279        tokio::task::spawn(async move {
280            let mut buf = Vec::new();
281
282            loop {
283                match reader.read_buf(&mut buf).await {
284                    // Connection has been terminated or vector buffer is complete
285                    Ok(0) => {
286                        tracing::debug!("Connection closed by server");
287                        break;
288                    }
289                    Err(e) => {
290                        tracing::debug!("Connection ended: {e}");
291                        break;
292                    }
293                    // Received bytes of data
294                    Ok(bytes) => {
295                        tracing::trace!("Received <binary> {bytes} bytes");
296
297                        // While received data has a line break
298                        // drain it and pass it to the handler
299                        while let Some((i, _)) = &buf
300                            .windows(suffix.len())
301                            .enumerate()
302                            .find(|(_, pair)| pair.eq(&suffix))
303                        {
304                            let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
305                            data.truncate(data.len() - suffix.len());
306
307                            if let Err(e) =
308                                Python::with_gil(|py| handler.call1(py, (data.as_slice(),)))
309                            {
310                                tracing::error!("Call to handler failed: {e}");
311                                break;
312                            }
313                        }
314                    }
315                };
316            }
317
318            tracing::debug!("Completed task 'read'");
319        })
320    }
321
322    fn spawn_heartbeat_task(
323        connection_state: Arc<AtomicU8>,
324        heartbeat: (u64, Vec<u8>),
325        writer: SharedTcpWriter,
326        suffix: Vec<u8>,
327    ) -> tokio::task::JoinHandle<()> {
328        tracing::debug!("Started task 'heartbeat'");
329        let (interval_secs, mut message) = heartbeat;
330
331        tokio::task::spawn(async move {
332            let interval = Duration::from_secs(interval_secs);
333            message.extend(suffix);
334
335            loop {
336                tokio::time::sleep(interval).await;
337
338                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
339                    ConnectionMode::Active => {
340                        let mut guard = writer.lock().await;
341                        match guard.write_all(&message).await {
342                            Ok(()) => tracing::trace!("Sent heartbeat"),
343                            Err(e) => tracing::error!("Failed to send heartbeat: {e}"),
344                        }
345                    }
346                    ConnectionMode::Reconnect => continue,
347                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
348                }
349            }
350
351            tracing::debug!("Completed task 'heartbeat'");
352        })
353    }
354}
355
356/// Shutdown socket connection.
357///
358/// The client must be explicitly shutdown before dropping otherwise
359/// the connection might still be alive for some time before terminating.
360/// Closing the connection is an async call which cannot be done by the
361/// drop method so it must be done explicitly.
362async fn shutdown(
363    read_task: Arc<tokio::task::JoinHandle<()>>,
364    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
365    writer: SharedTcpWriter,
366) {
367    tracing::debug!("Shutting down inner client");
368
369    let timeout = Duration::from_secs(5);
370    if tokio::time::timeout(timeout, async {
371        // Final close of writer
372        let mut writer = writer.lock().await;
373        if let Err(e) = writer.shutdown().await {
374            tracing::error!("Error on shutdown: {e}");
375        }
376        drop(writer);
377
378        tokio::time::sleep(Duration::from_millis(100)).await;
379
380        // Abort tasks
381        if !read_task.is_finished() {
382            read_task.abort();
383            tracing::debug!("Aborted read task");
384        }
385        if let Some(task) = heartbeat_task {
386            if !task.is_finished() {
387                task.abort();
388                tracing::debug!("Aborted heartbeat task");
389            }
390        }
391    })
392    .await
393    .is_err()
394    {
395        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
396    }
397
398    tracing::debug!("Closed");
399}
400
401impl Drop for SocketClientInner {
402    fn drop(&mut self) {
403        if !self.read_task.is_finished() {
404            self.read_task.abort();
405        }
406
407        // Cancel heart beat task
408        if let Some(ref handle) = self.heartbeat_task.take() {
409            if !handle.is_finished() {
410                handle.abort();
411            }
412        }
413    }
414}
415
416#[cfg_attr(
417    feature = "python",
418    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
419)]
420pub struct SocketClient {
421    pub(crate) writer: SharedTcpWriter,
422    pub(crate) controller_task: tokio::task::JoinHandle<()>,
423    pub(crate) connection_mode: Arc<AtomicU8>,
424    pub(crate) suffix: Vec<u8>,
425}
426
427impl SocketClient {
428    pub async fn connect(
429        config: SocketConfig,
430        post_connection: Option<PyObject>,
431        post_reconnection: Option<PyObject>,
432        post_disconnection: Option<PyObject>,
433    ) -> anyhow::Result<Self> {
434        let suffix = config.suffix.clone();
435        let inner = SocketClientInner::connect_url(config).await?;
436        let writer = inner.writer.clone();
437        let connection_mode = inner.connection_mode.clone();
438
439        let controller_task = Self::spawn_controller_task(
440            inner,
441            connection_mode.clone(),
442            post_reconnection,
443            post_disconnection,
444        );
445
446        if let Some(handler) = post_connection {
447            Python::with_gil(|py| match handler.call0(py) {
448                Ok(_) => tracing::debug!("Called `post_connection` handler"),
449                Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
450            });
451        }
452
453        Ok(Self {
454            writer,
455            controller_task,
456            connection_mode,
457            suffix,
458        })
459    }
460
461    /// Returns the current connection mode.
462    pub fn connection_mode(&self) -> ConnectionMode {
463        ConnectionMode::from_atomic(&self.connection_mode)
464    }
465
466    /// Check if the client connection is active.
467    ///
468    /// Returns `true` if the client is connected and has not been signalled to disconnect.
469    /// The client will automatically retry connection based on its configuration.
470    #[inline]
471    #[must_use]
472    pub fn is_active(&self) -> bool {
473        self.connection_mode().is_active()
474    }
475
476    /// Check if the client is reconnecting.
477    ///
478    /// Returns `true` if the client lost connection and is attempting to reestablish it.
479    /// The client will automatically retry connection based on its configuration.
480    #[inline]
481    #[must_use]
482    pub fn is_reconnecting(&self) -> bool {
483        self.connection_mode().is_reconnect()
484    }
485
486    /// Check if the client is disconnecting.
487    ///
488    /// Returns `true` if the client is in disconnect mode.
489    #[inline]
490    #[must_use]
491    pub fn is_disconnecting(&self) -> bool {
492        self.connection_mode().is_disconnect()
493    }
494
495    /// Check if the client is closed.
496    ///
497    /// Returns `true` if the client has been explicitly disconnected or reached
498    /// maximum reconnection attempts. In this state, the client cannot be reused
499    /// and a new client must be created for further connections.
500    #[inline]
501    #[must_use]
502    pub fn is_closed(&self) -> bool {
503        self.connection_mode().is_closed()
504    }
505
506    /// Close the client.
507    ///
508    /// Controller task will periodically check the disconnect mode
509    /// and shutdown the client if it is not alive.
510    pub async fn close(&self) {
511        self.connection_mode
512            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
513
514        match tokio::time::timeout(Duration::from_secs(5), async {
515            while !self.is_closed() {
516                tokio::time::sleep(Duration::from_millis(10)).await;
517            }
518
519            if !self.controller_task.is_finished() {
520                self.controller_task.abort();
521                tracing::debug!("Aborted controller task");
522            }
523        })
524        .await
525        {
526            Ok(()) => {
527                tracing::debug!("Controller task finished");
528            }
529            Err(_) => {
530                tracing::error!("Timeout waiting for controller task to finish");
531            }
532        }
533    }
534
535    pub async fn send_bytes(&self, data: &[u8]) -> Result<(), std::io::Error> {
536        if self.is_closed() {
537            return Err(std::io::Error::new(
538                std::io::ErrorKind::NotConnected,
539                "Not connected",
540            ));
541        }
542
543        let timeout = Duration::from_secs(2);
544        let check_interval = Duration::from_millis(1);
545
546        if !self.is_active() {
547            tracing::debug!("Waiting for client to become ACTIVE before sending (2s)...");
548            match tokio::time::timeout(timeout, async {
549                while !self.is_active() {
550                    if matches!(
551                        self.connection_mode(),
552                        ConnectionMode::Disconnect | ConnectionMode::Closed
553                    ) {
554                        return Err("Client disconnected waiting to send");
555                    }
556
557                    tokio::time::sleep(check_interval).await;
558                }
559
560                Ok(())
561            })
562            .await
563            {
564                Ok(Ok(())) => tracing::debug!("Client now active"),
565                Ok(Err(e)) => {
566                    tracing::error!("Cannot send data ({}): {e}", String::from_utf8_lossy(data));
567                    return Ok(());
568                }
569                Err(_) => {
570                    tracing::error!(
571                        "Cannot send data ({}): timeout waiting to become ACTIVE",
572                        String::from_utf8_lossy(data)
573                    );
574                    return Ok(());
575                }
576            }
577        }
578
579        let mut writer = self.writer.lock().await;
580        writer.write_all(data).await?;
581        writer.write_all(&self.suffix).await
582    }
583
584    fn spawn_controller_task(
585        mut inner: SocketClientInner,
586        connection_mode: Arc<AtomicU8>,
587        post_reconnection: Option<PyObject>,
588        post_disconnection: Option<PyObject>,
589    ) -> tokio::task::JoinHandle<()> {
590        tokio::task::spawn(async move {
591            tracing::debug!("Starting task 'controller'");
592
593            let check_interval = Duration::from_millis(10);
594
595            loop {
596                tokio::time::sleep(check_interval).await;
597                let mode = ConnectionMode::from_atomic(&connection_mode);
598
599                if mode.is_disconnect() {
600                    tracing::debug!("Disconnecting");
601                    shutdown(
602                        inner.read_task.clone(),
603                        inner.heartbeat_task.take(),
604                        inner.writer.clone(),
605                    )
606                    .await;
607
608                    if let Some(ref handler) = post_disconnection {
609                        Python::with_gil(|py| match handler.call0(py) {
610                            Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
611                            Err(e) => {
612                                tracing::error!("Error calling `post_disconnection` handler: {e}")
613                            }
614                        });
615                    }
616                    break; // Controller finished
617                }
618
619                if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
620                    match inner.reconnect().await {
621                        Ok(()) => {
622                            tracing::debug!("Reconnected successfully");
623                            inner.backoff.reset();
624
625                            if let Some(ref handler) = post_reconnection {
626                                Python::with_gil(|py| match handler.call0(py) {
627                                    Ok(_) => tracing::debug!("Called `post_reconnection` handler"),
628                                    Err(e) => tracing::error!(
629                                        "Error calling `post_reconnection` handler: {e}"
630                                    ),
631                                });
632                            }
633                        }
634                        Err(e) => {
635                            let duration = inner.backoff.next_duration();
636                            tracing::warn!("Reconnect attempt failed: {e}",);
637                            if !duration.is_zero() {
638                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
639                            }
640                            tokio::time::sleep(duration).await;
641                        }
642                    }
643                }
644            }
645            inner
646                .connection_mode
647                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
648        })
649    }
650}
651
652////////////////////////////////////////////////////////////////////////////////
653// Tests
654////////////////////////////////////////////////////////////////////////////////
655#[cfg(test)]
656#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
657mod tests {
658    use std::{ffi::CString, net::TcpListener};
659
660    use nautilus_common::testing::wait_until_async;
661    use pyo3::prepare_freethreaded_python;
662    use tokio::{
663        io::{AsyncReadExt, AsyncWriteExt},
664        net::TcpStream,
665        task,
666        time::{sleep, Duration},
667    };
668
669    use super::*;
670
671    fn create_handler() -> PyObject {
672        let code_raw = r#"
673class Counter:
674    def __init__(self):
675        self.count = 0
676        self.check = False
677
678    def handler(self, bytes):
679        msg = bytes.decode()
680        if msg == 'ping':
681            self.count += 1
682        elif msg == 'heartbeat message':
683            self.check = True
684
685    def get_check(self):
686        return self.check
687
688    def get_count(self):
689        return self.count
690
691counter = Counter()
692"#;
693        let code = CString::new(code_raw).unwrap();
694        let filename = CString::new("test".to_string()).unwrap();
695        let module = CString::new("test".to_string()).unwrap();
696        Python::with_gil(|py| {
697            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
698            let counter = pymod.getattr("counter").unwrap().into_py(py);
699            let handler = counter.getattr(py, "handler").unwrap().into_py(py);
700            handler
701        })
702    }
703
704    fn bind_test_server() -> (u16, TcpListener) {
705        let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind ephemeral port");
706        let port = listener.local_addr().unwrap().port();
707        (port, listener)
708    }
709
710    async fn run_echo_server(mut socket: TcpStream) {
711        let mut buf = Vec::new();
712        loop {
713            match socket.read_buf(&mut buf).await {
714                Ok(0) => {
715                    break;
716                }
717                Ok(_n) => {
718                    while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
719                        let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
720                        // Remove trailing \r\n
721                        line.truncate(line.len() - 2);
722
723                        if line == b"close" {
724                            let _ = socket.shutdown().await;
725                            return;
726                        }
727
728                        let mut echo_data = line;
729                        echo_data.extend_from_slice(b"\r\n");
730                        if socket.write_all(&echo_data).await.is_err() {
731                            break;
732                        }
733                    }
734                }
735                Err(e) => {
736                    eprintln!("Server read error: {e}");
737                    break;
738                }
739            }
740        }
741    }
742
743    #[tokio::test]
744    async fn test_basic_send_receive() {
745        prepare_freethreaded_python();
746
747        let (port, listener) = bind_test_server();
748        let server_task = task::spawn(async move {
749            let (socket, _) = tokio::net::TcpListener::from_std(listener)
750                .unwrap()
751                .accept()
752                .await
753                .unwrap();
754            run_echo_server(socket).await;
755        });
756
757        let config = SocketConfig {
758            url: format!("127.0.0.1:{port}"),
759            mode: Mode::Plain,
760            suffix: b"\r\n".to_vec(),
761            handler: Arc::new(create_handler()),
762            heartbeat: None,
763            reconnect_timeout_ms: None,
764            reconnect_delay_initial_ms: None,
765            reconnect_backoff_factor: None,
766            reconnect_delay_max_ms: None,
767            reconnect_jitter_ms: None,
768            certs_dir: None,
769        };
770
771        let client = SocketClient::connect(config, None, None, None)
772            .await
773            .expect("Client connect failed unexpectedly");
774
775        client.send_bytes(b"Hello").await.unwrap();
776        client.send_bytes(b"World").await.unwrap();
777
778        // Wait a bit for the server to echo them back
779        sleep(Duration::from_millis(100)).await;
780
781        client.send_bytes(b"close").await.unwrap();
782        server_task.await.unwrap();
783        assert!(!client.is_closed());
784    }
785
786    #[tokio::test]
787    async fn test_reconnect_fail_exhausted() {
788        prepare_freethreaded_python();
789
790        let (port, listener) = bind_test_server();
791        drop(listener); // We drop it immediately -> no server is listening
792
793        let config = SocketConfig {
794            url: format!("127.0.0.1:{port}"),
795            mode: Mode::Plain,
796            suffix: b"\r\n".to_vec(),
797            handler: Arc::new(create_handler()),
798            heartbeat: None,
799            reconnect_timeout_ms: None,
800            reconnect_delay_initial_ms: None,
801            reconnect_backoff_factor: None,
802            reconnect_delay_max_ms: None,
803            reconnect_jitter_ms: None,
804            certs_dir: None,
805        };
806
807        let client_res = SocketClient::connect(config, None, None, None).await;
808        assert!(
809            client_res.is_err(),
810            "Should fail quickly with no server listening"
811        );
812    }
813
814    #[tokio::test]
815    async fn test_user_disconnect() {
816        prepare_freethreaded_python();
817
818        let (port, listener) = bind_test_server();
819        let server_task = task::spawn(async move {
820            let (socket, _) = tokio::net::TcpListener::from_std(listener)
821                .unwrap()
822                .accept()
823                .await
824                .unwrap();
825            let mut buf = [0u8; 1024];
826            let _ = socket.try_read(&mut buf);
827
828            loop {
829                sleep(Duration::from_secs(1)).await;
830            }
831        });
832
833        let config = SocketConfig {
834            url: format!("127.0.0.1:{port}"),
835            mode: Mode::Plain,
836            suffix: b"\r\n".to_vec(),
837            handler: Arc::new(create_handler()),
838            heartbeat: None,
839            reconnect_timeout_ms: None,
840            reconnect_delay_initial_ms: None,
841            reconnect_backoff_factor: None,
842            reconnect_delay_max_ms: None,
843            reconnect_jitter_ms: None,
844            certs_dir: None,
845        };
846
847        let client = SocketClient::connect(config, None, None, None)
848            .await
849            .unwrap();
850
851        client.close().await;
852        assert!(client.is_closed());
853        server_task.abort();
854    }
855
856    #[tokio::test]
857    async fn test_heartbeat() {
858        prepare_freethreaded_python();
859
860        let (port, listener) = bind_test_server();
861        let received = Arc::new(Mutex::new(Vec::new()));
862        let received2 = received.clone();
863
864        let server_task = task::spawn(async move {
865            let (socket, _) = tokio::net::TcpListener::from_std(listener)
866                .unwrap()
867                .accept()
868                .await
869                .unwrap();
870
871            let mut buf = Vec::new();
872            loop {
873                match socket.try_read_buf(&mut buf) {
874                    Ok(0) => break,
875                    Ok(_) => {
876                        while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
877                            let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
878                            line.truncate(line.len() - 2);
879                            received2.lock().await.push(line);
880                        }
881                    }
882                    Err(_) => {
883                        tokio::time::sleep(Duration::from_millis(10)).await;
884                    }
885                }
886            }
887        });
888
889        // Heartbeat every 1 second
890        let heartbeat = Some((1, b"ping".to_vec()));
891
892        let config = SocketConfig {
893            url: format!("127.0.0.1:{port}"),
894            mode: Mode::Plain,
895            suffix: b"\r\n".to_vec(),
896            handler: Arc::new(create_handler().into()),
897            heartbeat,
898            reconnect_timeout_ms: None,
899            reconnect_delay_initial_ms: None,
900            reconnect_backoff_factor: None,
901            reconnect_delay_max_ms: None,
902            reconnect_jitter_ms: None,
903            certs_dir: None,
904        };
905
906        let client = SocketClient::connect(config, None, None, None)
907            .await
908            .unwrap();
909
910        // Wait ~3 seconds to collect some heartbeats
911        sleep(Duration::from_secs(3)).await;
912
913        {
914            let lock = received.lock().await;
915            let pings = lock
916                .iter()
917                .filter(|line| line == &&b"ping".to_vec())
918                .count();
919            assert!(
920                pings >= 2,
921                "Expected at least 2 heartbeat pings; got {pings}"
922            );
923        }
924
925        client.close().await;
926        server_task.abort();
927    }
928
929    #[tokio::test]
930    async fn test_python_handler_error() {
931        prepare_freethreaded_python();
932
933        let (port, listener) = bind_test_server();
934        let server_task = task::spawn(async move {
935            let (socket, _) = tokio::net::TcpListener::from_std(listener)
936                .unwrap()
937                .accept()
938                .await
939                .unwrap();
940            run_echo_server(socket).await;
941        });
942
943        let code_raw = r#"
944def handler(bytes_data):
945    txt = bytes_data.decode()
946    if "ERR" in txt:
947        raise ValueError("Simulated error in handler")
948    return
949"#;
950        let code = CString::new(code_raw).unwrap();
951        let filename = CString::new("test".to_string()).unwrap();
952        let module = CString::new("test".to_string()).unwrap();
953
954        let handler = Python::with_gil(|py| {
955            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
956            let func = pymod.getattr("handler").unwrap();
957            Arc::new(func.into_py(py))
958        });
959
960        let config = SocketConfig {
961            url: format!("127.0.0.1:{port}"),
962            mode: Mode::Plain,
963            suffix: b"\r\n".to_vec(),
964            handler,
965            heartbeat: None,
966            reconnect_timeout_ms: None,
967            reconnect_delay_initial_ms: None,
968            reconnect_backoff_factor: None,
969            reconnect_delay_max_ms: None,
970            reconnect_jitter_ms: None,
971            certs_dir: None,
972        };
973
974        let client = SocketClient::connect(config, None, None, None)
975            .await
976            .expect("Client connect failed unexpectedly");
977
978        client.send_bytes(b"hello").await.unwrap();
979        sleep(Duration::from_millis(100)).await;
980
981        client.send_bytes(b"ERR").await.unwrap();
982        sleep(Duration::from_secs(1)).await;
983
984        assert!(client.is_active());
985
986        client.close().await;
987
988        assert!(client.is_closed());
989        server_task.abort();
990    }
991
992    #[tokio::test]
993    async fn test_reconnect_success() {
994        prepare_freethreaded_python();
995
996        let (port, listener) = bind_test_server();
997        let listener = tokio::net::TcpListener::from_std(listener).unwrap();
998
999        // Spawn a server task that:
1000        // 1. Accepts the first connection and then drops it after a short delay (simulate disconnect)
1001        // 2. Waits a bit and then accepts a new connection and runs the echo server
1002        let server_task = task::spawn(async move {
1003            // Accept first connection
1004            let (mut socket, _) = listener.accept().await.expect("First accept failed");
1005
1006            // Wait briefly and then force-close the connection
1007            sleep(Duration::from_millis(500)).await;
1008            let _ = socket.shutdown().await;
1009
1010            // Wait for the client's reconnect attempt
1011            sleep(Duration::from_millis(500)).await;
1012
1013            // Run the echo server on the new connection
1014            let (socket, _) = listener.accept().await.expect("Second accept failed");
1015            run_echo_server(socket).await;
1016        });
1017
1018        let config = SocketConfig {
1019            url: format!("127.0.0.1:{port}"),
1020            mode: Mode::Plain,
1021            suffix: b"\r\n".to_vec(),
1022            handler: Arc::new(create_handler()),
1023            heartbeat: None,
1024            reconnect_timeout_ms: Some(5_000),
1025            reconnect_delay_initial_ms: Some(500),
1026            reconnect_delay_max_ms: Some(5_000),
1027            reconnect_backoff_factor: Some(2.0),
1028            reconnect_jitter_ms: Some(50),
1029            certs_dir: None,
1030        };
1031
1032        let client = SocketClient::connect(config, None, None, None)
1033            .await
1034            .expect("Client connect failed unexpectedly");
1035
1036        // Initially, the client should be active
1037        assert!(client.is_active(), "Client should start as active");
1038
1039        // Wait until the client loses connection (i.e. not active),
1040        // then wait until it reconnects (active again).
1041        wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1042
1043        client
1044            .send_bytes(b"TestReconnect")
1045            .await
1046            .expect("Send failed");
1047
1048        client.close().await;
1049        server_task.abort();
1050    }
1051}