nautilus_network/
socket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance raw TCP client implementation with TLS capability, automatic reconnection
17//! with exponential backoff and state management.
18
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED)
21//! - Synchronized reconnection with backoff
22//! - Split read/write architecture
23//! - Python callback integration
24//!
25//! **Design**:
26//! - Single reader, multiple writer model
27//! - Read half runs in dedicated task
28//! - Write half runs in dedicated task connected with channel
29//! - Controller task manages lifecycle
30
31use std::{
32    fmt::Debug,
33    path::Path,
34    sync::{
35        Arc,
36        atomic::{AtomicU8, Ordering},
37    },
38    time::Duration,
39};
40
41use bytes::Bytes;
42use nautilus_core::CleanDrop;
43use nautilus_cryptography::providers::install_cryptographic_provider;
44use tokio::{
45    io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
46    net::TcpStream,
47};
48use tokio_tungstenite::{
49    MaybeTlsStream,
50    tungstenite::{Error, client::IntoClientRequest, stream::Mode},
51};
52
53use crate::{
54    backoff::ExponentialBackoff,
55    error::SendError,
56    fix::process_fix_buffer,
57    logging::{log_task_aborted, log_task_started, log_task_stopped},
58    mode::ConnectionMode,
59    tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
60};
61
62// Connection timing constants
63const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
64const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
65const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
66const SEND_OPERATION_TIMEOUT_SECS: u64 = 2;
67const SEND_OPERATION_CHECK_INTERVAL_MS: u64 = 1;
68
69type TcpWriter = WriteHalf<MaybeTlsStream<TcpStream>>;
70type TcpReader = ReadHalf<MaybeTlsStream<TcpStream>>;
71pub type TcpMessageHandler = Arc<dyn Fn(&[u8]) + Send + Sync>;
72
73/// Configuration for TCP socket connection.
74#[cfg_attr(
75    feature = "python",
76    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
77)]
78pub struct SocketConfig {
79    /// The URL to connect to.
80    pub url: String,
81    /// The connection mode {Plain, TLS}.
82    pub mode: Mode,
83    /// The sequence of bytes which separates lines.
84    pub suffix: Vec<u8>,
85    /// The optional function to handle incoming messages.
86    pub message_handler: Option<TcpMessageHandler>,
87    /// The optional heartbeat with period and beat message.
88    pub heartbeat: Option<(u64, Vec<u8>)>,
89    /// The timeout (milliseconds) for reconnection attempts.
90    pub reconnect_timeout_ms: Option<u64>,
91    /// The initial reconnection delay (milliseconds) for reconnects.
92    pub reconnect_delay_initial_ms: Option<u64>,
93    /// The maximum reconnect delay (milliseconds) for exponential backoff.
94    pub reconnect_delay_max_ms: Option<u64>,
95    /// The exponential backoff factor for reconnection delays.
96    pub reconnect_backoff_factor: Option<f64>,
97    /// The maximum jitter (milliseconds) added to reconnection delays.
98    pub reconnect_jitter_ms: Option<u64>,
99    /// The path to the certificates directory.
100    pub certs_dir: Option<String>,
101}
102
103impl Debug for SocketConfig {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        f.debug_struct("SocketConfig")
106            .field("url", &self.url)
107            .field("mode", &self.mode)
108            .field("suffix", &self.suffix)
109            .field(
110                "message_handler",
111                &self.message_handler.as_ref().map(|_| "<function>"),
112            )
113            .field("heartbeat", &self.heartbeat)
114            .field("reconnect_timeout_ms", &self.reconnect_timeout_ms)
115            .field(
116                "reconnect_delay_initial_ms",
117                &self.reconnect_delay_initial_ms,
118            )
119            .field("reconnect_delay_max_ms", &self.reconnect_delay_max_ms)
120            .field("reconnect_backoff_factor", &self.reconnect_backoff_factor)
121            .field("reconnect_jitter_ms", &self.reconnect_jitter_ms)
122            .field("certs_dir", &self.certs_dir)
123            .finish()
124    }
125}
126
127impl Clone for SocketConfig {
128    fn clone(&self) -> Self {
129        Self {
130            url: self.url.clone(),
131            mode: self.mode,
132            suffix: self.suffix.clone(),
133            message_handler: self.message_handler.clone(),
134            heartbeat: self.heartbeat.clone(),
135            reconnect_timeout_ms: self.reconnect_timeout_ms,
136            reconnect_delay_initial_ms: self.reconnect_delay_initial_ms,
137            reconnect_delay_max_ms: self.reconnect_delay_max_ms,
138            reconnect_backoff_factor: self.reconnect_backoff_factor,
139            reconnect_jitter_ms: self.reconnect_jitter_ms,
140            certs_dir: self.certs_dir.clone(),
141        }
142    }
143}
144
145/// Represents a command for the writer task.
146#[derive(Debug)]
147pub enum WriterCommand {
148    /// Update the writer reference with a new one after reconnection.
149    Update(TcpWriter),
150    /// Send data to the server.
151    Send(Bytes),
152}
153
154/// Creates a `TcpStream` with the server.
155///
156/// The stream can be encrypted with TLS or Plain. The stream is split into
157/// read and write ends:
158/// - The read end is passed to the task that keeps receiving
159///   messages from the server and passing them to a handler.
160/// - The write end is passed to a task which receives messages over a channel
161///   to send to the server.
162///
163/// The heartbeat is optional and can be configured with an interval and data to
164/// send.
165///
166/// The client uses a suffix to separate messages on the byte stream. It is
167/// appended to all sent messages and heartbeats. It is also used to split
168/// the received byte stream.
169#[cfg_attr(
170    feature = "python",
171    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
172)]
173struct SocketClientInner {
174    config: SocketConfig,
175    connector: Option<Connector>,
176    read_task: Arc<tokio::task::JoinHandle<()>>,
177    write_task: tokio::task::JoinHandle<()>,
178    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
179    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
180    connection_mode: Arc<AtomicU8>,
181    reconnect_timeout: Duration,
182    backoff: ExponentialBackoff,
183    handler: Option<TcpMessageHandler>,
184}
185
186impl SocketClientInner {
187    /// Connect to a URL with the specified configuration.
188    ///
189    /// # Errors
190    ///
191    /// Returns an error if connection fails or configuration is invalid.
192    pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
193        install_cryptographic_provider();
194
195        let SocketConfig {
196            url,
197            mode,
198            heartbeat,
199            suffix,
200            message_handler,
201            reconnect_timeout_ms,
202            reconnect_delay_initial_ms,
203            reconnect_delay_max_ms,
204            reconnect_backoff_factor,
205            reconnect_jitter_ms,
206            certs_dir,
207        } = &config.clone();
208        let connector = if let Some(dir) = certs_dir {
209            let config = create_tls_config_from_certs_dir(Path::new(dir))?;
210            Some(Connector::Rustls(Arc::new(config)))
211        } else {
212            None
213        };
214
215        let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector.clone()).await?;
216        tracing::debug!("Connected");
217
218        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
219
220        let read_task = Arc::new(Self::spawn_read_task(
221            connection_mode.clone(),
222            reader,
223            message_handler.clone(),
224            suffix.clone(),
225        ));
226
227        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
228
229        let write_task =
230            Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
231
232        // Optionally spawn a heartbeat task to periodically ping server
233        let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
234            Self::spawn_heartbeat_task(
235                connection_mode.clone(),
236                heartbeat.clone(),
237                writer_tx.clone(),
238            )
239        });
240
241        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
242        let backoff = ExponentialBackoff::new(
243            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
244            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
245            reconnect_backoff_factor.unwrap_or(1.5),
246            reconnect_jitter_ms.unwrap_or(100),
247            true, // immediate-first
248        )?;
249
250        Ok(Self {
251            config,
252            connector,
253            read_task,
254            write_task,
255            writer_tx,
256            heartbeat_task,
257            connection_mode,
258            reconnect_timeout,
259            backoff,
260            handler: message_handler.clone(),
261        })
262    }
263
264    /// Establish a TLS or plain TCP connection with the server.
265    ///
266    /// # Errors
267    ///
268    /// Returns an error if the connection cannot be established.
269    pub async fn tls_connect_with_server(
270        url: &str,
271        mode: Mode,
272        connector: Option<Connector>,
273    ) -> Result<(TcpReader, TcpWriter), Error> {
274        tracing::debug!("Connecting to {url}");
275        let tcp_result = TcpStream::connect(url).await;
276
277        match tcp_result {
278            Ok(stream) => {
279                tracing::debug!("TCP connection established, proceeding with TLS");
280                let request = url.into_client_request()?;
281                tcp_tls(&request, mode, stream, connector)
282                    .await
283                    .map(tokio::io::split)
284            }
285            Err(e) => {
286                tracing::error!("TCP connection failed: {e:?}");
287                Err(Error::Io(e))
288            }
289        }
290    }
291
292    /// Reconnect with server.
293    ///
294    /// Makes a new connection with server, uses the new read and write halves
295    /// to update the reader and writer.
296    async fn reconnect(&mut self) -> Result<(), Error> {
297        tracing::debug!("Reconnecting");
298
299        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
300            tracing::debug!("Reconnect aborted due to disconnect state");
301            return Ok(());
302        }
303
304        tokio::time::timeout(self.reconnect_timeout, async {
305            let SocketConfig {
306                url,
307                mode,
308                heartbeat: _,
309                suffix,
310                message_handler: _,
311                reconnect_timeout_ms: _,
312                reconnect_delay_initial_ms: _,
313                reconnect_backoff_factor: _,
314                reconnect_delay_max_ms: _,
315                reconnect_jitter_ms: _,
316                certs_dir: _,
317            } = &self.config;
318            // Create a fresh connection
319            let connector = self.connector.clone();
320            // Attempt to connect; abort early if a disconnect was requested
321            let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
322
323            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
324                tracing::debug!("Reconnect aborted mid-flight (after connect)");
325                return Ok(());
326            }
327            tracing::debug!("Connected");
328
329            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
330                tracing::error!("{e}");
331            }
332
333            // Delay before closing connection
334            tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
335
336            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
337                tracing::debug!("Reconnect aborted mid-flight (after delay)");
338                return Ok(());
339            }
340
341            if !self.read_task.is_finished() {
342                self.read_task.abort();
343                log_task_aborted("read");
344            }
345
346            // If a disconnect was requested during reconnect, do not proceed to reactivate
347            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
348                tracing::debug!("Reconnect aborted mid-flight (before spawn read)");
349                return Ok(());
350            }
351
352            // Mark as active only if not disconnecting
353            self.connection_mode
354                .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
355
356            // Spawn new read task
357            self.read_task = Arc::new(Self::spawn_read_task(
358                self.connection_mode.clone(),
359                reader,
360                self.handler.clone(),
361                suffix.clone(),
362            ));
363
364            tracing::debug!("Reconnect succeeded");
365            Ok(())
366        })
367        .await
368        .map_err(|_| {
369            Error::Io(std::io::Error::new(
370                std::io::ErrorKind::TimedOut,
371                format!(
372                    "reconnection timed out after {}s",
373                    self.reconnect_timeout.as_secs_f64()
374                ),
375            ))
376        })?
377    }
378
379    /// Check if the client is still alive.
380    ///
381    /// The client is connected if the read task has not finished. It is expected
382    /// that in case of any failure client or server side. The read task will be
383    /// shutdown. There might be some delay between the connection being closed
384    /// and the client detecting it.
385    #[inline]
386    #[must_use]
387    pub fn is_alive(&self) -> bool {
388        !self.read_task.is_finished()
389    }
390
391    #[must_use]
392    fn spawn_read_task(
393        connection_state: Arc<AtomicU8>,
394        mut reader: TcpReader,
395        handler: Option<TcpMessageHandler>,
396        suffix: Vec<u8>,
397    ) -> tokio::task::JoinHandle<()> {
398        log_task_started("read");
399
400        // Interval between checking the connection mode
401        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
402
403        tokio::task::spawn(async move {
404            let mut buf = Vec::new();
405
406            loop {
407                if !ConnectionMode::from_atomic(&connection_state).is_active() {
408                    break;
409                }
410
411                match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
412                    // Connection has been terminated or vector buffer is complete
413                    Ok(Ok(0)) => {
414                        tracing::debug!("Connection closed by server");
415                        break;
416                    }
417                    Ok(Err(e)) => {
418                        tracing::debug!("Connection ended: {e}");
419                        break;
420                    }
421                    // Received bytes of data
422                    Ok(Ok(bytes)) => {
423                        tracing::trace!("Received <binary> {bytes} bytes");
424
425                        // Check if buffer contains FIX protocol messages (starts with "8=FIX")
426                        let is_fix = buf.len() >= 5 && buf.starts_with(b"8=FIX");
427
428                        if is_fix && handler.is_some() {
429                            // FIX protocol processing
430                            if let Some(ref handler) = handler {
431                                process_fix_buffer(&mut buf, handler);
432                            }
433                        } else {
434                            // Regular suffix-based message processing
435                            while let Some((i, _)) = &buf
436                                .windows(suffix.len())
437                                .enumerate()
438                                .find(|(_, pair)| pair.eq(&suffix))
439                            {
440                                let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
441                                data.truncate(data.len() - suffix.len());
442
443                                if let Some(ref handler) = handler {
444                                    handler(&data);
445                                }
446                            }
447                        }
448                    }
449                    Err(_) => {
450                        // Timeout - continue loop and check connection mode
451                        continue;
452                    }
453                }
454            }
455
456            log_task_stopped("read");
457        })
458    }
459
460    fn spawn_write_task(
461        connection_state: Arc<AtomicU8>,
462        writer: TcpWriter,
463        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
464        suffix: Vec<u8>,
465    ) -> tokio::task::JoinHandle<()> {
466        log_task_started("write");
467
468        // Interval between checking the connection mode
469        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
470
471        tokio::task::spawn(async move {
472            let mut active_writer = writer;
473
474            loop {
475                if matches!(
476                    ConnectionMode::from_atomic(&connection_state),
477                    ConnectionMode::Disconnect | ConnectionMode::Closed
478                ) {
479                    break;
480                }
481
482                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
483                    Ok(Some(msg)) => {
484                        // Re-check connection mode after receiving a message
485                        let mode = ConnectionMode::from_atomic(&connection_state);
486                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
487                            break;
488                        }
489
490                        match msg {
491                            WriterCommand::Update(new_writer) => {
492                                tracing::debug!("Received new writer");
493
494                                // Delay before closing connection
495                                tokio::time::sleep(Duration::from_millis(100)).await;
496
497                                // Attempt to shutdown the writer gracefully before updating,
498                                // we ignore any error as the writer may already be closed.
499                                _ = tokio::time::timeout(
500                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
501                                    active_writer.shutdown(),
502                                )
503                                .await;
504
505                                active_writer = new_writer;
506                                tracing::debug!("Updated writer");
507                            }
508                            _ if mode.is_reconnect() => {
509                                tracing::warn!("Skipping message while reconnecting, {msg:?}");
510                                continue;
511                            }
512                            WriterCommand::Send(msg) => {
513                                if let Err(e) = active_writer.write_all(&msg).await {
514                                    tracing::error!("Failed to send message: {e}");
515                                    // Mode is active so trigger reconnection
516                                    tracing::warn!("Writer triggering reconnect");
517                                    connection_state
518                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
519                                    continue;
520                                }
521                                if let Err(e) = active_writer.write_all(&suffix).await {
522                                    tracing::error!("Failed to send message: {e}");
523                                }
524                            }
525                        }
526                    }
527                    Ok(None) => {
528                        // Channel closed - writer task should terminate
529                        tracing::debug!("Writer channel closed, terminating writer task");
530                        break;
531                    }
532                    Err(_) => {
533                        // Timeout - just continue the loop
534                        continue;
535                    }
536                }
537            }
538
539            // Attempt to shutdown the writer gracefully before exiting,
540            // we ignore any error as the writer may already be closed.
541            _ = tokio::time::timeout(
542                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
543                active_writer.shutdown(),
544            )
545            .await;
546
547            log_task_stopped("write");
548        })
549    }
550
551    fn spawn_heartbeat_task(
552        connection_state: Arc<AtomicU8>,
553        heartbeat: (u64, Vec<u8>),
554        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
555    ) -> tokio::task::JoinHandle<()> {
556        log_task_started("heartbeat");
557        let (interval_secs, message) = heartbeat;
558
559        tokio::task::spawn(async move {
560            let interval = Duration::from_secs(interval_secs);
561
562            loop {
563                tokio::time::sleep(interval).await;
564
565                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
566                    ConnectionMode::Active => {
567                        let msg = WriterCommand::Send(message.clone().into());
568
569                        match writer_tx.send(msg) {
570                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
571                            Err(e) => {
572                                tracing::error!("Failed to send heartbeat to writer task: {e}");
573                            }
574                        }
575                    }
576                    ConnectionMode::Reconnect => continue,
577                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
578                }
579            }
580
581            log_task_stopped("heartbeat");
582        })
583    }
584}
585
586impl Drop for SocketClientInner {
587    fn drop(&mut self) {
588        // Delegate to explicit cleanup handler
589        self.clean_drop();
590    }
591}
592
593impl CleanDrop for SocketClientInner {
594    fn clean_drop(&mut self) {
595        if !self.read_task.is_finished() {
596            self.read_task.abort();
597            log_task_aborted("read");
598        }
599
600        if !self.write_task.is_finished() {
601            self.write_task.abort();
602            log_task_aborted("write");
603        }
604
605        if let Some(ref handle) = self.heartbeat_task.take()
606            && !handle.is_finished()
607        {
608            handle.abort();
609            log_task_aborted("heartbeat");
610        }
611
612        #[cfg(feature = "python")]
613        {
614            // Remove stored handler to break ref cycle
615            self.config.message_handler = None;
616        }
617    }
618}
619
620#[cfg_attr(
621    feature = "python",
622    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
623)]
624pub struct SocketClient {
625    pub(crate) controller_task: tokio::task::JoinHandle<()>,
626    pub(crate) connection_mode: Arc<AtomicU8>,
627    pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
628}
629
630impl Debug for SocketClient {
631    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
632        f.debug_struct(stringify!(SocketClient)).finish()
633    }
634}
635
636impl SocketClient {
637    /// Connect to the server.
638    ///
639    /// # Errors
640    ///
641    /// Returns any error connecting to the server.
642    pub async fn connect(
643        config: SocketConfig,
644        post_connection: Option<Arc<dyn Fn() + Send + Sync>>,
645        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
646        post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
647    ) -> anyhow::Result<Self> {
648        let inner = SocketClientInner::connect_url(config).await?;
649        let writer_tx = inner.writer_tx.clone();
650        let connection_mode = inner.connection_mode.clone();
651
652        let controller_task = Self::spawn_controller_task(
653            inner,
654            connection_mode.clone(),
655            post_reconnection,
656            post_disconnection,
657        );
658
659        if let Some(handler) = post_connection {
660            handler();
661            tracing::debug!("Called `post_connection` handler");
662        }
663
664        Ok(Self {
665            controller_task,
666            connection_mode,
667            writer_tx,
668        })
669    }
670
671    /// Returns the current connection mode.
672    #[must_use]
673    pub fn connection_mode(&self) -> ConnectionMode {
674        ConnectionMode::from_atomic(&self.connection_mode)
675    }
676
677    /// Check if the client connection is active.
678    ///
679    /// Returns `true` if the client is connected and has not been signalled to disconnect.
680    /// The client will automatically retry connection based on its configuration.
681    #[inline]
682    #[must_use]
683    pub fn is_active(&self) -> bool {
684        self.connection_mode().is_active()
685    }
686
687    /// Check if the client is reconnecting.
688    ///
689    /// Returns `true` if the client lost connection and is attempting to reestablish it.
690    /// The client will automatically retry connection based on its configuration.
691    #[inline]
692    #[must_use]
693    pub fn is_reconnecting(&self) -> bool {
694        self.connection_mode().is_reconnect()
695    }
696
697    /// Check if the client is disconnecting.
698    ///
699    /// Returns `true` if the client is in disconnect mode.
700    #[inline]
701    #[must_use]
702    pub fn is_disconnecting(&self) -> bool {
703        self.connection_mode().is_disconnect()
704    }
705
706    /// Check if the client is closed.
707    ///
708    /// Returns `true` if the client has been explicitly disconnected or reached
709    /// maximum reconnection attempts. In this state, the client cannot be reused
710    /// and a new client must be created for further connections.
711    #[inline]
712    #[must_use]
713    pub fn is_closed(&self) -> bool {
714        self.connection_mode().is_closed()
715    }
716
717    /// Close the client.
718    ///
719    /// Controller task will periodically check the disconnect mode
720    /// and shutdown the client if it is not alive.
721    pub async fn close(&self) {
722        self.connection_mode
723            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
724
725        match tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
726            while !self.is_closed() {
727                tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
728            }
729
730            if !self.controller_task.is_finished() {
731                self.controller_task.abort();
732                log_task_aborted("controller");
733            }
734        })
735        .await
736        {
737            Ok(()) => {
738                log_task_stopped("controller");
739            }
740            Err(_) => {
741                tracing::error!("Timeout waiting for controller task to finish");
742            }
743        }
744    }
745
746    /// Sends a message of the given `data`.
747    ///
748    /// # Errors
749    ///
750    /// Returns an error if sending fails.
751    pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
752        if self.is_closed() {
753            return Err(SendError::Closed);
754        }
755
756        let timeout = Duration::from_secs(SEND_OPERATION_TIMEOUT_SECS);
757        let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
758
759        if !self.is_active() {
760            tracing::debug!("Waiting for client to become ACTIVE before sending...");
761
762            let inner = tokio::time::timeout(timeout, async {
763                loop {
764                    if self.is_active() {
765                        return Ok(());
766                    }
767                    if matches!(
768                        self.connection_mode(),
769                        ConnectionMode::Disconnect | ConnectionMode::Closed
770                    ) {
771                        return Err(());
772                    }
773                    tokio::time::sleep(check_interval).await;
774                }
775            })
776            .await
777            .map_err(|_| SendError::Timeout)?;
778            inner.map_err(|()| SendError::Closed)?;
779        }
780
781        let msg = WriterCommand::Send(data.into());
782        self.writer_tx
783            .send(msg)
784            .map_err(|e| SendError::BrokenPipe(e.to_string()))
785    }
786
787    fn spawn_controller_task(
788        mut inner: SocketClientInner,
789        connection_mode: Arc<AtomicU8>,
790        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
791        post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
792    ) -> tokio::task::JoinHandle<()> {
793        tokio::task::spawn(async move {
794            log_task_started("controller");
795
796            let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
797
798            loop {
799                tokio::time::sleep(check_interval).await;
800                let mode = ConnectionMode::from_atomic(&connection_mode);
801
802                if mode.is_disconnect() {
803                    tracing::debug!("Disconnecting");
804
805                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
806                    if tokio::time::timeout(timeout, async {
807                        // Delay awaiting graceful shutdown
808                        tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
809
810                        if !inner.read_task.is_finished() {
811                            inner.read_task.abort();
812                            log_task_aborted("read");
813                        }
814
815                        if let Some(task) = &inner.heartbeat_task
816                            && !task.is_finished()
817                        {
818                            task.abort();
819                            log_task_aborted("heartbeat");
820                        }
821                    })
822                    .await
823                    .is_err()
824                    {
825                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
826                    }
827
828                    tracing::debug!("Closed");
829
830                    if let Some(ref handler) = post_disconnection {
831                        handler();
832                        tracing::debug!("Called `post_disconnection` handler");
833                    }
834                    break; // Controller finished
835                }
836
837                if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
838                    match inner.reconnect().await {
839                        Ok(()) => {
840                            tracing::debug!("Reconnected successfully");
841                            inner.backoff.reset();
842                            // Only invoke reconnect handler if still active
843                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
844                                if let Some(ref handler) = post_reconnection {
845                                    handler();
846                                    tracing::debug!("Called `post_reconnection` handler");
847                                }
848                            } else {
849                                tracing::debug!(
850                                    "Skipping post_reconnection handlers due to disconnect state"
851                                );
852                            }
853                        }
854                        Err(e) => {
855                            let duration = inner.backoff.next_duration();
856                            tracing::warn!("Reconnect attempt failed: {e}");
857                            if !duration.is_zero() {
858                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
859                            }
860                            tokio::time::sleep(duration).await;
861                        }
862                    }
863                }
864            }
865            inner
866                .connection_mode
867                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
868
869            log_task_stopped("controller");
870        })
871    }
872}
873
874// Abort controller task on drop to clean up background tasks
875impl Drop for SocketClient {
876    fn drop(&mut self) {
877        if !self.controller_task.is_finished() {
878            self.controller_task.abort();
879            log_task_aborted("controller");
880        }
881    }
882}
883
884////////////////////////////////////////////////////////////////////////////////
885// Tests
886////////////////////////////////////////////////////////////////////////////////
887#[cfg(test)]
888#[cfg(feature = "python")]
889#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
890mod tests {
891    use nautilus_common::testing::wait_until_async;
892    use pyo3::prepare_freethreaded_python;
893    use tokio::{
894        io::{AsyncReadExt, AsyncWriteExt},
895        net::{TcpListener, TcpStream},
896        sync::Mutex,
897        task,
898        time::{Duration, sleep},
899    };
900
901    use super::*;
902
903    async fn bind_test_server() -> (u16, TcpListener) {
904        let listener = TcpListener::bind("127.0.0.1:0")
905            .await
906            .expect("Failed to bind ephemeral port");
907        let port = listener.local_addr().unwrap().port();
908        (port, listener)
909    }
910
911    async fn run_echo_server(mut socket: TcpStream) {
912        let mut buf = Vec::new();
913        loop {
914            match socket.read_buf(&mut buf).await {
915                Ok(0) => {
916                    break;
917                }
918                Ok(_n) => {
919                    while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
920                        let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
921                        // Remove trailing \r\n
922                        line.truncate(line.len() - 2);
923
924                        if line == b"close" {
925                            let _ = socket.shutdown().await;
926                            return;
927                        }
928
929                        let mut echo_data = line;
930                        echo_data.extend_from_slice(b"\r\n");
931                        if socket.write_all(&echo_data).await.is_err() {
932                            break;
933                        }
934                    }
935                }
936                Err(e) => {
937                    eprintln!("Server read error: {e}");
938                    break;
939                }
940            }
941        }
942    }
943
944    #[tokio::test]
945    async fn test_basic_send_receive() {
946        prepare_freethreaded_python();
947
948        let (port, listener) = bind_test_server().await;
949        let server_task = task::spawn(async move {
950            let (socket, _) = listener.accept().await.unwrap();
951            run_echo_server(socket).await;
952        });
953
954        let config = SocketConfig {
955            url: format!("127.0.0.1:{port}"),
956            mode: Mode::Plain,
957            suffix: b"\r\n".to_vec(),
958            message_handler: None,
959            heartbeat: None,
960            reconnect_timeout_ms: None,
961            reconnect_delay_initial_ms: None,
962            reconnect_backoff_factor: None,
963            reconnect_delay_max_ms: None,
964            reconnect_jitter_ms: None,
965            certs_dir: None,
966        };
967
968        let client = SocketClient::connect(config, None, None, None)
969            .await
970            .expect("Client connect failed unexpectedly");
971
972        client.send_bytes(b"Hello".into()).await.unwrap();
973        client.send_bytes(b"World".into()).await.unwrap();
974
975        // Wait a bit for the server to echo them back
976        sleep(Duration::from_millis(100)).await;
977
978        client.send_bytes(b"close".into()).await.unwrap();
979        server_task.await.unwrap();
980        assert!(!client.is_closed());
981    }
982
983    #[tokio::test]
984    async fn test_reconnect_fail_exhausted() {
985        prepare_freethreaded_python();
986
987        let (port, listener) = bind_test_server().await;
988        drop(listener); // We drop it immediately -> no server is listening
989
990        let config = SocketConfig {
991            url: format!("127.0.0.1:{port}"),
992            mode: Mode::Plain,
993            suffix: b"\r\n".to_vec(),
994            message_handler: None,
995            heartbeat: None,
996            reconnect_timeout_ms: None,
997            reconnect_delay_initial_ms: None,
998            reconnect_backoff_factor: None,
999            reconnect_delay_max_ms: None,
1000            reconnect_jitter_ms: None,
1001            certs_dir: None,
1002        };
1003
1004        let client_res = SocketClient::connect(config, None, None, None).await;
1005        assert!(
1006            client_res.is_err(),
1007            "Should fail quickly with no server listening"
1008        );
1009    }
1010
1011    #[tokio::test]
1012    async fn test_user_disconnect() {
1013        prepare_freethreaded_python();
1014
1015        let (port, listener) = bind_test_server().await;
1016        let server_task = task::spawn(async move {
1017            let (socket, _) = listener.accept().await.unwrap();
1018            let mut buf = [0u8; 1024];
1019            let _ = socket.try_read(&mut buf);
1020
1021            loop {
1022                sleep(Duration::from_secs(1)).await;
1023            }
1024        });
1025
1026        let config = SocketConfig {
1027            url: format!("127.0.0.1:{port}"),
1028            mode: Mode::Plain,
1029            suffix: b"\r\n".to_vec(),
1030            message_handler: None,
1031            heartbeat: None,
1032            reconnect_timeout_ms: None,
1033            reconnect_delay_initial_ms: None,
1034            reconnect_backoff_factor: None,
1035            reconnect_delay_max_ms: None,
1036            reconnect_jitter_ms: None,
1037            certs_dir: None,
1038        };
1039
1040        let client = SocketClient::connect(config, None, None, None)
1041            .await
1042            .unwrap();
1043
1044        client.close().await;
1045        assert!(client.is_closed());
1046        server_task.abort();
1047    }
1048
1049    #[tokio::test]
1050    async fn test_heartbeat() {
1051        prepare_freethreaded_python();
1052
1053        let (port, listener) = bind_test_server().await;
1054        let received = Arc::new(Mutex::new(Vec::new()));
1055        let received2 = received.clone();
1056
1057        let server_task = task::spawn(async move {
1058            let (socket, _) = listener.accept().await.unwrap();
1059
1060            let mut buf = Vec::new();
1061            loop {
1062                match socket.try_read_buf(&mut buf) {
1063                    Ok(0) => break,
1064                    Ok(_) => {
1065                        while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1066                            let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1067                            line.truncate(line.len() - 2);
1068                            received2.lock().await.push(line);
1069                        }
1070                    }
1071                    Err(_) => {
1072                        tokio::time::sleep(Duration::from_millis(10)).await;
1073                    }
1074                }
1075            }
1076        });
1077
1078        // Heartbeat every 1 second
1079        let heartbeat = Some((1, b"ping".to_vec()));
1080
1081        let config = SocketConfig {
1082            url: format!("127.0.0.1:{port}"),
1083            mode: Mode::Plain,
1084            suffix: b"\r\n".to_vec(),
1085            message_handler: None,
1086            heartbeat,
1087            reconnect_timeout_ms: None,
1088            reconnect_delay_initial_ms: None,
1089            reconnect_backoff_factor: None,
1090            reconnect_delay_max_ms: None,
1091            reconnect_jitter_ms: None,
1092            certs_dir: None,
1093        };
1094
1095        let client = SocketClient::connect(config, None, None, None)
1096            .await
1097            .unwrap();
1098
1099        // Wait ~3 seconds to collect some heartbeats
1100        sleep(Duration::from_secs(3)).await;
1101
1102        {
1103            let lock = received.lock().await;
1104            let pings = lock
1105                .iter()
1106                .filter(|line| line == &&b"ping".to_vec())
1107                .count();
1108            assert!(
1109                pings >= 2,
1110                "Expected at least 2 heartbeat pings; got {pings}"
1111            );
1112        }
1113
1114        client.close().await;
1115        server_task.abort();
1116    }
1117
1118    #[tokio::test]
1119    async fn test_reconnect_success() {
1120        prepare_freethreaded_python();
1121
1122        let (port, listener) = bind_test_server().await;
1123
1124        // Spawn a server task that:
1125        // 1. Accepts the first connection and then drops it after a short delay (simulate disconnect)
1126        // 2. Waits a bit and then accepts a new connection and runs the echo server
1127        let server_task = task::spawn(async move {
1128            // Accept first connection
1129            let (mut socket, _) = listener.accept().await.expect("First accept failed");
1130
1131            // Wait briefly and then force-close the connection
1132            sleep(Duration::from_millis(500)).await;
1133            let _ = socket.shutdown().await;
1134
1135            // Wait for the client's reconnect attempt
1136            sleep(Duration::from_millis(500)).await;
1137
1138            // Run the echo server on the new connection
1139            let (socket, _) = listener.accept().await.expect("Second accept failed");
1140            run_echo_server(socket).await;
1141        });
1142
1143        let config = SocketConfig {
1144            url: format!("127.0.0.1:{port}"),
1145            mode: Mode::Plain,
1146            suffix: b"\r\n".to_vec(),
1147            message_handler: None,
1148            heartbeat: None,
1149            reconnect_timeout_ms: Some(5_000),
1150            reconnect_delay_initial_ms: Some(500),
1151            reconnect_delay_max_ms: Some(5_000),
1152            reconnect_backoff_factor: Some(2.0),
1153            reconnect_jitter_ms: Some(50),
1154            certs_dir: None,
1155        };
1156
1157        let client = SocketClient::connect(config, None, None, None)
1158            .await
1159            .expect("Client connect failed unexpectedly");
1160
1161        // Initially, the client should be active
1162        assert!(client.is_active(), "Client should start as active");
1163
1164        // Wait until the client loses connection (i.e. not active),
1165        // then wait until it reconnects (active again).
1166        wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1167
1168        client
1169            .send_bytes(b"TestReconnect".into())
1170            .await
1171            .expect("Send failed");
1172
1173        client.close().await;
1174        server_task.abort();
1175    }
1176}
1177
1178#[cfg(test)]
1179mod rust_tests {
1180    use tokio::{
1181        net::TcpListener,
1182        task,
1183        time::{Duration, sleep},
1184    };
1185
1186    use super::*;
1187
1188    #[tokio::test]
1189    async fn test_reconnect_then_close() {
1190        // Bind an ephemeral port
1191        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1192        let port = listener.local_addr().unwrap().port();
1193
1194        // Server task: accept one connection and then drop it
1195        let server = task::spawn(async move {
1196            if let Ok((mut sock, _)) = listener.accept().await {
1197                drop(sock.shutdown());
1198            }
1199            // Keep listener alive briefly to avoid premature exit
1200            sleep(Duration::from_secs(1)).await;
1201        });
1202
1203        // Configure client with a short reconnect backoff
1204        let config = SocketConfig {
1205            url: format!("127.0.0.1:{port}"),
1206            mode: Mode::Plain,
1207            suffix: b"\r\n".to_vec(),
1208            message_handler: None,
1209            heartbeat: None,
1210            reconnect_timeout_ms: Some(1_000),
1211            reconnect_delay_initial_ms: Some(50),
1212            reconnect_delay_max_ms: Some(100),
1213            reconnect_backoff_factor: Some(1.0),
1214            reconnect_jitter_ms: Some(0),
1215            certs_dir: None,
1216        };
1217
1218        // Connect client (handler=None)
1219        let client = {
1220            #[cfg(feature = "python")]
1221            {
1222                SocketClient::connect(config.clone(), None, None, None)
1223                    .await
1224                    .unwrap()
1225            }
1226            #[cfg(not(feature = "python"))]
1227            {
1228                SocketClient::connect(config.clone(), None, None, None)
1229                    .await
1230                    .unwrap()
1231            }
1232        };
1233
1234        // Allow server to drop connection and client to notice
1235        sleep(Duration::from_millis(100)).await;
1236
1237        // Now close the client
1238        client.close().await;
1239        assert!(client.is_closed());
1240        server.abort();
1241    }
1242}