nautilus_network/socket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance raw TCP client implementation with TLS capability, automatic reconnection
17//! with exponential backoff and state management.
18//!
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED).
21//! - Synchronized reconnection with backoff.
22//! - Split read/write architecture.
23//! - Python callback integration.
24//!
25//! **Design**:
26//! - Single reader, multiple writer model.
27//! - Read half runs in dedicated task.
28//! - Write half runs in dedicated task connected with channel.
29//! - Controller task manages lifecycle.
30
31use std::{
32    collections::VecDeque,
33    fmt::Debug,
34    path::Path,
35    sync::{
36        Arc,
37        atomic::{AtomicU8, Ordering},
38    },
39    time::Duration,
40};
41
42use bytes::Bytes;
43use nautilus_core::CleanDrop;
44use nautilus_cryptography::providers::install_cryptographic_provider;
45use tokio::io::{AsyncReadExt, AsyncWriteExt};
46use tokio_tungstenite::tungstenite::{Error, client::IntoClientRequest, stream::Mode};
47
48use super::{
49    SocketConfig, TcpMessageHandler, TcpReader, TcpWriter, WriterCommand, fix::process_fix_buffer,
50};
51use crate::{
52    backoff::ExponentialBackoff,
53    error::SendError,
54    logging::{log_task_aborted, log_task_started, log_task_stopped},
55    mode::ConnectionMode,
56    net::TcpStream,
57    tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
58};
59
60// Connection timing constants
61const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
62const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
63const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
64const SEND_OPERATION_CHECK_INTERVAL_MS: u64 = 1;
65
66/// Creates a `TcpStream` with the server.
67///
68/// The stream can be encrypted with TLS or Plain. The stream is split into
69/// read and write ends:
70/// - The read end is passed to the task that keeps receiving
71///   messages from the server and passing them to a handler.
72/// - The write end is passed to a task which receives messages over a channel
73///   to send to the server.
74///
75/// The heartbeat is optional and can be configured with an interval and data to
76/// send.
77///
78/// The client uses a suffix to separate messages on the byte stream. It is
79/// appended to all sent messages and heartbeats. It is also used to split
80/// the received byte stream.
81#[cfg_attr(
82    feature = "python",
83    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
84)]
85struct SocketClientInner {
86    config: SocketConfig,
87    connector: Option<Connector>,
88    read_task: Arc<tokio::task::JoinHandle<()>>,
89    write_task: tokio::task::JoinHandle<()>,
90    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
91    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
92    connection_mode: Arc<AtomicU8>,
93    reconnect_timeout: Duration,
94    backoff: ExponentialBackoff,
95    handler: Option<TcpMessageHandler>,
96    reconnect_max_attempts: Option<u32>,
97    reconnect_attempt_count: u32,
98}
99
100impl SocketClientInner {
101    /// Connect to a URL with the specified configuration.
102    ///
103    /// # Errors
104    ///
105    /// Returns an error if connection fails or configuration is invalid.
106    pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
107        install_cryptographic_provider();
108
109        let SocketConfig {
110            url,
111            mode,
112            heartbeat,
113            suffix,
114            message_handler,
115            reconnect_timeout_ms,
116            reconnect_delay_initial_ms,
117            reconnect_delay_max_ms,
118            reconnect_backoff_factor,
119            reconnect_jitter_ms,
120            connection_max_retries,
121            reconnect_max_attempts,
122            certs_dir,
123        } = &config.clone();
124        let connector = if let Some(dir) = certs_dir {
125            let config = create_tls_config_from_certs_dir(Path::new(dir), false)?;
126            Some(Connector::Rustls(Arc::new(config)))
127        } else {
128            None
129        };
130
131        // Retry initial connection with exponential backoff to handle transient DNS/network issues
132        const CONNECTION_TIMEOUT_SECS: u64 = 10;
133        let max_retries = connection_max_retries.unwrap_or(5);
134
135        let mut backoff = ExponentialBackoff::new(
136            Duration::from_millis(500),
137            Duration::from_millis(5000),
138            2.0,
139            250,
140            false,
141        )?;
142
143        #[allow(unused_assignments)]
144        let mut last_error = String::new();
145        let mut attempt = 0;
146        let (reader, writer) = loop {
147            attempt += 1;
148
149            match tokio::time::timeout(
150                Duration::from_secs(CONNECTION_TIMEOUT_SECS),
151                Self::tls_connect_with_server(url, *mode, connector.clone()),
152            )
153            .await
154            {
155                Ok(Ok(result)) => {
156                    if attempt > 1 {
157                        tracing::info!("Socket connection established after {attempt} attempts");
158                    }
159                    break result;
160                }
161                Ok(Err(e)) => {
162                    last_error = e.to_string();
163                    tracing::warn!(
164                        attempt,
165                        max_retries,
166                        url = %url,
167                        error = %last_error,
168                        "Socket connection attempt failed"
169                    );
170                }
171                Err(_) => {
172                    last_error = format!(
173                        "Connection timeout after {CONNECTION_TIMEOUT_SECS}s (possible DNS resolution failure)"
174                    );
175                    tracing::warn!(
176                        attempt,
177                        max_retries,
178                        url = %url,
179                        "Socket connection attempt timed out"
180                    );
181                }
182            }
183
184            if attempt >= max_retries {
185                anyhow::bail!(
186                    "Failed to connect to {} after {} attempts: {}. \
187                    If this is a DNS error, check your network configuration and DNS settings.",
188                    url,
189                    max_retries,
190                    if last_error.is_empty() {
191                        "unknown error"
192                    } else {
193                        &last_error
194                    }
195                );
196            }
197
198            let delay = backoff.next_duration();
199            tracing::debug!(
200                "Retrying in {delay:?} (attempt {}/{})",
201                attempt + 1,
202                max_retries
203            );
204            tokio::time::sleep(delay).await;
205        };
206
207        tracing::debug!("Connected");
208
209        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
210
211        let read_task = Arc::new(Self::spawn_read_task(
212            connection_mode.clone(),
213            reader,
214            message_handler.clone(),
215            suffix.clone(),
216        ));
217
218        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
219
220        let write_task =
221            Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
222
223        // Optionally spawn a heartbeat task to periodically ping server
224        let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
225            Self::spawn_heartbeat_task(
226                connection_mode.clone(),
227                heartbeat.clone(),
228                writer_tx.clone(),
229            )
230        });
231
232        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
233        let backoff = ExponentialBackoff::new(
234            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
235            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
236            reconnect_backoff_factor.unwrap_or(1.5),
237            reconnect_jitter_ms.unwrap_or(100),
238            true, // immediate-first
239        )?;
240
241        Ok(Self {
242            config,
243            connector,
244            read_task,
245            write_task,
246            writer_tx,
247            heartbeat_task,
248            connection_mode,
249            reconnect_timeout,
250            backoff,
251            handler: message_handler.clone(),
252            reconnect_max_attempts: *reconnect_max_attempts,
253            reconnect_attempt_count: 0,
254        })
255    }
256
257    /// Parse URL and extract socket address and request URL.
258    ///
259    /// Accepts either:
260    /// - Raw socket address: "host:port" → returns ("host:port", "scheme://host:port")
261    /// - Full URL: "scheme://host:port/path" → returns ("host:port", original URL)
262    ///
263    /// # Errors
264    ///
265    /// Returns an error if the URL is invalid or missing required components.
266    fn parse_socket_url(url: &str, mode: Mode) -> Result<(String, String), Error> {
267        if url.contains("://") {
268            // URL with scheme (e.g., "wss://host:port/path")
269            let parsed = url.parse::<http::Uri>().map_err(|e| {
270                Error::Io(std::io::Error::new(
271                    std::io::ErrorKind::InvalidInput,
272                    format!("Invalid URL: {e}"),
273                ))
274            })?;
275
276            let host = parsed.host().ok_or_else(|| {
277                Error::Io(std::io::Error::new(
278                    std::io::ErrorKind::InvalidInput,
279                    "URL missing host",
280                ))
281            })?;
282
283            let port = parsed
284                .port_u16()
285                .unwrap_or_else(|| match parsed.scheme_str() {
286                    Some("wss" | "https") => 443,
287                    Some("ws" | "http") => 80,
288                    _ => match mode {
289                        Mode::Tls => 443,
290                        Mode::Plain => 80,
291                    },
292                });
293
294            Ok((format!("{host}:{port}"), url.to_string()))
295        } else {
296            // Raw socket address (e.g., "host:port")
297            // Construct a proper URL for the request based on mode
298            let scheme = match mode {
299                Mode::Tls => "wss",
300                Mode::Plain => "ws",
301            };
302            Ok((url.to_string(), format!("{scheme}://{url}")))
303        }
304    }
305
306    /// Establish a TLS or plain TCP connection with the server.
307    ///
308    /// Accepts either a raw socket address (e.g., "host:port") or a full URL with scheme
309    /// (e.g., "wss://host:port"). For FIX/raw socket connections, use the host:port format.
310    /// For WebSocket-style connections, include the scheme.
311    ///
312    /// # Errors
313    ///
314    /// Returns an error if the connection cannot be established.
315    pub async fn tls_connect_with_server(
316        url: &str,
317        mode: Mode,
318        connector: Option<Connector>,
319    ) -> Result<(TcpReader, TcpWriter), Error> {
320        tracing::debug!("Connecting to {url}");
321
322        let (socket_addr, request_url) = Self::parse_socket_url(url, mode)?;
323        let tcp_result = TcpStream::connect(&socket_addr).await;
324
325        match tcp_result {
326            Ok(stream) => {
327                tracing::debug!("TCP connection established to {socket_addr}, proceeding with TLS");
328                if let Err(e) = stream.set_nodelay(true) {
329                    tracing::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
330                }
331                let request = request_url.into_client_request()?;
332                tcp_tls(&request, mode, stream, connector)
333                    .await
334                    .map(tokio::io::split)
335            }
336            Err(e) => {
337                tracing::error!("TCP connection failed to {socket_addr}: {e:?}");
338                Err(Error::Io(e))
339            }
340        }
341    }
342
343    /// Reconnect with server.
344    ///
345    /// Makes a new connection with server, uses the new read and write halves
346    /// to update the reader and writer.
347    async fn reconnect(&mut self) -> Result<(), Error> {
348        tracing::debug!("Reconnecting");
349
350        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
351            tracing::debug!("Reconnect aborted due to disconnect state");
352            return Ok(());
353        }
354
355        tokio::time::timeout(self.reconnect_timeout, async {
356            let SocketConfig {
357                url,
358                mode,
359                heartbeat: _,
360                suffix,
361                message_handler: _,
362                reconnect_timeout_ms: _,
363                reconnect_delay_initial_ms: _,
364                reconnect_backoff_factor: _,
365                reconnect_delay_max_ms: _,
366                reconnect_jitter_ms: _,
367                connection_max_retries: _,
368                reconnect_max_attempts: _,
369                certs_dir: _,
370            } = &self.config;
371            // Create a fresh connection
372            let connector = self.connector.clone();
373            // Attempt to connect; abort early if a disconnect was requested
374            let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
375
376            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
377                tracing::debug!("Reconnect aborted mid-flight (after connect)");
378                return Ok(());
379            }
380            tracing::debug!("Connected");
381
382            // Use a oneshot channel to synchronize with the writer task.
383            // We must verify that the buffer was successfully drained before transitioning to ACTIVE
384            // to prevent silent message loss if the new connection drops immediately.
385            let (tx, rx) = tokio::sync::oneshot::channel();
386            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
387                tracing::error!("{e}");
388                return Err(Error::Io(std::io::Error::new(
389                    std::io::ErrorKind::BrokenPipe,
390                    format!("Failed to send update command: {e}"),
391                )));
392            }
393
394            // Wait for writer to confirm it has drained the buffer
395            match rx.await {
396                Ok(true) => tracing::debug!("Writer confirmed buffer drain success"),
397                Ok(false) => {
398                    tracing::warn!("Writer failed to drain buffer, aborting reconnect");
399                    // Return error to trigger retry logic in controller
400                    return Err(Error::Io(std::io::Error::other(
401                        "Failed to drain reconnection buffer",
402                    )));
403                }
404                Err(e) => {
405                    tracing::error!("Writer dropped update channel: {e}");
406                    return Err(Error::Io(std::io::Error::new(
407                        std::io::ErrorKind::BrokenPipe,
408                        "Writer task dropped response channel",
409                    )));
410                }
411            }
412
413            // Delay before closing connection
414            tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
415
416            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
417                tracing::debug!("Reconnect aborted mid-flight (after delay)");
418                return Ok(());
419            }
420
421            if !self.read_task.is_finished() {
422                self.read_task.abort();
423                log_task_aborted("read");
424            }
425
426            // Atomically transition from Reconnect to Active
427            // This prevents race condition where disconnect could be requested between check and store
428            if self
429                .connection_mode
430                .compare_exchange(
431                    ConnectionMode::Reconnect.as_u8(),
432                    ConnectionMode::Active.as_u8(),
433                    Ordering::SeqCst,
434                    Ordering::SeqCst,
435                )
436                .is_err()
437            {
438                tracing::debug!("Reconnect aborted (state changed during reconnect)");
439                return Ok(());
440            }
441
442            // Spawn new read task
443            self.read_task = Arc::new(Self::spawn_read_task(
444                self.connection_mode.clone(),
445                reader,
446                self.handler.clone(),
447                suffix.clone(),
448            ));
449
450            tracing::debug!("Reconnect succeeded");
451            Ok(())
452        })
453        .await
454        .map_err(|_| {
455            Error::Io(std::io::Error::new(
456                std::io::ErrorKind::TimedOut,
457                format!(
458                    "reconnection timed out after {}s",
459                    self.reconnect_timeout.as_secs_f64()
460                ),
461            ))
462        })?
463    }
464
465    /// Check if the client is still alive.
466    ///
467    /// The client is connected if the read task has not finished. It is expected
468    /// that in case of any failure client or server side. The read task will be
469    /// shutdown. There might be some delay between the connection being closed
470    /// and the client detecting it.
471    #[inline]
472    #[must_use]
473    pub fn is_alive(&self) -> bool {
474        !self.read_task.is_finished()
475    }
476
477    #[must_use]
478    fn spawn_read_task(
479        connection_state: Arc<AtomicU8>,
480        mut reader: TcpReader,
481        handler: Option<TcpMessageHandler>,
482        suffix: Vec<u8>,
483    ) -> tokio::task::JoinHandle<()> {
484        log_task_started("read");
485
486        // Interval between checking the connection mode
487        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
488
489        tokio::task::spawn(async move {
490            let mut buf = Vec::new();
491
492            loop {
493                if !ConnectionMode::from_atomic(&connection_state).is_active() {
494                    break;
495                }
496
497                match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
498                    // Connection has been terminated or vector buffer is complete
499                    Ok(Ok(0)) => {
500                        tracing::debug!("Connection closed by server");
501                        break;
502                    }
503                    Ok(Err(e)) => {
504                        tracing::debug!("Connection ended: {e}");
505                        break;
506                    }
507                    // Received bytes of data
508                    Ok(Ok(bytes)) => {
509                        tracing::trace!("Received <binary> {bytes} bytes");
510
511                        // Check if buffer contains FIX protocol messages (starts with "8=FIX")
512                        let is_fix = buf.len() >= 5 && buf.starts_with(b"8=FIX");
513
514                        if is_fix && handler.is_some() {
515                            // FIX protocol processing
516                            if let Some(ref handler) = handler {
517                                process_fix_buffer(&mut buf, handler);
518                            }
519                        } else {
520                            // Regular suffix-based message processing
521                            while let Some((i, _)) = &buf
522                                .windows(suffix.len())
523                                .enumerate()
524                                .find(|(_, pair)| pair.eq(&suffix))
525                            {
526                                let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
527                                data.truncate(data.len() - suffix.len());
528
529                                if let Some(ref handler) = handler {
530                                    handler(&data);
531                                }
532                            }
533                        }
534                    }
535                    Err(_) => {
536                        // Timeout - continue loop and check connection mode
537                        continue;
538                    }
539                }
540            }
541
542            log_task_stopped("read");
543        })
544    }
545
546    /// Drains buffered messages after reconnection completes.
547    ///
548    /// Attempts to send all buffered messages that were queued during reconnection.
549    /// Uses a peek-and-pop pattern to preserve messages if sending fails midway through the buffer.
550    ///
551    /// # Returns
552    ///
553    /// Returns `true` if a send error occurred (buffer may still contain unsent messages),
554    /// `false` if all messages were sent successfully (buffer is empty).
555    async fn drain_reconnect_buffer(
556        buffer: &mut VecDeque<Bytes>,
557        writer: &mut TcpWriter,
558        suffix: &[u8],
559    ) -> bool {
560        if buffer.is_empty() {
561            return false;
562        }
563
564        let initial_buffer_len = buffer.len();
565        tracing::info!(
566            "Sending {} buffered messages after reconnection",
567            initial_buffer_len
568        );
569
570        let mut send_error_occurred = false;
571
572        while let Some(buffered_msg) = buffer.front() {
573            let mut combined_msg = Vec::with_capacity(buffered_msg.len() + suffix.len());
574            combined_msg.extend_from_slice(buffered_msg);
575            combined_msg.extend_from_slice(suffix);
576
577            if let Err(e) = writer.write_all(&combined_msg).await {
578                tracing::error!(
579                    "Failed to send buffered message with suffix after reconnection: {e}, {} messages remain in buffer",
580                    buffer.len()
581                );
582                send_error_occurred = true;
583                break;
584            }
585
586            buffer.pop_front();
587        }
588
589        if buffer.is_empty() {
590            tracing::info!(
591                "Successfully sent all {} buffered messages",
592                initial_buffer_len
593            );
594        }
595
596        send_error_occurred
597    }
598
599    fn spawn_write_task(
600        connection_state: Arc<AtomicU8>,
601        writer: TcpWriter,
602        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
603        suffix: Vec<u8>,
604    ) -> tokio::task::JoinHandle<()> {
605        log_task_started("write");
606
607        // Interval between checking the connection mode
608        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
609
610        tokio::task::spawn(async move {
611            let mut active_writer = writer;
612            let mut reconnect_buffer: VecDeque<Bytes> = VecDeque::new();
613
614            loop {
615                if matches!(
616                    ConnectionMode::from_atomic(&connection_state),
617                    ConnectionMode::Disconnect | ConnectionMode::Closed
618                ) {
619                    break;
620                }
621
622                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
623                    Ok(Some(msg)) => {
624                        // Re-check connection mode after receiving a message
625                        let mode = ConnectionMode::from_atomic(&connection_state);
626                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
627                            break;
628                        }
629
630                        match msg {
631                            WriterCommand::Update(new_writer, tx) => {
632                                tracing::debug!("Received new writer");
633
634                                // Delay before closing connection
635                                tokio::time::sleep(Duration::from_millis(100)).await;
636
637                                // Attempt to shutdown the writer gracefully before updating,
638                                // we ignore any error as the writer may already be closed.
639                                _ = tokio::time::timeout(
640                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
641                                    active_writer.shutdown(),
642                                )
643                                .await;
644
645                                active_writer = new_writer;
646                                tracing::debug!("Updated writer");
647
648                                let send_error = Self::drain_reconnect_buffer(
649                                    &mut reconnect_buffer,
650                                    &mut active_writer,
651                                    &suffix,
652                                )
653                                .await;
654
655                                if let Err(e) = tx.send(!send_error) {
656                                    tracing::error!(
657                                        "Failed to report drain status to controller: {e:?}"
658                                    );
659                                }
660                            }
661                            _ if mode.is_reconnect() => {
662                                if let WriterCommand::Send(data) = msg {
663                                    tracing::debug!(
664                                        "Buffering message while reconnecting ({} bytes)",
665                                        data.len()
666                                    );
667                                    reconnect_buffer.push_back(data);
668                                }
669                                continue;
670                            }
671                            WriterCommand::Send(msg) => {
672                                if let Err(e) = active_writer.write_all(&msg).await {
673                                    tracing::error!("Failed to send message: {e}");
674                                    tracing::warn!("Writer triggering reconnect");
675                                    reconnect_buffer.push_back(msg);
676                                    connection_state
677                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
678                                    continue;
679                                }
680                                if let Err(e) = active_writer.write_all(&suffix).await {
681                                    tracing::error!("Failed to send suffix: {e}");
682                                    tracing::warn!("Writer triggering reconnect");
683                                    // Buffer this message before triggering reconnect since suffix failed
684                                    reconnect_buffer.push_back(msg);
685                                    connection_state
686                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
687                                    continue;
688                                }
689                            }
690                        }
691                    }
692                    Ok(None) => {
693                        // Channel closed - writer task should terminate
694                        tracing::debug!("Writer channel closed, terminating writer task");
695                        break;
696                    }
697                    Err(_) => {
698                        // Timeout - just continue the loop
699                        continue;
700                    }
701                }
702            }
703
704            // Attempt to shutdown the writer gracefully before exiting,
705            // we ignore any error as the writer may already be closed.
706            _ = tokio::time::timeout(
707                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
708                active_writer.shutdown(),
709            )
710            .await;
711
712            log_task_stopped("write");
713        })
714    }
715
716    fn spawn_heartbeat_task(
717        connection_state: Arc<AtomicU8>,
718        heartbeat: (u64, Vec<u8>),
719        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
720    ) -> tokio::task::JoinHandle<()> {
721        log_task_started("heartbeat");
722        let (interval_secs, message) = heartbeat;
723
724        tokio::task::spawn(async move {
725            let interval = Duration::from_secs(interval_secs);
726
727            loop {
728                tokio::time::sleep(interval).await;
729
730                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
731                    ConnectionMode::Active => {
732                        let msg = WriterCommand::Send(message.clone().into());
733
734                        match writer_tx.send(msg) {
735                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
736                            Err(e) => {
737                                tracing::error!("Failed to send heartbeat to writer task: {e}");
738                            }
739                        }
740                    }
741                    ConnectionMode::Reconnect => continue,
742                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
743                }
744            }
745
746            log_task_stopped("heartbeat");
747        })
748    }
749}
750
751impl Drop for SocketClientInner {
752    fn drop(&mut self) {
753        // Delegate to explicit cleanup handler
754        self.clean_drop();
755    }
756}
757
758/// Cleanup on drop: aborts background tasks and clears handlers to break reference cycles.
759impl CleanDrop for SocketClientInner {
760    fn clean_drop(&mut self) {
761        if !self.read_task.is_finished() {
762            self.read_task.abort();
763            log_task_aborted("read");
764        }
765
766        if !self.write_task.is_finished() {
767            self.write_task.abort();
768            log_task_aborted("write");
769        }
770
771        if let Some(ref handle) = self.heartbeat_task.take()
772            && !handle.is_finished()
773        {
774            handle.abort();
775            log_task_aborted("heartbeat");
776        }
777
778        #[cfg(feature = "python")]
779        {
780            // Remove stored handler to break ref cycle
781            self.config.message_handler = None;
782        }
783    }
784}
785
786#[cfg_attr(
787    feature = "python",
788    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
789)]
790pub struct SocketClient {
791    pub(crate) controller_task: tokio::task::JoinHandle<()>,
792    pub(crate) connection_mode: Arc<AtomicU8>,
793    pub(crate) reconnect_timeout: Duration,
794    pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
795}
796
797impl Debug for SocketClient {
798    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
799        f.debug_struct(stringify!(SocketClient)).finish()
800    }
801}
802
803impl SocketClient {
804    /// Connect to the server.
805    ///
806    /// # Errors
807    ///
808    /// Returns any error connecting to the server.
809    pub async fn connect(
810        config: SocketConfig,
811        post_connection: Option<Arc<dyn Fn() + Send + Sync>>,
812        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
813        post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
814    ) -> anyhow::Result<Self> {
815        let inner = SocketClientInner::connect_url(config).await?;
816        let writer_tx = inner.writer_tx.clone();
817        let connection_mode = inner.connection_mode.clone();
818        let reconnect_timeout = inner.reconnect_timeout;
819
820        let controller_task = Self::spawn_controller_task(
821            inner,
822            connection_mode.clone(),
823            post_reconnection,
824            post_disconnection,
825        );
826
827        if let Some(handler) = post_connection {
828            handler();
829            tracing::debug!("Called `post_connection` handler");
830        }
831
832        Ok(Self {
833            controller_task,
834            connection_mode,
835            reconnect_timeout,
836            writer_tx,
837        })
838    }
839
840    /// Returns the current connection mode.
841    #[must_use]
842    pub fn connection_mode(&self) -> ConnectionMode {
843        ConnectionMode::from_atomic(&self.connection_mode)
844    }
845
846    /// Check if the client connection is active.
847    ///
848    /// Returns `true` if the client is connected and has not been signalled to disconnect.
849    /// The client will automatically retry connection based on its configuration.
850    #[inline]
851    #[must_use]
852    pub fn is_active(&self) -> bool {
853        self.connection_mode().is_active()
854    }
855
856    /// Check if the client is reconnecting.
857    ///
858    /// Returns `true` if the client lost connection and is attempting to reestablish it.
859    /// The client will automatically retry connection based on its configuration.
860    #[inline]
861    #[must_use]
862    pub fn is_reconnecting(&self) -> bool {
863        self.connection_mode().is_reconnect()
864    }
865
866    /// Check if the client is disconnecting.
867    ///
868    /// Returns `true` if the client is in disconnect mode.
869    #[inline]
870    #[must_use]
871    pub fn is_disconnecting(&self) -> bool {
872        self.connection_mode().is_disconnect()
873    }
874
875    /// Check if the client is closed.
876    ///
877    /// Returns `true` if the client has been explicitly disconnected or reached
878    /// maximum reconnection attempts. In this state, the client cannot be reused
879    /// and a new client must be created for further connections.
880    #[inline]
881    #[must_use]
882    pub fn is_closed(&self) -> bool {
883        self.connection_mode().is_closed()
884    }
885
886    /// Close the client.
887    ///
888    /// Controller task will periodically check the disconnect mode
889    /// and shutdown the client if it is not alive.
890    pub async fn close(&self) {
891        self.connection_mode
892            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
893
894        if let Ok(()) =
895            tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
896                while !self.is_closed() {
897                    tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS))
898                        .await;
899                }
900
901                if !self.controller_task.is_finished() {
902                    self.controller_task.abort();
903                    log_task_aborted("controller");
904                }
905            })
906            .await
907        {
908            log_task_stopped("controller");
909        } else {
910            tracing::error!("Timeout waiting for controller task to finish");
911            if !self.controller_task.is_finished() {
912                self.controller_task.abort();
913                log_task_aborted("controller");
914            }
915        }
916    }
917
918    /// Sends a message of the given `data`.
919    ///
920    /// # Errors
921    ///
922    /// Returns an error if sending fails.
923    pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
924        // Check connection state to fail fast
925        if self.is_closed() || self.is_disconnecting() {
926            return Err(SendError::Closed);
927        }
928
929        let timeout = self.reconnect_timeout;
930        let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
931
932        if !self.is_active() {
933            tracing::debug!("Waiting for client to become ACTIVE before sending...");
934
935            let inner = tokio::time::timeout(timeout, async {
936                loop {
937                    if self.is_active() {
938                        return Ok(());
939                    }
940                    if matches!(
941                        self.connection_mode(),
942                        ConnectionMode::Disconnect | ConnectionMode::Closed
943                    ) {
944                        return Err(());
945                    }
946                    tokio::time::sleep(check_interval).await;
947                }
948            })
949            .await
950            .map_err(|_| SendError::Timeout)?;
951            inner.map_err(|()| SendError::Closed)?;
952        }
953
954        let msg = WriterCommand::Send(data.into());
955        self.writer_tx
956            .send(msg)
957            .map_err(|e| SendError::BrokenPipe(e.to_string()))
958    }
959
960    fn spawn_controller_task(
961        mut inner: SocketClientInner,
962        connection_mode: Arc<AtomicU8>,
963        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
964        post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
965    ) -> tokio::task::JoinHandle<()> {
966        tokio::task::spawn(async move {
967            log_task_started("controller");
968
969            let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
970
971            loop {
972                tokio::time::sleep(check_interval).await;
973                let mut mode = ConnectionMode::from_atomic(&connection_mode);
974
975                if mode.is_disconnect() {
976                    tracing::debug!("Disconnecting");
977
978                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
979                    if tokio::time::timeout(timeout, async {
980                        // Delay awaiting graceful shutdown
981                        tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
982
983                        if !inner.read_task.is_finished() {
984                            inner.read_task.abort();
985                            log_task_aborted("read");
986                        }
987
988                        if let Some(task) = &inner.heartbeat_task
989                            && !task.is_finished()
990                        {
991                            task.abort();
992                            log_task_aborted("heartbeat");
993                        }
994                    })
995                    .await
996                    .is_err()
997                    {
998                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
999                    }
1000
1001                    tracing::debug!("Closed");
1002
1003                    if let Some(ref handler) = post_disconnection {
1004                        handler();
1005                        tracing::debug!("Called `post_disconnection` handler");
1006                    }
1007                    break; // Controller finished
1008                }
1009
1010                if mode.is_active() && !inner.is_alive() {
1011                    if connection_mode
1012                        .compare_exchange(
1013                            ConnectionMode::Active.as_u8(),
1014                            ConnectionMode::Reconnect.as_u8(),
1015                            Ordering::SeqCst,
1016                            Ordering::SeqCst,
1017                        )
1018                        .is_ok()
1019                    {
1020                        tracing::debug!("Detected dead read task, transitioning to RECONNECT");
1021                    }
1022                    mode = ConnectionMode::from_atomic(&connection_mode);
1023                }
1024
1025                if mode.is_reconnect() {
1026                    // Check max reconnection attempts before attempting reconnect
1027                    if let Some(max_attempts) = inner.reconnect_max_attempts
1028                        && inner.reconnect_attempt_count >= max_attempts
1029                    {
1030                        tracing::error!(
1031                            "Max reconnection attempts ({}) exceeded, transitioning to CLOSED",
1032                            max_attempts
1033                        );
1034                        connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1035                        break;
1036                    }
1037
1038                    inner.reconnect_attempt_count += 1;
1039                    match inner.reconnect().await {
1040                        Ok(()) => {
1041                            tracing::debug!("Reconnected successfully");
1042                            inner.backoff.reset();
1043                            inner.reconnect_attempt_count = 0; // Reset counter on success
1044                            // Only invoke reconnect handler if still active
1045                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
1046                                if let Some(ref handler) = post_reconnection {
1047                                    handler();
1048                                    tracing::debug!("Called `post_reconnection` handler");
1049                                }
1050                            } else {
1051                                tracing::debug!(
1052                                    "Skipping post_reconnection handlers due to disconnect state"
1053                                );
1054                            }
1055                        }
1056                        Err(e) => {
1057                            let duration = inner.backoff.next_duration();
1058                            tracing::warn!(
1059                                "Reconnect attempt {} failed: {e}",
1060                                inner.reconnect_attempt_count
1061                            );
1062                            if !duration.is_zero() {
1063                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
1064                            }
1065                            tokio::time::sleep(duration).await;
1066                        }
1067                    }
1068                }
1069            }
1070            inner
1071                .connection_mode
1072                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1073
1074            log_task_stopped("controller");
1075        })
1076    }
1077}
1078
1079// Abort controller task on drop to clean up background tasks
1080impl Drop for SocketClient {
1081    fn drop(&mut self) {
1082        if !self.controller_task.is_finished() {
1083            self.controller_task.abort();
1084            log_task_aborted("controller");
1085        }
1086    }
1087}
1088
1089////////////////////////////////////////////////////////////////////////////////
1090// Tests
1091////////////////////////////////////////////////////////////////////////////////
1092
1093#[cfg(test)]
1094#[cfg(feature = "python")]
1095#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
1096mod tests {
1097    use nautilus_common::testing::wait_until_async;
1098    use pyo3::Python;
1099    use tokio::{
1100        io::{AsyncReadExt, AsyncWriteExt},
1101        net::{TcpListener, TcpStream},
1102        sync::Mutex,
1103        task,
1104        time::{Duration, sleep},
1105    };
1106
1107    use super::*;
1108
1109    async fn bind_test_server() -> (u16, TcpListener) {
1110        let listener = TcpListener::bind("127.0.0.1:0")
1111            .await
1112            .expect("Failed to bind ephemeral port");
1113        let port = listener.local_addr().unwrap().port();
1114        (port, listener)
1115    }
1116
1117    async fn run_echo_server(mut socket: TcpStream) {
1118        let mut buf = Vec::new();
1119        loop {
1120            match socket.read_buf(&mut buf).await {
1121                Ok(0) => {
1122                    break;
1123                }
1124                Ok(_n) => {
1125                    while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1126                        let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1127                        // Remove trailing \r\n
1128                        line.truncate(line.len() - 2);
1129
1130                        if line == b"close" {
1131                            let _ = socket.shutdown().await;
1132                            return;
1133                        }
1134
1135                        let mut echo_data = line;
1136                        echo_data.extend_from_slice(b"\r\n");
1137                        if socket.write_all(&echo_data).await.is_err() {
1138                            break;
1139                        }
1140                    }
1141                }
1142                Err(e) => {
1143                    eprintln!("Server read error: {e}");
1144                    break;
1145                }
1146            }
1147        }
1148    }
1149
1150    #[tokio::test]
1151    async fn test_basic_send_receive() {
1152        Python::initialize();
1153
1154        let (port, listener) = bind_test_server().await;
1155        let server_task = task::spawn(async move {
1156            let (socket, _) = listener.accept().await.unwrap();
1157            run_echo_server(socket).await;
1158        });
1159
1160        let config = SocketConfig {
1161            url: format!("127.0.0.1:{port}"),
1162            mode: Mode::Plain,
1163            suffix: b"\r\n".to_vec(),
1164            message_handler: None,
1165            heartbeat: None,
1166            reconnect_timeout_ms: None,
1167            reconnect_delay_initial_ms: None,
1168            reconnect_backoff_factor: None,
1169            reconnect_delay_max_ms: None,
1170            reconnect_jitter_ms: None,
1171            reconnect_max_attempts: None,
1172            connection_max_retries: None,
1173            certs_dir: None,
1174        };
1175
1176        let client = SocketClient::connect(config, None, None, None)
1177            .await
1178            .expect("Client connect failed unexpectedly");
1179
1180        client.send_bytes(b"Hello".into()).await.unwrap();
1181        client.send_bytes(b"World".into()).await.unwrap();
1182
1183        // Wait a bit for the server to echo them back
1184        sleep(Duration::from_millis(100)).await;
1185
1186        client.send_bytes(b"close".into()).await.unwrap();
1187        server_task.await.unwrap();
1188        assert!(!client.is_closed());
1189    }
1190
1191    #[tokio::test]
1192    async fn test_reconnect_fail_exhausted() {
1193        Python::initialize();
1194
1195        let (port, listener) = bind_test_server().await;
1196        drop(listener); // We drop it immediately -> no server is listening
1197
1198        // Wait until port is truly unavailable (OS has released it)
1199        wait_until_async(
1200            || async {
1201                TcpStream::connect(format!("127.0.0.1:{port}"))
1202                    .await
1203                    .is_err()
1204            },
1205            Duration::from_secs(2),
1206        )
1207        .await;
1208
1209        let config = SocketConfig {
1210            url: format!("127.0.0.1:{port}"),
1211            mode: Mode::Plain,
1212            suffix: b"\r\n".to_vec(),
1213            message_handler: None,
1214            heartbeat: None,
1215            reconnect_timeout_ms: Some(100),
1216            reconnect_delay_initial_ms: Some(50),
1217            reconnect_backoff_factor: Some(1.0),
1218            reconnect_delay_max_ms: Some(50),
1219            reconnect_jitter_ms: Some(0),
1220            connection_max_retries: Some(1),
1221            reconnect_max_attempts: None,
1222            certs_dir: None,
1223        };
1224
1225        let client_res = SocketClient::connect(config, None, None, None).await;
1226        assert!(
1227            client_res.is_err(),
1228            "Should fail quickly with no server listening"
1229        );
1230    }
1231
1232    #[tokio::test]
1233    async fn test_user_disconnect() {
1234        Python::initialize();
1235
1236        let (port, listener) = bind_test_server().await;
1237        let server_task = task::spawn(async move {
1238            let (socket, _) = listener.accept().await.unwrap();
1239            let mut buf = [0u8; 1024];
1240            let _ = socket.try_read(&mut buf);
1241
1242            loop {
1243                sleep(Duration::from_secs(1)).await;
1244            }
1245        });
1246
1247        let config = SocketConfig {
1248            url: format!("127.0.0.1:{port}"),
1249            mode: Mode::Plain,
1250            suffix: b"\r\n".to_vec(),
1251            message_handler: None,
1252            heartbeat: None,
1253            reconnect_timeout_ms: None,
1254            reconnect_delay_initial_ms: None,
1255            reconnect_backoff_factor: None,
1256            reconnect_delay_max_ms: None,
1257            reconnect_jitter_ms: None,
1258            reconnect_max_attempts: None,
1259            connection_max_retries: None,
1260            certs_dir: None,
1261        };
1262
1263        let client = SocketClient::connect(config, None, None, None)
1264            .await
1265            .unwrap();
1266
1267        client.close().await;
1268        assert!(client.is_closed());
1269        server_task.abort();
1270    }
1271
1272    #[tokio::test]
1273    async fn test_heartbeat() {
1274        Python::initialize();
1275
1276        let (port, listener) = bind_test_server().await;
1277        let received = Arc::new(Mutex::new(Vec::new()));
1278        let received2 = received.clone();
1279
1280        let server_task = task::spawn(async move {
1281            let (socket, _) = listener.accept().await.unwrap();
1282
1283            let mut buf = Vec::new();
1284            loop {
1285                match socket.try_read_buf(&mut buf) {
1286                    Ok(0) => break,
1287                    Ok(_) => {
1288                        while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1289                            let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1290                            line.truncate(line.len() - 2);
1291                            received2.lock().await.push(line);
1292                        }
1293                    }
1294                    Err(_) => {
1295                        tokio::time::sleep(Duration::from_millis(10)).await;
1296                    }
1297                }
1298            }
1299        });
1300
1301        // Heartbeat every 1 second
1302        let heartbeat = Some((1, b"ping".to_vec()));
1303
1304        let config = SocketConfig {
1305            url: format!("127.0.0.1:{port}"),
1306            mode: Mode::Plain,
1307            suffix: b"\r\n".to_vec(),
1308            message_handler: None,
1309            heartbeat,
1310            reconnect_timeout_ms: None,
1311            reconnect_delay_initial_ms: None,
1312            reconnect_backoff_factor: None,
1313            reconnect_delay_max_ms: None,
1314            reconnect_jitter_ms: None,
1315            reconnect_max_attempts: None,
1316            connection_max_retries: None,
1317            certs_dir: None,
1318        };
1319
1320        let client = SocketClient::connect(config, None, None, None)
1321            .await
1322            .unwrap();
1323
1324        // Wait ~3 seconds to collect some heartbeats
1325        sleep(Duration::from_secs(3)).await;
1326
1327        {
1328            let lock = received.lock().await;
1329            let pings = lock
1330                .iter()
1331                .filter(|line| line == &&b"ping".to_vec())
1332                .count();
1333            assert!(
1334                pings >= 2,
1335                "Expected at least 2 heartbeat pings; got {pings}"
1336            );
1337        }
1338
1339        client.close().await;
1340        server_task.abort();
1341    }
1342
1343    #[tokio::test]
1344    async fn test_reconnect_success() {
1345        Python::initialize();
1346
1347        let (port, listener) = bind_test_server().await;
1348
1349        // Spawn a server task that:
1350        // 1. Accepts the first connection and then drops it after a short delay (simulate disconnect)
1351        // 2. Waits a bit and then accepts a new connection and runs the echo server
1352        let server_task = task::spawn(async move {
1353            // Accept first connection
1354            let (mut socket, _) = listener.accept().await.expect("First accept failed");
1355
1356            // Wait briefly and then force-close the connection
1357            sleep(Duration::from_millis(500)).await;
1358            let _ = socket.shutdown().await;
1359
1360            // Wait for the client's reconnect attempt
1361            sleep(Duration::from_millis(500)).await;
1362
1363            // Run the echo server on the new connection
1364            let (socket, _) = listener.accept().await.expect("Second accept failed");
1365            run_echo_server(socket).await;
1366        });
1367
1368        let config = SocketConfig {
1369            url: format!("127.0.0.1:{port}"),
1370            mode: Mode::Plain,
1371            suffix: b"\r\n".to_vec(),
1372            message_handler: None,
1373            heartbeat: None,
1374            reconnect_timeout_ms: Some(5_000),
1375            reconnect_delay_initial_ms: Some(500),
1376            reconnect_delay_max_ms: Some(5_000),
1377            reconnect_backoff_factor: Some(2.0),
1378            reconnect_jitter_ms: Some(50),
1379            reconnect_max_attempts: None,
1380            connection_max_retries: None,
1381            certs_dir: None,
1382        };
1383
1384        let client = SocketClient::connect(config, None, None, None)
1385            .await
1386            .expect("Client connect failed unexpectedly");
1387
1388        // Initially, the client should be active
1389        assert!(client.is_active(), "Client should start as active");
1390
1391        // Wait until the client loses connection (i.e. not active),
1392        // then wait until it reconnects (active again).
1393        wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1394
1395        client
1396            .send_bytes(b"TestReconnect".into())
1397            .await
1398            .expect("Send failed");
1399
1400        client.close().await;
1401        server_task.abort();
1402    }
1403}
1404
1405#[cfg(test)]
1406#[cfg(not(feature = "turmoil"))]
1407mod rust_tests {
1408    use nautilus_common::testing::wait_until_async;
1409    use rstest::rstest;
1410    use tokio::{
1411        io::{AsyncReadExt, AsyncWriteExt},
1412        net::TcpListener,
1413        task,
1414        time::{Duration, sleep},
1415    };
1416
1417    use super::*;
1418
1419    #[rstest]
1420    #[tokio::test]
1421    async fn test_reconnect_then_close() {
1422        // Bind an ephemeral port
1423        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1424        let port = listener.local_addr().unwrap().port();
1425
1426        // Server task: accept one connection and then drop it
1427        let server = task::spawn(async move {
1428            if let Ok((mut sock, _)) = listener.accept().await {
1429                drop(sock.shutdown());
1430            }
1431            // Keep listener alive briefly to avoid premature exit
1432            sleep(Duration::from_secs(1)).await;
1433        });
1434
1435        // Configure client with a short reconnect backoff
1436        let config = SocketConfig {
1437            url: format!("127.0.0.1:{port}"),
1438            mode: Mode::Plain,
1439            suffix: b"\r\n".to_vec(),
1440            message_handler: None,
1441            heartbeat: None,
1442            reconnect_timeout_ms: Some(1_000),
1443            reconnect_delay_initial_ms: Some(50),
1444            reconnect_delay_max_ms: Some(100),
1445            reconnect_backoff_factor: Some(1.0),
1446            reconnect_jitter_ms: Some(0),
1447            connection_max_retries: Some(1),
1448            reconnect_max_attempts: None,
1449            certs_dir: None,
1450        };
1451
1452        // Connect client (handler=None)
1453        let client = SocketClient::connect(config.clone(), None, None, None)
1454            .await
1455            .unwrap();
1456
1457        // Wait for client to detect dropped connection and enter reconnect state
1458        wait_until_async(
1459            || async { client.is_reconnecting() },
1460            Duration::from_secs(2),
1461        )
1462        .await;
1463
1464        // Now close the client
1465        client.close().await;
1466        assert!(client.is_closed());
1467        server.abort();
1468    }
1469
1470    #[rstest]
1471    #[tokio::test]
1472    async fn test_reconnect_state_flips_when_reader_stops() {
1473        // Bind an ephemeral port and accept a single connection which we immediately close.
1474        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1475        let port = listener.local_addr().unwrap().port();
1476
1477        let server = task::spawn(async move {
1478            if let Ok((sock, _)) = listener.accept().await {
1479                drop(sock);
1480            }
1481            // Give the client a moment to observe the closed connection.
1482            sleep(Duration::from_millis(50)).await;
1483        });
1484
1485        let config = SocketConfig {
1486            url: format!("127.0.0.1:{port}"),
1487            mode: Mode::Plain,
1488            suffix: b"\r\n".to_vec(),
1489            message_handler: None,
1490            heartbeat: None,
1491            reconnect_timeout_ms: Some(1_000),
1492            reconnect_delay_initial_ms: Some(50),
1493            reconnect_delay_max_ms: Some(100),
1494            reconnect_backoff_factor: Some(1.0),
1495            reconnect_jitter_ms: Some(0),
1496            connection_max_retries: Some(1),
1497            reconnect_max_attempts: None,
1498            certs_dir: None,
1499        };
1500
1501        let client = SocketClient::connect(config, None, None, None)
1502            .await
1503            .unwrap();
1504
1505        wait_until_async(
1506            || async { client.is_reconnecting() },
1507            Duration::from_secs(2),
1508        )
1509        .await;
1510
1511        client.close().await;
1512        server.abort();
1513    }
1514
1515    #[rstest]
1516    fn test_parse_socket_url_raw_address() {
1517        // Raw socket address with TLS mode
1518        let (socket_addr, request_url) =
1519            SocketClientInner::parse_socket_url("example.com:6130", Mode::Tls).unwrap();
1520        assert_eq!(socket_addr, "example.com:6130");
1521        assert_eq!(request_url, "wss://example.com:6130");
1522
1523        // Raw socket address with Plain mode
1524        let (socket_addr, request_url) =
1525            SocketClientInner::parse_socket_url("localhost:8080", Mode::Plain).unwrap();
1526        assert_eq!(socket_addr, "localhost:8080");
1527        assert_eq!(request_url, "ws://localhost:8080");
1528    }
1529
1530    #[rstest]
1531    fn test_parse_socket_url_with_scheme() {
1532        // Full URL with wss scheme
1533        let (socket_addr, request_url) =
1534            SocketClientInner::parse_socket_url("wss://example.com:443/path", Mode::Tls).unwrap();
1535        assert_eq!(socket_addr, "example.com:443");
1536        assert_eq!(request_url, "wss://example.com:443/path");
1537
1538        // Full URL with ws scheme
1539        let (socket_addr, request_url) =
1540            SocketClientInner::parse_socket_url("ws://localhost:8080", Mode::Plain).unwrap();
1541        assert_eq!(socket_addr, "localhost:8080");
1542        assert_eq!(request_url, "ws://localhost:8080");
1543    }
1544
1545    #[rstest]
1546    fn test_parse_socket_url_default_ports() {
1547        // wss without explicit port defaults to 443
1548        let (socket_addr, _) =
1549            SocketClientInner::parse_socket_url("wss://example.com", Mode::Tls).unwrap();
1550        assert_eq!(socket_addr, "example.com:443");
1551
1552        // ws without explicit port defaults to 80
1553        let (socket_addr, _) =
1554            SocketClientInner::parse_socket_url("ws://example.com", Mode::Plain).unwrap();
1555        assert_eq!(socket_addr, "example.com:80");
1556
1557        // https defaults to 443
1558        let (socket_addr, _) =
1559            SocketClientInner::parse_socket_url("https://example.com", Mode::Tls).unwrap();
1560        assert_eq!(socket_addr, "example.com:443");
1561
1562        // http defaults to 80
1563        let (socket_addr, _) =
1564            SocketClientInner::parse_socket_url("http://example.com", Mode::Plain).unwrap();
1565        assert_eq!(socket_addr, "example.com:80");
1566    }
1567
1568    #[rstest]
1569    fn test_parse_socket_url_unknown_scheme_uses_mode() {
1570        // Unknown scheme defaults to mode-based port
1571        let (socket_addr, _) =
1572            SocketClientInner::parse_socket_url("custom://example.com", Mode::Tls).unwrap();
1573        assert_eq!(socket_addr, "example.com:443");
1574
1575        let (socket_addr, _) =
1576            SocketClientInner::parse_socket_url("custom://example.com", Mode::Plain).unwrap();
1577        assert_eq!(socket_addr, "example.com:80");
1578    }
1579
1580    #[rstest]
1581    fn test_parse_socket_url_ipv6() {
1582        // IPv6 address with port
1583        let (socket_addr, request_url) =
1584            SocketClientInner::parse_socket_url("[::1]:8080", Mode::Plain).unwrap();
1585        assert_eq!(socket_addr, "[::1]:8080");
1586        assert_eq!(request_url, "ws://[::1]:8080");
1587
1588        // IPv6 in URL
1589        let (socket_addr, _) =
1590            SocketClientInner::parse_socket_url("ws://[::1]:8080", Mode::Plain).unwrap();
1591        assert_eq!(socket_addr, "[::1]:8080");
1592    }
1593
1594    #[rstest]
1595    #[tokio::test]
1596    async fn test_url_parsing_raw_socket_address() {
1597        // Test that raw socket addresses (host:port) work correctly
1598        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1599        let port = listener.local_addr().unwrap().port();
1600
1601        let server = task::spawn(async move {
1602            if let Ok((sock, _)) = listener.accept().await {
1603                drop(sock);
1604            }
1605            sleep(Duration::from_millis(50)).await;
1606        });
1607
1608        let config = SocketConfig {
1609            url: format!("127.0.0.1:{port}"), // Raw socket address format
1610            mode: Mode::Plain,
1611            suffix: b"\r\n".to_vec(),
1612            message_handler: None,
1613            heartbeat: None,
1614            reconnect_timeout_ms: Some(1_000),
1615            reconnect_delay_initial_ms: Some(50),
1616            reconnect_delay_max_ms: Some(100),
1617            reconnect_backoff_factor: Some(1.0),
1618            reconnect_jitter_ms: Some(0),
1619            connection_max_retries: Some(1),
1620            reconnect_max_attempts: None,
1621            certs_dir: None,
1622        };
1623
1624        // Should successfully connect with raw socket address
1625        let client = SocketClient::connect(config, None, None, None).await;
1626        assert!(
1627            client.is_ok(),
1628            "Client should connect with raw socket address format"
1629        );
1630
1631        if let Ok(client) = client {
1632            client.close().await;
1633        }
1634        server.abort();
1635    }
1636
1637    #[rstest]
1638    #[tokio::test]
1639    async fn test_url_parsing_with_scheme() {
1640        // Test that URLs with schemes also work
1641        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1642        let port = listener.local_addr().unwrap().port();
1643
1644        let server = task::spawn(async move {
1645            if let Ok((sock, _)) = listener.accept().await {
1646                drop(sock);
1647            }
1648            sleep(Duration::from_millis(50)).await;
1649        });
1650
1651        let config = SocketConfig {
1652            url: format!("ws://127.0.0.1:{port}"), // URL with scheme
1653            mode: Mode::Plain,
1654            suffix: b"\r\n".to_vec(),
1655            message_handler: None,
1656            heartbeat: None,
1657            reconnect_timeout_ms: Some(1_000),
1658            reconnect_delay_initial_ms: Some(50),
1659            reconnect_delay_max_ms: Some(100),
1660            reconnect_backoff_factor: Some(1.0),
1661            reconnect_jitter_ms: Some(0),
1662            connection_max_retries: Some(1),
1663            reconnect_max_attempts: None,
1664            certs_dir: None,
1665        };
1666
1667        // Should successfully connect with URL format
1668        let client = SocketClient::connect(config, None, None, None).await;
1669        assert!(
1670            client.is_ok(),
1671            "Client should connect with URL scheme format"
1672        );
1673
1674        if let Ok(client) = client {
1675            client.close().await;
1676        }
1677        server.abort();
1678    }
1679
1680    #[rstest]
1681    fn test_parse_socket_url_ipv6_with_zone() {
1682        // IPv6 with zone ID (link-local address)
1683        let (socket_addr, request_url) =
1684            SocketClientInner::parse_socket_url("[fe80::1%eth0]:8080", Mode::Plain).unwrap();
1685        assert_eq!(socket_addr, "[fe80::1%eth0]:8080");
1686        assert_eq!(request_url, "ws://[fe80::1%eth0]:8080");
1687
1688        // Verify zone is preserved in URL format too
1689        let (socket_addr, request_url) =
1690            SocketClientInner::parse_socket_url("ws://[fe80::1%lo]:9090", Mode::Plain).unwrap();
1691        assert_eq!(socket_addr, "[fe80::1%lo]:9090");
1692        assert_eq!(request_url, "ws://[fe80::1%lo]:9090");
1693    }
1694
1695    #[rstest]
1696    #[tokio::test]
1697    async fn test_ipv6_loopback_connection() {
1698        // Test IPv6 loopback address connection
1699        // Skip if IPv6 is not available on the system
1700        if TcpListener::bind("[::1]:0").await.is_err() {
1701            eprintln!("IPv6 not available, skipping test");
1702            return;
1703        }
1704
1705        let listener = TcpListener::bind("[::1]:0").await.unwrap();
1706        let port = listener.local_addr().unwrap().port();
1707
1708        let server = task::spawn(async move {
1709            if let Ok((mut sock, _)) = listener.accept().await {
1710                let mut buf = vec![0u8; 1024];
1711                if let Ok(n) = sock.read(&mut buf).await {
1712                    // Echo back
1713                    let _ = sock.write_all(&buf[..n]).await;
1714                }
1715            }
1716            sleep(Duration::from_millis(50)).await;
1717        });
1718
1719        let config = SocketConfig {
1720            url: format!("[::1]:{port}"), // IPv6 loopback
1721            mode: Mode::Plain,
1722            suffix: b"\r\n".to_vec(),
1723            message_handler: None,
1724            heartbeat: None,
1725            reconnect_timeout_ms: Some(1_000),
1726            reconnect_delay_initial_ms: Some(50),
1727            reconnect_delay_max_ms: Some(100),
1728            reconnect_backoff_factor: Some(1.0),
1729            reconnect_jitter_ms: Some(0),
1730            connection_max_retries: Some(1),
1731            reconnect_max_attempts: None,
1732            certs_dir: None,
1733        };
1734
1735        let client = SocketClient::connect(config, None, None, None).await;
1736        assert!(
1737            client.is_ok(),
1738            "Client should connect to IPv6 loopback address"
1739        );
1740
1741        if let Ok(client) = client {
1742            client.close().await;
1743        }
1744        server.abort();
1745    }
1746
1747    #[rstest]
1748    #[tokio::test]
1749    async fn test_send_waits_during_reconnection() {
1750        // Test that send operations wait for reconnection to complete (up to configured timeout)
1751        use nautilus_common::testing::wait_until_async;
1752
1753        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1754        let port = listener.local_addr().unwrap().port();
1755
1756        let server = task::spawn(async move {
1757            // First connection - accept and immediately close
1758            if let Ok((sock, _)) = listener.accept().await {
1759                drop(sock);
1760            }
1761
1762            // Wait before accepting second connection
1763            sleep(Duration::from_millis(500)).await;
1764
1765            // Second connection - accept and keep alive
1766            if let Ok((mut sock, _)) = listener.accept().await {
1767                // Echo messages
1768                let mut buf = vec![0u8; 1024];
1769                while let Ok(n) = sock.read(&mut buf).await {
1770                    if n == 0 {
1771                        break;
1772                    }
1773                    if sock.write_all(&buf[..n]).await.is_err() {
1774                        break;
1775                    }
1776                }
1777            }
1778        });
1779
1780        let config = SocketConfig {
1781            url: format!("127.0.0.1:{port}"),
1782            mode: Mode::Plain,
1783            suffix: b"\r\n".to_vec(),
1784            message_handler: None,
1785            heartbeat: None,
1786            reconnect_timeout_ms: Some(5_000), // 5s timeout - enough for reconnect
1787            reconnect_delay_initial_ms: Some(100),
1788            reconnect_delay_max_ms: Some(200),
1789            reconnect_backoff_factor: Some(1.0),
1790            reconnect_jitter_ms: Some(0),
1791            connection_max_retries: Some(1),
1792            reconnect_max_attempts: None,
1793            certs_dir: None,
1794        };
1795
1796        let client = SocketClient::connect(config, None, None, None)
1797            .await
1798            .unwrap();
1799
1800        // Wait for reconnection to trigger
1801        wait_until_async(
1802            || async { client.is_reconnecting() },
1803            Duration::from_secs(2),
1804        )
1805        .await;
1806
1807        // Try to send while reconnecting - should wait and succeed after reconnect
1808        let send_result = tokio::time::timeout(
1809            Duration::from_secs(3),
1810            client.send_bytes(b"test_message".to_vec()),
1811        )
1812        .await;
1813
1814        assert!(
1815            send_result.is_ok() && send_result.unwrap().is_ok(),
1816            "Send should succeed after waiting for reconnection"
1817        );
1818
1819        client.close().await;
1820        server.abort();
1821    }
1822
1823    #[rstest]
1824    #[tokio::test]
1825    async fn test_send_bytes_timeout_uses_configured_reconnect_timeout() {
1826        // Test that send_bytes operations respect the configured reconnect_timeout.
1827        // When a client is stuck in RECONNECT longer than the timeout, sends should fail with Timeout.
1828        use nautilus_common::testing::wait_until_async;
1829
1830        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1831        let port = listener.local_addr().unwrap().port();
1832
1833        let server = task::spawn(async move {
1834            // Accept first connection and immediately close it
1835            if let Ok((sock, _)) = listener.accept().await {
1836                drop(sock);
1837            }
1838            // Drop listener entirely so reconnection fails completely
1839            drop(listener);
1840            sleep(Duration::from_secs(60)).await;
1841        });
1842
1843        let config = SocketConfig {
1844            url: format!("127.0.0.1:{port}"),
1845            mode: Mode::Plain,
1846            suffix: b"\r\n".to_vec(),
1847            message_handler: None,
1848            heartbeat: None,
1849            reconnect_timeout_ms: Some(1_000), // 1s timeout for faster test
1850            reconnect_delay_initial_ms: Some(200), // Short backoff (but > timeout) to keep client in RECONNECT
1851            reconnect_delay_max_ms: Some(200),
1852            reconnect_backoff_factor: Some(1.0),
1853            reconnect_jitter_ms: Some(0),
1854            connection_max_retries: Some(1),
1855            reconnect_max_attempts: None,
1856            certs_dir: None,
1857        };
1858
1859        let client = SocketClient::connect(config, None, None, None)
1860            .await
1861            .unwrap();
1862
1863        // Wait for client to enter RECONNECT state
1864        wait_until_async(
1865            || async { client.is_reconnecting() },
1866            Duration::from_secs(3),
1867        )
1868        .await;
1869
1870        // Attempt send while stuck in RECONNECT - should timeout after 1s (configured timeout)
1871        // The client will try to reconnect for 1s, fail, then wait 5s backoff before next attempt
1872        let start = std::time::Instant::now();
1873        let send_result = client.send_bytes(b"test".to_vec()).await;
1874        let elapsed = start.elapsed();
1875
1876        assert!(
1877            send_result.is_err(),
1878            "Send should fail when client stuck in RECONNECT, was: {:?}",
1879            send_result
1880        );
1881        assert!(
1882            matches!(send_result, Err(crate::error::SendError::Timeout)),
1883            "Send should return Timeout error, was: {:?}",
1884            send_result
1885        );
1886        // Verify timeout respects configured value (1s), but don't check upper bound
1887        // as CI scheduler jitter can cause legitimate delays beyond the timeout
1888        assert!(
1889            elapsed >= Duration::from_millis(900),
1890            "Send should timeout after at least 1s (configured timeout), took {:?}",
1891            elapsed
1892        );
1893
1894        client.close().await;
1895        server.abort();
1896    }
1897}