nautilus_network/socket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance raw TCP client implementation with TLS capability, automatic reconnection
17//! with exponential backoff and state management.
18//!
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED).
21//! - Synchronized reconnection with backoff.
22//! - Split read/write architecture.
23//! - Python callback integration.
24//!
25//! **Design**:
26//! - Single reader, multiple writer model.
27//! - Read half runs in dedicated task.
28//! - Write half runs in dedicated task connected with channel.
29//! - Controller task manages lifecycle.
30
31use std::{
32    collections::VecDeque,
33    fmt::Debug,
34    path::Path,
35    sync::{
36        Arc,
37        atomic::{AtomicU8, Ordering},
38    },
39    time::Duration,
40};
41
42use bytes::Bytes;
43use nautilus_core::CleanDrop;
44use nautilus_cryptography::providers::install_cryptographic_provider;
45use tokio::io::{AsyncReadExt, AsyncWriteExt};
46use tokio_tungstenite::tungstenite::{Error, client::IntoClientRequest, stream::Mode};
47
48use super::{
49    SocketConfig, TcpMessageHandler, TcpReader, TcpWriter, WriterCommand, fix::process_fix_buffer,
50};
51use crate::{
52    backoff::ExponentialBackoff,
53    error::SendError,
54    logging::{log_task_aborted, log_task_started, log_task_stopped},
55    mode::ConnectionMode,
56    net::TcpStream,
57    tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
58};
59
60// Connection timing constants
61const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
62const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
63const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
64const SEND_OPERATION_CHECK_INTERVAL_MS: u64 = 1;
65
66/// Creates a `TcpStream` with the server.
67///
68/// The stream can be encrypted with TLS or Plain. The stream is split into
69/// read and write ends:
70/// - The read end is passed to the task that keeps receiving
71///   messages from the server and passing them to a handler.
72/// - The write end is passed to a task which receives messages over a channel
73///   to send to the server.
74///
75/// The heartbeat is optional and can be configured with an interval and data to
76/// send.
77///
78/// The client uses a suffix to separate messages on the byte stream. It is
79/// appended to all sent messages and heartbeats. It is also used to split
80/// the received byte stream.
81#[cfg_attr(
82    feature = "python",
83    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
84)]
85struct SocketClientInner {
86    config: SocketConfig,
87    connector: Option<Connector>,
88    read_task: Arc<tokio::task::JoinHandle<()>>,
89    write_task: tokio::task::JoinHandle<()>,
90    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
91    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
92    connection_mode: Arc<AtomicU8>,
93    reconnect_timeout: Duration,
94    backoff: ExponentialBackoff,
95    handler: Option<TcpMessageHandler>,
96    reconnect_max_attempts: Option<u32>,
97    reconnect_attempt_count: u32,
98}
99
100impl SocketClientInner {
101    /// Connect to a URL with the specified configuration.
102    ///
103    /// # Errors
104    ///
105    /// Returns an error if connection fails or configuration is invalid.
106    pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
107        install_cryptographic_provider();
108
109        let SocketConfig {
110            url,
111            mode,
112            heartbeat,
113            suffix,
114            message_handler,
115            reconnect_timeout_ms,
116            reconnect_delay_initial_ms,
117            reconnect_delay_max_ms,
118            reconnect_backoff_factor,
119            reconnect_jitter_ms,
120            connection_max_retries,
121            reconnect_max_attempts,
122            certs_dir,
123        } = &config.clone();
124        let connector = if let Some(dir) = certs_dir {
125            let config = create_tls_config_from_certs_dir(Path::new(dir), false)?;
126            Some(Connector::Rustls(Arc::new(config)))
127        } else {
128            None
129        };
130
131        // Retry initial connection with exponential backoff to handle transient DNS/network issues
132        const CONNECTION_TIMEOUT_SECS: u64 = 10;
133        let max_retries = connection_max_retries.unwrap_or(5);
134
135        let mut backoff = ExponentialBackoff::new(
136            Duration::from_millis(500),
137            Duration::from_millis(5000),
138            2.0,
139            250,
140            false,
141        )?;
142
143        #[allow(unused_assignments)]
144        let mut last_error = String::new();
145        let mut attempt = 0;
146        let (reader, writer) = loop {
147            attempt += 1;
148
149            match tokio::time::timeout(
150                Duration::from_secs(CONNECTION_TIMEOUT_SECS),
151                Self::tls_connect_with_server(url, *mode, connector.clone()),
152            )
153            .await
154            {
155                Ok(Ok(result)) => {
156                    if attempt > 1 {
157                        tracing::info!("Socket connection established after {attempt} attempts");
158                    }
159                    break result;
160                }
161                Ok(Err(e)) => {
162                    last_error = e.to_string();
163                    tracing::warn!(
164                        attempt,
165                        max_retries,
166                        url = %url,
167                        error = %last_error,
168                        "Socket connection attempt failed"
169                    );
170                }
171                Err(_) => {
172                    last_error = format!(
173                        "Connection timeout after {CONNECTION_TIMEOUT_SECS}s (possible DNS resolution failure)"
174                    );
175                    tracing::warn!(
176                        attempt,
177                        max_retries,
178                        url = %url,
179                        "Socket connection attempt timed out"
180                    );
181                }
182            }
183
184            if attempt >= max_retries {
185                anyhow::bail!(
186                    "Failed to connect to {} after {} attempts: {}. \
187                    If this is a DNS error, check your network configuration and DNS settings.",
188                    url,
189                    max_retries,
190                    if last_error.is_empty() {
191                        "unknown error"
192                    } else {
193                        &last_error
194                    }
195                );
196            }
197
198            let delay = backoff.next_duration();
199            tracing::debug!(
200                "Retrying in {delay:?} (attempt {}/{})",
201                attempt + 1,
202                max_retries
203            );
204            tokio::time::sleep(delay).await;
205        };
206
207        tracing::debug!("Connected");
208
209        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
210
211        let read_task = Arc::new(Self::spawn_read_task(
212            connection_mode.clone(),
213            reader,
214            message_handler.clone(),
215            suffix.clone(),
216        ));
217
218        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
219
220        let write_task =
221            Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
222
223        // Optionally spawn a heartbeat task to periodically ping server
224        let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
225            Self::spawn_heartbeat_task(
226                connection_mode.clone(),
227                heartbeat.clone(),
228                writer_tx.clone(),
229            )
230        });
231
232        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
233        let backoff = ExponentialBackoff::new(
234            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
235            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
236            reconnect_backoff_factor.unwrap_or(1.5),
237            reconnect_jitter_ms.unwrap_or(100),
238            true, // immediate-first
239        )?;
240
241        Ok(Self {
242            config,
243            connector,
244            read_task,
245            write_task,
246            writer_tx,
247            heartbeat_task,
248            connection_mode,
249            reconnect_timeout,
250            backoff,
251            handler: message_handler.clone(),
252            reconnect_max_attempts: *reconnect_max_attempts,
253            reconnect_attempt_count: 0,
254        })
255    }
256
257    /// Parse URL and extract socket address and request URL.
258    ///
259    /// Accepts either:
260    /// - Raw socket address: "host:port" → returns ("host:port", "scheme://host:port")
261    /// - Full URL: "scheme://host:port/path" → returns ("host:port", original URL)
262    ///
263    /// # Errors
264    ///
265    /// Returns an error if the URL is invalid or missing required components.
266    fn parse_socket_url(url: &str, mode: Mode) -> Result<(String, String), Error> {
267        if url.contains("://") {
268            // URL with scheme (e.g., "wss://host:port/path")
269            let parsed = url.parse::<http::Uri>().map_err(|e| {
270                Error::Io(std::io::Error::new(
271                    std::io::ErrorKind::InvalidInput,
272                    format!("Invalid URL: {e}"),
273                ))
274            })?;
275
276            let host = parsed.host().ok_or_else(|| {
277                Error::Io(std::io::Error::new(
278                    std::io::ErrorKind::InvalidInput,
279                    "URL missing host",
280                ))
281            })?;
282
283            let port = parsed
284                .port_u16()
285                .unwrap_or_else(|| match parsed.scheme_str() {
286                    Some("wss" | "https") => 443,
287                    Some("ws" | "http") => 80,
288                    _ => match mode {
289                        Mode::Tls => 443,
290                        Mode::Plain => 80,
291                    },
292                });
293
294            Ok((format!("{host}:{port}"), url.to_string()))
295        } else {
296            // Raw socket address (e.g., "host:port")
297            // Construct a proper URL for the request based on mode
298            let scheme = match mode {
299                Mode::Tls => "wss",
300                Mode::Plain => "ws",
301            };
302            Ok((url.to_string(), format!("{scheme}://{url}")))
303        }
304    }
305
306    /// Establish a TLS or plain TCP connection with the server.
307    ///
308    /// Accepts either a raw socket address (e.g., "host:port") or a full URL with scheme
309    /// (e.g., "wss://host:port"). For FIX/raw socket connections, use the host:port format.
310    /// For WebSocket-style connections, include the scheme.
311    ///
312    /// # Errors
313    ///
314    /// Returns an error if the connection cannot be established.
315    pub async fn tls_connect_with_server(
316        url: &str,
317        mode: Mode,
318        connector: Option<Connector>,
319    ) -> Result<(TcpReader, TcpWriter), Error> {
320        tracing::debug!("Connecting to {url}");
321
322        let (socket_addr, request_url) = Self::parse_socket_url(url, mode)?;
323        let tcp_result = TcpStream::connect(&socket_addr).await;
324
325        match tcp_result {
326            Ok(stream) => {
327                tracing::debug!("TCP connection established to {socket_addr}, proceeding with TLS");
328                if let Err(e) = stream.set_nodelay(true) {
329                    tracing::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
330                }
331                let request = request_url.into_client_request()?;
332                tcp_tls(&request, mode, stream, connector)
333                    .await
334                    .map(tokio::io::split)
335            }
336            Err(e) => {
337                tracing::error!("TCP connection failed to {socket_addr}: {e:?}");
338                Err(Error::Io(e))
339            }
340        }
341    }
342
343    /// Reconnect with server.
344    ///
345    /// Makes a new connection with server, uses the new read and write halves
346    /// to update the reader and writer.
347    async fn reconnect(&mut self) -> Result<(), Error> {
348        tracing::debug!("Reconnecting");
349
350        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
351            tracing::debug!("Reconnect aborted due to disconnect state");
352            return Ok(());
353        }
354
355        tokio::time::timeout(self.reconnect_timeout, async {
356            let SocketConfig {
357                url,
358                mode,
359                heartbeat: _,
360                suffix,
361                message_handler: _,
362                reconnect_timeout_ms: _,
363                reconnect_delay_initial_ms: _,
364                reconnect_backoff_factor: _,
365                reconnect_delay_max_ms: _,
366                reconnect_jitter_ms: _,
367                connection_max_retries: _,
368                reconnect_max_attempts: _,
369                certs_dir: _,
370            } = &self.config;
371            // Create a fresh connection
372            let connector = self.connector.clone();
373            // Attempt to connect; abort early if a disconnect was requested
374            let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
375
376            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
377                tracing::debug!("Reconnect aborted mid-flight (after connect)");
378                return Ok(());
379            }
380            tracing::debug!("Connected");
381
382            // Use a oneshot channel to synchronize with the writer task.
383            // We must verify that the buffer was successfully drained before transitioning to ACTIVE
384            // to prevent silent message loss if the new connection drops immediately.
385            let (tx, rx) = tokio::sync::oneshot::channel();
386            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
387                tracing::error!("{e}");
388                return Err(Error::Io(std::io::Error::new(
389                    std::io::ErrorKind::BrokenPipe,
390                    format!("Failed to send update command: {e}"),
391                )));
392            }
393
394            // Wait for writer to confirm it has drained the buffer
395            match rx.await {
396                Ok(true) => tracing::debug!("Writer confirmed buffer drain success"),
397                Ok(false) => {
398                    tracing::warn!("Writer failed to drain buffer, aborting reconnect");
399                    // Return error to trigger retry logic in controller
400                    return Err(Error::Io(std::io::Error::other(
401                        "Failed to drain reconnection buffer",
402                    )));
403                }
404                Err(e) => {
405                    tracing::error!("Writer dropped update channel: {e}");
406                    return Err(Error::Io(std::io::Error::new(
407                        std::io::ErrorKind::BrokenPipe,
408                        "Writer task dropped response channel",
409                    )));
410                }
411            }
412
413            // Delay before closing connection
414            tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
415
416            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
417                tracing::debug!("Reconnect aborted mid-flight (after delay)");
418                return Ok(());
419            }
420
421            if !self.read_task.is_finished() {
422                self.read_task.abort();
423                log_task_aborted("read");
424            }
425
426            // Atomically transition from Reconnect to Active
427            // This prevents race condition where disconnect could be requested between check and store
428            if self
429                .connection_mode
430                .compare_exchange(
431                    ConnectionMode::Reconnect.as_u8(),
432                    ConnectionMode::Active.as_u8(),
433                    Ordering::SeqCst,
434                    Ordering::SeqCst,
435                )
436                .is_err()
437            {
438                tracing::debug!("Reconnect aborted (state changed during reconnect)");
439                return Ok(());
440            }
441
442            // Spawn new read task
443            self.read_task = Arc::new(Self::spawn_read_task(
444                self.connection_mode.clone(),
445                reader,
446                self.handler.clone(),
447                suffix.clone(),
448            ));
449
450            tracing::debug!("Reconnect succeeded");
451            Ok(())
452        })
453        .await
454        .map_err(|_| {
455            Error::Io(std::io::Error::new(
456                std::io::ErrorKind::TimedOut,
457                format!(
458                    "reconnection timed out after {}s",
459                    self.reconnect_timeout.as_secs_f64()
460                ),
461            ))
462        })?
463    }
464
465    /// Check if the client is still alive.
466    ///
467    /// The client is connected if the read task has not finished. It is expected
468    /// that in case of any failure client or server side. The read task will be
469    /// shutdown. There might be some delay between the connection being closed
470    /// and the client detecting it.
471    #[inline]
472    #[must_use]
473    pub fn is_alive(&self) -> bool {
474        !self.read_task.is_finished()
475    }
476
477    #[must_use]
478    fn spawn_read_task(
479        connection_state: Arc<AtomicU8>,
480        mut reader: TcpReader,
481        handler: Option<TcpMessageHandler>,
482        suffix: Vec<u8>,
483    ) -> tokio::task::JoinHandle<()> {
484        log_task_started("read");
485
486        // Interval between checking the connection mode
487        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
488
489        tokio::task::spawn(async move {
490            let mut buf = Vec::new();
491
492            loop {
493                if !ConnectionMode::from_atomic(&connection_state).is_active() {
494                    break;
495                }
496
497                match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
498                    // Connection has been terminated or vector buffer is complete
499                    Ok(Ok(0)) => {
500                        tracing::debug!("Connection closed by server");
501                        break;
502                    }
503                    Ok(Err(e)) => {
504                        tracing::debug!("Connection ended: {e}");
505                        break;
506                    }
507                    // Received bytes of data
508                    Ok(Ok(bytes)) => {
509                        tracing::trace!("Received <binary> {bytes} bytes");
510
511                        // Check if buffer contains FIX protocol messages (starts with "8=FIX")
512                        let is_fix = buf.len() >= 5 && buf.starts_with(b"8=FIX");
513
514                        if is_fix && handler.is_some() {
515                            // FIX protocol processing
516                            if let Some(ref handler) = handler {
517                                process_fix_buffer(&mut buf, handler);
518                            }
519                        } else {
520                            // Regular suffix-based message processing
521                            while let Some((i, _)) = &buf
522                                .windows(suffix.len())
523                                .enumerate()
524                                .find(|(_, pair)| pair.eq(&suffix))
525                            {
526                                let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
527                                data.truncate(data.len() - suffix.len());
528
529                                if let Some(ref handler) = handler {
530                                    handler(&data);
531                                }
532                            }
533                        }
534                    }
535                    Err(_) => {
536                        // Timeout - continue loop and check connection mode
537                        continue;
538                    }
539                }
540            }
541
542            log_task_stopped("read");
543        })
544    }
545
546    /// Drains buffered messages after reconnection completes.
547    ///
548    /// Attempts to send all buffered messages that were queued during reconnection.
549    /// Uses a peek-and-pop pattern to preserve messages if sending fails midway through the buffer.
550    ///
551    /// # Returns
552    ///
553    /// Returns `true` if a send error occurred (buffer may still contain unsent messages),
554    /// `false` if all messages were sent successfully (buffer is empty).
555    async fn drain_reconnect_buffer(
556        buffer: &mut VecDeque<Bytes>,
557        writer: &mut TcpWriter,
558        suffix: &[u8],
559    ) -> bool {
560        if buffer.is_empty() {
561            return false;
562        }
563
564        let initial_buffer_len = buffer.len();
565        tracing::info!(
566            "Sending {} buffered messages after reconnection",
567            initial_buffer_len
568        );
569
570        let mut send_error_occurred = false;
571
572        while let Some(buffered_msg) = buffer.front() {
573            let mut combined_msg = Vec::with_capacity(buffered_msg.len() + suffix.len());
574            combined_msg.extend_from_slice(buffered_msg);
575            combined_msg.extend_from_slice(suffix);
576
577            if let Err(e) = writer.write_all(&combined_msg).await {
578                tracing::error!(
579                    "Failed to send buffered message with suffix after reconnection: {e}, {} messages remain in buffer",
580                    buffer.len()
581                );
582                send_error_occurred = true;
583                break;
584            }
585
586            buffer.pop_front();
587        }
588
589        if buffer.is_empty() {
590            tracing::info!(
591                "Successfully sent all {} buffered messages",
592                initial_buffer_len
593            );
594        }
595
596        send_error_occurred
597    }
598
599    fn spawn_write_task(
600        connection_state: Arc<AtomicU8>,
601        writer: TcpWriter,
602        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
603        suffix: Vec<u8>,
604    ) -> tokio::task::JoinHandle<()> {
605        log_task_started("write");
606
607        // Interval between checking the connection mode
608        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
609
610        tokio::task::spawn(async move {
611            let mut active_writer = writer;
612            let mut reconnect_buffer: VecDeque<Bytes> = VecDeque::new();
613
614            loop {
615                if matches!(
616                    ConnectionMode::from_atomic(&connection_state),
617                    ConnectionMode::Disconnect | ConnectionMode::Closed
618                ) {
619                    break;
620                }
621
622                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
623                    Ok(Some(msg)) => {
624                        // Re-check connection mode after receiving a message
625                        let mode = ConnectionMode::from_atomic(&connection_state);
626                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
627                            break;
628                        }
629
630                        match msg {
631                            WriterCommand::Update(new_writer, tx) => {
632                                tracing::debug!("Received new writer");
633
634                                // Delay before closing connection
635                                tokio::time::sleep(Duration::from_millis(100)).await;
636
637                                // Attempt to shutdown the writer gracefully before updating,
638                                // we ignore any error as the writer may already be closed.
639                                _ = tokio::time::timeout(
640                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
641                                    active_writer.shutdown(),
642                                )
643                                .await;
644
645                                active_writer = new_writer;
646                                tracing::debug!("Updated writer");
647
648                                let send_error = Self::drain_reconnect_buffer(
649                                    &mut reconnect_buffer,
650                                    &mut active_writer,
651                                    &suffix,
652                                )
653                                .await;
654
655                                if let Err(e) = tx.send(!send_error) {
656                                    tracing::error!(
657                                        "Failed to report drain status to controller: {e:?}"
658                                    );
659                                }
660                            }
661                            _ if mode.is_reconnect() => {
662                                if let WriterCommand::Send(data) = msg {
663                                    tracing::debug!(
664                                        "Buffering message while reconnecting ({} bytes)",
665                                        data.len()
666                                    );
667                                    reconnect_buffer.push_back(data);
668                                }
669                                continue;
670                            }
671                            WriterCommand::Send(msg) => {
672                                if let Err(e) = active_writer.write_all(&msg).await {
673                                    tracing::error!("Failed to send message: {e}");
674                                    tracing::warn!("Writer triggering reconnect");
675                                    reconnect_buffer.push_back(msg);
676                                    connection_state
677                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
678                                    continue;
679                                }
680                                if let Err(e) = active_writer.write_all(&suffix).await {
681                                    tracing::error!("Failed to send suffix: {e}");
682                                    tracing::warn!("Writer triggering reconnect");
683                                    // Buffer this message before triggering reconnect since suffix failed
684                                    reconnect_buffer.push_back(msg);
685                                    connection_state
686                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
687                                    continue;
688                                }
689                            }
690                        }
691                    }
692                    Ok(None) => {
693                        // Channel closed - writer task should terminate
694                        tracing::debug!("Writer channel closed, terminating writer task");
695                        break;
696                    }
697                    Err(_) => {
698                        // Timeout - just continue the loop
699                        continue;
700                    }
701                }
702            }
703
704            // Attempt to shutdown the writer gracefully before exiting,
705            // we ignore any error as the writer may already be closed.
706            _ = tokio::time::timeout(
707                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
708                active_writer.shutdown(),
709            )
710            .await;
711
712            log_task_stopped("write");
713        })
714    }
715
716    fn spawn_heartbeat_task(
717        connection_state: Arc<AtomicU8>,
718        heartbeat: (u64, Vec<u8>),
719        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
720    ) -> tokio::task::JoinHandle<()> {
721        log_task_started("heartbeat");
722        let (interval_secs, message) = heartbeat;
723
724        tokio::task::spawn(async move {
725            let interval = Duration::from_secs(interval_secs);
726
727            loop {
728                tokio::time::sleep(interval).await;
729
730                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
731                    ConnectionMode::Active => {
732                        let msg = WriterCommand::Send(message.clone().into());
733
734                        match writer_tx.send(msg) {
735                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
736                            Err(e) => {
737                                tracing::error!("Failed to send heartbeat to writer task: {e}");
738                            }
739                        }
740                    }
741                    ConnectionMode::Reconnect => continue,
742                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
743                }
744            }
745
746            log_task_stopped("heartbeat");
747        })
748    }
749}
750
751impl Drop for SocketClientInner {
752    fn drop(&mut self) {
753        // Delegate to explicit cleanup handler
754        self.clean_drop();
755    }
756}
757
758/// Cleanup on drop: aborts background tasks and clears handlers to break reference cycles.
759impl CleanDrop for SocketClientInner {
760    fn clean_drop(&mut self) {
761        if !self.read_task.is_finished() {
762            self.read_task.abort();
763            log_task_aborted("read");
764        }
765
766        if !self.write_task.is_finished() {
767            self.write_task.abort();
768            log_task_aborted("write");
769        }
770
771        if let Some(ref handle) = self.heartbeat_task.take()
772            && !handle.is_finished()
773        {
774            handle.abort();
775            log_task_aborted("heartbeat");
776        }
777
778        #[cfg(feature = "python")]
779        {
780            // Remove stored handler to break ref cycle
781            self.config.message_handler = None;
782        }
783    }
784}
785
786#[cfg_attr(
787    feature = "python",
788    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
789)]
790pub struct SocketClient {
791    pub(crate) controller_task: tokio::task::JoinHandle<()>,
792    pub(crate) connection_mode: Arc<AtomicU8>,
793    pub(crate) reconnect_timeout: Duration,
794    pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
795}
796
797impl Debug for SocketClient {
798    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
799        f.debug_struct(stringify!(SocketClient)).finish()
800    }
801}
802
803impl SocketClient {
804    /// Connect to the server.
805    ///
806    /// # Errors
807    ///
808    /// Returns any error connecting to the server.
809    pub async fn connect(
810        config: SocketConfig,
811        post_connection: Option<Arc<dyn Fn() + Send + Sync>>,
812        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
813        post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
814    ) -> anyhow::Result<Self> {
815        let inner = SocketClientInner::connect_url(config).await?;
816        let writer_tx = inner.writer_tx.clone();
817        let connection_mode = inner.connection_mode.clone();
818        let reconnect_timeout = inner.reconnect_timeout;
819
820        let controller_task = Self::spawn_controller_task(
821            inner,
822            connection_mode.clone(),
823            post_reconnection,
824            post_disconnection,
825        );
826
827        if let Some(handler) = post_connection {
828            handler();
829            tracing::debug!("Called `post_connection` handler");
830        }
831
832        Ok(Self {
833            controller_task,
834            connection_mode,
835            reconnect_timeout,
836            writer_tx,
837        })
838    }
839
840    /// Returns the current connection mode.
841    #[must_use]
842    pub fn connection_mode(&self) -> ConnectionMode {
843        ConnectionMode::from_atomic(&self.connection_mode)
844    }
845
846    /// Check if the client connection is active.
847    ///
848    /// Returns `true` if the client is connected and has not been signalled to disconnect.
849    /// The client will automatically retry connection based on its configuration.
850    #[inline]
851    #[must_use]
852    pub fn is_active(&self) -> bool {
853        self.connection_mode().is_active()
854    }
855
856    /// Check if the client is reconnecting.
857    ///
858    /// Returns `true` if the client lost connection and is attempting to reestablish it.
859    /// The client will automatically retry connection based on its configuration.
860    #[inline]
861    #[must_use]
862    pub fn is_reconnecting(&self) -> bool {
863        self.connection_mode().is_reconnect()
864    }
865
866    /// Check if the client is disconnecting.
867    ///
868    /// Returns `true` if the client is in disconnect mode.
869    #[inline]
870    #[must_use]
871    pub fn is_disconnecting(&self) -> bool {
872        self.connection_mode().is_disconnect()
873    }
874
875    /// Check if the client is closed.
876    ///
877    /// Returns `true` if the client has been explicitly disconnected or reached
878    /// maximum reconnection attempts. In this state, the client cannot be reused
879    /// and a new client must be created for further connections.
880    #[inline]
881    #[must_use]
882    pub fn is_closed(&self) -> bool {
883        self.connection_mode().is_closed()
884    }
885
886    /// Close the client.
887    ///
888    /// Controller task will periodically check the disconnect mode
889    /// and shutdown the client if it is not alive.
890    pub async fn close(&self) {
891        self.connection_mode
892            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
893
894        if let Ok(()) =
895            tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
896                while !self.is_closed() {
897                    tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS))
898                        .await;
899                }
900
901                if !self.controller_task.is_finished() {
902                    self.controller_task.abort();
903                    log_task_aborted("controller");
904                }
905            })
906            .await
907        {
908            log_task_stopped("controller");
909        } else {
910            tracing::error!("Timeout waiting for controller task to finish");
911            if !self.controller_task.is_finished() {
912                self.controller_task.abort();
913                log_task_aborted("controller");
914            }
915        }
916    }
917
918    /// Sends a message of the given `data`.
919    ///
920    /// # Errors
921    ///
922    /// Returns an error if sending fails.
923    pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
924        // Check connection state to fail fast
925        if self.is_closed() || self.is_disconnecting() {
926            return Err(SendError::Closed);
927        }
928
929        let timeout = self.reconnect_timeout;
930        let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
931
932        if !self.is_active() {
933            tracing::debug!("Waiting for client to become ACTIVE before sending...");
934
935            let inner = tokio::time::timeout(timeout, async {
936                loop {
937                    if self.is_active() {
938                        return Ok(());
939                    }
940                    if matches!(
941                        self.connection_mode(),
942                        ConnectionMode::Disconnect | ConnectionMode::Closed
943                    ) {
944                        return Err(());
945                    }
946                    tokio::time::sleep(check_interval).await;
947                }
948            })
949            .await
950            .map_err(|_| SendError::Timeout)?;
951            inner.map_err(|()| SendError::Closed)?;
952        }
953
954        let msg = WriterCommand::Send(data.into());
955        self.writer_tx
956            .send(msg)
957            .map_err(|e| SendError::BrokenPipe(e.to_string()))
958    }
959
960    fn spawn_controller_task(
961        mut inner: SocketClientInner,
962        connection_mode: Arc<AtomicU8>,
963        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
964        post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
965    ) -> tokio::task::JoinHandle<()> {
966        tokio::task::spawn(async move {
967            log_task_started("controller");
968
969            let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
970
971            loop {
972                tokio::time::sleep(check_interval).await;
973                let mut mode = ConnectionMode::from_atomic(&connection_mode);
974
975                if mode.is_disconnect() {
976                    tracing::debug!("Disconnecting");
977
978                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
979                    if tokio::time::timeout(timeout, async {
980                        // Delay awaiting graceful shutdown
981                        tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
982
983                        if !inner.read_task.is_finished() {
984                            inner.read_task.abort();
985                            log_task_aborted("read");
986                        }
987
988                        if let Some(task) = &inner.heartbeat_task
989                            && !task.is_finished()
990                        {
991                            task.abort();
992                            log_task_aborted("heartbeat");
993                        }
994                    })
995                    .await
996                    .is_err()
997                    {
998                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
999                    }
1000
1001                    tracing::debug!("Closed");
1002
1003                    if let Some(ref handler) = post_disconnection {
1004                        handler();
1005                        tracing::debug!("Called `post_disconnection` handler");
1006                    }
1007                    break; // Controller finished
1008                }
1009
1010                if mode.is_active() && !inner.is_alive() {
1011                    if connection_mode
1012                        .compare_exchange(
1013                            ConnectionMode::Active.as_u8(),
1014                            ConnectionMode::Reconnect.as_u8(),
1015                            Ordering::SeqCst,
1016                            Ordering::SeqCst,
1017                        )
1018                        .is_ok()
1019                    {
1020                        tracing::debug!("Detected dead read task, transitioning to RECONNECT");
1021                    }
1022                    mode = ConnectionMode::from_atomic(&connection_mode);
1023                }
1024
1025                if mode.is_reconnect() {
1026                    // Check max reconnection attempts before attempting reconnect
1027                    if let Some(max_attempts) = inner.reconnect_max_attempts
1028                        && inner.reconnect_attempt_count >= max_attempts
1029                    {
1030                        tracing::error!(
1031                            "Max reconnection attempts ({}) exceeded, transitioning to CLOSED",
1032                            max_attempts
1033                        );
1034                        connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1035                        break;
1036                    }
1037
1038                    inner.reconnect_attempt_count += 1;
1039                    match inner.reconnect().await {
1040                        Ok(()) => {
1041                            tracing::debug!("Reconnected successfully");
1042                            inner.backoff.reset();
1043                            inner.reconnect_attempt_count = 0; // Reset counter on success
1044                            // Only invoke reconnect handler if still active
1045                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
1046                                if let Some(ref handler) = post_reconnection {
1047                                    handler();
1048                                    tracing::debug!("Called `post_reconnection` handler");
1049                                }
1050                            } else {
1051                                tracing::debug!(
1052                                    "Skipping post_reconnection handlers due to disconnect state"
1053                                );
1054                            }
1055                        }
1056                        Err(e) => {
1057                            let duration = inner.backoff.next_duration();
1058                            tracing::warn!(
1059                                "Reconnect attempt {} failed: {e}",
1060                                inner.reconnect_attempt_count
1061                            );
1062                            if !duration.is_zero() {
1063                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
1064                            }
1065                            tokio::time::sleep(duration).await;
1066                        }
1067                    }
1068                }
1069            }
1070            inner
1071                .connection_mode
1072                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1073
1074            log_task_stopped("controller");
1075        })
1076    }
1077}
1078
1079// Abort controller task on drop to clean up background tasks
1080impl Drop for SocketClient {
1081    fn drop(&mut self) {
1082        if !self.controller_task.is_finished() {
1083            self.controller_task.abort();
1084            log_task_aborted("controller");
1085        }
1086    }
1087}
1088
1089#[cfg(test)]
1090#[cfg(feature = "python")]
1091#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
1092mod tests {
1093    use nautilus_common::testing::wait_until_async;
1094    use pyo3::Python;
1095    use tokio::{
1096        io::{AsyncReadExt, AsyncWriteExt},
1097        net::{TcpListener, TcpStream},
1098        sync::Mutex,
1099        task,
1100        time::{Duration, sleep},
1101    };
1102
1103    use super::*;
1104
1105    async fn bind_test_server() -> (u16, TcpListener) {
1106        let listener = TcpListener::bind("127.0.0.1:0")
1107            .await
1108            .expect("Failed to bind ephemeral port");
1109        let port = listener.local_addr().unwrap().port();
1110        (port, listener)
1111    }
1112
1113    async fn run_echo_server(mut socket: TcpStream) {
1114        let mut buf = Vec::new();
1115        loop {
1116            match socket.read_buf(&mut buf).await {
1117                Ok(0) => {
1118                    break;
1119                }
1120                Ok(_n) => {
1121                    while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1122                        let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1123                        // Remove trailing \r\n
1124                        line.truncate(line.len() - 2);
1125
1126                        if line == b"close" {
1127                            let _ = socket.shutdown().await;
1128                            return;
1129                        }
1130
1131                        let mut echo_data = line;
1132                        echo_data.extend_from_slice(b"\r\n");
1133                        if socket.write_all(&echo_data).await.is_err() {
1134                            break;
1135                        }
1136                    }
1137                }
1138                Err(e) => {
1139                    eprintln!("Server read error: {e}");
1140                    break;
1141                }
1142            }
1143        }
1144    }
1145
1146    #[tokio::test]
1147    async fn test_basic_send_receive() {
1148        Python::initialize();
1149
1150        let (port, listener) = bind_test_server().await;
1151        let server_task = task::spawn(async move {
1152            let (socket, _) = listener.accept().await.unwrap();
1153            run_echo_server(socket).await;
1154        });
1155
1156        let config = SocketConfig {
1157            url: format!("127.0.0.1:{port}"),
1158            mode: Mode::Plain,
1159            suffix: b"\r\n".to_vec(),
1160            message_handler: None,
1161            heartbeat: None,
1162            reconnect_timeout_ms: None,
1163            reconnect_delay_initial_ms: None,
1164            reconnect_backoff_factor: None,
1165            reconnect_delay_max_ms: None,
1166            reconnect_jitter_ms: None,
1167            reconnect_max_attempts: None,
1168            connection_max_retries: None,
1169            certs_dir: None,
1170        };
1171
1172        let client = SocketClient::connect(config, None, None, None)
1173            .await
1174            .expect("Client connect failed unexpectedly");
1175
1176        client.send_bytes(b"Hello".into()).await.unwrap();
1177        client.send_bytes(b"World".into()).await.unwrap();
1178
1179        // Wait a bit for the server to echo them back
1180        sleep(Duration::from_millis(100)).await;
1181
1182        client.send_bytes(b"close".into()).await.unwrap();
1183        server_task.await.unwrap();
1184        assert!(!client.is_closed());
1185    }
1186
1187    #[tokio::test]
1188    async fn test_reconnect_fail_exhausted() {
1189        Python::initialize();
1190
1191        let (port, listener) = bind_test_server().await;
1192        drop(listener); // We drop it immediately -> no server is listening
1193
1194        // Wait until port is truly unavailable (OS has released it)
1195        wait_until_async(
1196            || async {
1197                TcpStream::connect(format!("127.0.0.1:{port}"))
1198                    .await
1199                    .is_err()
1200            },
1201            Duration::from_secs(2),
1202        )
1203        .await;
1204
1205        let config = SocketConfig {
1206            url: format!("127.0.0.1:{port}"),
1207            mode: Mode::Plain,
1208            suffix: b"\r\n".to_vec(),
1209            message_handler: None,
1210            heartbeat: None,
1211            reconnect_timeout_ms: Some(100),
1212            reconnect_delay_initial_ms: Some(50),
1213            reconnect_backoff_factor: Some(1.0),
1214            reconnect_delay_max_ms: Some(50),
1215            reconnect_jitter_ms: Some(0),
1216            connection_max_retries: Some(1),
1217            reconnect_max_attempts: None,
1218            certs_dir: None,
1219        };
1220
1221        let client_res = SocketClient::connect(config, None, None, None).await;
1222        assert!(
1223            client_res.is_err(),
1224            "Should fail quickly with no server listening"
1225        );
1226    }
1227
1228    #[tokio::test]
1229    async fn test_user_disconnect() {
1230        Python::initialize();
1231
1232        let (port, listener) = bind_test_server().await;
1233        let server_task = task::spawn(async move {
1234            let (socket, _) = listener.accept().await.unwrap();
1235            let mut buf = [0u8; 1024];
1236            let _ = socket.try_read(&mut buf);
1237
1238            loop {
1239                sleep(Duration::from_secs(1)).await;
1240            }
1241        });
1242
1243        let config = SocketConfig {
1244            url: format!("127.0.0.1:{port}"),
1245            mode: Mode::Plain,
1246            suffix: b"\r\n".to_vec(),
1247            message_handler: None,
1248            heartbeat: None,
1249            reconnect_timeout_ms: None,
1250            reconnect_delay_initial_ms: None,
1251            reconnect_backoff_factor: None,
1252            reconnect_delay_max_ms: None,
1253            reconnect_jitter_ms: None,
1254            reconnect_max_attempts: None,
1255            connection_max_retries: None,
1256            certs_dir: None,
1257        };
1258
1259        let client = SocketClient::connect(config, None, None, None)
1260            .await
1261            .unwrap();
1262
1263        client.close().await;
1264        assert!(client.is_closed());
1265        server_task.abort();
1266    }
1267
1268    #[tokio::test]
1269    async fn test_heartbeat() {
1270        Python::initialize();
1271
1272        let (port, listener) = bind_test_server().await;
1273        let received = Arc::new(Mutex::new(Vec::new()));
1274        let received2 = received.clone();
1275
1276        let server_task = task::spawn(async move {
1277            let (socket, _) = listener.accept().await.unwrap();
1278
1279            let mut buf = Vec::new();
1280            loop {
1281                match socket.try_read_buf(&mut buf) {
1282                    Ok(0) => break,
1283                    Ok(_) => {
1284                        while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1285                            let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1286                            line.truncate(line.len() - 2);
1287                            received2.lock().await.push(line);
1288                        }
1289                    }
1290                    Err(_) => {
1291                        tokio::time::sleep(Duration::from_millis(10)).await;
1292                    }
1293                }
1294            }
1295        });
1296
1297        // Heartbeat every 1 second
1298        let heartbeat = Some((1, b"ping".to_vec()));
1299
1300        let config = SocketConfig {
1301            url: format!("127.0.0.1:{port}"),
1302            mode: Mode::Plain,
1303            suffix: b"\r\n".to_vec(),
1304            message_handler: None,
1305            heartbeat,
1306            reconnect_timeout_ms: None,
1307            reconnect_delay_initial_ms: None,
1308            reconnect_backoff_factor: None,
1309            reconnect_delay_max_ms: None,
1310            reconnect_jitter_ms: None,
1311            reconnect_max_attempts: None,
1312            connection_max_retries: None,
1313            certs_dir: None,
1314        };
1315
1316        let client = SocketClient::connect(config, None, None, None)
1317            .await
1318            .unwrap();
1319
1320        // Wait ~3 seconds to collect some heartbeats
1321        sleep(Duration::from_secs(3)).await;
1322
1323        {
1324            let lock = received.lock().await;
1325            let pings = lock
1326                .iter()
1327                .filter(|line| line == &&b"ping".to_vec())
1328                .count();
1329            assert!(
1330                pings >= 2,
1331                "Expected at least 2 heartbeat pings; got {pings}"
1332            );
1333        }
1334
1335        client.close().await;
1336        server_task.abort();
1337    }
1338
1339    #[tokio::test]
1340    async fn test_reconnect_success() {
1341        Python::initialize();
1342
1343        let (port, listener) = bind_test_server().await;
1344
1345        // Spawn a server task that:
1346        // 1. Accepts the first connection and then drops it after a short delay (simulate disconnect)
1347        // 2. Waits a bit and then accepts a new connection and runs the echo server
1348        let server_task = task::spawn(async move {
1349            // Accept first connection
1350            let (mut socket, _) = listener.accept().await.expect("First accept failed");
1351
1352            // Wait briefly and then force-close the connection
1353            sleep(Duration::from_millis(500)).await;
1354            let _ = socket.shutdown().await;
1355
1356            // Wait for the client's reconnect attempt
1357            sleep(Duration::from_millis(500)).await;
1358
1359            // Run the echo server on the new connection
1360            let (socket, _) = listener.accept().await.expect("Second accept failed");
1361            run_echo_server(socket).await;
1362        });
1363
1364        let config = SocketConfig {
1365            url: format!("127.0.0.1:{port}"),
1366            mode: Mode::Plain,
1367            suffix: b"\r\n".to_vec(),
1368            message_handler: None,
1369            heartbeat: None,
1370            reconnect_timeout_ms: Some(5_000),
1371            reconnect_delay_initial_ms: Some(500),
1372            reconnect_delay_max_ms: Some(5_000),
1373            reconnect_backoff_factor: Some(2.0),
1374            reconnect_jitter_ms: Some(50),
1375            reconnect_max_attempts: None,
1376            connection_max_retries: None,
1377            certs_dir: None,
1378        };
1379
1380        let client = SocketClient::connect(config, None, None, None)
1381            .await
1382            .expect("Client connect failed unexpectedly");
1383
1384        // Initially, the client should be active
1385        assert!(client.is_active(), "Client should start as active");
1386
1387        // Wait until the client loses connection (i.e. not active),
1388        // then wait until it reconnects (active again).
1389        wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1390
1391        client
1392            .send_bytes(b"TestReconnect".into())
1393            .await
1394            .expect("Send failed");
1395
1396        client.close().await;
1397        server_task.abort();
1398    }
1399}
1400
1401#[cfg(test)]
1402#[cfg(not(feature = "turmoil"))]
1403mod rust_tests {
1404    use nautilus_common::testing::wait_until_async;
1405    use rstest::rstest;
1406    use tokio::{
1407        io::{AsyncReadExt, AsyncWriteExt},
1408        net::TcpListener,
1409        task,
1410        time::{Duration, sleep},
1411    };
1412
1413    use super::*;
1414
1415    #[rstest]
1416    #[tokio::test]
1417    async fn test_reconnect_then_close() {
1418        // Bind an ephemeral port
1419        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1420        let port = listener.local_addr().unwrap().port();
1421
1422        // Server task: accept one connection and then drop it
1423        let server = task::spawn(async move {
1424            if let Ok((mut sock, _)) = listener.accept().await {
1425                drop(sock.shutdown());
1426            }
1427            // Keep listener alive briefly to avoid premature exit
1428            sleep(Duration::from_secs(1)).await;
1429        });
1430
1431        // Configure client with a short reconnect backoff
1432        let config = SocketConfig {
1433            url: format!("127.0.0.1:{port}"),
1434            mode: Mode::Plain,
1435            suffix: b"\r\n".to_vec(),
1436            message_handler: None,
1437            heartbeat: None,
1438            reconnect_timeout_ms: Some(1_000),
1439            reconnect_delay_initial_ms: Some(50),
1440            reconnect_delay_max_ms: Some(100),
1441            reconnect_backoff_factor: Some(1.0),
1442            reconnect_jitter_ms: Some(0),
1443            connection_max_retries: Some(1),
1444            reconnect_max_attempts: None,
1445            certs_dir: None,
1446        };
1447
1448        // Connect client (handler=None)
1449        let client = SocketClient::connect(config.clone(), None, None, None)
1450            .await
1451            .unwrap();
1452
1453        // Wait for client to detect dropped connection and enter reconnect state
1454        wait_until_async(
1455            || async { client.is_reconnecting() },
1456            Duration::from_secs(2),
1457        )
1458        .await;
1459
1460        // Now close the client
1461        client.close().await;
1462        assert!(client.is_closed());
1463        server.abort();
1464    }
1465
1466    #[rstest]
1467    #[tokio::test]
1468    async fn test_reconnect_state_flips_when_reader_stops() {
1469        // Bind an ephemeral port and accept a single connection which we immediately close.
1470        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1471        let port = listener.local_addr().unwrap().port();
1472
1473        let server = task::spawn(async move {
1474            if let Ok((sock, _)) = listener.accept().await {
1475                drop(sock);
1476            }
1477            // Give the client a moment to observe the closed connection.
1478            sleep(Duration::from_millis(50)).await;
1479        });
1480
1481        let config = SocketConfig {
1482            url: format!("127.0.0.1:{port}"),
1483            mode: Mode::Plain,
1484            suffix: b"\r\n".to_vec(),
1485            message_handler: None,
1486            heartbeat: None,
1487            reconnect_timeout_ms: Some(1_000),
1488            reconnect_delay_initial_ms: Some(50),
1489            reconnect_delay_max_ms: Some(100),
1490            reconnect_backoff_factor: Some(1.0),
1491            reconnect_jitter_ms: Some(0),
1492            connection_max_retries: Some(1),
1493            reconnect_max_attempts: None,
1494            certs_dir: None,
1495        };
1496
1497        let client = SocketClient::connect(config, None, None, None)
1498            .await
1499            .unwrap();
1500
1501        wait_until_async(
1502            || async { client.is_reconnecting() },
1503            Duration::from_secs(2),
1504        )
1505        .await;
1506
1507        client.close().await;
1508        server.abort();
1509    }
1510
1511    #[rstest]
1512    fn test_parse_socket_url_raw_address() {
1513        // Raw socket address with TLS mode
1514        let (socket_addr, request_url) =
1515            SocketClientInner::parse_socket_url("example.com:6130", Mode::Tls).unwrap();
1516        assert_eq!(socket_addr, "example.com:6130");
1517        assert_eq!(request_url, "wss://example.com:6130");
1518
1519        // Raw socket address with Plain mode
1520        let (socket_addr, request_url) =
1521            SocketClientInner::parse_socket_url("localhost:8080", Mode::Plain).unwrap();
1522        assert_eq!(socket_addr, "localhost:8080");
1523        assert_eq!(request_url, "ws://localhost:8080");
1524    }
1525
1526    #[rstest]
1527    fn test_parse_socket_url_with_scheme() {
1528        // Full URL with wss scheme
1529        let (socket_addr, request_url) =
1530            SocketClientInner::parse_socket_url("wss://example.com:443/path", Mode::Tls).unwrap();
1531        assert_eq!(socket_addr, "example.com:443");
1532        assert_eq!(request_url, "wss://example.com:443/path");
1533
1534        // Full URL with ws scheme
1535        let (socket_addr, request_url) =
1536            SocketClientInner::parse_socket_url("ws://localhost:8080", Mode::Plain).unwrap();
1537        assert_eq!(socket_addr, "localhost:8080");
1538        assert_eq!(request_url, "ws://localhost:8080");
1539    }
1540
1541    #[rstest]
1542    fn test_parse_socket_url_default_ports() {
1543        // wss without explicit port defaults to 443
1544        let (socket_addr, _) =
1545            SocketClientInner::parse_socket_url("wss://example.com", Mode::Tls).unwrap();
1546        assert_eq!(socket_addr, "example.com:443");
1547
1548        // ws without explicit port defaults to 80
1549        let (socket_addr, _) =
1550            SocketClientInner::parse_socket_url("ws://example.com", Mode::Plain).unwrap();
1551        assert_eq!(socket_addr, "example.com:80");
1552
1553        // https defaults to 443
1554        let (socket_addr, _) =
1555            SocketClientInner::parse_socket_url("https://example.com", Mode::Tls).unwrap();
1556        assert_eq!(socket_addr, "example.com:443");
1557
1558        // http defaults to 80
1559        let (socket_addr, _) =
1560            SocketClientInner::parse_socket_url("http://example.com", Mode::Plain).unwrap();
1561        assert_eq!(socket_addr, "example.com:80");
1562    }
1563
1564    #[rstest]
1565    fn test_parse_socket_url_unknown_scheme_uses_mode() {
1566        // Unknown scheme defaults to mode-based port
1567        let (socket_addr, _) =
1568            SocketClientInner::parse_socket_url("custom://example.com", Mode::Tls).unwrap();
1569        assert_eq!(socket_addr, "example.com:443");
1570
1571        let (socket_addr, _) =
1572            SocketClientInner::parse_socket_url("custom://example.com", Mode::Plain).unwrap();
1573        assert_eq!(socket_addr, "example.com:80");
1574    }
1575
1576    #[rstest]
1577    fn test_parse_socket_url_ipv6() {
1578        // IPv6 address with port
1579        let (socket_addr, request_url) =
1580            SocketClientInner::parse_socket_url("[::1]:8080", Mode::Plain).unwrap();
1581        assert_eq!(socket_addr, "[::1]:8080");
1582        assert_eq!(request_url, "ws://[::1]:8080");
1583
1584        // IPv6 in URL
1585        let (socket_addr, _) =
1586            SocketClientInner::parse_socket_url("ws://[::1]:8080", Mode::Plain).unwrap();
1587        assert_eq!(socket_addr, "[::1]:8080");
1588    }
1589
1590    #[rstest]
1591    #[tokio::test]
1592    async fn test_url_parsing_raw_socket_address() {
1593        // Test that raw socket addresses (host:port) work correctly
1594        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1595        let port = listener.local_addr().unwrap().port();
1596
1597        let server = task::spawn(async move {
1598            if let Ok((sock, _)) = listener.accept().await {
1599                drop(sock);
1600            }
1601            sleep(Duration::from_millis(50)).await;
1602        });
1603
1604        let config = SocketConfig {
1605            url: format!("127.0.0.1:{port}"), // Raw socket address format
1606            mode: Mode::Plain,
1607            suffix: b"\r\n".to_vec(),
1608            message_handler: None,
1609            heartbeat: None,
1610            reconnect_timeout_ms: Some(1_000),
1611            reconnect_delay_initial_ms: Some(50),
1612            reconnect_delay_max_ms: Some(100),
1613            reconnect_backoff_factor: Some(1.0),
1614            reconnect_jitter_ms: Some(0),
1615            connection_max_retries: Some(1),
1616            reconnect_max_attempts: None,
1617            certs_dir: None,
1618        };
1619
1620        // Should successfully connect with raw socket address
1621        let client = SocketClient::connect(config, None, None, None).await;
1622        assert!(
1623            client.is_ok(),
1624            "Client should connect with raw socket address format"
1625        );
1626
1627        if let Ok(client) = client {
1628            client.close().await;
1629        }
1630        server.abort();
1631    }
1632
1633    #[rstest]
1634    #[tokio::test]
1635    async fn test_url_parsing_with_scheme() {
1636        // Test that URLs with schemes also work
1637        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1638        let port = listener.local_addr().unwrap().port();
1639
1640        let server = task::spawn(async move {
1641            if let Ok((sock, _)) = listener.accept().await {
1642                drop(sock);
1643            }
1644            sleep(Duration::from_millis(50)).await;
1645        });
1646
1647        let config = SocketConfig {
1648            url: format!("ws://127.0.0.1:{port}"), // URL with scheme
1649            mode: Mode::Plain,
1650            suffix: b"\r\n".to_vec(),
1651            message_handler: None,
1652            heartbeat: None,
1653            reconnect_timeout_ms: Some(1_000),
1654            reconnect_delay_initial_ms: Some(50),
1655            reconnect_delay_max_ms: Some(100),
1656            reconnect_backoff_factor: Some(1.0),
1657            reconnect_jitter_ms: Some(0),
1658            connection_max_retries: Some(1),
1659            reconnect_max_attempts: None,
1660            certs_dir: None,
1661        };
1662
1663        // Should successfully connect with URL format
1664        let client = SocketClient::connect(config, None, None, None).await;
1665        assert!(
1666            client.is_ok(),
1667            "Client should connect with URL scheme format"
1668        );
1669
1670        if let Ok(client) = client {
1671            client.close().await;
1672        }
1673        server.abort();
1674    }
1675
1676    #[rstest]
1677    fn test_parse_socket_url_ipv6_with_zone() {
1678        // IPv6 with zone ID (link-local address)
1679        let (socket_addr, request_url) =
1680            SocketClientInner::parse_socket_url("[fe80::1%eth0]:8080", Mode::Plain).unwrap();
1681        assert_eq!(socket_addr, "[fe80::1%eth0]:8080");
1682        assert_eq!(request_url, "ws://[fe80::1%eth0]:8080");
1683
1684        // Verify zone is preserved in URL format too
1685        let (socket_addr, request_url) =
1686            SocketClientInner::parse_socket_url("ws://[fe80::1%lo]:9090", Mode::Plain).unwrap();
1687        assert_eq!(socket_addr, "[fe80::1%lo]:9090");
1688        assert_eq!(request_url, "ws://[fe80::1%lo]:9090");
1689    }
1690
1691    #[rstest]
1692    #[tokio::test]
1693    async fn test_ipv6_loopback_connection() {
1694        // Test IPv6 loopback address connection
1695        // Skip if IPv6 is not available on the system
1696        if TcpListener::bind("[::1]:0").await.is_err() {
1697            eprintln!("IPv6 not available, skipping test");
1698            return;
1699        }
1700
1701        let listener = TcpListener::bind("[::1]:0").await.unwrap();
1702        let port = listener.local_addr().unwrap().port();
1703
1704        let server = task::spawn(async move {
1705            if let Ok((mut sock, _)) = listener.accept().await {
1706                let mut buf = vec![0u8; 1024];
1707                if let Ok(n) = sock.read(&mut buf).await {
1708                    // Echo back
1709                    let _ = sock.write_all(&buf[..n]).await;
1710                }
1711            }
1712            sleep(Duration::from_millis(50)).await;
1713        });
1714
1715        let config = SocketConfig {
1716            url: format!("[::1]:{port}"), // IPv6 loopback
1717            mode: Mode::Plain,
1718            suffix: b"\r\n".to_vec(),
1719            message_handler: None,
1720            heartbeat: None,
1721            reconnect_timeout_ms: Some(1_000),
1722            reconnect_delay_initial_ms: Some(50),
1723            reconnect_delay_max_ms: Some(100),
1724            reconnect_backoff_factor: Some(1.0),
1725            reconnect_jitter_ms: Some(0),
1726            connection_max_retries: Some(1),
1727            reconnect_max_attempts: None,
1728            certs_dir: None,
1729        };
1730
1731        let client = SocketClient::connect(config, None, None, None).await;
1732        assert!(
1733            client.is_ok(),
1734            "Client should connect to IPv6 loopback address"
1735        );
1736
1737        if let Ok(client) = client {
1738            client.close().await;
1739        }
1740        server.abort();
1741    }
1742
1743    #[rstest]
1744    #[tokio::test]
1745    async fn test_send_waits_during_reconnection() {
1746        // Test that send operations wait for reconnection to complete (up to configured timeout)
1747        use nautilus_common::testing::wait_until_async;
1748
1749        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1750        let port = listener.local_addr().unwrap().port();
1751
1752        let server = task::spawn(async move {
1753            // First connection - accept and immediately close
1754            if let Ok((sock, _)) = listener.accept().await {
1755                drop(sock);
1756            }
1757
1758            // Wait before accepting second connection
1759            sleep(Duration::from_millis(500)).await;
1760
1761            // Second connection - accept and keep alive
1762            if let Ok((mut sock, _)) = listener.accept().await {
1763                // Echo messages
1764                let mut buf = vec![0u8; 1024];
1765                while let Ok(n) = sock.read(&mut buf).await {
1766                    if n == 0 {
1767                        break;
1768                    }
1769                    if sock.write_all(&buf[..n]).await.is_err() {
1770                        break;
1771                    }
1772                }
1773            }
1774        });
1775
1776        let config = SocketConfig {
1777            url: format!("127.0.0.1:{port}"),
1778            mode: Mode::Plain,
1779            suffix: b"\r\n".to_vec(),
1780            message_handler: None,
1781            heartbeat: None,
1782            reconnect_timeout_ms: Some(5_000), // 5s timeout - enough for reconnect
1783            reconnect_delay_initial_ms: Some(100),
1784            reconnect_delay_max_ms: Some(200),
1785            reconnect_backoff_factor: Some(1.0),
1786            reconnect_jitter_ms: Some(0),
1787            connection_max_retries: Some(1),
1788            reconnect_max_attempts: None,
1789            certs_dir: None,
1790        };
1791
1792        let client = SocketClient::connect(config, None, None, None)
1793            .await
1794            .unwrap();
1795
1796        // Wait for reconnection to trigger
1797        wait_until_async(
1798            || async { client.is_reconnecting() },
1799            Duration::from_secs(2),
1800        )
1801        .await;
1802
1803        // Try to send while reconnecting - should wait and succeed after reconnect
1804        let send_result = tokio::time::timeout(
1805            Duration::from_secs(3),
1806            client.send_bytes(b"test_message".to_vec()),
1807        )
1808        .await;
1809
1810        assert!(
1811            send_result.is_ok() && send_result.unwrap().is_ok(),
1812            "Send should succeed after waiting for reconnection"
1813        );
1814
1815        client.close().await;
1816        server.abort();
1817    }
1818
1819    #[rstest]
1820    #[tokio::test]
1821    async fn test_send_bytes_timeout_uses_configured_reconnect_timeout() {
1822        // Test that send_bytes operations respect the configured reconnect_timeout.
1823        // When a client is stuck in RECONNECT longer than the timeout, sends should fail with Timeout.
1824        use nautilus_common::testing::wait_until_async;
1825
1826        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1827        let port = listener.local_addr().unwrap().port();
1828
1829        let server = task::spawn(async move {
1830            // Accept first connection and immediately close it
1831            if let Ok((sock, _)) = listener.accept().await {
1832                drop(sock);
1833            }
1834            // Drop listener entirely so reconnection fails completely
1835            drop(listener);
1836            sleep(Duration::from_secs(60)).await;
1837        });
1838
1839        let config = SocketConfig {
1840            url: format!("127.0.0.1:{port}"),
1841            mode: Mode::Plain,
1842            suffix: b"\r\n".to_vec(),
1843            message_handler: None,
1844            heartbeat: None,
1845            reconnect_timeout_ms: Some(1_000), // 1s timeout for faster test
1846            reconnect_delay_initial_ms: Some(200), // Short backoff (but > timeout) to keep client in RECONNECT
1847            reconnect_delay_max_ms: Some(200),
1848            reconnect_backoff_factor: Some(1.0),
1849            reconnect_jitter_ms: Some(0),
1850            connection_max_retries: Some(1),
1851            reconnect_max_attempts: None,
1852            certs_dir: None,
1853        };
1854
1855        let client = SocketClient::connect(config, None, None, None)
1856            .await
1857            .unwrap();
1858
1859        // Wait for client to enter RECONNECT state
1860        wait_until_async(
1861            || async { client.is_reconnecting() },
1862            Duration::from_secs(3),
1863        )
1864        .await;
1865
1866        // Attempt send while stuck in RECONNECT - should timeout after 1s (configured timeout)
1867        // The client will try to reconnect for 1s, fail, then wait 5s backoff before next attempt
1868        let start = std::time::Instant::now();
1869        let send_result = client.send_bytes(b"test".to_vec()).await;
1870        let elapsed = start.elapsed();
1871
1872        assert!(
1873            send_result.is_err(),
1874            "Send should fail when client stuck in RECONNECT, was: {send_result:?}"
1875        );
1876        assert!(
1877            matches!(send_result, Err(crate::error::SendError::Timeout)),
1878            "Send should return Timeout error, was: {send_result:?}"
1879        );
1880        // Verify timeout respects configured value (1s), but don't check upper bound
1881        // as CI scheduler jitter can cause legitimate delays beyond the timeout
1882        assert!(
1883            elapsed >= Duration::from_millis(900),
1884            "Send should timeout after at least 1s (configured timeout), took {elapsed:?}"
1885        );
1886
1887        client.close().await;
1888        server.abort();
1889    }
1890}