nautilus_network/socket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance raw TCP client implementation with TLS capability, automatic reconnection
17//! with exponential backoff and state management.
18//!
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED).
21//! - Synchronized reconnection with backoff.
22//! - Split read/write architecture.
23//! - Python callback integration.
24//!
25//! **Design**:
26//! - Single reader, multiple writer model.
27//! - Read half runs in dedicated task.
28//! - Write half runs in dedicated task connected with channel.
29//! - Controller task manages lifecycle.
30
31use std::{
32    collections::VecDeque,
33    fmt::Debug,
34    path::Path,
35    sync::{
36        Arc,
37        atomic::{AtomicU8, Ordering},
38    },
39    time::Duration,
40};
41
42use bytes::Bytes;
43use nautilus_core::CleanDrop;
44use nautilus_cryptography::providers::install_cryptographic_provider;
45use tokio::io::{AsyncReadExt, AsyncWriteExt};
46use tokio_tungstenite::tungstenite::{Error, client::IntoClientRequest, stream::Mode};
47
48use super::{
49    SocketConfig, TcpMessageHandler, TcpReader, TcpWriter, WriterCommand, fix::process_fix_buffer,
50};
51use crate::{
52    backoff::ExponentialBackoff,
53    error::SendError,
54    logging::{log_task_aborted, log_task_started, log_task_stopped},
55    mode::ConnectionMode,
56    net::TcpStream,
57    tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
58};
59
60// Connection timing constants
61const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
62const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
63const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
64const SEND_OPERATION_CHECK_INTERVAL_MS: u64 = 1;
65
66// Maximum buffer size for read operations (10 MB)
67const MAX_READ_BUFFER_BYTES: usize = 10 * 1024 * 1024;
68
69/// Creates a `TcpStream` with the server.
70///
71/// The stream can be encrypted with TLS or Plain. The stream is split into
72/// read and write ends:
73/// - The read end is passed to the task that keeps receiving
74///   messages from the server and passing them to a handler.
75/// - The write end is passed to a task which receives messages over a channel
76///   to send to the server.
77///
78/// The heartbeat is optional and can be configured with an interval and data to
79/// send.
80///
81/// The client uses a suffix to separate messages on the byte stream. It is
82/// appended to all sent messages and heartbeats. It is also used to split
83/// the received byte stream.
84#[cfg_attr(
85    feature = "python",
86    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
87)]
88struct SocketClientInner {
89    config: SocketConfig,
90    connector: Option<Connector>,
91    read_task: Arc<tokio::task::JoinHandle<()>>,
92    write_task: tokio::task::JoinHandle<()>,
93    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
94    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
95    connection_mode: Arc<AtomicU8>,
96    reconnect_timeout: Duration,
97    backoff: ExponentialBackoff,
98    handler: Option<TcpMessageHandler>,
99    reconnect_max_attempts: Option<u32>,
100    reconnect_attempt_count: u32,
101}
102
103impl SocketClientInner {
104    /// Connect to a URL with the specified configuration.
105    ///
106    /// # Errors
107    ///
108    /// Returns an error if connection fails or configuration is invalid.
109    pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
110        const CONNECTION_TIMEOUT_SECS: u64 = 10;
111
112        install_cryptographic_provider();
113
114        // Validate suffix is non-empty to prevent panic in read loop (windows(0) panics)
115        if config.suffix.is_empty() {
116            anyhow::bail!("Socket suffix cannot be empty: suffix is required for message framing");
117        }
118
119        let SocketConfig {
120            url,
121            mode,
122            heartbeat,
123            suffix,
124            message_handler,
125            reconnect_timeout_ms,
126            reconnect_delay_initial_ms,
127            reconnect_delay_max_ms,
128            reconnect_backoff_factor,
129            reconnect_jitter_ms,
130            connection_max_retries,
131            reconnect_max_attempts,
132            certs_dir,
133        } = &config.clone();
134        let connector = if let Some(dir) = certs_dir {
135            let config = create_tls_config_from_certs_dir(Path::new(dir), false)?;
136            Some(Connector::Rustls(Arc::new(config)))
137        } else {
138            None
139        };
140
141        // Retry initial connection with exponential backoff to handle transient DNS/network issues
142        let max_retries = connection_max_retries.unwrap_or(5);
143
144        let mut backoff = ExponentialBackoff::new(
145            Duration::from_millis(500),
146            Duration::from_millis(5000),
147            2.0,
148            250,
149            false,
150        )?;
151
152        #[allow(unused_assignments)]
153        let mut last_error = String::new();
154        let mut attempt = 0;
155        let (reader, writer) = loop {
156            attempt += 1;
157
158            match tokio::time::timeout(
159                Duration::from_secs(CONNECTION_TIMEOUT_SECS),
160                Self::tls_connect_with_server(url, *mode, connector.clone()),
161            )
162            .await
163            {
164                Ok(Ok(result)) => {
165                    if attempt > 1 {
166                        log::info!("Socket connection established after {attempt} attempts");
167                    }
168                    break result;
169                }
170                Ok(Err(e)) => {
171                    last_error = e.to_string();
172                    log::warn!(
173                        "Socket connection attempt {attempt}/{max_retries} to {url} failed: {last_error}"
174                    );
175                }
176                Err(_) => {
177                    last_error = format!(
178                        "Connection timeout after {CONNECTION_TIMEOUT_SECS}s (possible DNS resolution failure)"
179                    );
180                    log::warn!(
181                        "Socket connection attempt {attempt}/{max_retries} to {url} timed out"
182                    );
183                }
184            }
185
186            if attempt >= max_retries {
187                anyhow::bail!(
188                    "Failed to connect to {} after {} attempts: {}. \
189                    If this is a DNS error, check your network configuration and DNS settings.",
190                    url,
191                    max_retries,
192                    if last_error.is_empty() {
193                        "unknown error"
194                    } else {
195                        &last_error
196                    }
197                );
198            }
199
200            let delay = backoff.next_duration();
201            log::debug!(
202                "Retrying in {delay:?} (attempt {}/{})",
203                attempt + 1,
204                max_retries
205            );
206            tokio::time::sleep(delay).await;
207        };
208
209        log::debug!("Connected");
210
211        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
212
213        let read_task = Arc::new(Self::spawn_read_task(
214            connection_mode.clone(),
215            reader,
216            message_handler.clone(),
217            suffix.clone(),
218        ));
219
220        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
221
222        let write_task =
223            Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
224
225        // Optionally spawn a heartbeat task to periodically ping server
226        let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
227            Self::spawn_heartbeat_task(
228                connection_mode.clone(),
229                heartbeat.clone(),
230                writer_tx.clone(),
231            )
232        });
233
234        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
235        let backoff = ExponentialBackoff::new(
236            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
237            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
238            reconnect_backoff_factor.unwrap_or(1.5),
239            reconnect_jitter_ms.unwrap_or(100),
240            true, // immediate-first
241        )?;
242
243        Ok(Self {
244            config,
245            connector,
246            read_task,
247            write_task,
248            writer_tx,
249            heartbeat_task,
250            connection_mode,
251            reconnect_timeout,
252            backoff,
253            handler: message_handler.clone(),
254            reconnect_max_attempts: *reconnect_max_attempts,
255            reconnect_attempt_count: 0,
256        })
257    }
258
259    /// Parse URL and extract socket address and request URL.
260    ///
261    /// Accepts either:
262    /// - Raw socket address: "host:port" → returns ("host:port", "scheme://host:port")
263    /// - Full URL: "scheme://host:port/path" → returns ("host:port", original URL)
264    ///
265    /// # Errors
266    ///
267    /// Returns an error if the URL is invalid or missing required components.
268    fn parse_socket_url(url: &str, mode: Mode) -> Result<(String, String), Error> {
269        if url.contains("://") {
270            // URL with scheme (e.g., "wss://host:port/path")
271            let parsed = url.parse::<http::Uri>().map_err(|e| {
272                Error::Io(std::io::Error::new(
273                    std::io::ErrorKind::InvalidInput,
274                    format!("Invalid URL: {e}"),
275                ))
276            })?;
277
278            let host = parsed.host().ok_or_else(|| {
279                Error::Io(std::io::Error::new(
280                    std::io::ErrorKind::InvalidInput,
281                    "URL missing host",
282                ))
283            })?;
284
285            let port = parsed
286                .port_u16()
287                .unwrap_or_else(|| match parsed.scheme_str() {
288                    Some("wss" | "https") => 443,
289                    Some("ws" | "http") => 80,
290                    _ => match mode {
291                        Mode::Tls => 443,
292                        Mode::Plain => 80,
293                    },
294                });
295
296            Ok((format!("{host}:{port}"), url.to_string()))
297        } else {
298            // Raw socket address (e.g., "host:port")
299            // Construct a proper URL for the request based on mode
300            let scheme = match mode {
301                Mode::Tls => "wss",
302                Mode::Plain => "ws",
303            };
304            Ok((url.to_string(), format!("{scheme}://{url}")))
305        }
306    }
307
308    /// Establish a TLS or plain TCP connection with the server.
309    ///
310    /// Accepts either a raw socket address (e.g., "host:port") or a full URL with scheme
311    /// (e.g., "wss://host:port"). For FIX/raw socket connections, use the host:port format.
312    /// For WebSocket-style connections, include the scheme.
313    ///
314    /// # Errors
315    ///
316    /// Returns an error if the connection cannot be established.
317    pub async fn tls_connect_with_server(
318        url: &str,
319        mode: Mode,
320        connector: Option<Connector>,
321    ) -> Result<(TcpReader, TcpWriter), Error> {
322        log::debug!("Connecting to {url}");
323
324        let (socket_addr, request_url) = Self::parse_socket_url(url, mode)?;
325        let tcp_result = TcpStream::connect(&socket_addr).await;
326
327        match tcp_result {
328            Ok(stream) => {
329                log::debug!("TCP connection established to {socket_addr}, proceeding with TLS");
330                if let Err(e) = stream.set_nodelay(true) {
331                    log::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
332                }
333                let request = request_url.into_client_request()?;
334                tcp_tls(&request, mode, stream, connector)
335                    .await
336                    .map(tokio::io::split)
337            }
338            Err(e) => {
339                log::error!("TCP connection failed to {socket_addr}: {e:?}");
340                Err(Error::Io(e))
341            }
342        }
343    }
344
345    /// Reconnect with server.
346    ///
347    /// Makes a new connection with server, uses the new read and write halves
348    /// to update the reader and writer.
349    async fn reconnect(&mut self) -> Result<(), Error> {
350        log::debug!("Reconnecting");
351
352        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
353            log::debug!("Reconnect aborted due to disconnect state");
354            return Ok(());
355        }
356
357        tokio::time::timeout(self.reconnect_timeout, async {
358            let SocketConfig {
359                url,
360                mode,
361                heartbeat: _,
362                suffix,
363                message_handler: _,
364                reconnect_timeout_ms: _,
365                reconnect_delay_initial_ms: _,
366                reconnect_backoff_factor: _,
367                reconnect_delay_max_ms: _,
368                reconnect_jitter_ms: _,
369                connection_max_retries: _,
370                reconnect_max_attempts: _,
371                certs_dir: _,
372            } = &self.config;
373            // Create a fresh connection
374            let connector = self.connector.clone();
375            // Attempt to connect; abort early if a disconnect was requested
376            let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
377
378            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
379                log::debug!("Reconnect aborted mid-flight (after connect)");
380                return Ok(());
381            }
382            log::debug!("Connected");
383
384            // Use a oneshot channel to synchronize with the writer task.
385            // We must verify that the buffer was successfully drained before transitioning to ACTIVE
386            // to prevent silent message loss if the new connection drops immediately.
387            let (tx, rx) = tokio::sync::oneshot::channel();
388            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
389                log::error!("{e}");
390                return Err(Error::Io(std::io::Error::new(
391                    std::io::ErrorKind::BrokenPipe,
392                    format!("Failed to send update command: {e}"),
393                )));
394            }
395
396            // Wait for writer to confirm it has drained the buffer
397            match rx.await {
398                Ok(true) => log::debug!("Writer confirmed buffer drain success"),
399                Ok(false) => {
400                    log::warn!("Writer failed to drain buffer, aborting reconnect");
401                    // Return error to trigger retry logic in controller
402                    return Err(Error::Io(std::io::Error::other(
403                        "Failed to drain reconnection buffer",
404                    )));
405                }
406                Err(e) => {
407                    log::error!("Writer dropped update channel: {e}");
408                    return Err(Error::Io(std::io::Error::new(
409                        std::io::ErrorKind::BrokenPipe,
410                        "Writer task dropped response channel",
411                    )));
412                }
413            }
414
415            // Delay before closing connection
416            tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
417
418            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
419                log::debug!("Reconnect aborted mid-flight (after delay)");
420                return Ok(());
421            }
422
423            if !self.read_task.is_finished() {
424                self.read_task.abort();
425                log_task_aborted("read");
426            }
427
428            // Atomically transition from Reconnect to Active
429            // This prevents race condition where disconnect could be requested between check and store
430            if self
431                .connection_mode
432                .compare_exchange(
433                    ConnectionMode::Reconnect.as_u8(),
434                    ConnectionMode::Active.as_u8(),
435                    Ordering::SeqCst,
436                    Ordering::SeqCst,
437                )
438                .is_err()
439            {
440                log::debug!("Reconnect aborted (state changed during reconnect)");
441                return Ok(());
442            }
443
444            // Spawn new read task
445            self.read_task = Arc::new(Self::spawn_read_task(
446                self.connection_mode.clone(),
447                reader,
448                self.handler.clone(),
449                suffix.clone(),
450            ));
451
452            log::debug!("Reconnect succeeded");
453            Ok(())
454        })
455        .await
456        .map_err(|_| {
457            Error::Io(std::io::Error::new(
458                std::io::ErrorKind::TimedOut,
459                format!(
460                    "reconnection timed out after {}s",
461                    self.reconnect_timeout.as_secs_f64()
462                ),
463            ))
464        })?
465    }
466
467    /// Check if the client is still alive.
468    ///
469    /// The client is connected if the read task has not finished. It is expected
470    /// that in case of any failure client or server side. The read task will be
471    /// shutdown. There might be some delay between the connection being closed
472    /// and the client detecting it.
473    #[inline]
474    #[must_use]
475    pub fn is_alive(&self) -> bool {
476        !self.read_task.is_finished()
477    }
478
479    #[must_use]
480    fn spawn_read_task(
481        connection_state: Arc<AtomicU8>,
482        mut reader: TcpReader,
483        handler: Option<TcpMessageHandler>,
484        suffix: Vec<u8>,
485    ) -> tokio::task::JoinHandle<()> {
486        log_task_started("read");
487
488        // Interval between checking the connection mode
489        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
490
491        tokio::task::spawn(async move {
492            let mut buf = Vec::new();
493
494            loop {
495                if !ConnectionMode::from_atomic(&connection_state).is_active() {
496                    break;
497                }
498
499                match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
500                    // Connection has been terminated or vector buffer is complete
501                    Ok(Ok(0)) => {
502                        log::debug!("Connection closed by server");
503                        break;
504                    }
505                    Ok(Err(e)) => {
506                        log::debug!("Connection ended: {e}");
507                        break;
508                    }
509                    // Received bytes of data
510                    Ok(Ok(bytes)) => {
511                        log::trace!("Received <binary> {bytes} bytes");
512
513                        // Check if buffer contains FIX protocol messages (starts with "8=FIX")
514                        let is_fix = buf.len() >= 5 && buf.starts_with(b"8=FIX");
515
516                        if is_fix && handler.is_some() {
517                            // FIX protocol processing
518                            if let Some(ref handler) = handler {
519                                process_fix_buffer(&mut buf, handler);
520                            }
521                        } else {
522                            // Regular suffix-based message processing
523                            while let Some((i, _)) = &buf
524                                .windows(suffix.len())
525                                .enumerate()
526                                .find(|(_, pair)| pair.eq(&suffix))
527                            {
528                                let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
529                                data.truncate(data.len() - suffix.len());
530
531                                if let Some(ref handler) = handler {
532                                    handler(&data);
533                                }
534                            }
535                        }
536
537                        if buf.len() > MAX_READ_BUFFER_BYTES {
538                            log::error!(
539                                "Read buffer exceeded maximum size ({MAX_READ_BUFFER_BYTES} bytes), closing connection"
540                            );
541                            break;
542                        }
543                    }
544                    Err(_) => {
545                        // Timeout - continue loop and check connection mode
546                        continue;
547                    }
548                }
549            }
550
551            log_task_stopped("read");
552        })
553    }
554
555    /// Drains buffered messages after reconnection completes.
556    ///
557    /// Attempts to send all buffered messages that were queued during reconnection.
558    /// Uses a peek-and-pop pattern to preserve messages if sending fails midway through the buffer.
559    ///
560    /// # Returns
561    ///
562    /// Returns `true` if a send error occurred (buffer may still contain unsent messages),
563    /// `false` if all messages were sent successfully (buffer is empty).
564    async fn drain_reconnect_buffer(
565        buffer: &mut VecDeque<Bytes>,
566        writer: &mut TcpWriter,
567        suffix: &[u8],
568    ) -> bool {
569        if buffer.is_empty() {
570            return false;
571        }
572
573        let initial_buffer_len = buffer.len();
574        log::info!("Sending {initial_buffer_len} buffered messages after reconnection");
575
576        let mut send_error_occurred = false;
577
578        while let Some(buffered_msg) = buffer.front() {
579            let mut combined_msg = Vec::with_capacity(buffered_msg.len() + suffix.len());
580            combined_msg.extend_from_slice(buffered_msg);
581            combined_msg.extend_from_slice(suffix);
582
583            if let Err(e) = writer.write_all(&combined_msg).await {
584                log::error!(
585                    "Failed to send buffered message with suffix after reconnection: {e}, {} messages remain in buffer",
586                    buffer.len()
587                );
588                send_error_occurred = true;
589                break;
590            }
591
592            buffer.pop_front();
593        }
594
595        if buffer.is_empty() {
596            log::info!("Successfully sent all {initial_buffer_len} buffered messages");
597        }
598
599        send_error_occurred
600    }
601
602    fn spawn_write_task(
603        connection_state: Arc<AtomicU8>,
604        writer: TcpWriter,
605        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
606        suffix: Vec<u8>,
607    ) -> tokio::task::JoinHandle<()> {
608        log_task_started("write");
609
610        // Interval between checking the connection mode
611        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
612
613        tokio::task::spawn(async move {
614            let mut active_writer = writer;
615            let mut reconnect_buffer: VecDeque<Bytes> = VecDeque::new();
616
617            loop {
618                if matches!(
619                    ConnectionMode::from_atomic(&connection_state),
620                    ConnectionMode::Disconnect | ConnectionMode::Closed
621                ) {
622                    break;
623                }
624
625                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
626                    Ok(Some(msg)) => {
627                        // Re-check connection mode after receiving a message
628                        let mode = ConnectionMode::from_atomic(&connection_state);
629                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
630                            break;
631                        }
632
633                        match msg {
634                            WriterCommand::Update(new_writer, tx) => {
635                                log::debug!("Received new writer");
636
637                                // Delay before closing connection
638                                tokio::time::sleep(Duration::from_millis(100)).await;
639
640                                // Attempt to shutdown the writer gracefully before updating,
641                                // we ignore any error as the writer may already be closed.
642                                _ = tokio::time::timeout(
643                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
644                                    active_writer.shutdown(),
645                                )
646                                .await;
647
648                                active_writer = new_writer;
649                                log::debug!("Updated writer");
650
651                                let send_error = Self::drain_reconnect_buffer(
652                                    &mut reconnect_buffer,
653                                    &mut active_writer,
654                                    &suffix,
655                                )
656                                .await;
657
658                                if let Err(e) = tx.send(!send_error) {
659                                    log::error!(
660                                        "Failed to report drain status to controller: {e:?}"
661                                    );
662                                }
663                            }
664                            _ if mode.is_reconnect() => {
665                                if let WriterCommand::Send(data) = msg {
666                                    log::debug!(
667                                        "Buffering message while reconnecting ({} bytes)",
668                                        data.len()
669                                    );
670                                    reconnect_buffer.push_back(data);
671                                }
672                                continue;
673                            }
674                            WriterCommand::Send(msg) => {
675                                if let Err(e) = active_writer.write_all(&msg).await {
676                                    log::error!("Failed to send message: {e}");
677                                    log::warn!("Writer triggering reconnect");
678                                    reconnect_buffer.push_back(msg);
679                                    connection_state
680                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
681                                    continue;
682                                }
683                                if let Err(e) = active_writer.write_all(&suffix).await {
684                                    log::error!("Failed to send suffix: {e}");
685                                    log::warn!("Writer triggering reconnect");
686                                    // Buffer this message before triggering reconnect since suffix failed
687                                    reconnect_buffer.push_back(msg);
688                                    connection_state
689                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
690                                    continue;
691                                }
692                            }
693                        }
694                    }
695                    Ok(None) => {
696                        // Channel closed - writer task should terminate
697                        log::debug!("Writer channel closed, terminating writer task");
698                        break;
699                    }
700                    Err(_) => {
701                        // Timeout - just continue the loop
702                        continue;
703                    }
704                }
705            }
706
707            // Attempt to shutdown the writer gracefully before exiting,
708            // we ignore any error as the writer may already be closed.
709            _ = tokio::time::timeout(
710                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
711                active_writer.shutdown(),
712            )
713            .await;
714
715            log_task_stopped("write");
716        })
717    }
718
719    fn spawn_heartbeat_task(
720        connection_state: Arc<AtomicU8>,
721        heartbeat: (u64, Vec<u8>),
722        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
723    ) -> tokio::task::JoinHandle<()> {
724        log_task_started("heartbeat");
725        let (interval_secs, message) = heartbeat;
726
727        tokio::task::spawn(async move {
728            let interval = Duration::from_secs(interval_secs);
729
730            loop {
731                tokio::time::sleep(interval).await;
732
733                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
734                    ConnectionMode::Active => {
735                        let msg = WriterCommand::Send(message.clone().into());
736
737                        match writer_tx.send(msg) {
738                            Ok(()) => log::trace!("Sent heartbeat to writer task"),
739                            Err(e) => {
740                                log::error!("Failed to send heartbeat to writer task: {e}");
741                            }
742                        }
743                    }
744                    ConnectionMode::Reconnect => continue,
745                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
746                }
747            }
748
749            log_task_stopped("heartbeat");
750        })
751    }
752}
753
754impl Drop for SocketClientInner {
755    fn drop(&mut self) {
756        // Delegate to explicit cleanup handler
757        self.clean_drop();
758    }
759}
760
761/// Cleanup on drop: aborts background tasks and clears handlers to break reference cycles.
762impl CleanDrop for SocketClientInner {
763    fn clean_drop(&mut self) {
764        if !self.read_task.is_finished() {
765            self.read_task.abort();
766            log_task_aborted("read");
767        }
768
769        if !self.write_task.is_finished() {
770            self.write_task.abort();
771            log_task_aborted("write");
772        }
773
774        if let Some(ref handle) = self.heartbeat_task.take()
775            && !handle.is_finished()
776        {
777            handle.abort();
778            log_task_aborted("heartbeat");
779        }
780
781        #[cfg(feature = "python")]
782        {
783            // Remove stored handler to break ref cycle
784            self.config.message_handler = None;
785        }
786    }
787}
788
789#[cfg_attr(
790    feature = "python",
791    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
792)]
793pub struct SocketClient {
794    pub(crate) controller_task: tokio::task::JoinHandle<()>,
795    pub(crate) connection_mode: Arc<AtomicU8>,
796    pub(crate) reconnect_timeout: Duration,
797    pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
798}
799
800impl Debug for SocketClient {
801    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
802        f.debug_struct(stringify!(SocketClient)).finish()
803    }
804}
805
806impl SocketClient {
807    /// Connect to the server.
808    ///
809    /// # Errors
810    ///
811    /// Returns any error connecting to the server.
812    pub async fn connect(
813        config: SocketConfig,
814        post_connection: Option<Arc<dyn Fn() + Send + Sync>>,
815        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
816        post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
817    ) -> anyhow::Result<Self> {
818        let inner = SocketClientInner::connect_url(config).await?;
819        let writer_tx = inner.writer_tx.clone();
820        let connection_mode = inner.connection_mode.clone();
821        let reconnect_timeout = inner.reconnect_timeout;
822
823        let controller_task = Self::spawn_controller_task(
824            inner,
825            connection_mode.clone(),
826            post_reconnection,
827            post_disconnection,
828        );
829
830        if let Some(handler) = post_connection {
831            handler();
832            log::debug!("Called `post_connection` handler");
833        }
834
835        Ok(Self {
836            controller_task,
837            connection_mode,
838            reconnect_timeout,
839            writer_tx,
840        })
841    }
842
843    /// Returns the current connection mode.
844    #[must_use]
845    pub fn connection_mode(&self) -> ConnectionMode {
846        ConnectionMode::from_atomic(&self.connection_mode)
847    }
848
849    /// Check if the client connection is active.
850    ///
851    /// Returns `true` if the client is connected and has not been signalled to disconnect.
852    /// The client will automatically retry connection based on its configuration.
853    #[inline]
854    #[must_use]
855    pub fn is_active(&self) -> bool {
856        self.connection_mode().is_active()
857    }
858
859    /// Check if the client is reconnecting.
860    ///
861    /// Returns `true` if the client lost connection and is attempting to reestablish it.
862    /// The client will automatically retry connection based on its configuration.
863    #[inline]
864    #[must_use]
865    pub fn is_reconnecting(&self) -> bool {
866        self.connection_mode().is_reconnect()
867    }
868
869    /// Check if the client is disconnecting.
870    ///
871    /// Returns `true` if the client is in disconnect mode.
872    #[inline]
873    #[must_use]
874    pub fn is_disconnecting(&self) -> bool {
875        self.connection_mode().is_disconnect()
876    }
877
878    /// Check if the client is closed.
879    ///
880    /// Returns `true` if the client has been explicitly disconnected or reached
881    /// maximum reconnection attempts. In this state, the client cannot be reused
882    /// and a new client must be created for further connections.
883    #[inline]
884    #[must_use]
885    pub fn is_closed(&self) -> bool {
886        self.connection_mode().is_closed()
887    }
888
889    /// Close the client.
890    ///
891    /// Controller task will periodically check the disconnect mode
892    /// and shutdown the client if it is not alive.
893    pub async fn close(&self) {
894        self.connection_mode
895            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
896
897        if tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
898            while !self.is_closed() {
899                tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
900            }
901
902            if !self.controller_task.is_finished() {
903                self.controller_task.abort();
904                log_task_aborted("controller");
905            }
906        })
907        .await
908            == Ok(())
909        {
910            log_task_stopped("controller");
911        } else {
912            log::error!("Timeout waiting for controller task to finish");
913            if !self.controller_task.is_finished() {
914                self.controller_task.abort();
915                log_task_aborted("controller");
916            }
917        }
918    }
919
920    /// Sends a message of the given `data`.
921    ///
922    /// # Errors
923    ///
924    /// Returns an error if sending fails.
925    pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
926        // Check connection state to fail fast
927        if self.is_closed() || self.is_disconnecting() {
928            return Err(SendError::Closed);
929        }
930
931        let timeout = self.reconnect_timeout;
932        let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
933
934        if !self.is_active() {
935            log::debug!("Waiting for client to become ACTIVE before sending...");
936
937            let inner = tokio::time::timeout(timeout, async {
938                loop {
939                    if self.is_active() {
940                        return Ok(());
941                    }
942                    if matches!(
943                        self.connection_mode(),
944                        ConnectionMode::Disconnect | ConnectionMode::Closed
945                    ) {
946                        return Err(());
947                    }
948                    tokio::time::sleep(check_interval).await;
949                }
950            })
951            .await
952            .map_err(|_| SendError::Timeout)?;
953            inner.map_err(|()| SendError::Closed)?;
954        }
955
956        let msg = WriterCommand::Send(data.into());
957        self.writer_tx
958            .send(msg)
959            .map_err(|e| SendError::BrokenPipe(e.to_string()))
960    }
961
962    fn spawn_controller_task(
963        mut inner: SocketClientInner,
964        connection_mode: Arc<AtomicU8>,
965        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
966        post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
967    ) -> tokio::task::JoinHandle<()> {
968        tokio::task::spawn(async move {
969            log_task_started("controller");
970
971            let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
972
973            loop {
974                tokio::time::sleep(check_interval).await;
975                let mut mode = ConnectionMode::from_atomic(&connection_mode);
976
977                if mode.is_disconnect() {
978                    log::debug!("Disconnecting");
979
980                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
981                    if tokio::time::timeout(timeout, async {
982                        // Delay awaiting graceful shutdown
983                        tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
984
985                        if !inner.read_task.is_finished() {
986                            inner.read_task.abort();
987                            log_task_aborted("read");
988                        }
989
990                        if let Some(task) = &inner.heartbeat_task
991                            && !task.is_finished()
992                        {
993                            task.abort();
994                            log_task_aborted("heartbeat");
995                        }
996                    })
997                    .await
998                    .is_err()
999                    {
1000                        log::error!("Shutdown timed out after {}s", timeout.as_secs());
1001                    }
1002
1003                    log::debug!("Closed");
1004
1005                    if let Some(ref handler) = post_disconnection {
1006                        handler();
1007                        log::debug!("Called `post_disconnection` handler");
1008                    }
1009                    break; // Controller finished
1010                }
1011
1012                if mode.is_closed() {
1013                    log::debug!("Connection closed");
1014                    break;
1015                }
1016
1017                if mode.is_active() && !inner.is_alive() {
1018                    if connection_mode
1019                        .compare_exchange(
1020                            ConnectionMode::Active.as_u8(),
1021                            ConnectionMode::Reconnect.as_u8(),
1022                            Ordering::SeqCst,
1023                            Ordering::SeqCst,
1024                        )
1025                        .is_ok()
1026                    {
1027                        log::debug!("Detected dead read task, transitioning to RECONNECT");
1028                    }
1029                    mode = ConnectionMode::from_atomic(&connection_mode);
1030                }
1031
1032                if mode.is_reconnect() {
1033                    // Check max reconnection attempts before attempting reconnect
1034                    if let Some(max_attempts) = inner.reconnect_max_attempts
1035                        && inner.reconnect_attempt_count >= max_attempts
1036                    {
1037                        log::error!(
1038                            "Max reconnection attempts ({max_attempts}) exceeded, transitioning to CLOSED"
1039                        );
1040                        connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1041                        break;
1042                    }
1043
1044                    inner.reconnect_attempt_count += 1;
1045                    match inner.reconnect().await {
1046                        Ok(()) => {
1047                            log::debug!("Reconnected successfully");
1048                            inner.backoff.reset();
1049                            inner.reconnect_attempt_count = 0; // Reset counter on success
1050                            // Only invoke reconnect handler if still active
1051                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
1052                                if let Some(ref handler) = post_reconnection {
1053                                    handler();
1054                                    log::debug!("Called `post_reconnection` handler");
1055                                }
1056                            } else {
1057                                log::debug!(
1058                                    "Skipping post_reconnection handlers due to disconnect state"
1059                                );
1060                            }
1061                        }
1062                        Err(e) => {
1063                            let duration = inner.backoff.next_duration();
1064                            log::warn!(
1065                                "Reconnect attempt {} failed: {e}",
1066                                inner.reconnect_attempt_count
1067                            );
1068                            if !duration.is_zero() {
1069                                log::warn!("Backing off for {}s...", duration.as_secs_f64());
1070                            }
1071                            tokio::time::sleep(duration).await;
1072                        }
1073                    }
1074                }
1075            }
1076            inner
1077                .connection_mode
1078                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1079
1080            log_task_stopped("controller");
1081        })
1082    }
1083}
1084
1085// Abort controller task on drop to clean up background tasks
1086impl Drop for SocketClient {
1087    fn drop(&mut self) {
1088        if !self.controller_task.is_finished() {
1089            self.controller_task.abort();
1090            log_task_aborted("controller");
1091        }
1092    }
1093}
1094
1095#[cfg(test)]
1096#[cfg(feature = "python")]
1097#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
1098mod tests {
1099    use nautilus_common::testing::wait_until_async;
1100    use pyo3::Python;
1101    use tokio::{
1102        io::{AsyncReadExt, AsyncWriteExt},
1103        net::{TcpListener, TcpStream},
1104        sync::Mutex,
1105        task,
1106        time::{Duration, sleep},
1107    };
1108
1109    use super::*;
1110
1111    async fn bind_test_server() -> (u16, TcpListener) {
1112        let listener = TcpListener::bind("127.0.0.1:0")
1113            .await
1114            .expect("Failed to bind ephemeral port");
1115        let port = listener.local_addr().unwrap().port();
1116        (port, listener)
1117    }
1118
1119    async fn run_echo_server(mut socket: TcpStream) {
1120        let mut buf = Vec::new();
1121        loop {
1122            match socket.read_buf(&mut buf).await {
1123                Ok(0) => {
1124                    break;
1125                }
1126                Ok(_n) => {
1127                    while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1128                        let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1129                        // Remove trailing \r\n
1130                        line.truncate(line.len() - 2);
1131
1132                        if line == b"close" {
1133                            let _ = socket.shutdown().await;
1134                            return;
1135                        }
1136
1137                        let mut echo_data = line;
1138                        echo_data.extend_from_slice(b"\r\n");
1139                        if socket.write_all(&echo_data).await.is_err() {
1140                            break;
1141                        }
1142                    }
1143                }
1144                Err(e) => {
1145                    eprintln!("Server read error: {e}");
1146                    break;
1147                }
1148            }
1149        }
1150    }
1151
1152    #[tokio::test]
1153    async fn test_basic_send_receive() {
1154        Python::initialize();
1155
1156        let (port, listener) = bind_test_server().await;
1157        let server_task = task::spawn(async move {
1158            let (socket, _) = listener.accept().await.unwrap();
1159            run_echo_server(socket).await;
1160        });
1161
1162        let config = SocketConfig {
1163            url: format!("127.0.0.1:{port}"),
1164            mode: Mode::Plain,
1165            suffix: b"\r\n".to_vec(),
1166            message_handler: None,
1167            heartbeat: None,
1168            reconnect_timeout_ms: None,
1169            reconnect_delay_initial_ms: None,
1170            reconnect_backoff_factor: None,
1171            reconnect_delay_max_ms: None,
1172            reconnect_jitter_ms: None,
1173            reconnect_max_attempts: None,
1174            connection_max_retries: None,
1175            certs_dir: None,
1176        };
1177
1178        let client = SocketClient::connect(config, None, None, None)
1179            .await
1180            .expect("Client connect failed unexpectedly");
1181
1182        client.send_bytes(b"Hello".into()).await.unwrap();
1183        client.send_bytes(b"World".into()).await.unwrap();
1184
1185        // Wait a bit for the server to echo them back
1186        sleep(Duration::from_millis(100)).await;
1187
1188        client.send_bytes(b"close".into()).await.unwrap();
1189        server_task.await.unwrap();
1190        assert!(!client.is_closed());
1191    }
1192
1193    #[tokio::test]
1194    async fn test_reconnect_fail_exhausted() {
1195        Python::initialize();
1196
1197        let (port, listener) = bind_test_server().await;
1198        drop(listener); // We drop it immediately -> no server is listening
1199
1200        // Wait until port is truly unavailable (OS has released it)
1201        wait_until_async(
1202            || async {
1203                TcpStream::connect(format!("127.0.0.1:{port}"))
1204                    .await
1205                    .is_err()
1206            },
1207            Duration::from_secs(2),
1208        )
1209        .await;
1210
1211        let config = SocketConfig {
1212            url: format!("127.0.0.1:{port}"),
1213            mode: Mode::Plain,
1214            suffix: b"\r\n".to_vec(),
1215            message_handler: None,
1216            heartbeat: None,
1217            reconnect_timeout_ms: Some(100),
1218            reconnect_delay_initial_ms: Some(50),
1219            reconnect_backoff_factor: Some(1.0),
1220            reconnect_delay_max_ms: Some(50),
1221            reconnect_jitter_ms: Some(0),
1222            connection_max_retries: Some(1),
1223            reconnect_max_attempts: None,
1224            certs_dir: None,
1225        };
1226
1227        let client_res = SocketClient::connect(config, None, None, None).await;
1228        assert!(
1229            client_res.is_err(),
1230            "Should fail quickly with no server listening"
1231        );
1232    }
1233
1234    #[tokio::test]
1235    async fn test_user_disconnect() {
1236        Python::initialize();
1237
1238        let (port, listener) = bind_test_server().await;
1239        let server_task = task::spawn(async move {
1240            let (socket, _) = listener.accept().await.unwrap();
1241            let mut buf = [0u8; 1024];
1242            let _ = socket.try_read(&mut buf);
1243
1244            loop {
1245                sleep(Duration::from_secs(1)).await;
1246            }
1247        });
1248
1249        let config = SocketConfig {
1250            url: format!("127.0.0.1:{port}"),
1251            mode: Mode::Plain,
1252            suffix: b"\r\n".to_vec(),
1253            message_handler: None,
1254            heartbeat: None,
1255            reconnect_timeout_ms: None,
1256            reconnect_delay_initial_ms: None,
1257            reconnect_backoff_factor: None,
1258            reconnect_delay_max_ms: None,
1259            reconnect_jitter_ms: None,
1260            reconnect_max_attempts: None,
1261            connection_max_retries: None,
1262            certs_dir: None,
1263        };
1264
1265        let client = SocketClient::connect(config, None, None, None)
1266            .await
1267            .unwrap();
1268
1269        client.close().await;
1270        assert!(client.is_closed());
1271        server_task.abort();
1272    }
1273
1274    #[tokio::test]
1275    async fn test_heartbeat() {
1276        Python::initialize();
1277
1278        let (port, listener) = bind_test_server().await;
1279        let received = Arc::new(Mutex::new(Vec::new()));
1280        let received2 = received.clone();
1281
1282        let server_task = task::spawn(async move {
1283            let (socket, _) = listener.accept().await.unwrap();
1284
1285            let mut buf = Vec::new();
1286            loop {
1287                match socket.try_read_buf(&mut buf) {
1288                    Ok(0) => break,
1289                    Ok(_) => {
1290                        while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1291                            let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1292                            line.truncate(line.len() - 2);
1293                            received2.lock().await.push(line);
1294                        }
1295                    }
1296                    Err(_) => {
1297                        tokio::time::sleep(Duration::from_millis(10)).await;
1298                    }
1299                }
1300            }
1301        });
1302
1303        // Heartbeat every 1 second
1304        let heartbeat = Some((1, b"ping".to_vec()));
1305
1306        let config = SocketConfig {
1307            url: format!("127.0.0.1:{port}"),
1308            mode: Mode::Plain,
1309            suffix: b"\r\n".to_vec(),
1310            message_handler: None,
1311            heartbeat,
1312            reconnect_timeout_ms: None,
1313            reconnect_delay_initial_ms: None,
1314            reconnect_backoff_factor: None,
1315            reconnect_delay_max_ms: None,
1316            reconnect_jitter_ms: None,
1317            reconnect_max_attempts: None,
1318            connection_max_retries: None,
1319            certs_dir: None,
1320        };
1321
1322        let client = SocketClient::connect(config, None, None, None)
1323            .await
1324            .unwrap();
1325
1326        // Wait ~3 seconds to collect some heartbeats
1327        sleep(Duration::from_secs(3)).await;
1328
1329        {
1330            let lock = received.lock().await;
1331            let pings = lock
1332                .iter()
1333                .filter(|line| line == &&b"ping".to_vec())
1334                .count();
1335            assert!(
1336                pings >= 2,
1337                "Expected at least 2 heartbeat pings; got {pings}"
1338            );
1339        }
1340
1341        client.close().await;
1342        server_task.abort();
1343    }
1344
1345    #[tokio::test]
1346    async fn test_reconnect_success() {
1347        Python::initialize();
1348
1349        let (port, listener) = bind_test_server().await;
1350
1351        // Spawn a server task that:
1352        // 1. Accepts the first connection and then drops it after a short delay (simulate disconnect)
1353        // 2. Waits a bit and then accepts a new connection and runs the echo server
1354        let server_task = task::spawn(async move {
1355            // Accept first connection
1356            let (mut socket, _) = listener.accept().await.expect("First accept failed");
1357
1358            // Wait briefly and then force-close the connection
1359            sleep(Duration::from_millis(500)).await;
1360            let _ = socket.shutdown().await;
1361
1362            // Wait for the client's reconnect attempt
1363            sleep(Duration::from_millis(500)).await;
1364
1365            // Run the echo server on the new connection
1366            let (socket, _) = listener.accept().await.expect("Second accept failed");
1367            run_echo_server(socket).await;
1368        });
1369
1370        let config = SocketConfig {
1371            url: format!("127.0.0.1:{port}"),
1372            mode: Mode::Plain,
1373            suffix: b"\r\n".to_vec(),
1374            message_handler: None,
1375            heartbeat: None,
1376            reconnect_timeout_ms: Some(5_000),
1377            reconnect_delay_initial_ms: Some(500),
1378            reconnect_delay_max_ms: Some(5_000),
1379            reconnect_backoff_factor: Some(2.0),
1380            reconnect_jitter_ms: Some(50),
1381            reconnect_max_attempts: None,
1382            connection_max_retries: None,
1383            certs_dir: None,
1384        };
1385
1386        let client = SocketClient::connect(config, None, None, None)
1387            .await
1388            .expect("Client connect failed unexpectedly");
1389
1390        // Initially, the client should be active
1391        assert!(client.is_active(), "Client should start as active");
1392
1393        // Wait until the client loses connection (i.e. not active),
1394        // then wait until it reconnects (active again).
1395        wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1396
1397        client
1398            .send_bytes(b"TestReconnect".into())
1399            .await
1400            .expect("Send failed");
1401
1402        client.close().await;
1403        server_task.abort();
1404    }
1405}
1406
1407#[cfg(test)]
1408#[cfg(not(feature = "turmoil"))]
1409mod rust_tests {
1410    use nautilus_common::testing::wait_until_async;
1411    use rstest::rstest;
1412    use tokio::{
1413        io::{AsyncReadExt, AsyncWriteExt},
1414        net::TcpListener,
1415        task,
1416        time::{Duration, sleep},
1417    };
1418
1419    use super::*;
1420
1421    #[rstest]
1422    #[tokio::test]
1423    async fn test_reconnect_then_close() {
1424        // Bind an ephemeral port
1425        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1426        let port = listener.local_addr().unwrap().port();
1427
1428        // Server task: accept one connection and then drop it
1429        let server = task::spawn(async move {
1430            if let Ok((mut sock, _)) = listener.accept().await {
1431                drop(sock.shutdown());
1432            }
1433            // Keep listener alive briefly to avoid premature exit
1434            sleep(Duration::from_secs(1)).await;
1435        });
1436
1437        // Configure client with a short reconnect backoff
1438        let config = SocketConfig {
1439            url: format!("127.0.0.1:{port}"),
1440            mode: Mode::Plain,
1441            suffix: b"\r\n".to_vec(),
1442            message_handler: None,
1443            heartbeat: None,
1444            reconnect_timeout_ms: Some(1_000),
1445            reconnect_delay_initial_ms: Some(50),
1446            reconnect_delay_max_ms: Some(100),
1447            reconnect_backoff_factor: Some(1.0),
1448            reconnect_jitter_ms: Some(0),
1449            connection_max_retries: Some(1),
1450            reconnect_max_attempts: None,
1451            certs_dir: None,
1452        };
1453
1454        // Connect client (handler=None)
1455        let client = SocketClient::connect(config.clone(), None, None, None)
1456            .await
1457            .unwrap();
1458
1459        // Wait for client to detect dropped connection and enter reconnect state
1460        wait_until_async(
1461            || async { client.is_reconnecting() },
1462            Duration::from_secs(2),
1463        )
1464        .await;
1465
1466        // Now close the client
1467        client.close().await;
1468        assert!(client.is_closed());
1469        server.abort();
1470    }
1471
1472    #[rstest]
1473    #[tokio::test]
1474    async fn test_reconnect_state_flips_when_reader_stops() {
1475        // Bind an ephemeral port and accept a single connection which we immediately close.
1476        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1477        let port = listener.local_addr().unwrap().port();
1478
1479        let server = task::spawn(async move {
1480            if let Ok((sock, _)) = listener.accept().await {
1481                drop(sock);
1482            }
1483            // Give the client a moment to observe the closed connection.
1484            sleep(Duration::from_millis(50)).await;
1485        });
1486
1487        let config = SocketConfig {
1488            url: format!("127.0.0.1:{port}"),
1489            mode: Mode::Plain,
1490            suffix: b"\r\n".to_vec(),
1491            message_handler: None,
1492            heartbeat: None,
1493            reconnect_timeout_ms: Some(1_000),
1494            reconnect_delay_initial_ms: Some(50),
1495            reconnect_delay_max_ms: Some(100),
1496            reconnect_backoff_factor: Some(1.0),
1497            reconnect_jitter_ms: Some(0),
1498            connection_max_retries: Some(1),
1499            reconnect_max_attempts: None,
1500            certs_dir: None,
1501        };
1502
1503        let client = SocketClient::connect(config, None, None, None)
1504            .await
1505            .unwrap();
1506
1507        wait_until_async(
1508            || async { client.is_reconnecting() },
1509            Duration::from_secs(2),
1510        )
1511        .await;
1512
1513        client.close().await;
1514        server.abort();
1515    }
1516
1517    #[rstest]
1518    fn test_parse_socket_url_raw_address() {
1519        // Raw socket address with TLS mode
1520        let (socket_addr, request_url) =
1521            SocketClientInner::parse_socket_url("example.com:6130", Mode::Tls).unwrap();
1522        assert_eq!(socket_addr, "example.com:6130");
1523        assert_eq!(request_url, "wss://example.com:6130");
1524
1525        // Raw socket address with Plain mode
1526        let (socket_addr, request_url) =
1527            SocketClientInner::parse_socket_url("localhost:8080", Mode::Plain).unwrap();
1528        assert_eq!(socket_addr, "localhost:8080");
1529        assert_eq!(request_url, "ws://localhost:8080");
1530    }
1531
1532    #[rstest]
1533    fn test_parse_socket_url_with_scheme() {
1534        // Full URL with wss scheme
1535        let (socket_addr, request_url) =
1536            SocketClientInner::parse_socket_url("wss://example.com:443/path", Mode::Tls).unwrap();
1537        assert_eq!(socket_addr, "example.com:443");
1538        assert_eq!(request_url, "wss://example.com:443/path");
1539
1540        // Full URL with ws scheme
1541        let (socket_addr, request_url) =
1542            SocketClientInner::parse_socket_url("ws://localhost:8080", Mode::Plain).unwrap();
1543        assert_eq!(socket_addr, "localhost:8080");
1544        assert_eq!(request_url, "ws://localhost:8080");
1545    }
1546
1547    #[rstest]
1548    fn test_parse_socket_url_default_ports() {
1549        // wss without explicit port defaults to 443
1550        let (socket_addr, _) =
1551            SocketClientInner::parse_socket_url("wss://example.com", Mode::Tls).unwrap();
1552        assert_eq!(socket_addr, "example.com:443");
1553
1554        // ws without explicit port defaults to 80
1555        let (socket_addr, _) =
1556            SocketClientInner::parse_socket_url("ws://example.com", Mode::Plain).unwrap();
1557        assert_eq!(socket_addr, "example.com:80");
1558
1559        // https defaults to 443
1560        let (socket_addr, _) =
1561            SocketClientInner::parse_socket_url("https://example.com", Mode::Tls).unwrap();
1562        assert_eq!(socket_addr, "example.com:443");
1563
1564        // http defaults to 80
1565        let (socket_addr, _) =
1566            SocketClientInner::parse_socket_url("http://example.com", Mode::Plain).unwrap();
1567        assert_eq!(socket_addr, "example.com:80");
1568    }
1569
1570    #[rstest]
1571    fn test_parse_socket_url_unknown_scheme_uses_mode() {
1572        // Unknown scheme defaults to mode-based port
1573        let (socket_addr, _) =
1574            SocketClientInner::parse_socket_url("custom://example.com", Mode::Tls).unwrap();
1575        assert_eq!(socket_addr, "example.com:443");
1576
1577        let (socket_addr, _) =
1578            SocketClientInner::parse_socket_url("custom://example.com", Mode::Plain).unwrap();
1579        assert_eq!(socket_addr, "example.com:80");
1580    }
1581
1582    #[rstest]
1583    fn test_parse_socket_url_ipv6() {
1584        // IPv6 address with port
1585        let (socket_addr, request_url) =
1586            SocketClientInner::parse_socket_url("[::1]:8080", Mode::Plain).unwrap();
1587        assert_eq!(socket_addr, "[::1]:8080");
1588        assert_eq!(request_url, "ws://[::1]:8080");
1589
1590        // IPv6 in URL
1591        let (socket_addr, _) =
1592            SocketClientInner::parse_socket_url("ws://[::1]:8080", Mode::Plain).unwrap();
1593        assert_eq!(socket_addr, "[::1]:8080");
1594    }
1595
1596    #[rstest]
1597    #[tokio::test]
1598    async fn test_url_parsing_raw_socket_address() {
1599        // Test that raw socket addresses (host:port) work correctly
1600        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1601        let port = listener.local_addr().unwrap().port();
1602
1603        let server = task::spawn(async move {
1604            if let Ok((sock, _)) = listener.accept().await {
1605                drop(sock);
1606            }
1607            sleep(Duration::from_millis(50)).await;
1608        });
1609
1610        let config = SocketConfig {
1611            url: format!("127.0.0.1:{port}"), // Raw socket address format
1612            mode: Mode::Plain,
1613            suffix: b"\r\n".to_vec(),
1614            message_handler: None,
1615            heartbeat: None,
1616            reconnect_timeout_ms: Some(1_000),
1617            reconnect_delay_initial_ms: Some(50),
1618            reconnect_delay_max_ms: Some(100),
1619            reconnect_backoff_factor: Some(1.0),
1620            reconnect_jitter_ms: Some(0),
1621            connection_max_retries: Some(1),
1622            reconnect_max_attempts: None,
1623            certs_dir: None,
1624        };
1625
1626        // Should successfully connect with raw socket address
1627        let client = SocketClient::connect(config, None, None, None).await;
1628        assert!(
1629            client.is_ok(),
1630            "Client should connect with raw socket address format"
1631        );
1632
1633        if let Ok(client) = client {
1634            client.close().await;
1635        }
1636        server.abort();
1637    }
1638
1639    #[rstest]
1640    #[tokio::test]
1641    async fn test_url_parsing_with_scheme() {
1642        // Test that URLs with schemes also work
1643        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1644        let port = listener.local_addr().unwrap().port();
1645
1646        let server = task::spawn(async move {
1647            if let Ok((sock, _)) = listener.accept().await {
1648                drop(sock);
1649            }
1650            sleep(Duration::from_millis(50)).await;
1651        });
1652
1653        let config = SocketConfig {
1654            url: format!("ws://127.0.0.1:{port}"), // URL with scheme
1655            mode: Mode::Plain,
1656            suffix: b"\r\n".to_vec(),
1657            message_handler: None,
1658            heartbeat: None,
1659            reconnect_timeout_ms: Some(1_000),
1660            reconnect_delay_initial_ms: Some(50),
1661            reconnect_delay_max_ms: Some(100),
1662            reconnect_backoff_factor: Some(1.0),
1663            reconnect_jitter_ms: Some(0),
1664            connection_max_retries: Some(1),
1665            reconnect_max_attempts: None,
1666            certs_dir: None,
1667        };
1668
1669        // Should successfully connect with URL format
1670        let client = SocketClient::connect(config, None, None, None).await;
1671        assert!(
1672            client.is_ok(),
1673            "Client should connect with URL scheme format"
1674        );
1675
1676        if let Ok(client) = client {
1677            client.close().await;
1678        }
1679        server.abort();
1680    }
1681
1682    #[rstest]
1683    fn test_parse_socket_url_ipv6_with_zone() {
1684        // IPv6 with zone ID (link-local address)
1685        let (socket_addr, request_url) =
1686            SocketClientInner::parse_socket_url("[fe80::1%eth0]:8080", Mode::Plain).unwrap();
1687        assert_eq!(socket_addr, "[fe80::1%eth0]:8080");
1688        assert_eq!(request_url, "ws://[fe80::1%eth0]:8080");
1689
1690        // Verify zone is preserved in URL format too
1691        let (socket_addr, request_url) =
1692            SocketClientInner::parse_socket_url("ws://[fe80::1%lo]:9090", Mode::Plain).unwrap();
1693        assert_eq!(socket_addr, "[fe80::1%lo]:9090");
1694        assert_eq!(request_url, "ws://[fe80::1%lo]:9090");
1695    }
1696
1697    #[rstest]
1698    #[tokio::test]
1699    async fn test_ipv6_loopback_connection() {
1700        // Test IPv6 loopback address connection
1701        // Skip if IPv6 is not available on the system
1702        if TcpListener::bind("[::1]:0").await.is_err() {
1703            eprintln!("IPv6 not available, skipping test");
1704            return;
1705        }
1706
1707        let listener = TcpListener::bind("[::1]:0").await.unwrap();
1708        let port = listener.local_addr().unwrap().port();
1709
1710        let server = task::spawn(async move {
1711            if let Ok((mut sock, _)) = listener.accept().await {
1712                let mut buf = vec![0u8; 1024];
1713                if let Ok(n) = sock.read(&mut buf).await {
1714                    // Echo back
1715                    let _ = sock.write_all(&buf[..n]).await;
1716                }
1717            }
1718            sleep(Duration::from_millis(50)).await;
1719        });
1720
1721        let config = SocketConfig {
1722            url: format!("[::1]:{port}"), // IPv6 loopback
1723            mode: Mode::Plain,
1724            suffix: b"\r\n".to_vec(),
1725            message_handler: None,
1726            heartbeat: None,
1727            reconnect_timeout_ms: Some(1_000),
1728            reconnect_delay_initial_ms: Some(50),
1729            reconnect_delay_max_ms: Some(100),
1730            reconnect_backoff_factor: Some(1.0),
1731            reconnect_jitter_ms: Some(0),
1732            connection_max_retries: Some(1),
1733            reconnect_max_attempts: None,
1734            certs_dir: None,
1735        };
1736
1737        let client = SocketClient::connect(config, None, None, None).await;
1738        assert!(
1739            client.is_ok(),
1740            "Client should connect to IPv6 loopback address"
1741        );
1742
1743        if let Ok(client) = client {
1744            client.close().await;
1745        }
1746        server.abort();
1747    }
1748
1749    #[rstest]
1750    #[tokio::test]
1751    async fn test_send_waits_during_reconnection() {
1752        // Test that send operations wait for reconnection to complete (up to configured timeout)
1753        use nautilus_common::testing::wait_until_async;
1754
1755        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1756        let port = listener.local_addr().unwrap().port();
1757
1758        let server = task::spawn(async move {
1759            // First connection - accept and immediately close
1760            if let Ok((sock, _)) = listener.accept().await {
1761                drop(sock);
1762            }
1763
1764            // Wait before accepting second connection
1765            sleep(Duration::from_millis(500)).await;
1766
1767            // Second connection - accept and keep alive
1768            if let Ok((mut sock, _)) = listener.accept().await {
1769                // Echo messages
1770                let mut buf = vec![0u8; 1024];
1771                while let Ok(n) = sock.read(&mut buf).await {
1772                    if n == 0 {
1773                        break;
1774                    }
1775                    if sock.write_all(&buf[..n]).await.is_err() {
1776                        break;
1777                    }
1778                }
1779            }
1780        });
1781
1782        let config = SocketConfig {
1783            url: format!("127.0.0.1:{port}"),
1784            mode: Mode::Plain,
1785            suffix: b"\r\n".to_vec(),
1786            message_handler: None,
1787            heartbeat: None,
1788            reconnect_timeout_ms: Some(5_000), // 5s timeout - enough for reconnect
1789            reconnect_delay_initial_ms: Some(100),
1790            reconnect_delay_max_ms: Some(200),
1791            reconnect_backoff_factor: Some(1.0),
1792            reconnect_jitter_ms: Some(0),
1793            connection_max_retries: Some(1),
1794            reconnect_max_attempts: None,
1795            certs_dir: None,
1796        };
1797
1798        let client = SocketClient::connect(config, None, None, None)
1799            .await
1800            .unwrap();
1801
1802        // Wait for reconnection to trigger
1803        wait_until_async(
1804            || async { client.is_reconnecting() },
1805            Duration::from_secs(2),
1806        )
1807        .await;
1808
1809        // Try to send while reconnecting - should wait and succeed after reconnect
1810        let send_result = tokio::time::timeout(
1811            Duration::from_secs(3),
1812            client.send_bytes(b"test_message".to_vec()),
1813        )
1814        .await;
1815
1816        assert!(
1817            send_result.is_ok() && send_result.unwrap().is_ok(),
1818            "Send should succeed after waiting for reconnection"
1819        );
1820
1821        client.close().await;
1822        server.abort();
1823    }
1824
1825    #[rstest]
1826    #[tokio::test]
1827    async fn test_send_bytes_timeout_uses_configured_reconnect_timeout() {
1828        // Test that send_bytes operations respect the configured reconnect_timeout.
1829        // When a client is stuck in RECONNECT longer than the timeout, sends should fail with Timeout.
1830        use nautilus_common::testing::wait_until_async;
1831
1832        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1833        let port = listener.local_addr().unwrap().port();
1834
1835        let server = task::spawn(async move {
1836            // Accept first connection and immediately close it
1837            if let Ok((sock, _)) = listener.accept().await {
1838                drop(sock);
1839            }
1840            // Drop listener entirely so reconnection fails completely
1841            drop(listener);
1842            sleep(Duration::from_secs(60)).await;
1843        });
1844
1845        let config = SocketConfig {
1846            url: format!("127.0.0.1:{port}"),
1847            mode: Mode::Plain,
1848            suffix: b"\r\n".to_vec(),
1849            message_handler: None,
1850            heartbeat: None,
1851            reconnect_timeout_ms: Some(1_000), // 1s timeout for faster test
1852            reconnect_delay_initial_ms: Some(200), // Short backoff (but > timeout) to keep client in RECONNECT
1853            reconnect_delay_max_ms: Some(200),
1854            reconnect_backoff_factor: Some(1.0),
1855            reconnect_jitter_ms: Some(0),
1856            connection_max_retries: Some(1),
1857            reconnect_max_attempts: None,
1858            certs_dir: None,
1859        };
1860
1861        let client = SocketClient::connect(config, None, None, None)
1862            .await
1863            .unwrap();
1864
1865        // Wait for client to enter RECONNECT state
1866        wait_until_async(
1867            || async { client.is_reconnecting() },
1868            Duration::from_secs(3),
1869        )
1870        .await;
1871
1872        // Attempt send while stuck in RECONNECT - should timeout after 1s (configured timeout)
1873        // The client will try to reconnect for 1s, fail, then wait 5s backoff before next attempt
1874        let start = std::time::Instant::now();
1875        let send_result = client.send_bytes(b"test".to_vec()).await;
1876        let elapsed = start.elapsed();
1877
1878        assert!(
1879            send_result.is_err(),
1880            "Send should fail when client stuck in RECONNECT, was: {send_result:?}"
1881        );
1882        assert!(
1883            matches!(send_result, Err(crate::error::SendError::Timeout)),
1884            "Send should return Timeout error, was: {send_result:?}"
1885        );
1886        // Verify timeout respects configured value (1s), but don't check upper bound
1887        // as CI scheduler jitter can cause legitimate delays beyond the timeout
1888        assert!(
1889            elapsed >= Duration::from_millis(900),
1890            "Send should timeout after at least 1s (configured timeout), took {elapsed:?}"
1891        );
1892
1893        client.close().await;
1894        server.abort();
1895    }
1896
1897    #[rstest]
1898    #[tokio::test]
1899    async fn test_empty_suffix_rejected() {
1900        let config = SocketConfig {
1901            url: "127.0.0.1:9999".to_string(),
1902            mode: Mode::Plain,
1903            suffix: vec![],
1904            message_handler: None,
1905            heartbeat: None,
1906            reconnect_timeout_ms: None,
1907            reconnect_delay_initial_ms: None,
1908            reconnect_delay_max_ms: None,
1909            reconnect_backoff_factor: None,
1910            reconnect_jitter_ms: None,
1911            reconnect_max_attempts: None,
1912            connection_max_retries: Some(1),
1913            certs_dir: None,
1914        };
1915
1916        let result = SocketClient::connect(config, None, None, None).await;
1917
1918        assert!(
1919            result.is_err(),
1920            "Empty suffix should cause connection to fail"
1921        );
1922        let err_msg = result.unwrap_err().to_string();
1923        assert!(
1924            err_msg.contains("suffix cannot be empty"),
1925            "Error should mention empty suffix, got: {err_msg}"
1926        );
1927    }
1928}