nautilus_network/
socket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance raw TCP client implementation with TLS capability, automatic reconnection
17//! with exponential backoff and state management.
18
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED)
21//! - Synchronized reconnection with backoff
22//! - Split read/write architecture
23//! - Python callback integration
24//!
25//! **Design**:
26//! - Single reader, multiple writer model
27//! - Read half runs in dedicated task
28//! - Write half runs in dedicated task connected with channel
29//! - Controller task manages lifecycle
30
31use std::{
32    path::Path,
33    sync::{
34        Arc,
35        atomic::{AtomicU8, Ordering},
36    },
37    time::Duration,
38};
39
40use bytes::Bytes;
41use nautilus_cryptography::providers::install_cryptographic_provider;
42use pyo3::prelude::*;
43use tokio::{
44    io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
45    net::TcpStream,
46};
47use tokio_tungstenite::{
48    MaybeTlsStream,
49    tungstenite::{Error, client::IntoClientRequest, stream::Mode},
50};
51
52use crate::{
53    backoff::ExponentialBackoff,
54    fix::process_fix_buffer,
55    mode::ConnectionMode,
56    tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
57};
58
59type TcpWriter = WriteHalf<MaybeTlsStream<TcpStream>>;
60type TcpReader = ReadHalf<MaybeTlsStream<TcpStream>>;
61pub type TcpMessageHandler = dyn Fn(&[u8]) + Send + Sync;
62
63/// Configuration for TCP socket connection.
64#[derive(Debug, Clone)]
65#[cfg_attr(
66    feature = "python",
67    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
68)]
69pub struct SocketConfig {
70    /// The URL to connect to.
71    pub url: String,
72    /// The connection mode {Plain, TLS}.
73    pub mode: Mode,
74    /// The sequence of bytes which separates lines.
75    pub suffix: Vec<u8>,
76    /// The optional Python function to handle incoming messages.
77    pub py_handler: Option<Arc<PyObject>>,
78    /// The optional heartbeat with period and beat message.
79    pub heartbeat: Option<(u64, Vec<u8>)>,
80    /// The timeout (milliseconds) for reconnection attempts.
81    pub reconnect_timeout_ms: Option<u64>,
82    /// The initial reconnection delay (milliseconds) for reconnects.
83    pub reconnect_delay_initial_ms: Option<u64>,
84    /// The maximum reconnect delay (milliseconds) for exponential backoff.
85    pub reconnect_delay_max_ms: Option<u64>,
86    /// The exponential backoff factor for reconnection delays.
87    pub reconnect_backoff_factor: Option<f64>,
88    /// The maximum jitter (milliseconds) added to reconnection delays.
89    pub reconnect_jitter_ms: Option<u64>,
90    /// The path to the certificates directory.
91    pub certs_dir: Option<String>,
92}
93
94/// Represents a command for the writer task.
95#[derive(Debug)]
96pub enum WriterCommand {
97    /// Update the writer reference with a new one after reconnection.
98    Update(TcpWriter),
99    /// Send data to the server.
100    Send(Bytes),
101}
102
103/// Creates a TcpStream with the server.
104///
105/// The stream can be encrypted with TLS or Plain. The stream is split into
106/// read and write ends:
107/// - The read end is passed to the task that keeps receiving
108///   messages from the server and passing them to a handler.
109/// - The write end is passed to a task which receives messages over a channel
110///   to send to the server.
111///
112/// The heartbeat is optional and can be configured with an interval and data to
113/// send.
114///
115/// The client uses a suffix to separate messages on the byte stream. It is
116/// appended to all sent messages and heartbeats. It is also used to split
117/// the received byte stream.
118#[cfg_attr(
119    feature = "python",
120    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
121)]
122struct SocketClientInner {
123    config: SocketConfig,
124    connector: Option<Connector>,
125    read_task: Arc<tokio::task::JoinHandle<()>>,
126    write_task: tokio::task::JoinHandle<()>,
127    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
128    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
129    connection_mode: Arc<AtomicU8>,
130    reconnect_timeout: Duration,
131    backoff: ExponentialBackoff,
132    handler: Option<Arc<TcpMessageHandler>>,
133}
134
135impl SocketClientInner {
136    pub async fn connect_url(
137        config: SocketConfig,
138        handler: Option<Arc<TcpMessageHandler>>,
139    ) -> anyhow::Result<Self> {
140        install_cryptographic_provider();
141
142        let SocketConfig {
143            url,
144            mode,
145            heartbeat,
146            suffix,
147            py_handler,
148            reconnect_timeout_ms,
149            reconnect_delay_initial_ms,
150            reconnect_delay_max_ms,
151            reconnect_backoff_factor,
152            reconnect_jitter_ms,
153            certs_dir,
154        } = &config;
155        let connector = if let Some(dir) = certs_dir {
156            let config = create_tls_config_from_certs_dir(Path::new(dir))?;
157            Some(Connector::Rustls(Arc::new(config)))
158        } else {
159            None
160        };
161
162        let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector.clone()).await?;
163        tracing::debug!("Connected");
164
165        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
166
167        let read_task = Arc::new(Self::spawn_read_task(
168            connection_mode.clone(),
169            reader,
170            handler.clone(),
171            py_handler.clone(),
172            suffix.clone(),
173        ));
174
175        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
176
177        let write_task =
178            Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
179
180        // Optionally spawn a heartbeat task to periodically ping server
181        let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
182            Self::spawn_heartbeat_task(
183                connection_mode.clone(),
184                heartbeat.clone(),
185                writer_tx.clone(),
186            )
187        });
188
189        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
190        let backoff = ExponentialBackoff::new(
191            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
192            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
193            reconnect_backoff_factor.unwrap_or(1.5),
194            reconnect_jitter_ms.unwrap_or(100),
195            true, // immediate-first
196        );
197
198        Ok(Self {
199            config,
200            connector,
201            read_task,
202            write_task,
203            writer_tx,
204            heartbeat_task,
205            connection_mode,
206            reconnect_timeout,
207            backoff,
208            handler,
209        })
210    }
211
212    pub async fn tls_connect_with_server(
213        url: &str,
214        mode: Mode,
215        connector: Option<Connector>,
216    ) -> Result<(TcpReader, TcpWriter), Error> {
217        tracing::debug!("Connecting to {url}");
218        let tcp_result = TcpStream::connect(url).await;
219
220        match tcp_result {
221            Ok(stream) => {
222                tracing::debug!("TCP connection established, proceeding with TLS");
223                let request = url.into_client_request()?;
224                tcp_tls(&request, mode, stream, connector)
225                    .await
226                    .map(tokio::io::split)
227            }
228            Err(e) => {
229                tracing::error!("TCP connection failed: {e:?}");
230                Err(Error::Io(e))
231            }
232        }
233    }
234
235    /// Reconnect with server.
236    ///
237    /// Makes a new connection with server, uses the new read and write halves
238    /// to update the reader and writer.
239    async fn reconnect(&mut self) -> Result<(), Error> {
240        tracing::debug!("Reconnecting");
241
242        tokio::time::timeout(self.reconnect_timeout, async {
243            let SocketConfig {
244                url,
245                mode,
246                heartbeat: _,
247                suffix,
248                py_handler,
249                reconnect_timeout_ms: _,
250                reconnect_delay_initial_ms: _,
251                reconnect_backoff_factor: _,
252                reconnect_delay_max_ms: _,
253                reconnect_jitter_ms: _,
254                certs_dir: _,
255            } = &self.config;
256            // Create a fresh connection
257            let connector = self.connector.clone();
258            let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
259            tracing::debug!("Connected");
260
261            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
262                tracing::error!("{e}");
263            }
264
265            // Delay before closing connection
266            tokio::time::sleep(Duration::from_millis(100)).await;
267
268            if !self.read_task.is_finished() {
269                self.read_task.abort();
270                tracing::debug!("Aborted task 'read'");
271            }
272
273            self.connection_mode
274                .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
275
276            // Spawn new read task
277            self.read_task = Arc::new(Self::spawn_read_task(
278                self.connection_mode.clone(),
279                reader,
280                self.handler.clone(),
281                py_handler.clone(),
282                suffix.clone(),
283            ));
284
285            tracing::debug!("Reconnect succeeded");
286            Ok(())
287        })
288        .await
289        .map_err(|_| {
290            Error::Io(std::io::Error::new(
291                std::io::ErrorKind::TimedOut,
292                format!(
293                    "reconnection timed out after {}s",
294                    self.reconnect_timeout.as_secs_f64()
295                ),
296            ))
297        })?
298    }
299
300    /// Check if the client is still alive.
301    ///
302    /// The client is connected if the read task has not finished. It is expected
303    /// that in case of any failure client or server side. The read task will be
304    /// shutdown. There might be some delay between the connection being closed
305    /// and the client detecting it.
306    #[inline]
307    #[must_use]
308    pub fn is_alive(&self) -> bool {
309        !self.read_task.is_finished()
310    }
311
312    #[must_use]
313    fn spawn_read_task(
314        connection_state: Arc<AtomicU8>,
315        mut reader: TcpReader,
316        handler: Option<Arc<TcpMessageHandler>>,
317        py_handler: Option<Arc<PyObject>>,
318        suffix: Vec<u8>,
319    ) -> tokio::task::JoinHandle<()> {
320        tracing::debug!("Started task 'read'");
321
322        // Interval between checking the connection mode
323        let check_interval = Duration::from_millis(10);
324
325        tokio::task::spawn(async move {
326            let mut buf = Vec::new();
327
328            loop {
329                if !ConnectionMode::from_atomic(&connection_state).is_active() {
330                    break;
331                }
332
333                match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
334                    // Connection has been terminated or vector buffer is complete
335                    Ok(Ok(0)) => {
336                        tracing::debug!("Connection closed by server");
337                        break;
338                    }
339                    Ok(Err(e)) => {
340                        tracing::debug!("Connection ended: {e}");
341                        break;
342                    }
343                    // Received bytes of data
344                    Ok(Ok(bytes)) => {
345                        tracing::trace!("Received <binary> {bytes} bytes");
346
347                        if let Some(handler) = &handler {
348                            process_fix_buffer(&mut buf, handler);
349                        } else {
350                            while let Some((i, _)) = &buf
351                                .windows(suffix.len())
352                                .enumerate()
353                                .find(|(_, pair)| pair.eq(&suffix))
354                            {
355                                let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
356                                data.truncate(data.len() - suffix.len());
357
358                                if let Some(handler) = &handler {
359                                    handler(&data);
360                                }
361
362                                if let Some(py_handler) = &py_handler {
363                                    if let Err(e) = Python::with_gil(|py| {
364                                        py_handler.call1(py, (data.as_slice(),))
365                                    }) {
366                                        tracing::error!("Call to handler failed: {e}");
367                                        break;
368                                    }
369                                }
370                            }
371                        }
372                    }
373                    Err(_) => {
374                        // Timeout - continue loop and check connection mode
375                        continue;
376                    }
377                }
378            }
379
380            tracing::debug!("Completed task 'read'");
381        })
382    }
383
384    fn spawn_write_task(
385        connection_state: Arc<AtomicU8>,
386        writer: TcpWriter,
387        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
388        suffix: Vec<u8>,
389    ) -> tokio::task::JoinHandle<()> {
390        tracing::debug!("Started task 'write'");
391
392        // Interval between checking the connection mode
393        let check_interval = Duration::from_millis(10);
394
395        tokio::task::spawn(async move {
396            let mut active_writer = writer;
397
398            loop {
399                if matches!(
400                    ConnectionMode::from_atomic(&connection_state),
401                    ConnectionMode::Disconnect | ConnectionMode::Closed
402                ) {
403                    break;
404                }
405
406                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
407                    Ok(Some(msg)) => {
408                        // Re-check connection mode after receiving a message
409                        let mode = ConnectionMode::from_atomic(&connection_state);
410                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
411                            break;
412                        }
413
414                        match msg {
415                            WriterCommand::Update(new_writer) => {
416                                tracing::debug!("Received new writer");
417
418                                // Delay before closing connection
419                                tokio::time::sleep(Duration::from_millis(100)).await;
420
421                                // Attempt to shutdown the writer gracefully before updating,
422                                // we ignore any error as the writer may already be closed.
423                                _ = active_writer.shutdown().await;
424
425                                active_writer = new_writer;
426                                tracing::debug!("Updated writer");
427                            }
428                            _ if mode.is_reconnect() => {
429                                tracing::warn!("Skipping message while reconnecting, {msg:?}");
430                                continue;
431                            }
432                            WriterCommand::Send(msg) => {
433                                if let Err(e) = active_writer.write_all(&msg).await {
434                                    tracing::error!("Failed to send message: {e}");
435                                    // Mode is active so trigger reconnection
436                                    tracing::warn!("Writer triggering reconnect");
437                                    connection_state
438                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
439                                    continue;
440                                }
441                                if let Err(e) = active_writer.write_all(&suffix).await {
442                                    tracing::error!("Failed to send message: {e}");
443                                }
444                            }
445                        }
446                    }
447                    Ok(None) => {
448                        // Channel closed - writer task should terminate
449                        tracing::debug!("Writer channel closed, terminating writer task");
450                        break;
451                    }
452                    Err(_) => {
453                        // Timeout - just continue the loop
454                        continue;
455                    }
456                }
457            }
458
459            // Attempt to shutdown the writer gracefully before exiting,
460            // we ignore any error as the writer may already be closed.
461            _ = active_writer.shutdown().await;
462
463            tracing::debug!("Completed task 'write'");
464        })
465    }
466
467    fn spawn_heartbeat_task(
468        connection_state: Arc<AtomicU8>,
469        heartbeat: (u64, Vec<u8>),
470        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
471    ) -> tokio::task::JoinHandle<()> {
472        tracing::debug!("Started task 'heartbeat'");
473        let (interval_secs, message) = heartbeat;
474
475        tokio::task::spawn(async move {
476            let interval = Duration::from_secs(interval_secs);
477
478            loop {
479                tokio::time::sleep(interval).await;
480
481                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
482                    ConnectionMode::Active => {
483                        let msg = WriterCommand::Send(message.clone().into());
484
485                        match writer_tx.send(msg) {
486                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
487                            Err(e) => {
488                                tracing::error!("Failed to send heartbeat to writer task: {e}");
489                            }
490                        }
491                    }
492                    ConnectionMode::Reconnect => continue,
493                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
494                }
495            }
496
497            tracing::debug!("Completed task 'heartbeat'");
498        })
499    }
500}
501
502impl Drop for SocketClientInner {
503    fn drop(&mut self) {
504        if !self.read_task.is_finished() {
505            self.read_task.abort();
506            tracing::debug!("Aborted task 'read'");
507        }
508
509        if !self.write_task.is_finished() {
510            self.write_task.abort();
511            tracing::debug!("Aborted task 'write'");
512        }
513
514        if let Some(ref handle) = self.heartbeat_task.take() {
515            if !handle.is_finished() {
516                handle.abort();
517                tracing::debug!("Aborted task 'heartbeat'");
518            }
519        }
520    }
521}
522
523#[cfg_attr(
524    feature = "python",
525    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
526)]
527pub struct SocketClient {
528    pub(crate) controller_task: tokio::task::JoinHandle<()>,
529    pub(crate) connection_mode: Arc<AtomicU8>,
530    pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
531}
532
533impl SocketClient {
534    /// Connect to the server.
535    ///
536    /// # Errors
537    ///
538    /// Returns any error connecting to the server.
539    pub async fn connect(
540        config: SocketConfig,
541        handler: Option<Arc<TcpMessageHandler>>,
542        post_connection: Option<PyObject>,
543        post_reconnection: Option<PyObject>,
544        post_disconnection: Option<PyObject>,
545    ) -> anyhow::Result<Self> {
546        let inner = SocketClientInner::connect_url(config, handler).await?;
547        let writer_tx = inner.writer_tx.clone();
548        let connection_mode = inner.connection_mode.clone();
549
550        let controller_task = Self::spawn_controller_task(
551            inner,
552            connection_mode.clone(),
553            post_reconnection,
554            post_disconnection,
555        );
556
557        if let Some(handler) = post_connection {
558            Python::with_gil(|py| match handler.call0(py) {
559                Ok(_) => tracing::debug!("Called `post_connection` handler"),
560                Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
561            });
562        }
563
564        Ok(Self {
565            controller_task,
566            connection_mode,
567            writer_tx,
568        })
569    }
570
571    /// Returns the current connection mode.
572    #[must_use]
573    pub fn connection_mode(&self) -> ConnectionMode {
574        ConnectionMode::from_atomic(&self.connection_mode)
575    }
576
577    /// Check if the client connection is active.
578    ///
579    /// Returns `true` if the client is connected and has not been signalled to disconnect.
580    /// The client will automatically retry connection based on its configuration.
581    #[inline]
582    #[must_use]
583    pub fn is_active(&self) -> bool {
584        self.connection_mode().is_active()
585    }
586
587    /// Check if the client is reconnecting.
588    ///
589    /// Returns `true` if the client lost connection and is attempting to reestablish it.
590    /// The client will automatically retry connection based on its configuration.
591    #[inline]
592    #[must_use]
593    pub fn is_reconnecting(&self) -> bool {
594        self.connection_mode().is_reconnect()
595    }
596
597    /// Check if the client is disconnecting.
598    ///
599    /// Returns `true` if the client is in disconnect mode.
600    #[inline]
601    #[must_use]
602    pub fn is_disconnecting(&self) -> bool {
603        self.connection_mode().is_disconnect()
604    }
605
606    /// Check if the client is closed.
607    ///
608    /// Returns `true` if the client has been explicitly disconnected or reached
609    /// maximum reconnection attempts. In this state, the client cannot be reused
610    /// and a new client must be created for further connections.
611    #[inline]
612    #[must_use]
613    pub fn is_closed(&self) -> bool {
614        self.connection_mode().is_closed()
615    }
616
617    /// Close the client.
618    ///
619    /// Controller task will periodically check the disconnect mode
620    /// and shutdown the client if it is not alive.
621    pub async fn close(&self) {
622        self.connection_mode
623            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
624
625        match tokio::time::timeout(Duration::from_secs(5), async {
626            while !self.is_closed() {
627                tokio::time::sleep(Duration::from_millis(10)).await;
628            }
629
630            if !self.controller_task.is_finished() {
631                self.controller_task.abort();
632                tracing::debug!("Aborted controller task");
633            }
634        })
635        .await
636        {
637            Ok(()) => {
638                tracing::debug!("Controller task finished");
639            }
640            Err(_) => {
641                tracing::error!("Timeout waiting for controller task to finish");
642            }
643        }
644    }
645
646    /// Sends a message of the given `data`.
647    ///
648    /// # Errors
649    ///
650    /// Returns any I/O error.
651    pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), std::io::Error> {
652        if self.is_closed() {
653            return Err(std::io::Error::new(
654                std::io::ErrorKind::NotConnected,
655                "Not connected",
656            ));
657        }
658
659        let timeout = Duration::from_secs(2);
660        let check_interval = Duration::from_millis(1);
661
662        if !self.is_active() {
663            tracing::debug!("Waiting for client to become ACTIVE before sending (2s)...");
664            match tokio::time::timeout(timeout, async {
665                while !self.is_active() {
666                    if matches!(
667                        self.connection_mode(),
668                        ConnectionMode::Disconnect | ConnectionMode::Closed
669                    ) {
670                        return Err("Client disconnected waiting to send");
671                    }
672
673                    tokio::time::sleep(check_interval).await;
674                }
675
676                Ok(())
677            })
678            .await
679            {
680                Ok(Ok(())) => tracing::debug!("Client now active"),
681                Ok(Err(e)) => {
682                    tracing::error!(
683                        "Failed to send data ({}): {e}",
684                        String::from_utf8_lossy(&data)
685                    );
686                    return Ok(());
687                }
688                Err(_) => {
689                    tracing::error!(
690                        "Failed to send data ({}): timeout waiting to become ACTIVE",
691                        String::from_utf8_lossy(&data)
692                    );
693                    return Ok(());
694                }
695            }
696        }
697
698        let msg = WriterCommand::Send(data.into());
699        if let Err(e) = self.writer_tx.send(msg) {
700            tracing::error!("{e}");
701        }
702        Ok(())
703    }
704
705    fn spawn_controller_task(
706        mut inner: SocketClientInner,
707        connection_mode: Arc<AtomicU8>,
708        post_reconnection: Option<PyObject>,
709        post_disconnection: Option<PyObject>,
710    ) -> tokio::task::JoinHandle<()> {
711        tokio::task::spawn(async move {
712            tracing::debug!("Started task 'controller'");
713
714            let check_interval = Duration::from_millis(10);
715
716            loop {
717                tokio::time::sleep(check_interval).await;
718                let mode = ConnectionMode::from_atomic(&connection_mode);
719
720                if mode.is_disconnect() {
721                    tracing::debug!("Disconnecting");
722
723                    let timeout = Duration::from_secs(5);
724                    if tokio::time::timeout(timeout, async {
725                        // Delay awaiting graceful shutdown
726                        tokio::time::sleep(Duration::from_millis(100)).await;
727
728                        if !inner.read_task.is_finished() {
729                            inner.read_task.abort();
730                            tracing::debug!("Aborted task 'read'");
731                        }
732
733                        if let Some(task) = &inner.heartbeat_task {
734                            if !task.is_finished() {
735                                task.abort();
736                                tracing::debug!("Aborted task 'heartbeat'");
737                            }
738                        }
739                    })
740                    .await
741                    .is_err()
742                    {
743                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
744                    }
745
746                    tracing::debug!("Closed");
747
748                    if let Some(ref handler) = post_disconnection {
749                        Python::with_gil(|py| match handler.call0(py) {
750                            Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
751                            Err(e) => {
752                                tracing::error!("Error calling `post_disconnection` handler: {e}");
753                            }
754                        });
755                    }
756                    break; // Controller finished
757                }
758
759                if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
760                    match inner.reconnect().await {
761                        Ok(()) => {
762                            tracing::debug!("Reconnected successfully");
763                            inner.backoff.reset();
764
765                            if let Some(ref handler) = post_reconnection {
766                                Python::with_gil(|py| match handler.call0(py) {
767                                    Ok(_) => {
768                                        tracing::debug!("Called `post_reconnection` handler");
769                                    }
770                                    Err(e) => tracing::error!(
771                                        "Error calling `post_reconnection` handler: {e}"
772                                    ),
773                                });
774                            }
775                        }
776                        Err(e) => {
777                            let duration = inner.backoff.next_duration();
778                            tracing::warn!("Reconnect attempt failed: {e}");
779                            if !duration.is_zero() {
780                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
781                            }
782                            tokio::time::sleep(duration).await;
783                        }
784                    }
785                }
786            }
787            inner
788                .connection_mode
789                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
790
791            tracing::debug!("Completed task 'controller'");
792        })
793    }
794}
795
796////////////////////////////////////////////////////////////////////////////////
797// Tests
798////////////////////////////////////////////////////////////////////////////////
799#[cfg(test)]
800#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
801mod tests {
802    use std::ffi::CString;
803
804    use nautilus_common::testing::wait_until_async;
805    use nautilus_core::python::IntoPyObjectNautilusExt;
806    use pyo3::prepare_freethreaded_python;
807    use tokio::{
808        io::{AsyncReadExt, AsyncWriteExt},
809        net::{TcpListener, TcpStream},
810        sync::Mutex,
811        task,
812        time::{Duration, sleep},
813    };
814
815    use super::*;
816
817    fn create_handler() -> PyObject {
818        let code_raw = r"
819class Counter:
820    def __init__(self):
821        self.count = 0
822        self.check = False
823
824    def handler(self, bytes):
825        msg = bytes.decode()
826        if msg == 'ping':
827            self.count += 1
828        elif msg == 'heartbeat message':
829            self.check = True
830
831    def get_check(self):
832        return self.check
833
834    def get_count(self):
835        return self.count
836
837counter = Counter()
838";
839        let code = CString::new(code_raw).unwrap();
840        let filename = CString::new("test".to_string()).unwrap();
841        let module = CString::new("test".to_string()).unwrap();
842        Python::with_gil(|py| {
843            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
844            let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
845
846            counter
847                .getattr(py, "handler")
848                .unwrap()
849                .into_py_any_unwrap(py)
850        })
851    }
852
853    async fn bind_test_server() -> (u16, TcpListener) {
854        let listener = TcpListener::bind("127.0.0.1:0")
855            .await
856            .expect("Failed to bind ephemeral port");
857        let port = listener.local_addr().unwrap().port();
858        (port, listener)
859    }
860
861    async fn run_echo_server(mut socket: TcpStream) {
862        let mut buf = Vec::new();
863        loop {
864            match socket.read_buf(&mut buf).await {
865                Ok(0) => {
866                    break;
867                }
868                Ok(_n) => {
869                    while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
870                        let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
871                        // Remove trailing \r\n
872                        line.truncate(line.len() - 2);
873
874                        if line == b"close" {
875                            let _ = socket.shutdown().await;
876                            return;
877                        }
878
879                        let mut echo_data = line;
880                        echo_data.extend_from_slice(b"\r\n");
881                        if socket.write_all(&echo_data).await.is_err() {
882                            break;
883                        }
884                    }
885                }
886                Err(e) => {
887                    eprintln!("Server read error: {e}");
888                    break;
889                }
890            }
891        }
892    }
893
894    #[tokio::test]
895    async fn test_basic_send_receive() {
896        prepare_freethreaded_python();
897
898        let (port, listener) = bind_test_server().await;
899        let server_task = task::spawn(async move {
900            let (socket, _) = listener.accept().await.unwrap();
901            run_echo_server(socket).await;
902        });
903
904        let config = SocketConfig {
905            url: format!("127.0.0.1:{port}"),
906            mode: Mode::Plain,
907            suffix: b"\r\n".to_vec(),
908            py_handler: Some(Arc::new(create_handler())),
909            heartbeat: None,
910            reconnect_timeout_ms: None,
911            reconnect_delay_initial_ms: None,
912            reconnect_backoff_factor: None,
913            reconnect_delay_max_ms: None,
914            reconnect_jitter_ms: None,
915            certs_dir: None,
916        };
917
918        let client = SocketClient::connect(config, None, None, None, None)
919            .await
920            .expect("Client connect failed unexpectedly");
921
922        client.send_bytes(b"Hello".into()).await.unwrap();
923        client.send_bytes(b"World".into()).await.unwrap();
924
925        // Wait a bit for the server to echo them back
926        sleep(Duration::from_millis(100)).await;
927
928        client.send_bytes(b"close".into()).await.unwrap();
929        server_task.await.unwrap();
930        assert!(!client.is_closed());
931    }
932
933    #[tokio::test]
934    async fn test_reconnect_fail_exhausted() {
935        prepare_freethreaded_python();
936
937        let (port, listener) = bind_test_server().await;
938        drop(listener); // We drop it immediately -> no server is listening
939
940        let config = SocketConfig {
941            url: format!("127.0.0.1:{port}"),
942            mode: Mode::Plain,
943            suffix: b"\r\n".to_vec(),
944            py_handler: Some(Arc::new(create_handler())),
945            heartbeat: None,
946            reconnect_timeout_ms: None,
947            reconnect_delay_initial_ms: None,
948            reconnect_backoff_factor: None,
949            reconnect_delay_max_ms: None,
950            reconnect_jitter_ms: None,
951            certs_dir: None,
952        };
953
954        let client_res = SocketClient::connect(config, None, None, None, None).await;
955        assert!(
956            client_res.is_err(),
957            "Should fail quickly with no server listening"
958        );
959    }
960
961    #[tokio::test]
962    async fn test_user_disconnect() {
963        prepare_freethreaded_python();
964
965        let (port, listener) = bind_test_server().await;
966        let server_task = task::spawn(async move {
967            let (socket, _) = listener.accept().await.unwrap();
968            let mut buf = [0u8; 1024];
969            let _ = socket.try_read(&mut buf);
970
971            loop {
972                sleep(Duration::from_secs(1)).await;
973            }
974        });
975
976        let config = SocketConfig {
977            url: format!("127.0.0.1:{port}"),
978            mode: Mode::Plain,
979            suffix: b"\r\n".to_vec(),
980            py_handler: Some(Arc::new(create_handler())),
981            heartbeat: None,
982            reconnect_timeout_ms: None,
983            reconnect_delay_initial_ms: None,
984            reconnect_backoff_factor: None,
985            reconnect_delay_max_ms: None,
986            reconnect_jitter_ms: None,
987            certs_dir: None,
988        };
989
990        let client = SocketClient::connect(config, None, None, None, None)
991            .await
992            .unwrap();
993
994        client.close().await;
995        assert!(client.is_closed());
996        server_task.abort();
997    }
998
999    #[tokio::test]
1000    async fn test_heartbeat() {
1001        prepare_freethreaded_python();
1002
1003        let (port, listener) = bind_test_server().await;
1004        let received = Arc::new(Mutex::new(Vec::new()));
1005        let received2 = received.clone();
1006
1007        let server_task = task::spawn(async move {
1008            let (socket, _) = listener.accept().await.unwrap();
1009
1010            let mut buf = Vec::new();
1011            loop {
1012                match socket.try_read_buf(&mut buf) {
1013                    Ok(0) => break,
1014                    Ok(_) => {
1015                        while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1016                            let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1017                            line.truncate(line.len() - 2);
1018                            received2.lock().await.push(line);
1019                        }
1020                    }
1021                    Err(_) => {
1022                        tokio::time::sleep(Duration::from_millis(10)).await;
1023                    }
1024                }
1025            }
1026        });
1027
1028        // Heartbeat every 1 second
1029        let heartbeat = Some((1, b"ping".to_vec()));
1030
1031        let config = SocketConfig {
1032            url: format!("127.0.0.1:{port}"),
1033            mode: Mode::Plain,
1034            suffix: b"\r\n".to_vec(),
1035            py_handler: Some(Arc::new(create_handler())),
1036            heartbeat,
1037            reconnect_timeout_ms: None,
1038            reconnect_delay_initial_ms: None,
1039            reconnect_backoff_factor: None,
1040            reconnect_delay_max_ms: None,
1041            reconnect_jitter_ms: None,
1042            certs_dir: None,
1043        };
1044
1045        let client = SocketClient::connect(config, None, None, None, None)
1046            .await
1047            .unwrap();
1048
1049        // Wait ~3 seconds to collect some heartbeats
1050        sleep(Duration::from_secs(3)).await;
1051
1052        {
1053            let lock = received.lock().await;
1054            let pings = lock
1055                .iter()
1056                .filter(|line| line == &&b"ping".to_vec())
1057                .count();
1058            assert!(
1059                pings >= 2,
1060                "Expected at least 2 heartbeat pings; got {pings}"
1061            );
1062        }
1063
1064        client.close().await;
1065        server_task.abort();
1066    }
1067
1068    #[tokio::test]
1069    async fn test_python_handler_error() {
1070        prepare_freethreaded_python();
1071
1072        let (port, listener) = bind_test_server().await;
1073        let server_task = task::spawn(async move {
1074            let (socket, _) = listener.accept().await.unwrap();
1075            run_echo_server(socket).await;
1076        });
1077
1078        let code_raw = r#"
1079def handler(bytes_data):
1080    txt = bytes_data.decode()
1081    if "ERR" in txt:
1082        raise ValueError("Simulated error in handler")
1083    return
1084"#;
1085        let code = CString::new(code_raw).unwrap();
1086        let filename = CString::new("test".to_string()).unwrap();
1087        let module = CString::new("test".to_string()).unwrap();
1088
1089        let py_handler = Some(Python::with_gil(|py| {
1090            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
1091            let func = pymod.getattr("handler").unwrap();
1092            Arc::new(func.into_py_any_unwrap(py))
1093        }));
1094
1095        let config = SocketConfig {
1096            url: format!("127.0.0.1:{port}"),
1097            mode: Mode::Plain,
1098            suffix: b"\r\n".to_vec(),
1099            py_handler,
1100            heartbeat: None,
1101            reconnect_timeout_ms: None,
1102            reconnect_delay_initial_ms: None,
1103            reconnect_backoff_factor: None,
1104            reconnect_delay_max_ms: None,
1105            reconnect_jitter_ms: None,
1106            certs_dir: None,
1107        };
1108
1109        let client = SocketClient::connect(config, None, None, None, None)
1110            .await
1111            .expect("Client connect failed unexpectedly");
1112
1113        client.send_bytes(b"hello".into()).await.unwrap();
1114        sleep(Duration::from_millis(100)).await;
1115
1116        client.send_bytes(b"ERR".into()).await.unwrap();
1117        sleep(Duration::from_secs(1)).await;
1118
1119        assert!(client.is_active());
1120
1121        client.close().await;
1122
1123        assert!(client.is_closed());
1124        server_task.abort();
1125    }
1126
1127    #[tokio::test]
1128    async fn test_reconnect_success() {
1129        prepare_freethreaded_python();
1130
1131        let (port, listener) = bind_test_server().await;
1132
1133        // Spawn a server task that:
1134        // 1. Accepts the first connection and then drops it after a short delay (simulate disconnect)
1135        // 2. Waits a bit and then accepts a new connection and runs the echo server
1136        let server_task = task::spawn(async move {
1137            // Accept first connection
1138            let (mut socket, _) = listener.accept().await.expect("First accept failed");
1139
1140            // Wait briefly and then force-close the connection
1141            sleep(Duration::from_millis(500)).await;
1142            let _ = socket.shutdown().await;
1143
1144            // Wait for the client's reconnect attempt
1145            sleep(Duration::from_millis(500)).await;
1146
1147            // Run the echo server on the new connection
1148            let (socket, _) = listener.accept().await.expect("Second accept failed");
1149            run_echo_server(socket).await;
1150        });
1151
1152        let config = SocketConfig {
1153            url: format!("127.0.0.1:{port}"),
1154            mode: Mode::Plain,
1155            suffix: b"\r\n".to_vec(),
1156            py_handler: Some(Arc::new(create_handler())),
1157            heartbeat: None,
1158            reconnect_timeout_ms: Some(5_000),
1159            reconnect_delay_initial_ms: Some(500),
1160            reconnect_delay_max_ms: Some(5_000),
1161            reconnect_backoff_factor: Some(2.0),
1162            reconnect_jitter_ms: Some(50),
1163            certs_dir: None,
1164        };
1165
1166        let client = SocketClient::connect(config, None, None, None, None)
1167            .await
1168            .expect("Client connect failed unexpectedly");
1169
1170        // Initially, the client should be active
1171        assert!(client.is_active(), "Client should start as active");
1172
1173        // Wait until the client loses connection (i.e. not active),
1174        // then wait until it reconnects (active again).
1175        wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1176
1177        client
1178            .send_bytes(b"TestReconnect".into())
1179            .await
1180            .expect("Send failed");
1181
1182        client.close().await;
1183        server_task.abort();
1184    }
1185}