nautilus_network/socket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance raw TCP client implementation with TLS capability, automatic reconnection
17//! with exponential backoff and state management.
18//!
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED).
21//! - Synchronized reconnection with backoff.
22//! - Split read/write architecture.
23//! - Python callback integration.
24//!
25//! **Design**:
26//! - Single reader, multiple writer model.
27//! - Read half runs in dedicated task.
28//! - Write half runs in dedicated task connected with channel.
29//! - Controller task manages lifecycle.
30
31use std::{
32    collections::VecDeque,
33    fmt::Debug,
34    path::Path,
35    sync::{
36        Arc,
37        atomic::{AtomicU8, Ordering},
38    },
39    time::Duration,
40};
41
42use bytes::Bytes;
43use nautilus_core::CleanDrop;
44use nautilus_cryptography::providers::install_cryptographic_provider;
45use tokio::io::{AsyncReadExt, AsyncWriteExt};
46use tokio_tungstenite::tungstenite::{Error, client::IntoClientRequest, stream::Mode};
47
48use super::{
49    SocketConfig, TcpMessageHandler, TcpReader, TcpWriter, WriterCommand, fix::process_fix_buffer,
50};
51use crate::{
52    backoff::ExponentialBackoff,
53    error::SendError,
54    logging::{log_task_aborted, log_task_started, log_task_stopped},
55    mode::ConnectionMode,
56    net::TcpStream,
57    tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
58};
59
60// Connection timing constants
61const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
62const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
63const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
64const SEND_OPERATION_CHECK_INTERVAL_MS: u64 = 1;
65
66// Maximum buffer size for read operations (10 MB)
67const MAX_READ_BUFFER_BYTES: usize = 10 * 1024 * 1024;
68
69/// Creates a `TcpStream` with the server.
70///
71/// The stream can be encrypted with TLS or Plain. The stream is split into
72/// read and write ends:
73/// - The read end is passed to the task that keeps receiving
74///   messages from the server and passing them to a handler.
75/// - The write end is passed to a task which receives messages over a channel
76///   to send to the server.
77///
78/// The heartbeat is optional and can be configured with an interval and data to
79/// send.
80///
81/// The client uses a suffix to separate messages on the byte stream. It is
82/// appended to all sent messages and heartbeats. It is also used to split
83/// the received byte stream.
84#[cfg_attr(
85    feature = "python",
86    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
87)]
88struct SocketClientInner {
89    config: SocketConfig,
90    connector: Option<Connector>,
91    read_task: Arc<tokio::task::JoinHandle<()>>,
92    write_task: tokio::task::JoinHandle<()>,
93    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
94    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
95    connection_mode: Arc<AtomicU8>,
96    reconnect_timeout: Duration,
97    backoff: ExponentialBackoff,
98    handler: Option<TcpMessageHandler>,
99    reconnect_max_attempts: Option<u32>,
100    reconnect_attempt_count: u32,
101}
102
103impl SocketClientInner {
104    /// Connect to a URL with the specified configuration.
105    ///
106    /// # Errors
107    ///
108    /// Returns an error if connection fails or configuration is invalid.
109    pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
110        install_cryptographic_provider();
111
112        let SocketConfig {
113            url,
114            mode,
115            heartbeat,
116            suffix,
117            message_handler,
118            reconnect_timeout_ms,
119            reconnect_delay_initial_ms,
120            reconnect_delay_max_ms,
121            reconnect_backoff_factor,
122            reconnect_jitter_ms,
123            connection_max_retries,
124            reconnect_max_attempts,
125            certs_dir,
126        } = &config.clone();
127        let connector = if let Some(dir) = certs_dir {
128            let config = create_tls_config_from_certs_dir(Path::new(dir), false)?;
129            Some(Connector::Rustls(Arc::new(config)))
130        } else {
131            None
132        };
133
134        // Retry initial connection with exponential backoff to handle transient DNS/network issues
135        const CONNECTION_TIMEOUT_SECS: u64 = 10;
136        let max_retries = connection_max_retries.unwrap_or(5);
137
138        let mut backoff = ExponentialBackoff::new(
139            Duration::from_millis(500),
140            Duration::from_millis(5000),
141            2.0,
142            250,
143            false,
144        )?;
145
146        #[allow(unused_assignments)]
147        let mut last_error = String::new();
148        let mut attempt = 0;
149        let (reader, writer) = loop {
150            attempt += 1;
151
152            match tokio::time::timeout(
153                Duration::from_secs(CONNECTION_TIMEOUT_SECS),
154                Self::tls_connect_with_server(url, *mode, connector.clone()),
155            )
156            .await
157            {
158                Ok(Ok(result)) => {
159                    if attempt > 1 {
160                        tracing::info!("Socket connection established after {attempt} attempts");
161                    }
162                    break result;
163                }
164                Ok(Err(e)) => {
165                    last_error = e.to_string();
166                    tracing::warn!(
167                        attempt,
168                        max_retries,
169                        url = %url,
170                        error = %last_error,
171                        "Socket connection attempt failed"
172                    );
173                }
174                Err(_) => {
175                    last_error = format!(
176                        "Connection timeout after {CONNECTION_TIMEOUT_SECS}s (possible DNS resolution failure)"
177                    );
178                    tracing::warn!(
179                        attempt,
180                        max_retries,
181                        url = %url,
182                        "Socket connection attempt timed out"
183                    );
184                }
185            }
186
187            if attempt >= max_retries {
188                anyhow::bail!(
189                    "Failed to connect to {} after {} attempts: {}. \
190                    If this is a DNS error, check your network configuration and DNS settings.",
191                    url,
192                    max_retries,
193                    if last_error.is_empty() {
194                        "unknown error"
195                    } else {
196                        &last_error
197                    }
198                );
199            }
200
201            let delay = backoff.next_duration();
202            tracing::debug!(
203                "Retrying in {delay:?} (attempt {}/{})",
204                attempt + 1,
205                max_retries
206            );
207            tokio::time::sleep(delay).await;
208        };
209
210        tracing::debug!("Connected");
211
212        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
213
214        let read_task = Arc::new(Self::spawn_read_task(
215            connection_mode.clone(),
216            reader,
217            message_handler.clone(),
218            suffix.clone(),
219        ));
220
221        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
222
223        let write_task =
224            Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
225
226        // Optionally spawn a heartbeat task to periodically ping server
227        let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
228            Self::spawn_heartbeat_task(
229                connection_mode.clone(),
230                heartbeat.clone(),
231                writer_tx.clone(),
232            )
233        });
234
235        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
236        let backoff = ExponentialBackoff::new(
237            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
238            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
239            reconnect_backoff_factor.unwrap_or(1.5),
240            reconnect_jitter_ms.unwrap_or(100),
241            true, // immediate-first
242        )?;
243
244        Ok(Self {
245            config,
246            connector,
247            read_task,
248            write_task,
249            writer_tx,
250            heartbeat_task,
251            connection_mode,
252            reconnect_timeout,
253            backoff,
254            handler: message_handler.clone(),
255            reconnect_max_attempts: *reconnect_max_attempts,
256            reconnect_attempt_count: 0,
257        })
258    }
259
260    /// Parse URL and extract socket address and request URL.
261    ///
262    /// Accepts either:
263    /// - Raw socket address: "host:port" → returns ("host:port", "scheme://host:port")
264    /// - Full URL: "scheme://host:port/path" → returns ("host:port", original URL)
265    ///
266    /// # Errors
267    ///
268    /// Returns an error if the URL is invalid or missing required components.
269    fn parse_socket_url(url: &str, mode: Mode) -> Result<(String, String), Error> {
270        if url.contains("://") {
271            // URL with scheme (e.g., "wss://host:port/path")
272            let parsed = url.parse::<http::Uri>().map_err(|e| {
273                Error::Io(std::io::Error::new(
274                    std::io::ErrorKind::InvalidInput,
275                    format!("Invalid URL: {e}"),
276                ))
277            })?;
278
279            let host = parsed.host().ok_or_else(|| {
280                Error::Io(std::io::Error::new(
281                    std::io::ErrorKind::InvalidInput,
282                    "URL missing host",
283                ))
284            })?;
285
286            let port = parsed
287                .port_u16()
288                .unwrap_or_else(|| match parsed.scheme_str() {
289                    Some("wss" | "https") => 443,
290                    Some("ws" | "http") => 80,
291                    _ => match mode {
292                        Mode::Tls => 443,
293                        Mode::Plain => 80,
294                    },
295                });
296
297            Ok((format!("{host}:{port}"), url.to_string()))
298        } else {
299            // Raw socket address (e.g., "host:port")
300            // Construct a proper URL for the request based on mode
301            let scheme = match mode {
302                Mode::Tls => "wss",
303                Mode::Plain => "ws",
304            };
305            Ok((url.to_string(), format!("{scheme}://{url}")))
306        }
307    }
308
309    /// Establish a TLS or plain TCP connection with the server.
310    ///
311    /// Accepts either a raw socket address (e.g., "host:port") or a full URL with scheme
312    /// (e.g., "wss://host:port"). For FIX/raw socket connections, use the host:port format.
313    /// For WebSocket-style connections, include the scheme.
314    ///
315    /// # Errors
316    ///
317    /// Returns an error if the connection cannot be established.
318    pub async fn tls_connect_with_server(
319        url: &str,
320        mode: Mode,
321        connector: Option<Connector>,
322    ) -> Result<(TcpReader, TcpWriter), Error> {
323        tracing::debug!("Connecting to {url}");
324
325        let (socket_addr, request_url) = Self::parse_socket_url(url, mode)?;
326        let tcp_result = TcpStream::connect(&socket_addr).await;
327
328        match tcp_result {
329            Ok(stream) => {
330                tracing::debug!("TCP connection established to {socket_addr}, proceeding with TLS");
331                if let Err(e) = stream.set_nodelay(true) {
332                    tracing::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
333                }
334                let request = request_url.into_client_request()?;
335                tcp_tls(&request, mode, stream, connector)
336                    .await
337                    .map(tokio::io::split)
338            }
339            Err(e) => {
340                tracing::error!("TCP connection failed to {socket_addr}: {e:?}");
341                Err(Error::Io(e))
342            }
343        }
344    }
345
346    /// Reconnect with server.
347    ///
348    /// Makes a new connection with server, uses the new read and write halves
349    /// to update the reader and writer.
350    async fn reconnect(&mut self) -> Result<(), Error> {
351        tracing::debug!("Reconnecting");
352
353        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
354            tracing::debug!("Reconnect aborted due to disconnect state");
355            return Ok(());
356        }
357
358        tokio::time::timeout(self.reconnect_timeout, async {
359            let SocketConfig {
360                url,
361                mode,
362                heartbeat: _,
363                suffix,
364                message_handler: _,
365                reconnect_timeout_ms: _,
366                reconnect_delay_initial_ms: _,
367                reconnect_backoff_factor: _,
368                reconnect_delay_max_ms: _,
369                reconnect_jitter_ms: _,
370                connection_max_retries: _,
371                reconnect_max_attempts: _,
372                certs_dir: _,
373            } = &self.config;
374            // Create a fresh connection
375            let connector = self.connector.clone();
376            // Attempt to connect; abort early if a disconnect was requested
377            let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
378
379            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
380                tracing::debug!("Reconnect aborted mid-flight (after connect)");
381                return Ok(());
382            }
383            tracing::debug!("Connected");
384
385            // Use a oneshot channel to synchronize with the writer task.
386            // We must verify that the buffer was successfully drained before transitioning to ACTIVE
387            // to prevent silent message loss if the new connection drops immediately.
388            let (tx, rx) = tokio::sync::oneshot::channel();
389            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
390                tracing::error!("{e}");
391                return Err(Error::Io(std::io::Error::new(
392                    std::io::ErrorKind::BrokenPipe,
393                    format!("Failed to send update command: {e}"),
394                )));
395            }
396
397            // Wait for writer to confirm it has drained the buffer
398            match rx.await {
399                Ok(true) => tracing::debug!("Writer confirmed buffer drain success"),
400                Ok(false) => {
401                    tracing::warn!("Writer failed to drain buffer, aborting reconnect");
402                    // Return error to trigger retry logic in controller
403                    return Err(Error::Io(std::io::Error::other(
404                        "Failed to drain reconnection buffer",
405                    )));
406                }
407                Err(e) => {
408                    tracing::error!("Writer dropped update channel: {e}");
409                    return Err(Error::Io(std::io::Error::new(
410                        std::io::ErrorKind::BrokenPipe,
411                        "Writer task dropped response channel",
412                    )));
413                }
414            }
415
416            // Delay before closing connection
417            tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
418
419            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
420                tracing::debug!("Reconnect aborted mid-flight (after delay)");
421                return Ok(());
422            }
423
424            if !self.read_task.is_finished() {
425                self.read_task.abort();
426                log_task_aborted("read");
427            }
428
429            // Atomically transition from Reconnect to Active
430            // This prevents race condition where disconnect could be requested between check and store
431            if self
432                .connection_mode
433                .compare_exchange(
434                    ConnectionMode::Reconnect.as_u8(),
435                    ConnectionMode::Active.as_u8(),
436                    Ordering::SeqCst,
437                    Ordering::SeqCst,
438                )
439                .is_err()
440            {
441                tracing::debug!("Reconnect aborted (state changed during reconnect)");
442                return Ok(());
443            }
444
445            // Spawn new read task
446            self.read_task = Arc::new(Self::spawn_read_task(
447                self.connection_mode.clone(),
448                reader,
449                self.handler.clone(),
450                suffix.clone(),
451            ));
452
453            tracing::debug!("Reconnect succeeded");
454            Ok(())
455        })
456        .await
457        .map_err(|_| {
458            Error::Io(std::io::Error::new(
459                std::io::ErrorKind::TimedOut,
460                format!(
461                    "reconnection timed out after {}s",
462                    self.reconnect_timeout.as_secs_f64()
463                ),
464            ))
465        })?
466    }
467
468    /// Check if the client is still alive.
469    ///
470    /// The client is connected if the read task has not finished. It is expected
471    /// that in case of any failure client or server side. The read task will be
472    /// shutdown. There might be some delay between the connection being closed
473    /// and the client detecting it.
474    #[inline]
475    #[must_use]
476    pub fn is_alive(&self) -> bool {
477        !self.read_task.is_finished()
478    }
479
480    #[must_use]
481    fn spawn_read_task(
482        connection_state: Arc<AtomicU8>,
483        mut reader: TcpReader,
484        handler: Option<TcpMessageHandler>,
485        suffix: Vec<u8>,
486    ) -> tokio::task::JoinHandle<()> {
487        log_task_started("read");
488
489        // Interval between checking the connection mode
490        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
491
492        tokio::task::spawn(async move {
493            let mut buf = Vec::new();
494
495            loop {
496                if !ConnectionMode::from_atomic(&connection_state).is_active() {
497                    break;
498                }
499
500                match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
501                    // Connection has been terminated or vector buffer is complete
502                    Ok(Ok(0)) => {
503                        tracing::debug!("Connection closed by server");
504                        break;
505                    }
506                    Ok(Err(e)) => {
507                        tracing::debug!("Connection ended: {e}");
508                        break;
509                    }
510                    // Received bytes of data
511                    Ok(Ok(bytes)) => {
512                        tracing::trace!("Received <binary> {bytes} bytes");
513
514                        // Check if buffer contains FIX protocol messages (starts with "8=FIX")
515                        let is_fix = buf.len() >= 5 && buf.starts_with(b"8=FIX");
516
517                        if is_fix && handler.is_some() {
518                            // FIX protocol processing
519                            if let Some(ref handler) = handler {
520                                process_fix_buffer(&mut buf, handler);
521                            }
522                        } else {
523                            // Regular suffix-based message processing
524                            while let Some((i, _)) = &buf
525                                .windows(suffix.len())
526                                .enumerate()
527                                .find(|(_, pair)| pair.eq(&suffix))
528                            {
529                                let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
530                                data.truncate(data.len() - suffix.len());
531
532                                if let Some(ref handler) = handler {
533                                    handler(&data);
534                                }
535                            }
536                        }
537
538                        if buf.len() > MAX_READ_BUFFER_BYTES {
539                            tracing::error!(
540                                "Read buffer exceeded maximum size ({} bytes), closing connection",
541                                MAX_READ_BUFFER_BYTES
542                            );
543                            break;
544                        }
545                    }
546                    Err(_) => {
547                        // Timeout - continue loop and check connection mode
548                        continue;
549                    }
550                }
551            }
552
553            log_task_stopped("read");
554        })
555    }
556
557    /// Drains buffered messages after reconnection completes.
558    ///
559    /// Attempts to send all buffered messages that were queued during reconnection.
560    /// Uses a peek-and-pop pattern to preserve messages if sending fails midway through the buffer.
561    ///
562    /// # Returns
563    ///
564    /// Returns `true` if a send error occurred (buffer may still contain unsent messages),
565    /// `false` if all messages were sent successfully (buffer is empty).
566    async fn drain_reconnect_buffer(
567        buffer: &mut VecDeque<Bytes>,
568        writer: &mut TcpWriter,
569        suffix: &[u8],
570    ) -> bool {
571        if buffer.is_empty() {
572            return false;
573        }
574
575        let initial_buffer_len = buffer.len();
576        tracing::info!(
577            "Sending {} buffered messages after reconnection",
578            initial_buffer_len
579        );
580
581        let mut send_error_occurred = false;
582
583        while let Some(buffered_msg) = buffer.front() {
584            let mut combined_msg = Vec::with_capacity(buffered_msg.len() + suffix.len());
585            combined_msg.extend_from_slice(buffered_msg);
586            combined_msg.extend_from_slice(suffix);
587
588            if let Err(e) = writer.write_all(&combined_msg).await {
589                tracing::error!(
590                    "Failed to send buffered message with suffix after reconnection: {e}, {} messages remain in buffer",
591                    buffer.len()
592                );
593                send_error_occurred = true;
594                break;
595            }
596
597            buffer.pop_front();
598        }
599
600        if buffer.is_empty() {
601            tracing::info!(
602                "Successfully sent all {} buffered messages",
603                initial_buffer_len
604            );
605        }
606
607        send_error_occurred
608    }
609
610    fn spawn_write_task(
611        connection_state: Arc<AtomicU8>,
612        writer: TcpWriter,
613        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
614        suffix: Vec<u8>,
615    ) -> tokio::task::JoinHandle<()> {
616        log_task_started("write");
617
618        // Interval between checking the connection mode
619        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
620
621        tokio::task::spawn(async move {
622            let mut active_writer = writer;
623            let mut reconnect_buffer: VecDeque<Bytes> = VecDeque::new();
624
625            loop {
626                if matches!(
627                    ConnectionMode::from_atomic(&connection_state),
628                    ConnectionMode::Disconnect | ConnectionMode::Closed
629                ) {
630                    break;
631                }
632
633                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
634                    Ok(Some(msg)) => {
635                        // Re-check connection mode after receiving a message
636                        let mode = ConnectionMode::from_atomic(&connection_state);
637                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
638                            break;
639                        }
640
641                        match msg {
642                            WriterCommand::Update(new_writer, tx) => {
643                                tracing::debug!("Received new writer");
644
645                                // Delay before closing connection
646                                tokio::time::sleep(Duration::from_millis(100)).await;
647
648                                // Attempt to shutdown the writer gracefully before updating,
649                                // we ignore any error as the writer may already be closed.
650                                _ = tokio::time::timeout(
651                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
652                                    active_writer.shutdown(),
653                                )
654                                .await;
655
656                                active_writer = new_writer;
657                                tracing::debug!("Updated writer");
658
659                                let send_error = Self::drain_reconnect_buffer(
660                                    &mut reconnect_buffer,
661                                    &mut active_writer,
662                                    &suffix,
663                                )
664                                .await;
665
666                                if let Err(e) = tx.send(!send_error) {
667                                    tracing::error!(
668                                        "Failed to report drain status to controller: {e:?}"
669                                    );
670                                }
671                            }
672                            _ if mode.is_reconnect() => {
673                                if let WriterCommand::Send(data) = msg {
674                                    tracing::debug!(
675                                        "Buffering message while reconnecting ({} bytes)",
676                                        data.len()
677                                    );
678                                    reconnect_buffer.push_back(data);
679                                }
680                                continue;
681                            }
682                            WriterCommand::Send(msg) => {
683                                if let Err(e) = active_writer.write_all(&msg).await {
684                                    tracing::error!("Failed to send message: {e}");
685                                    tracing::warn!("Writer triggering reconnect");
686                                    reconnect_buffer.push_back(msg);
687                                    connection_state
688                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
689                                    continue;
690                                }
691                                if let Err(e) = active_writer.write_all(&suffix).await {
692                                    tracing::error!("Failed to send suffix: {e}");
693                                    tracing::warn!("Writer triggering reconnect");
694                                    // Buffer this message before triggering reconnect since suffix failed
695                                    reconnect_buffer.push_back(msg);
696                                    connection_state
697                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
698                                    continue;
699                                }
700                            }
701                        }
702                    }
703                    Ok(None) => {
704                        // Channel closed - writer task should terminate
705                        tracing::debug!("Writer channel closed, terminating writer task");
706                        break;
707                    }
708                    Err(_) => {
709                        // Timeout - just continue the loop
710                        continue;
711                    }
712                }
713            }
714
715            // Attempt to shutdown the writer gracefully before exiting,
716            // we ignore any error as the writer may already be closed.
717            _ = tokio::time::timeout(
718                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
719                active_writer.shutdown(),
720            )
721            .await;
722
723            log_task_stopped("write");
724        })
725    }
726
727    fn spawn_heartbeat_task(
728        connection_state: Arc<AtomicU8>,
729        heartbeat: (u64, Vec<u8>),
730        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
731    ) -> tokio::task::JoinHandle<()> {
732        log_task_started("heartbeat");
733        let (interval_secs, message) = heartbeat;
734
735        tokio::task::spawn(async move {
736            let interval = Duration::from_secs(interval_secs);
737
738            loop {
739                tokio::time::sleep(interval).await;
740
741                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
742                    ConnectionMode::Active => {
743                        let msg = WriterCommand::Send(message.clone().into());
744
745                        match writer_tx.send(msg) {
746                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
747                            Err(e) => {
748                                tracing::error!("Failed to send heartbeat to writer task: {e}");
749                            }
750                        }
751                    }
752                    ConnectionMode::Reconnect => continue,
753                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
754                }
755            }
756
757            log_task_stopped("heartbeat");
758        })
759    }
760}
761
762impl Drop for SocketClientInner {
763    fn drop(&mut self) {
764        // Delegate to explicit cleanup handler
765        self.clean_drop();
766    }
767}
768
769/// Cleanup on drop: aborts background tasks and clears handlers to break reference cycles.
770impl CleanDrop for SocketClientInner {
771    fn clean_drop(&mut self) {
772        if !self.read_task.is_finished() {
773            self.read_task.abort();
774            log_task_aborted("read");
775        }
776
777        if !self.write_task.is_finished() {
778            self.write_task.abort();
779            log_task_aborted("write");
780        }
781
782        if let Some(ref handle) = self.heartbeat_task.take()
783            && !handle.is_finished()
784        {
785            handle.abort();
786            log_task_aborted("heartbeat");
787        }
788
789        #[cfg(feature = "python")]
790        {
791            // Remove stored handler to break ref cycle
792            self.config.message_handler = None;
793        }
794    }
795}
796
797#[cfg_attr(
798    feature = "python",
799    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
800)]
801pub struct SocketClient {
802    pub(crate) controller_task: tokio::task::JoinHandle<()>,
803    pub(crate) connection_mode: Arc<AtomicU8>,
804    pub(crate) reconnect_timeout: Duration,
805    pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
806}
807
808impl Debug for SocketClient {
809    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
810        f.debug_struct(stringify!(SocketClient)).finish()
811    }
812}
813
814impl SocketClient {
815    /// Connect to the server.
816    ///
817    /// # Errors
818    ///
819    /// Returns any error connecting to the server.
820    pub async fn connect(
821        config: SocketConfig,
822        post_connection: Option<Arc<dyn Fn() + Send + Sync>>,
823        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
824        post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
825    ) -> anyhow::Result<Self> {
826        let inner = SocketClientInner::connect_url(config).await?;
827        let writer_tx = inner.writer_tx.clone();
828        let connection_mode = inner.connection_mode.clone();
829        let reconnect_timeout = inner.reconnect_timeout;
830
831        let controller_task = Self::spawn_controller_task(
832            inner,
833            connection_mode.clone(),
834            post_reconnection,
835            post_disconnection,
836        );
837
838        if let Some(handler) = post_connection {
839            handler();
840            tracing::debug!("Called `post_connection` handler");
841        }
842
843        Ok(Self {
844            controller_task,
845            connection_mode,
846            reconnect_timeout,
847            writer_tx,
848        })
849    }
850
851    /// Returns the current connection mode.
852    #[must_use]
853    pub fn connection_mode(&self) -> ConnectionMode {
854        ConnectionMode::from_atomic(&self.connection_mode)
855    }
856
857    /// Check if the client connection is active.
858    ///
859    /// Returns `true` if the client is connected and has not been signalled to disconnect.
860    /// The client will automatically retry connection based on its configuration.
861    #[inline]
862    #[must_use]
863    pub fn is_active(&self) -> bool {
864        self.connection_mode().is_active()
865    }
866
867    /// Check if the client is reconnecting.
868    ///
869    /// Returns `true` if the client lost connection and is attempting to reestablish it.
870    /// The client will automatically retry connection based on its configuration.
871    #[inline]
872    #[must_use]
873    pub fn is_reconnecting(&self) -> bool {
874        self.connection_mode().is_reconnect()
875    }
876
877    /// Check if the client is disconnecting.
878    ///
879    /// Returns `true` if the client is in disconnect mode.
880    #[inline]
881    #[must_use]
882    pub fn is_disconnecting(&self) -> bool {
883        self.connection_mode().is_disconnect()
884    }
885
886    /// Check if the client is closed.
887    ///
888    /// Returns `true` if the client has been explicitly disconnected or reached
889    /// maximum reconnection attempts. In this state, the client cannot be reused
890    /// and a new client must be created for further connections.
891    #[inline]
892    #[must_use]
893    pub fn is_closed(&self) -> bool {
894        self.connection_mode().is_closed()
895    }
896
897    /// Close the client.
898    ///
899    /// Controller task will periodically check the disconnect mode
900    /// and shutdown the client if it is not alive.
901    pub async fn close(&self) {
902        self.connection_mode
903            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
904
905        if let Ok(()) =
906            tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
907                while !self.is_closed() {
908                    tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS))
909                        .await;
910                }
911
912                if !self.controller_task.is_finished() {
913                    self.controller_task.abort();
914                    log_task_aborted("controller");
915                }
916            })
917            .await
918        {
919            log_task_stopped("controller");
920        } else {
921            tracing::error!("Timeout waiting for controller task to finish");
922            if !self.controller_task.is_finished() {
923                self.controller_task.abort();
924                log_task_aborted("controller");
925            }
926        }
927    }
928
929    /// Sends a message of the given `data`.
930    ///
931    /// # Errors
932    ///
933    /// Returns an error if sending fails.
934    pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
935        // Check connection state to fail fast
936        if self.is_closed() || self.is_disconnecting() {
937            return Err(SendError::Closed);
938        }
939
940        let timeout = self.reconnect_timeout;
941        let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
942
943        if !self.is_active() {
944            tracing::debug!("Waiting for client to become ACTIVE before sending...");
945
946            let inner = tokio::time::timeout(timeout, async {
947                loop {
948                    if self.is_active() {
949                        return Ok(());
950                    }
951                    if matches!(
952                        self.connection_mode(),
953                        ConnectionMode::Disconnect | ConnectionMode::Closed
954                    ) {
955                        return Err(());
956                    }
957                    tokio::time::sleep(check_interval).await;
958                }
959            })
960            .await
961            .map_err(|_| SendError::Timeout)?;
962            inner.map_err(|()| SendError::Closed)?;
963        }
964
965        let msg = WriterCommand::Send(data.into());
966        self.writer_tx
967            .send(msg)
968            .map_err(|e| SendError::BrokenPipe(e.to_string()))
969    }
970
971    fn spawn_controller_task(
972        mut inner: SocketClientInner,
973        connection_mode: Arc<AtomicU8>,
974        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
975        post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
976    ) -> tokio::task::JoinHandle<()> {
977        tokio::task::spawn(async move {
978            log_task_started("controller");
979
980            let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
981
982            loop {
983                tokio::time::sleep(check_interval).await;
984                let mut mode = ConnectionMode::from_atomic(&connection_mode);
985
986                if mode.is_disconnect() {
987                    tracing::debug!("Disconnecting");
988
989                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
990                    if tokio::time::timeout(timeout, async {
991                        // Delay awaiting graceful shutdown
992                        tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
993
994                        if !inner.read_task.is_finished() {
995                            inner.read_task.abort();
996                            log_task_aborted("read");
997                        }
998
999                        if let Some(task) = &inner.heartbeat_task
1000                            && !task.is_finished()
1001                        {
1002                            task.abort();
1003                            log_task_aborted("heartbeat");
1004                        }
1005                    })
1006                    .await
1007                    .is_err()
1008                    {
1009                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
1010                    }
1011
1012                    tracing::debug!("Closed");
1013
1014                    if let Some(ref handler) = post_disconnection {
1015                        handler();
1016                        tracing::debug!("Called `post_disconnection` handler");
1017                    }
1018                    break; // Controller finished
1019                }
1020
1021                if mode.is_closed() {
1022                    tracing::debug!("Connection closed");
1023                    break;
1024                }
1025
1026                if mode.is_active() && !inner.is_alive() {
1027                    if connection_mode
1028                        .compare_exchange(
1029                            ConnectionMode::Active.as_u8(),
1030                            ConnectionMode::Reconnect.as_u8(),
1031                            Ordering::SeqCst,
1032                            Ordering::SeqCst,
1033                        )
1034                        .is_ok()
1035                    {
1036                        tracing::debug!("Detected dead read task, transitioning to RECONNECT");
1037                    }
1038                    mode = ConnectionMode::from_atomic(&connection_mode);
1039                }
1040
1041                if mode.is_reconnect() {
1042                    // Check max reconnection attempts before attempting reconnect
1043                    if let Some(max_attempts) = inner.reconnect_max_attempts
1044                        && inner.reconnect_attempt_count >= max_attempts
1045                    {
1046                        tracing::error!(
1047                            "Max reconnection attempts ({}) exceeded, transitioning to CLOSED",
1048                            max_attempts
1049                        );
1050                        connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1051                        break;
1052                    }
1053
1054                    inner.reconnect_attempt_count += 1;
1055                    match inner.reconnect().await {
1056                        Ok(()) => {
1057                            tracing::debug!("Reconnected successfully");
1058                            inner.backoff.reset();
1059                            inner.reconnect_attempt_count = 0; // Reset counter on success
1060                            // Only invoke reconnect handler if still active
1061                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
1062                                if let Some(ref handler) = post_reconnection {
1063                                    handler();
1064                                    tracing::debug!("Called `post_reconnection` handler");
1065                                }
1066                            } else {
1067                                tracing::debug!(
1068                                    "Skipping post_reconnection handlers due to disconnect state"
1069                                );
1070                            }
1071                        }
1072                        Err(e) => {
1073                            let duration = inner.backoff.next_duration();
1074                            tracing::warn!(
1075                                "Reconnect attempt {} failed: {e}",
1076                                inner.reconnect_attempt_count
1077                            );
1078                            if !duration.is_zero() {
1079                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
1080                            }
1081                            tokio::time::sleep(duration).await;
1082                        }
1083                    }
1084                }
1085            }
1086            inner
1087                .connection_mode
1088                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1089
1090            log_task_stopped("controller");
1091        })
1092    }
1093}
1094
1095// Abort controller task on drop to clean up background tasks
1096impl Drop for SocketClient {
1097    fn drop(&mut self) {
1098        if !self.controller_task.is_finished() {
1099            self.controller_task.abort();
1100            log_task_aborted("controller");
1101        }
1102    }
1103}
1104
1105#[cfg(test)]
1106#[cfg(feature = "python")]
1107#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
1108mod tests {
1109    use nautilus_common::testing::wait_until_async;
1110    use pyo3::Python;
1111    use tokio::{
1112        io::{AsyncReadExt, AsyncWriteExt},
1113        net::{TcpListener, TcpStream},
1114        sync::Mutex,
1115        task,
1116        time::{Duration, sleep},
1117    };
1118
1119    use super::*;
1120
1121    async fn bind_test_server() -> (u16, TcpListener) {
1122        let listener = TcpListener::bind("127.0.0.1:0")
1123            .await
1124            .expect("Failed to bind ephemeral port");
1125        let port = listener.local_addr().unwrap().port();
1126        (port, listener)
1127    }
1128
1129    async fn run_echo_server(mut socket: TcpStream) {
1130        let mut buf = Vec::new();
1131        loop {
1132            match socket.read_buf(&mut buf).await {
1133                Ok(0) => {
1134                    break;
1135                }
1136                Ok(_n) => {
1137                    while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1138                        let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1139                        // Remove trailing \r\n
1140                        line.truncate(line.len() - 2);
1141
1142                        if line == b"close" {
1143                            let _ = socket.shutdown().await;
1144                            return;
1145                        }
1146
1147                        let mut echo_data = line;
1148                        echo_data.extend_from_slice(b"\r\n");
1149                        if socket.write_all(&echo_data).await.is_err() {
1150                            break;
1151                        }
1152                    }
1153                }
1154                Err(e) => {
1155                    eprintln!("Server read error: {e}");
1156                    break;
1157                }
1158            }
1159        }
1160    }
1161
1162    #[tokio::test]
1163    async fn test_basic_send_receive() {
1164        Python::initialize();
1165
1166        let (port, listener) = bind_test_server().await;
1167        let server_task = task::spawn(async move {
1168            let (socket, _) = listener.accept().await.unwrap();
1169            run_echo_server(socket).await;
1170        });
1171
1172        let config = SocketConfig {
1173            url: format!("127.0.0.1:{port}"),
1174            mode: Mode::Plain,
1175            suffix: b"\r\n".to_vec(),
1176            message_handler: None,
1177            heartbeat: None,
1178            reconnect_timeout_ms: None,
1179            reconnect_delay_initial_ms: None,
1180            reconnect_backoff_factor: None,
1181            reconnect_delay_max_ms: None,
1182            reconnect_jitter_ms: None,
1183            reconnect_max_attempts: None,
1184            connection_max_retries: None,
1185            certs_dir: None,
1186        };
1187
1188        let client = SocketClient::connect(config, None, None, None)
1189            .await
1190            .expect("Client connect failed unexpectedly");
1191
1192        client.send_bytes(b"Hello".into()).await.unwrap();
1193        client.send_bytes(b"World".into()).await.unwrap();
1194
1195        // Wait a bit for the server to echo them back
1196        sleep(Duration::from_millis(100)).await;
1197
1198        client.send_bytes(b"close".into()).await.unwrap();
1199        server_task.await.unwrap();
1200        assert!(!client.is_closed());
1201    }
1202
1203    #[tokio::test]
1204    async fn test_reconnect_fail_exhausted() {
1205        Python::initialize();
1206
1207        let (port, listener) = bind_test_server().await;
1208        drop(listener); // We drop it immediately -> no server is listening
1209
1210        // Wait until port is truly unavailable (OS has released it)
1211        wait_until_async(
1212            || async {
1213                TcpStream::connect(format!("127.0.0.1:{port}"))
1214                    .await
1215                    .is_err()
1216            },
1217            Duration::from_secs(2),
1218        )
1219        .await;
1220
1221        let config = SocketConfig {
1222            url: format!("127.0.0.1:{port}"),
1223            mode: Mode::Plain,
1224            suffix: b"\r\n".to_vec(),
1225            message_handler: None,
1226            heartbeat: None,
1227            reconnect_timeout_ms: Some(100),
1228            reconnect_delay_initial_ms: Some(50),
1229            reconnect_backoff_factor: Some(1.0),
1230            reconnect_delay_max_ms: Some(50),
1231            reconnect_jitter_ms: Some(0),
1232            connection_max_retries: Some(1),
1233            reconnect_max_attempts: None,
1234            certs_dir: None,
1235        };
1236
1237        let client_res = SocketClient::connect(config, None, None, None).await;
1238        assert!(
1239            client_res.is_err(),
1240            "Should fail quickly with no server listening"
1241        );
1242    }
1243
1244    #[tokio::test]
1245    async fn test_user_disconnect() {
1246        Python::initialize();
1247
1248        let (port, listener) = bind_test_server().await;
1249        let server_task = task::spawn(async move {
1250            let (socket, _) = listener.accept().await.unwrap();
1251            let mut buf = [0u8; 1024];
1252            let _ = socket.try_read(&mut buf);
1253
1254            loop {
1255                sleep(Duration::from_secs(1)).await;
1256            }
1257        });
1258
1259        let config = SocketConfig {
1260            url: format!("127.0.0.1:{port}"),
1261            mode: Mode::Plain,
1262            suffix: b"\r\n".to_vec(),
1263            message_handler: None,
1264            heartbeat: None,
1265            reconnect_timeout_ms: None,
1266            reconnect_delay_initial_ms: None,
1267            reconnect_backoff_factor: None,
1268            reconnect_delay_max_ms: None,
1269            reconnect_jitter_ms: None,
1270            reconnect_max_attempts: None,
1271            connection_max_retries: None,
1272            certs_dir: None,
1273        };
1274
1275        let client = SocketClient::connect(config, None, None, None)
1276            .await
1277            .unwrap();
1278
1279        client.close().await;
1280        assert!(client.is_closed());
1281        server_task.abort();
1282    }
1283
1284    #[tokio::test]
1285    async fn test_heartbeat() {
1286        Python::initialize();
1287
1288        let (port, listener) = bind_test_server().await;
1289        let received = Arc::new(Mutex::new(Vec::new()));
1290        let received2 = received.clone();
1291
1292        let server_task = task::spawn(async move {
1293            let (socket, _) = listener.accept().await.unwrap();
1294
1295            let mut buf = Vec::new();
1296            loop {
1297                match socket.try_read_buf(&mut buf) {
1298                    Ok(0) => break,
1299                    Ok(_) => {
1300                        while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1301                            let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1302                            line.truncate(line.len() - 2);
1303                            received2.lock().await.push(line);
1304                        }
1305                    }
1306                    Err(_) => {
1307                        tokio::time::sleep(Duration::from_millis(10)).await;
1308                    }
1309                }
1310            }
1311        });
1312
1313        // Heartbeat every 1 second
1314        let heartbeat = Some((1, b"ping".to_vec()));
1315
1316        let config = SocketConfig {
1317            url: format!("127.0.0.1:{port}"),
1318            mode: Mode::Plain,
1319            suffix: b"\r\n".to_vec(),
1320            message_handler: None,
1321            heartbeat,
1322            reconnect_timeout_ms: None,
1323            reconnect_delay_initial_ms: None,
1324            reconnect_backoff_factor: None,
1325            reconnect_delay_max_ms: None,
1326            reconnect_jitter_ms: None,
1327            reconnect_max_attempts: None,
1328            connection_max_retries: None,
1329            certs_dir: None,
1330        };
1331
1332        let client = SocketClient::connect(config, None, None, None)
1333            .await
1334            .unwrap();
1335
1336        // Wait ~3 seconds to collect some heartbeats
1337        sleep(Duration::from_secs(3)).await;
1338
1339        {
1340            let lock = received.lock().await;
1341            let pings = lock
1342                .iter()
1343                .filter(|line| line == &&b"ping".to_vec())
1344                .count();
1345            assert!(
1346                pings >= 2,
1347                "Expected at least 2 heartbeat pings; got {pings}"
1348            );
1349        }
1350
1351        client.close().await;
1352        server_task.abort();
1353    }
1354
1355    #[tokio::test]
1356    async fn test_reconnect_success() {
1357        Python::initialize();
1358
1359        let (port, listener) = bind_test_server().await;
1360
1361        // Spawn a server task that:
1362        // 1. Accepts the first connection and then drops it after a short delay (simulate disconnect)
1363        // 2. Waits a bit and then accepts a new connection and runs the echo server
1364        let server_task = task::spawn(async move {
1365            // Accept first connection
1366            let (mut socket, _) = listener.accept().await.expect("First accept failed");
1367
1368            // Wait briefly and then force-close the connection
1369            sleep(Duration::from_millis(500)).await;
1370            let _ = socket.shutdown().await;
1371
1372            // Wait for the client's reconnect attempt
1373            sleep(Duration::from_millis(500)).await;
1374
1375            // Run the echo server on the new connection
1376            let (socket, _) = listener.accept().await.expect("Second accept failed");
1377            run_echo_server(socket).await;
1378        });
1379
1380        let config = SocketConfig {
1381            url: format!("127.0.0.1:{port}"),
1382            mode: Mode::Plain,
1383            suffix: b"\r\n".to_vec(),
1384            message_handler: None,
1385            heartbeat: None,
1386            reconnect_timeout_ms: Some(5_000),
1387            reconnect_delay_initial_ms: Some(500),
1388            reconnect_delay_max_ms: Some(5_000),
1389            reconnect_backoff_factor: Some(2.0),
1390            reconnect_jitter_ms: Some(50),
1391            reconnect_max_attempts: None,
1392            connection_max_retries: None,
1393            certs_dir: None,
1394        };
1395
1396        let client = SocketClient::connect(config, None, None, None)
1397            .await
1398            .expect("Client connect failed unexpectedly");
1399
1400        // Initially, the client should be active
1401        assert!(client.is_active(), "Client should start as active");
1402
1403        // Wait until the client loses connection (i.e. not active),
1404        // then wait until it reconnects (active again).
1405        wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1406
1407        client
1408            .send_bytes(b"TestReconnect".into())
1409            .await
1410            .expect("Send failed");
1411
1412        client.close().await;
1413        server_task.abort();
1414    }
1415}
1416
1417#[cfg(test)]
1418#[cfg(not(feature = "turmoil"))]
1419mod rust_tests {
1420    use nautilus_common::testing::wait_until_async;
1421    use rstest::rstest;
1422    use tokio::{
1423        io::{AsyncReadExt, AsyncWriteExt},
1424        net::TcpListener,
1425        task,
1426        time::{Duration, sleep},
1427    };
1428
1429    use super::*;
1430
1431    #[rstest]
1432    #[tokio::test]
1433    async fn test_reconnect_then_close() {
1434        // Bind an ephemeral port
1435        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1436        let port = listener.local_addr().unwrap().port();
1437
1438        // Server task: accept one connection and then drop it
1439        let server = task::spawn(async move {
1440            if let Ok((mut sock, _)) = listener.accept().await {
1441                drop(sock.shutdown());
1442            }
1443            // Keep listener alive briefly to avoid premature exit
1444            sleep(Duration::from_secs(1)).await;
1445        });
1446
1447        // Configure client with a short reconnect backoff
1448        let config = SocketConfig {
1449            url: format!("127.0.0.1:{port}"),
1450            mode: Mode::Plain,
1451            suffix: b"\r\n".to_vec(),
1452            message_handler: None,
1453            heartbeat: None,
1454            reconnect_timeout_ms: Some(1_000),
1455            reconnect_delay_initial_ms: Some(50),
1456            reconnect_delay_max_ms: Some(100),
1457            reconnect_backoff_factor: Some(1.0),
1458            reconnect_jitter_ms: Some(0),
1459            connection_max_retries: Some(1),
1460            reconnect_max_attempts: None,
1461            certs_dir: None,
1462        };
1463
1464        // Connect client (handler=None)
1465        let client = SocketClient::connect(config.clone(), None, None, None)
1466            .await
1467            .unwrap();
1468
1469        // Wait for client to detect dropped connection and enter reconnect state
1470        wait_until_async(
1471            || async { client.is_reconnecting() },
1472            Duration::from_secs(2),
1473        )
1474        .await;
1475
1476        // Now close the client
1477        client.close().await;
1478        assert!(client.is_closed());
1479        server.abort();
1480    }
1481
1482    #[rstest]
1483    #[tokio::test]
1484    async fn test_reconnect_state_flips_when_reader_stops() {
1485        // Bind an ephemeral port and accept a single connection which we immediately close.
1486        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1487        let port = listener.local_addr().unwrap().port();
1488
1489        let server = task::spawn(async move {
1490            if let Ok((sock, _)) = listener.accept().await {
1491                drop(sock);
1492            }
1493            // Give the client a moment to observe the closed connection.
1494            sleep(Duration::from_millis(50)).await;
1495        });
1496
1497        let config = SocketConfig {
1498            url: format!("127.0.0.1:{port}"),
1499            mode: Mode::Plain,
1500            suffix: b"\r\n".to_vec(),
1501            message_handler: None,
1502            heartbeat: None,
1503            reconnect_timeout_ms: Some(1_000),
1504            reconnect_delay_initial_ms: Some(50),
1505            reconnect_delay_max_ms: Some(100),
1506            reconnect_backoff_factor: Some(1.0),
1507            reconnect_jitter_ms: Some(0),
1508            connection_max_retries: Some(1),
1509            reconnect_max_attempts: None,
1510            certs_dir: None,
1511        };
1512
1513        let client = SocketClient::connect(config, None, None, None)
1514            .await
1515            .unwrap();
1516
1517        wait_until_async(
1518            || async { client.is_reconnecting() },
1519            Duration::from_secs(2),
1520        )
1521        .await;
1522
1523        client.close().await;
1524        server.abort();
1525    }
1526
1527    #[rstest]
1528    fn test_parse_socket_url_raw_address() {
1529        // Raw socket address with TLS mode
1530        let (socket_addr, request_url) =
1531            SocketClientInner::parse_socket_url("example.com:6130", Mode::Tls).unwrap();
1532        assert_eq!(socket_addr, "example.com:6130");
1533        assert_eq!(request_url, "wss://example.com:6130");
1534
1535        // Raw socket address with Plain mode
1536        let (socket_addr, request_url) =
1537            SocketClientInner::parse_socket_url("localhost:8080", Mode::Plain).unwrap();
1538        assert_eq!(socket_addr, "localhost:8080");
1539        assert_eq!(request_url, "ws://localhost:8080");
1540    }
1541
1542    #[rstest]
1543    fn test_parse_socket_url_with_scheme() {
1544        // Full URL with wss scheme
1545        let (socket_addr, request_url) =
1546            SocketClientInner::parse_socket_url("wss://example.com:443/path", Mode::Tls).unwrap();
1547        assert_eq!(socket_addr, "example.com:443");
1548        assert_eq!(request_url, "wss://example.com:443/path");
1549
1550        // Full URL with ws scheme
1551        let (socket_addr, request_url) =
1552            SocketClientInner::parse_socket_url("ws://localhost:8080", Mode::Plain).unwrap();
1553        assert_eq!(socket_addr, "localhost:8080");
1554        assert_eq!(request_url, "ws://localhost:8080");
1555    }
1556
1557    #[rstest]
1558    fn test_parse_socket_url_default_ports() {
1559        // wss without explicit port defaults to 443
1560        let (socket_addr, _) =
1561            SocketClientInner::parse_socket_url("wss://example.com", Mode::Tls).unwrap();
1562        assert_eq!(socket_addr, "example.com:443");
1563
1564        // ws without explicit port defaults to 80
1565        let (socket_addr, _) =
1566            SocketClientInner::parse_socket_url("ws://example.com", Mode::Plain).unwrap();
1567        assert_eq!(socket_addr, "example.com:80");
1568
1569        // https defaults to 443
1570        let (socket_addr, _) =
1571            SocketClientInner::parse_socket_url("https://example.com", Mode::Tls).unwrap();
1572        assert_eq!(socket_addr, "example.com:443");
1573
1574        // http defaults to 80
1575        let (socket_addr, _) =
1576            SocketClientInner::parse_socket_url("http://example.com", Mode::Plain).unwrap();
1577        assert_eq!(socket_addr, "example.com:80");
1578    }
1579
1580    #[rstest]
1581    fn test_parse_socket_url_unknown_scheme_uses_mode() {
1582        // Unknown scheme defaults to mode-based port
1583        let (socket_addr, _) =
1584            SocketClientInner::parse_socket_url("custom://example.com", Mode::Tls).unwrap();
1585        assert_eq!(socket_addr, "example.com:443");
1586
1587        let (socket_addr, _) =
1588            SocketClientInner::parse_socket_url("custom://example.com", Mode::Plain).unwrap();
1589        assert_eq!(socket_addr, "example.com:80");
1590    }
1591
1592    #[rstest]
1593    fn test_parse_socket_url_ipv6() {
1594        // IPv6 address with port
1595        let (socket_addr, request_url) =
1596            SocketClientInner::parse_socket_url("[::1]:8080", Mode::Plain).unwrap();
1597        assert_eq!(socket_addr, "[::1]:8080");
1598        assert_eq!(request_url, "ws://[::1]:8080");
1599
1600        // IPv6 in URL
1601        let (socket_addr, _) =
1602            SocketClientInner::parse_socket_url("ws://[::1]:8080", Mode::Plain).unwrap();
1603        assert_eq!(socket_addr, "[::1]:8080");
1604    }
1605
1606    #[rstest]
1607    #[tokio::test]
1608    async fn test_url_parsing_raw_socket_address() {
1609        // Test that raw socket addresses (host:port) work correctly
1610        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1611        let port = listener.local_addr().unwrap().port();
1612
1613        let server = task::spawn(async move {
1614            if let Ok((sock, _)) = listener.accept().await {
1615                drop(sock);
1616            }
1617            sleep(Duration::from_millis(50)).await;
1618        });
1619
1620        let config = SocketConfig {
1621            url: format!("127.0.0.1:{port}"), // Raw socket address format
1622            mode: Mode::Plain,
1623            suffix: b"\r\n".to_vec(),
1624            message_handler: None,
1625            heartbeat: None,
1626            reconnect_timeout_ms: Some(1_000),
1627            reconnect_delay_initial_ms: Some(50),
1628            reconnect_delay_max_ms: Some(100),
1629            reconnect_backoff_factor: Some(1.0),
1630            reconnect_jitter_ms: Some(0),
1631            connection_max_retries: Some(1),
1632            reconnect_max_attempts: None,
1633            certs_dir: None,
1634        };
1635
1636        // Should successfully connect with raw socket address
1637        let client = SocketClient::connect(config, None, None, None).await;
1638        assert!(
1639            client.is_ok(),
1640            "Client should connect with raw socket address format"
1641        );
1642
1643        if let Ok(client) = client {
1644            client.close().await;
1645        }
1646        server.abort();
1647    }
1648
1649    #[rstest]
1650    #[tokio::test]
1651    async fn test_url_parsing_with_scheme() {
1652        // Test that URLs with schemes also work
1653        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1654        let port = listener.local_addr().unwrap().port();
1655
1656        let server = task::spawn(async move {
1657            if let Ok((sock, _)) = listener.accept().await {
1658                drop(sock);
1659            }
1660            sleep(Duration::from_millis(50)).await;
1661        });
1662
1663        let config = SocketConfig {
1664            url: format!("ws://127.0.0.1:{port}"), // URL with scheme
1665            mode: Mode::Plain,
1666            suffix: b"\r\n".to_vec(),
1667            message_handler: None,
1668            heartbeat: None,
1669            reconnect_timeout_ms: Some(1_000),
1670            reconnect_delay_initial_ms: Some(50),
1671            reconnect_delay_max_ms: Some(100),
1672            reconnect_backoff_factor: Some(1.0),
1673            reconnect_jitter_ms: Some(0),
1674            connection_max_retries: Some(1),
1675            reconnect_max_attempts: None,
1676            certs_dir: None,
1677        };
1678
1679        // Should successfully connect with URL format
1680        let client = SocketClient::connect(config, None, None, None).await;
1681        assert!(
1682            client.is_ok(),
1683            "Client should connect with URL scheme format"
1684        );
1685
1686        if let Ok(client) = client {
1687            client.close().await;
1688        }
1689        server.abort();
1690    }
1691
1692    #[rstest]
1693    fn test_parse_socket_url_ipv6_with_zone() {
1694        // IPv6 with zone ID (link-local address)
1695        let (socket_addr, request_url) =
1696            SocketClientInner::parse_socket_url("[fe80::1%eth0]:8080", Mode::Plain).unwrap();
1697        assert_eq!(socket_addr, "[fe80::1%eth0]:8080");
1698        assert_eq!(request_url, "ws://[fe80::1%eth0]:8080");
1699
1700        // Verify zone is preserved in URL format too
1701        let (socket_addr, request_url) =
1702            SocketClientInner::parse_socket_url("ws://[fe80::1%lo]:9090", Mode::Plain).unwrap();
1703        assert_eq!(socket_addr, "[fe80::1%lo]:9090");
1704        assert_eq!(request_url, "ws://[fe80::1%lo]:9090");
1705    }
1706
1707    #[rstest]
1708    #[tokio::test]
1709    async fn test_ipv6_loopback_connection() {
1710        // Test IPv6 loopback address connection
1711        // Skip if IPv6 is not available on the system
1712        if TcpListener::bind("[::1]:0").await.is_err() {
1713            eprintln!("IPv6 not available, skipping test");
1714            return;
1715        }
1716
1717        let listener = TcpListener::bind("[::1]:0").await.unwrap();
1718        let port = listener.local_addr().unwrap().port();
1719
1720        let server = task::spawn(async move {
1721            if let Ok((mut sock, _)) = listener.accept().await {
1722                let mut buf = vec![0u8; 1024];
1723                if let Ok(n) = sock.read(&mut buf).await {
1724                    // Echo back
1725                    let _ = sock.write_all(&buf[..n]).await;
1726                }
1727            }
1728            sleep(Duration::from_millis(50)).await;
1729        });
1730
1731        let config = SocketConfig {
1732            url: format!("[::1]:{port}"), // IPv6 loopback
1733            mode: Mode::Plain,
1734            suffix: b"\r\n".to_vec(),
1735            message_handler: None,
1736            heartbeat: None,
1737            reconnect_timeout_ms: Some(1_000),
1738            reconnect_delay_initial_ms: Some(50),
1739            reconnect_delay_max_ms: Some(100),
1740            reconnect_backoff_factor: Some(1.0),
1741            reconnect_jitter_ms: Some(0),
1742            connection_max_retries: Some(1),
1743            reconnect_max_attempts: None,
1744            certs_dir: None,
1745        };
1746
1747        let client = SocketClient::connect(config, None, None, None).await;
1748        assert!(
1749            client.is_ok(),
1750            "Client should connect to IPv6 loopback address"
1751        );
1752
1753        if let Ok(client) = client {
1754            client.close().await;
1755        }
1756        server.abort();
1757    }
1758
1759    #[rstest]
1760    #[tokio::test]
1761    async fn test_send_waits_during_reconnection() {
1762        // Test that send operations wait for reconnection to complete (up to configured timeout)
1763        use nautilus_common::testing::wait_until_async;
1764
1765        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1766        let port = listener.local_addr().unwrap().port();
1767
1768        let server = task::spawn(async move {
1769            // First connection - accept and immediately close
1770            if let Ok((sock, _)) = listener.accept().await {
1771                drop(sock);
1772            }
1773
1774            // Wait before accepting second connection
1775            sleep(Duration::from_millis(500)).await;
1776
1777            // Second connection - accept and keep alive
1778            if let Ok((mut sock, _)) = listener.accept().await {
1779                // Echo messages
1780                let mut buf = vec![0u8; 1024];
1781                while let Ok(n) = sock.read(&mut buf).await {
1782                    if n == 0 {
1783                        break;
1784                    }
1785                    if sock.write_all(&buf[..n]).await.is_err() {
1786                        break;
1787                    }
1788                }
1789            }
1790        });
1791
1792        let config = SocketConfig {
1793            url: format!("127.0.0.1:{port}"),
1794            mode: Mode::Plain,
1795            suffix: b"\r\n".to_vec(),
1796            message_handler: None,
1797            heartbeat: None,
1798            reconnect_timeout_ms: Some(5_000), // 5s timeout - enough for reconnect
1799            reconnect_delay_initial_ms: Some(100),
1800            reconnect_delay_max_ms: Some(200),
1801            reconnect_backoff_factor: Some(1.0),
1802            reconnect_jitter_ms: Some(0),
1803            connection_max_retries: Some(1),
1804            reconnect_max_attempts: None,
1805            certs_dir: None,
1806        };
1807
1808        let client = SocketClient::connect(config, None, None, None)
1809            .await
1810            .unwrap();
1811
1812        // Wait for reconnection to trigger
1813        wait_until_async(
1814            || async { client.is_reconnecting() },
1815            Duration::from_secs(2),
1816        )
1817        .await;
1818
1819        // Try to send while reconnecting - should wait and succeed after reconnect
1820        let send_result = tokio::time::timeout(
1821            Duration::from_secs(3),
1822            client.send_bytes(b"test_message".to_vec()),
1823        )
1824        .await;
1825
1826        assert!(
1827            send_result.is_ok() && send_result.unwrap().is_ok(),
1828            "Send should succeed after waiting for reconnection"
1829        );
1830
1831        client.close().await;
1832        server.abort();
1833    }
1834
1835    #[rstest]
1836    #[tokio::test]
1837    async fn test_send_bytes_timeout_uses_configured_reconnect_timeout() {
1838        // Test that send_bytes operations respect the configured reconnect_timeout.
1839        // When a client is stuck in RECONNECT longer than the timeout, sends should fail with Timeout.
1840        use nautilus_common::testing::wait_until_async;
1841
1842        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1843        let port = listener.local_addr().unwrap().port();
1844
1845        let server = task::spawn(async move {
1846            // Accept first connection and immediately close it
1847            if let Ok((sock, _)) = listener.accept().await {
1848                drop(sock);
1849            }
1850            // Drop listener entirely so reconnection fails completely
1851            drop(listener);
1852            sleep(Duration::from_secs(60)).await;
1853        });
1854
1855        let config = SocketConfig {
1856            url: format!("127.0.0.1:{port}"),
1857            mode: Mode::Plain,
1858            suffix: b"\r\n".to_vec(),
1859            message_handler: None,
1860            heartbeat: None,
1861            reconnect_timeout_ms: Some(1_000), // 1s timeout for faster test
1862            reconnect_delay_initial_ms: Some(200), // Short backoff (but > timeout) to keep client in RECONNECT
1863            reconnect_delay_max_ms: Some(200),
1864            reconnect_backoff_factor: Some(1.0),
1865            reconnect_jitter_ms: Some(0),
1866            connection_max_retries: Some(1),
1867            reconnect_max_attempts: None,
1868            certs_dir: None,
1869        };
1870
1871        let client = SocketClient::connect(config, None, None, None)
1872            .await
1873            .unwrap();
1874
1875        // Wait for client to enter RECONNECT state
1876        wait_until_async(
1877            || async { client.is_reconnecting() },
1878            Duration::from_secs(3),
1879        )
1880        .await;
1881
1882        // Attempt send while stuck in RECONNECT - should timeout after 1s (configured timeout)
1883        // The client will try to reconnect for 1s, fail, then wait 5s backoff before next attempt
1884        let start = std::time::Instant::now();
1885        let send_result = client.send_bytes(b"test".to_vec()).await;
1886        let elapsed = start.elapsed();
1887
1888        assert!(
1889            send_result.is_err(),
1890            "Send should fail when client stuck in RECONNECT, was: {send_result:?}"
1891        );
1892        assert!(
1893            matches!(send_result, Err(crate::error::SendError::Timeout)),
1894            "Send should return Timeout error, was: {send_result:?}"
1895        );
1896        // Verify timeout respects configured value (1s), but don't check upper bound
1897        // as CI scheduler jitter can cause legitimate delays beyond the timeout
1898        assert!(
1899            elapsed >= Duration::from_millis(900),
1900            "Send should timeout after at least 1s (configured timeout), took {elapsed:?}"
1901        );
1902
1903        client.close().await;
1904        server.abort();
1905    }
1906}