nautilus_network/
socket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance raw TCP client implementation with TLS capability, automatic reconnection
17//! with exponential backoff and state management.
18
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED).
21//! - Synchronized reconnection with backoff.
22//! - Split read/write architecture.
23//! - Python callback integration.
24//!
25//! **Design**:
26//! - Single reader, multiple writer model.
27//! - Read half runs in dedicated task.
28//! - Write half runs in dedicated task connected with channel.
29//! - Controller task manages lifecycle.
30
31use std::{
32    fmt::Debug,
33    path::Path,
34    sync::{
35        Arc,
36        atomic::{AtomicU8, Ordering},
37    },
38    time::Duration,
39};
40
41use bytes::Bytes;
42use nautilus_core::CleanDrop;
43use nautilus_cryptography::providers::install_cryptographic_provider;
44use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
45use tokio_tungstenite::{
46    MaybeTlsStream,
47    tungstenite::{Error, client::IntoClientRequest, stream::Mode},
48};
49
50use crate::{
51    backoff::ExponentialBackoff,
52    error::SendError,
53    fix::process_fix_buffer,
54    logging::{log_task_aborted, log_task_started, log_task_stopped},
55    mode::ConnectionMode,
56    net::TcpStream,
57    tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
58};
59
60// Connection timing constants
61const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
62const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
63const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
64const SEND_OPERATION_CHECK_INTERVAL_MS: u64 = 1;
65
66type TcpWriter = WriteHalf<MaybeTlsStream<TcpStream>>;
67type TcpReader = ReadHalf<MaybeTlsStream<TcpStream>>;
68pub type TcpMessageHandler = Arc<dyn Fn(&[u8]) + Send + Sync>;
69
70/// Configuration for TCP socket connection.
71#[cfg_attr(
72    feature = "python",
73    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
74)]
75pub struct SocketConfig {
76    /// The URL to connect to.
77    pub url: String,
78    /// The connection mode {Plain, TLS}.
79    pub mode: Mode,
80    /// The sequence of bytes which separates lines.
81    pub suffix: Vec<u8>,
82    /// The optional function to handle incoming messages.
83    pub message_handler: Option<TcpMessageHandler>,
84    /// The optional heartbeat with period and beat message.
85    pub heartbeat: Option<(u64, Vec<u8>)>,
86    /// The timeout (milliseconds) for reconnection attempts.
87    pub reconnect_timeout_ms: Option<u64>,
88    /// The initial reconnection delay (milliseconds) for reconnects.
89    pub reconnect_delay_initial_ms: Option<u64>,
90    /// The maximum reconnect delay (milliseconds) for exponential backoff.
91    pub reconnect_delay_max_ms: Option<u64>,
92    /// The exponential backoff factor for reconnection delays.
93    pub reconnect_backoff_factor: Option<f64>,
94    /// The maximum jitter (milliseconds) added to reconnection delays.
95    pub reconnect_jitter_ms: Option<u64>,
96    /// The path to the certificates directory.
97    pub certs_dir: Option<String>,
98}
99
100impl Debug for SocketConfig {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        f.debug_struct(stringify!(SocketConfig))
103            .field("url", &self.url)
104            .field("mode", &self.mode)
105            .field("suffix", &self.suffix)
106            .field(
107                "message_handler",
108                &self.message_handler.as_ref().map(|_| "<function>"),
109            )
110            .field("heartbeat", &self.heartbeat)
111            .field("reconnect_timeout_ms", &self.reconnect_timeout_ms)
112            .field(
113                "reconnect_delay_initial_ms",
114                &self.reconnect_delay_initial_ms,
115            )
116            .field("reconnect_delay_max_ms", &self.reconnect_delay_max_ms)
117            .field("reconnect_backoff_factor", &self.reconnect_backoff_factor)
118            .field("reconnect_jitter_ms", &self.reconnect_jitter_ms)
119            .field("certs_dir", &self.certs_dir)
120            .finish()
121    }
122}
123
124impl Clone for SocketConfig {
125    fn clone(&self) -> Self {
126        Self {
127            url: self.url.clone(),
128            mode: self.mode,
129            suffix: self.suffix.clone(),
130            message_handler: self.message_handler.clone(),
131            heartbeat: self.heartbeat.clone(),
132            reconnect_timeout_ms: self.reconnect_timeout_ms,
133            reconnect_delay_initial_ms: self.reconnect_delay_initial_ms,
134            reconnect_delay_max_ms: self.reconnect_delay_max_ms,
135            reconnect_backoff_factor: self.reconnect_backoff_factor,
136            reconnect_jitter_ms: self.reconnect_jitter_ms,
137            certs_dir: self.certs_dir.clone(),
138        }
139    }
140}
141
142/// Represents a command for the writer task.
143#[derive(Debug)]
144pub enum WriterCommand {
145    /// Update the writer reference with a new one after reconnection.
146    Update(TcpWriter),
147    /// Send data to the server.
148    Send(Bytes),
149}
150
151/// Creates a `TcpStream` with the server.
152///
153/// The stream can be encrypted with TLS or Plain. The stream is split into
154/// read and write ends:
155/// - The read end is passed to the task that keeps receiving
156///   messages from the server and passing them to a handler.
157/// - The write end is passed to a task which receives messages over a channel
158///   to send to the server.
159///
160/// The heartbeat is optional and can be configured with an interval and data to
161/// send.
162///
163/// The client uses a suffix to separate messages on the byte stream. It is
164/// appended to all sent messages and heartbeats. It is also used to split
165/// the received byte stream.
166#[cfg_attr(
167    feature = "python",
168    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
169)]
170struct SocketClientInner {
171    config: SocketConfig,
172    connector: Option<Connector>,
173    read_task: Arc<tokio::task::JoinHandle<()>>,
174    write_task: tokio::task::JoinHandle<()>,
175    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
176    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
177    connection_mode: Arc<AtomicU8>,
178    reconnect_timeout: Duration,
179    backoff: ExponentialBackoff,
180    handler: Option<TcpMessageHandler>,
181}
182
183impl SocketClientInner {
184    /// Connect to a URL with the specified configuration.
185    ///
186    /// # Errors
187    ///
188    /// Returns an error if connection fails or configuration is invalid.
189    pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
190        install_cryptographic_provider();
191
192        let SocketConfig {
193            url,
194            mode,
195            heartbeat,
196            suffix,
197            message_handler,
198            reconnect_timeout_ms,
199            reconnect_delay_initial_ms,
200            reconnect_delay_max_ms,
201            reconnect_backoff_factor,
202            reconnect_jitter_ms,
203            certs_dir,
204        } = &config.clone();
205        let connector = if let Some(dir) = certs_dir {
206            let config = create_tls_config_from_certs_dir(Path::new(dir), false)?;
207            Some(Connector::Rustls(Arc::new(config)))
208        } else {
209            None
210        };
211
212        let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector.clone()).await?;
213        tracing::debug!("Connected");
214
215        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
216
217        let read_task = Arc::new(Self::spawn_read_task(
218            connection_mode.clone(),
219            reader,
220            message_handler.clone(),
221            suffix.clone(),
222        ));
223
224        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
225
226        let write_task =
227            Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
228
229        // Optionally spawn a heartbeat task to periodically ping server
230        let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
231            Self::spawn_heartbeat_task(
232                connection_mode.clone(),
233                heartbeat.clone(),
234                writer_tx.clone(),
235            )
236        });
237
238        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
239        let backoff = ExponentialBackoff::new(
240            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
241            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
242            reconnect_backoff_factor.unwrap_or(1.5),
243            reconnect_jitter_ms.unwrap_or(100),
244            true, // immediate-first
245        )?;
246
247        Ok(Self {
248            config,
249            connector,
250            read_task,
251            write_task,
252            writer_tx,
253            heartbeat_task,
254            connection_mode,
255            reconnect_timeout,
256            backoff,
257            handler: message_handler.clone(),
258        })
259    }
260
261    /// Parse URL and extract socket address and request URL.
262    ///
263    /// Accepts either:
264    /// - Raw socket address: "host:port" → returns ("host:port", "scheme://host:port")
265    /// - Full URL: "scheme://host:port/path" → returns ("host:port", original URL)
266    ///
267    /// # Errors
268    ///
269    /// Returns an error if the URL is invalid or missing required components.
270    fn parse_socket_url(url: &str, mode: Mode) -> Result<(String, String), Error> {
271        if url.contains("://") {
272            // URL with scheme (e.g., "wss://host:port/path")
273            let parsed = url.parse::<http::Uri>().map_err(|e| {
274                Error::Io(std::io::Error::new(
275                    std::io::ErrorKind::InvalidInput,
276                    format!("Invalid URL: {e}"),
277                ))
278            })?;
279
280            let host = parsed.host().ok_or_else(|| {
281                Error::Io(std::io::Error::new(
282                    std::io::ErrorKind::InvalidInput,
283                    "URL missing host",
284                ))
285            })?;
286
287            let port = parsed
288                .port_u16()
289                .unwrap_or_else(|| match parsed.scheme_str() {
290                    Some("wss" | "https") => 443,
291                    Some("ws" | "http") => 80,
292                    _ => match mode {
293                        Mode::Tls => 443,
294                        Mode::Plain => 80,
295                    },
296                });
297
298            Ok((format!("{host}:{port}"), url.to_string()))
299        } else {
300            // Raw socket address (e.g., "host:port")
301            // Construct a proper URL for the request based on mode
302            let scheme = match mode {
303                Mode::Tls => "wss",
304                Mode::Plain => "ws",
305            };
306            Ok((url.to_string(), format!("{scheme}://{url}")))
307        }
308    }
309
310    /// Establish a TLS or plain TCP connection with the server.
311    ///
312    /// Accepts either a raw socket address (e.g., "host:port") or a full URL with scheme
313    /// (e.g., "wss://host:port"). For FIX/raw socket connections, use the host:port format.
314    /// For WebSocket-style connections, include the scheme.
315    ///
316    /// # Errors
317    ///
318    /// Returns an error if the connection cannot be established.
319    pub async fn tls_connect_with_server(
320        url: &str,
321        mode: Mode,
322        connector: Option<Connector>,
323    ) -> Result<(TcpReader, TcpWriter), Error> {
324        tracing::debug!("Connecting to {url}");
325
326        let (socket_addr, request_url) = Self::parse_socket_url(url, mode)?;
327        let tcp_result = TcpStream::connect(&socket_addr).await;
328
329        match tcp_result {
330            Ok(stream) => {
331                tracing::debug!("TCP connection established to {socket_addr}, proceeding with TLS");
332                let request = request_url.into_client_request()?;
333                tcp_tls(&request, mode, stream, connector)
334                    .await
335                    .map(tokio::io::split)
336            }
337            Err(e) => {
338                tracing::error!("TCP connection failed to {socket_addr}: {e:?}");
339                Err(Error::Io(e))
340            }
341        }
342    }
343
344    /// Reconnect with server.
345    ///
346    /// Makes a new connection with server, uses the new read and write halves
347    /// to update the reader and writer.
348    async fn reconnect(&mut self) -> Result<(), Error> {
349        tracing::debug!("Reconnecting");
350
351        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
352            tracing::debug!("Reconnect aborted due to disconnect state");
353            return Ok(());
354        }
355
356        tokio::time::timeout(self.reconnect_timeout, async {
357            let SocketConfig {
358                url,
359                mode,
360                heartbeat: _,
361                suffix,
362                message_handler: _,
363                reconnect_timeout_ms: _,
364                reconnect_delay_initial_ms: _,
365                reconnect_backoff_factor: _,
366                reconnect_delay_max_ms: _,
367                reconnect_jitter_ms: _,
368                certs_dir: _,
369            } = &self.config;
370            // Create a fresh connection
371            let connector = self.connector.clone();
372            // Attempt to connect; abort early if a disconnect was requested
373            let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
374
375            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
376                tracing::debug!("Reconnect aborted mid-flight (after connect)");
377                return Ok(());
378            }
379            tracing::debug!("Connected");
380
381            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
382                tracing::error!("{e}");
383            }
384
385            // Delay before closing connection
386            tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
387
388            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
389                tracing::debug!("Reconnect aborted mid-flight (after delay)");
390                return Ok(());
391            }
392
393            if !self.read_task.is_finished() {
394                self.read_task.abort();
395                log_task_aborted("read");
396            }
397
398            // Atomically transition from Reconnect to Active
399            // This prevents race condition where disconnect could be requested between check and store
400            if self
401                .connection_mode
402                .compare_exchange(
403                    ConnectionMode::Reconnect.as_u8(),
404                    ConnectionMode::Active.as_u8(),
405                    Ordering::SeqCst,
406                    Ordering::SeqCst,
407                )
408                .is_err()
409            {
410                tracing::debug!("Reconnect aborted (state changed during reconnect)");
411                return Ok(());
412            }
413
414            // Spawn new read task
415            self.read_task = Arc::new(Self::spawn_read_task(
416                self.connection_mode.clone(),
417                reader,
418                self.handler.clone(),
419                suffix.clone(),
420            ));
421
422            tracing::debug!("Reconnect succeeded");
423            Ok(())
424        })
425        .await
426        .map_err(|_| {
427            Error::Io(std::io::Error::new(
428                std::io::ErrorKind::TimedOut,
429                format!(
430                    "reconnection timed out after {}s",
431                    self.reconnect_timeout.as_secs_f64()
432                ),
433            ))
434        })?
435    }
436
437    /// Check if the client is still alive.
438    ///
439    /// The client is connected if the read task has not finished. It is expected
440    /// that in case of any failure client or server side. The read task will be
441    /// shutdown. There might be some delay between the connection being closed
442    /// and the client detecting it.
443    #[inline]
444    #[must_use]
445    pub fn is_alive(&self) -> bool {
446        !self.read_task.is_finished()
447    }
448
449    #[must_use]
450    fn spawn_read_task(
451        connection_state: Arc<AtomicU8>,
452        mut reader: TcpReader,
453        handler: Option<TcpMessageHandler>,
454        suffix: Vec<u8>,
455    ) -> tokio::task::JoinHandle<()> {
456        log_task_started("read");
457
458        // Interval between checking the connection mode
459        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
460
461        tokio::task::spawn(async move {
462            let mut buf = Vec::new();
463
464            loop {
465                if !ConnectionMode::from_atomic(&connection_state).is_active() {
466                    break;
467                }
468
469                match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
470                    // Connection has been terminated or vector buffer is complete
471                    Ok(Ok(0)) => {
472                        tracing::debug!("Connection closed by server");
473                        break;
474                    }
475                    Ok(Err(e)) => {
476                        tracing::debug!("Connection ended: {e}");
477                        break;
478                    }
479                    // Received bytes of data
480                    Ok(Ok(bytes)) => {
481                        tracing::trace!("Received <binary> {bytes} bytes");
482
483                        // Check if buffer contains FIX protocol messages (starts with "8=FIX")
484                        let is_fix = buf.len() >= 5 && buf.starts_with(b"8=FIX");
485
486                        if is_fix && handler.is_some() {
487                            // FIX protocol processing
488                            if let Some(ref handler) = handler {
489                                process_fix_buffer(&mut buf, handler);
490                            }
491                        } else {
492                            // Regular suffix-based message processing
493                            while let Some((i, _)) = &buf
494                                .windows(suffix.len())
495                                .enumerate()
496                                .find(|(_, pair)| pair.eq(&suffix))
497                            {
498                                let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
499                                data.truncate(data.len() - suffix.len());
500
501                                if let Some(ref handler) = handler {
502                                    handler(&data);
503                                }
504                            }
505                        }
506                    }
507                    Err(_) => {
508                        // Timeout - continue loop and check connection mode
509                        continue;
510                    }
511                }
512            }
513
514            log_task_stopped("read");
515        })
516    }
517
518    fn spawn_write_task(
519        connection_state: Arc<AtomicU8>,
520        writer: TcpWriter,
521        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
522        suffix: Vec<u8>,
523    ) -> tokio::task::JoinHandle<()> {
524        log_task_started("write");
525
526        // Interval between checking the connection mode
527        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
528
529        tokio::task::spawn(async move {
530            let mut active_writer = writer;
531
532            loop {
533                if matches!(
534                    ConnectionMode::from_atomic(&connection_state),
535                    ConnectionMode::Disconnect | ConnectionMode::Closed
536                ) {
537                    break;
538                }
539
540                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
541                    Ok(Some(msg)) => {
542                        // Re-check connection mode after receiving a message
543                        let mode = ConnectionMode::from_atomic(&connection_state);
544                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
545                            break;
546                        }
547
548                        match msg {
549                            WriterCommand::Update(new_writer) => {
550                                tracing::debug!("Received new writer");
551
552                                // Delay before closing connection
553                                tokio::time::sleep(Duration::from_millis(100)).await;
554
555                                // Attempt to shutdown the writer gracefully before updating,
556                                // we ignore any error as the writer may already be closed.
557                                _ = tokio::time::timeout(
558                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
559                                    active_writer.shutdown(),
560                                )
561                                .await;
562
563                                active_writer = new_writer;
564                                tracing::debug!("Updated writer");
565                            }
566                            _ if mode.is_reconnect() => {
567                                tracing::warn!("Skipping message while reconnecting, {msg:?}");
568                                continue;
569                            }
570                            WriterCommand::Send(msg) => {
571                                if let Err(e) = active_writer.write_all(&msg).await {
572                                    tracing::error!("Failed to send message: {e}");
573                                    // Mode is active so trigger reconnection
574                                    tracing::warn!("Writer triggering reconnect");
575                                    connection_state
576                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
577                                    continue;
578                                }
579                                if let Err(e) = active_writer.write_all(&suffix).await {
580                                    tracing::error!("Failed to send suffix: {e}");
581                                    // Mode is active so trigger reconnection
582                                    tracing::warn!("Writer triggering reconnect");
583                                    connection_state
584                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
585                                    continue;
586                                }
587                            }
588                        }
589                    }
590                    Ok(None) => {
591                        // Channel closed - writer task should terminate
592                        tracing::debug!("Writer channel closed, terminating writer task");
593                        break;
594                    }
595                    Err(_) => {
596                        // Timeout - just continue the loop
597                        continue;
598                    }
599                }
600            }
601
602            // Attempt to shutdown the writer gracefully before exiting,
603            // we ignore any error as the writer may already be closed.
604            _ = tokio::time::timeout(
605                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
606                active_writer.shutdown(),
607            )
608            .await;
609
610            log_task_stopped("write");
611        })
612    }
613
614    fn spawn_heartbeat_task(
615        connection_state: Arc<AtomicU8>,
616        heartbeat: (u64, Vec<u8>),
617        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
618    ) -> tokio::task::JoinHandle<()> {
619        log_task_started("heartbeat");
620        let (interval_secs, message) = heartbeat;
621
622        tokio::task::spawn(async move {
623            let interval = Duration::from_secs(interval_secs);
624
625            loop {
626                tokio::time::sleep(interval).await;
627
628                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
629                    ConnectionMode::Active => {
630                        let msg = WriterCommand::Send(message.clone().into());
631
632                        match writer_tx.send(msg) {
633                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
634                            Err(e) => {
635                                tracing::error!("Failed to send heartbeat to writer task: {e}");
636                            }
637                        }
638                    }
639                    ConnectionMode::Reconnect => continue,
640                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
641                }
642            }
643
644            log_task_stopped("heartbeat");
645        })
646    }
647}
648
649impl Drop for SocketClientInner {
650    fn drop(&mut self) {
651        // Delegate to explicit cleanup handler
652        self.clean_drop();
653    }
654}
655
656/// Cleanup on drop: aborts background tasks and clears handlers to break reference cycles.
657impl CleanDrop for SocketClientInner {
658    fn clean_drop(&mut self) {
659        if !self.read_task.is_finished() {
660            self.read_task.abort();
661            log_task_aborted("read");
662        }
663
664        if !self.write_task.is_finished() {
665            self.write_task.abort();
666            log_task_aborted("write");
667        }
668
669        if let Some(ref handle) = self.heartbeat_task.take()
670            && !handle.is_finished()
671        {
672            handle.abort();
673            log_task_aborted("heartbeat");
674        }
675
676        #[cfg(feature = "python")]
677        {
678            // Remove stored handler to break ref cycle
679            self.config.message_handler = None;
680        }
681    }
682}
683
684#[cfg_attr(
685    feature = "python",
686    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
687)]
688pub struct SocketClient {
689    pub(crate) controller_task: tokio::task::JoinHandle<()>,
690    pub(crate) connection_mode: Arc<AtomicU8>,
691    pub(crate) reconnect_timeout: Duration,
692    pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
693}
694
695impl Debug for SocketClient {
696    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
697        f.debug_struct(stringify!(SocketClient)).finish()
698    }
699}
700
701impl SocketClient {
702    /// Connect to the server.
703    ///
704    /// # Errors
705    ///
706    /// Returns any error connecting to the server.
707    pub async fn connect(
708        config: SocketConfig,
709        post_connection: Option<Arc<dyn Fn() + Send + Sync>>,
710        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
711        post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
712    ) -> anyhow::Result<Self> {
713        let inner = SocketClientInner::connect_url(config).await?;
714        let writer_tx = inner.writer_tx.clone();
715        let connection_mode = inner.connection_mode.clone();
716        let reconnect_timeout = inner.reconnect_timeout;
717
718        let controller_task = Self::spawn_controller_task(
719            inner,
720            connection_mode.clone(),
721            post_reconnection,
722            post_disconnection,
723        );
724
725        if let Some(handler) = post_connection {
726            handler();
727            tracing::debug!("Called `post_connection` handler");
728        }
729
730        Ok(Self {
731            controller_task,
732            connection_mode,
733            reconnect_timeout,
734            writer_tx,
735        })
736    }
737
738    /// Returns the current connection mode.
739    #[must_use]
740    pub fn connection_mode(&self) -> ConnectionMode {
741        ConnectionMode::from_atomic(&self.connection_mode)
742    }
743
744    /// Check if the client connection is active.
745    ///
746    /// Returns `true` if the client is connected and has not been signalled to disconnect.
747    /// The client will automatically retry connection based on its configuration.
748    #[inline]
749    #[must_use]
750    pub fn is_active(&self) -> bool {
751        self.connection_mode().is_active()
752    }
753
754    /// Check if the client is reconnecting.
755    ///
756    /// Returns `true` if the client lost connection and is attempting to reestablish it.
757    /// The client will automatically retry connection based on its configuration.
758    #[inline]
759    #[must_use]
760    pub fn is_reconnecting(&self) -> bool {
761        self.connection_mode().is_reconnect()
762    }
763
764    /// Check if the client is disconnecting.
765    ///
766    /// Returns `true` if the client is in disconnect mode.
767    #[inline]
768    #[must_use]
769    pub fn is_disconnecting(&self) -> bool {
770        self.connection_mode().is_disconnect()
771    }
772
773    /// Check if the client is closed.
774    ///
775    /// Returns `true` if the client has been explicitly disconnected or reached
776    /// maximum reconnection attempts. In this state, the client cannot be reused
777    /// and a new client must be created for further connections.
778    #[inline]
779    #[must_use]
780    pub fn is_closed(&self) -> bool {
781        self.connection_mode().is_closed()
782    }
783
784    /// Close the client.
785    ///
786    /// Controller task will periodically check the disconnect mode
787    /// and shutdown the client if it is not alive.
788    pub async fn close(&self) {
789        self.connection_mode
790            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
791
792        if let Ok(()) =
793            tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
794                while !self.is_closed() {
795                    tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS))
796                        .await;
797                }
798
799                if !self.controller_task.is_finished() {
800                    self.controller_task.abort();
801                    log_task_aborted("controller");
802                }
803            })
804            .await
805        {
806            log_task_stopped("controller");
807        } else {
808            tracing::error!("Timeout waiting for controller task to finish");
809            if !self.controller_task.is_finished() {
810                self.controller_task.abort();
811                log_task_aborted("controller");
812            }
813        }
814    }
815
816    /// Sends a message of the given `data`.
817    ///
818    /// # Errors
819    ///
820    /// Returns an error if sending fails.
821    pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
822        if self.is_closed() {
823            return Err(SendError::Closed);
824        }
825
826        let timeout = self.reconnect_timeout;
827        let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
828
829        if !self.is_active() {
830            tracing::debug!("Waiting for client to become ACTIVE before sending...");
831
832            let inner = tokio::time::timeout(timeout, async {
833                loop {
834                    if self.is_active() {
835                        return Ok(());
836                    }
837                    if matches!(
838                        self.connection_mode(),
839                        ConnectionMode::Disconnect | ConnectionMode::Closed
840                    ) {
841                        return Err(());
842                    }
843                    tokio::time::sleep(check_interval).await;
844                }
845            })
846            .await
847            .map_err(|_| SendError::Timeout)?;
848            inner.map_err(|()| SendError::Closed)?;
849        }
850
851        let msg = WriterCommand::Send(data.into());
852        self.writer_tx
853            .send(msg)
854            .map_err(|e| SendError::BrokenPipe(e.to_string()))
855    }
856
857    fn spawn_controller_task(
858        mut inner: SocketClientInner,
859        connection_mode: Arc<AtomicU8>,
860        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
861        post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
862    ) -> tokio::task::JoinHandle<()> {
863        tokio::task::spawn(async move {
864            log_task_started("controller");
865
866            let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
867
868            loop {
869                tokio::time::sleep(check_interval).await;
870                let mut mode = ConnectionMode::from_atomic(&connection_mode);
871
872                if mode.is_disconnect() {
873                    tracing::debug!("Disconnecting");
874
875                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
876                    if tokio::time::timeout(timeout, async {
877                        // Delay awaiting graceful shutdown
878                        tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
879
880                        if !inner.read_task.is_finished() {
881                            inner.read_task.abort();
882                            log_task_aborted("read");
883                        }
884
885                        if let Some(task) = &inner.heartbeat_task
886                            && !task.is_finished()
887                        {
888                            task.abort();
889                            log_task_aborted("heartbeat");
890                        }
891                    })
892                    .await
893                    .is_err()
894                    {
895                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
896                    }
897
898                    tracing::debug!("Closed");
899
900                    if let Some(ref handler) = post_disconnection {
901                        handler();
902                        tracing::debug!("Called `post_disconnection` handler");
903                    }
904                    break; // Controller finished
905                }
906
907                if mode.is_active() && !inner.is_alive() {
908                    if connection_mode
909                        .compare_exchange(
910                            ConnectionMode::Active.as_u8(),
911                            ConnectionMode::Reconnect.as_u8(),
912                            Ordering::SeqCst,
913                            Ordering::SeqCst,
914                        )
915                        .is_ok()
916                    {
917                        tracing::debug!("Detected dead read task, transitioning to RECONNECT");
918                    }
919                    mode = ConnectionMode::from_atomic(&connection_mode);
920                }
921
922                if mode.is_reconnect() {
923                    match inner.reconnect().await {
924                        Ok(()) => {
925                            tracing::debug!("Reconnected successfully");
926                            inner.backoff.reset();
927                            // Only invoke reconnect handler if still active
928                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
929                                if let Some(ref handler) = post_reconnection {
930                                    handler();
931                                    tracing::debug!("Called `post_reconnection` handler");
932                                }
933                            } else {
934                                tracing::debug!(
935                                    "Skipping post_reconnection handlers due to disconnect state"
936                                );
937                            }
938                        }
939                        Err(e) => {
940                            let duration = inner.backoff.next_duration();
941                            tracing::warn!("Reconnect attempt failed: {e}");
942                            if !duration.is_zero() {
943                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
944                            }
945                            tokio::time::sleep(duration).await;
946                        }
947                    }
948                }
949            }
950            inner
951                .connection_mode
952                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
953
954            log_task_stopped("controller");
955        })
956    }
957}
958
959// Abort controller task on drop to clean up background tasks
960impl Drop for SocketClient {
961    fn drop(&mut self) {
962        if !self.controller_task.is_finished() {
963            self.controller_task.abort();
964            log_task_aborted("controller");
965        }
966    }
967}
968
969////////////////////////////////////////////////////////////////////////////////
970// Tests
971////////////////////////////////////////////////////////////////////////////////
972
973#[cfg(test)]
974#[cfg(feature = "python")]
975#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
976mod tests {
977    use nautilus_common::testing::wait_until_async;
978    use pyo3::Python;
979    use tokio::{
980        io::{AsyncReadExt, AsyncWriteExt},
981        net::{TcpListener, TcpStream},
982        sync::Mutex,
983        task,
984        time::{Duration, sleep},
985    };
986
987    use super::*;
988
989    async fn bind_test_server() -> (u16, TcpListener) {
990        let listener = TcpListener::bind("127.0.0.1:0")
991            .await
992            .expect("Failed to bind ephemeral port");
993        let port = listener.local_addr().unwrap().port();
994        (port, listener)
995    }
996
997    async fn run_echo_server(mut socket: TcpStream) {
998        let mut buf = Vec::new();
999        loop {
1000            match socket.read_buf(&mut buf).await {
1001                Ok(0) => {
1002                    break;
1003                }
1004                Ok(_n) => {
1005                    while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1006                        let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1007                        // Remove trailing \r\n
1008                        line.truncate(line.len() - 2);
1009
1010                        if line == b"close" {
1011                            let _ = socket.shutdown().await;
1012                            return;
1013                        }
1014
1015                        let mut echo_data = line;
1016                        echo_data.extend_from_slice(b"\r\n");
1017                        if socket.write_all(&echo_data).await.is_err() {
1018                            break;
1019                        }
1020                    }
1021                }
1022                Err(e) => {
1023                    eprintln!("Server read error: {e}");
1024                    break;
1025                }
1026            }
1027        }
1028    }
1029
1030    #[tokio::test]
1031    async fn test_basic_send_receive() {
1032        Python::initialize();
1033
1034        let (port, listener) = bind_test_server().await;
1035        let server_task = task::spawn(async move {
1036            let (socket, _) = listener.accept().await.unwrap();
1037            run_echo_server(socket).await;
1038        });
1039
1040        let config = SocketConfig {
1041            url: format!("127.0.0.1:{port}"),
1042            mode: Mode::Plain,
1043            suffix: b"\r\n".to_vec(),
1044            message_handler: None,
1045            heartbeat: None,
1046            reconnect_timeout_ms: None,
1047            reconnect_delay_initial_ms: None,
1048            reconnect_backoff_factor: None,
1049            reconnect_delay_max_ms: None,
1050            reconnect_jitter_ms: None,
1051            certs_dir: None,
1052        };
1053
1054        let client = SocketClient::connect(config, None, None, None)
1055            .await
1056            .expect("Client connect failed unexpectedly");
1057
1058        client.send_bytes(b"Hello".into()).await.unwrap();
1059        client.send_bytes(b"World".into()).await.unwrap();
1060
1061        // Wait a bit for the server to echo them back
1062        sleep(Duration::from_millis(100)).await;
1063
1064        client.send_bytes(b"close".into()).await.unwrap();
1065        server_task.await.unwrap();
1066        assert!(!client.is_closed());
1067    }
1068
1069    #[tokio::test]
1070    async fn test_reconnect_fail_exhausted() {
1071        Python::initialize();
1072
1073        let (port, listener) = bind_test_server().await;
1074        drop(listener); // We drop it immediately -> no server is listening
1075
1076        let config = SocketConfig {
1077            url: format!("127.0.0.1:{port}"),
1078            mode: Mode::Plain,
1079            suffix: b"\r\n".to_vec(),
1080            message_handler: None,
1081            heartbeat: None,
1082            reconnect_timeout_ms: None,
1083            reconnect_delay_initial_ms: None,
1084            reconnect_backoff_factor: None,
1085            reconnect_delay_max_ms: None,
1086            reconnect_jitter_ms: None,
1087            certs_dir: None,
1088        };
1089
1090        let client_res = SocketClient::connect(config, None, None, None).await;
1091        assert!(
1092            client_res.is_err(),
1093            "Should fail quickly with no server listening"
1094        );
1095    }
1096
1097    #[tokio::test]
1098    async fn test_user_disconnect() {
1099        Python::initialize();
1100
1101        let (port, listener) = bind_test_server().await;
1102        let server_task = task::spawn(async move {
1103            let (socket, _) = listener.accept().await.unwrap();
1104            let mut buf = [0u8; 1024];
1105            let _ = socket.try_read(&mut buf);
1106
1107            loop {
1108                sleep(Duration::from_secs(1)).await;
1109            }
1110        });
1111
1112        let config = SocketConfig {
1113            url: format!("127.0.0.1:{port}"),
1114            mode: Mode::Plain,
1115            suffix: b"\r\n".to_vec(),
1116            message_handler: None,
1117            heartbeat: None,
1118            reconnect_timeout_ms: None,
1119            reconnect_delay_initial_ms: None,
1120            reconnect_backoff_factor: None,
1121            reconnect_delay_max_ms: None,
1122            reconnect_jitter_ms: None,
1123            certs_dir: None,
1124        };
1125
1126        let client = SocketClient::connect(config, None, None, None)
1127            .await
1128            .unwrap();
1129
1130        client.close().await;
1131        assert!(client.is_closed());
1132        server_task.abort();
1133    }
1134
1135    #[tokio::test]
1136    async fn test_heartbeat() {
1137        Python::initialize();
1138
1139        let (port, listener) = bind_test_server().await;
1140        let received = Arc::new(Mutex::new(Vec::new()));
1141        let received2 = received.clone();
1142
1143        let server_task = task::spawn(async move {
1144            let (socket, _) = listener.accept().await.unwrap();
1145
1146            let mut buf = Vec::new();
1147            loop {
1148                match socket.try_read_buf(&mut buf) {
1149                    Ok(0) => break,
1150                    Ok(_) => {
1151                        while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1152                            let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1153                            line.truncate(line.len() - 2);
1154                            received2.lock().await.push(line);
1155                        }
1156                    }
1157                    Err(_) => {
1158                        tokio::time::sleep(Duration::from_millis(10)).await;
1159                    }
1160                }
1161            }
1162        });
1163
1164        // Heartbeat every 1 second
1165        let heartbeat = Some((1, b"ping".to_vec()));
1166
1167        let config = SocketConfig {
1168            url: format!("127.0.0.1:{port}"),
1169            mode: Mode::Plain,
1170            suffix: b"\r\n".to_vec(),
1171            message_handler: None,
1172            heartbeat,
1173            reconnect_timeout_ms: None,
1174            reconnect_delay_initial_ms: None,
1175            reconnect_backoff_factor: None,
1176            reconnect_delay_max_ms: None,
1177            reconnect_jitter_ms: None,
1178            certs_dir: None,
1179        };
1180
1181        let client = SocketClient::connect(config, None, None, None)
1182            .await
1183            .unwrap();
1184
1185        // Wait ~3 seconds to collect some heartbeats
1186        sleep(Duration::from_secs(3)).await;
1187
1188        {
1189            let lock = received.lock().await;
1190            let pings = lock
1191                .iter()
1192                .filter(|line| line == &&b"ping".to_vec())
1193                .count();
1194            assert!(
1195                pings >= 2,
1196                "Expected at least 2 heartbeat pings; got {pings}"
1197            );
1198        }
1199
1200        client.close().await;
1201        server_task.abort();
1202    }
1203
1204    #[tokio::test]
1205    async fn test_reconnect_success() {
1206        Python::initialize();
1207
1208        let (port, listener) = bind_test_server().await;
1209
1210        // Spawn a server task that:
1211        // 1. Accepts the first connection and then drops it after a short delay (simulate disconnect)
1212        // 2. Waits a bit and then accepts a new connection and runs the echo server
1213        let server_task = task::spawn(async move {
1214            // Accept first connection
1215            let (mut socket, _) = listener.accept().await.expect("First accept failed");
1216
1217            // Wait briefly and then force-close the connection
1218            sleep(Duration::from_millis(500)).await;
1219            let _ = socket.shutdown().await;
1220
1221            // Wait for the client's reconnect attempt
1222            sleep(Duration::from_millis(500)).await;
1223
1224            // Run the echo server on the new connection
1225            let (socket, _) = listener.accept().await.expect("Second accept failed");
1226            run_echo_server(socket).await;
1227        });
1228
1229        let config = SocketConfig {
1230            url: format!("127.0.0.1:{port}"),
1231            mode: Mode::Plain,
1232            suffix: b"\r\n".to_vec(),
1233            message_handler: None,
1234            heartbeat: None,
1235            reconnect_timeout_ms: Some(5_000),
1236            reconnect_delay_initial_ms: Some(500),
1237            reconnect_delay_max_ms: Some(5_000),
1238            reconnect_backoff_factor: Some(2.0),
1239            reconnect_jitter_ms: Some(50),
1240            certs_dir: None,
1241        };
1242
1243        let client = SocketClient::connect(config, None, None, None)
1244            .await
1245            .expect("Client connect failed unexpectedly");
1246
1247        // Initially, the client should be active
1248        assert!(client.is_active(), "Client should start as active");
1249
1250        // Wait until the client loses connection (i.e. not active),
1251        // then wait until it reconnects (active again).
1252        wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1253
1254        client
1255            .send_bytes(b"TestReconnect".into())
1256            .await
1257            .expect("Send failed");
1258
1259        client.close().await;
1260        server_task.abort();
1261    }
1262}
1263
1264#[cfg(test)]
1265#[cfg(not(feature = "turmoil"))]
1266mod rust_tests {
1267    use rstest::rstest;
1268    use tokio::{
1269        io::{AsyncReadExt, AsyncWriteExt},
1270        net::TcpListener,
1271        task,
1272        time::{Duration, sleep},
1273    };
1274
1275    use super::*;
1276
1277    #[rstest]
1278    #[tokio::test]
1279    async fn test_reconnect_then_close() {
1280        // Bind an ephemeral port
1281        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1282        let port = listener.local_addr().unwrap().port();
1283
1284        // Server task: accept one connection and then drop it
1285        let server = task::spawn(async move {
1286            if let Ok((mut sock, _)) = listener.accept().await {
1287                drop(sock.shutdown());
1288            }
1289            // Keep listener alive briefly to avoid premature exit
1290            sleep(Duration::from_secs(1)).await;
1291        });
1292
1293        // Configure client with a short reconnect backoff
1294        let config = SocketConfig {
1295            url: format!("127.0.0.1:{port}"),
1296            mode: Mode::Plain,
1297            suffix: b"\r\n".to_vec(),
1298            message_handler: None,
1299            heartbeat: None,
1300            reconnect_timeout_ms: Some(1_000),
1301            reconnect_delay_initial_ms: Some(50),
1302            reconnect_delay_max_ms: Some(100),
1303            reconnect_backoff_factor: Some(1.0),
1304            reconnect_jitter_ms: Some(0),
1305            certs_dir: None,
1306        };
1307
1308        // Connect client (handler=None)
1309        let client = SocketClient::connect(config.clone(), None, None, None)
1310            .await
1311            .unwrap();
1312
1313        // Allow server to drop connection and client to notice
1314        sleep(Duration::from_millis(100)).await;
1315
1316        // Now close the client
1317        client.close().await;
1318        assert!(client.is_closed());
1319        server.abort();
1320    }
1321
1322    #[rstest]
1323    #[tokio::test]
1324    async fn test_reconnect_state_flips_when_reader_stops() {
1325        // Bind an ephemeral port and accept a single connection which we immediately close.
1326        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1327        let port = listener.local_addr().unwrap().port();
1328
1329        let server = task::spawn(async move {
1330            if let Ok((sock, _)) = listener.accept().await {
1331                drop(sock);
1332            }
1333            // Give the client a moment to observe the closed connection.
1334            sleep(Duration::from_millis(50)).await;
1335        });
1336
1337        let config = SocketConfig {
1338            url: format!("127.0.0.1:{port}"),
1339            mode: Mode::Plain,
1340            suffix: b"\r\n".to_vec(),
1341            message_handler: None,
1342            heartbeat: None,
1343            reconnect_timeout_ms: Some(1_000),
1344            reconnect_delay_initial_ms: Some(50),
1345            reconnect_delay_max_ms: Some(100),
1346            reconnect_backoff_factor: Some(1.0),
1347            reconnect_jitter_ms: Some(0),
1348            certs_dir: None,
1349        };
1350
1351        let client = SocketClient::connect(config, None, None, None)
1352            .await
1353            .unwrap();
1354
1355        tokio::time::timeout(Duration::from_secs(2), async {
1356            loop {
1357                if client.is_reconnecting() {
1358                    break;
1359                }
1360                tokio::time::sleep(Duration::from_millis(10)).await;
1361            }
1362        })
1363        .await
1364        .expect("client did not enter RECONNECT state");
1365
1366        client.close().await;
1367        server.abort();
1368    }
1369
1370    #[rstest]
1371    fn test_parse_socket_url_raw_address() {
1372        // Raw socket address with TLS mode
1373        let (socket_addr, request_url) =
1374            SocketClientInner::parse_socket_url("example.com:6130", Mode::Tls).unwrap();
1375        assert_eq!(socket_addr, "example.com:6130");
1376        assert_eq!(request_url, "wss://example.com:6130");
1377
1378        // Raw socket address with Plain mode
1379        let (socket_addr, request_url) =
1380            SocketClientInner::parse_socket_url("localhost:8080", Mode::Plain).unwrap();
1381        assert_eq!(socket_addr, "localhost:8080");
1382        assert_eq!(request_url, "ws://localhost:8080");
1383    }
1384
1385    #[rstest]
1386    fn test_parse_socket_url_with_scheme() {
1387        // Full URL with wss scheme
1388        let (socket_addr, request_url) =
1389            SocketClientInner::parse_socket_url("wss://example.com:443/path", Mode::Tls).unwrap();
1390        assert_eq!(socket_addr, "example.com:443");
1391        assert_eq!(request_url, "wss://example.com:443/path");
1392
1393        // Full URL with ws scheme
1394        let (socket_addr, request_url) =
1395            SocketClientInner::parse_socket_url("ws://localhost:8080", Mode::Plain).unwrap();
1396        assert_eq!(socket_addr, "localhost:8080");
1397        assert_eq!(request_url, "ws://localhost:8080");
1398    }
1399
1400    #[rstest]
1401    fn test_parse_socket_url_default_ports() {
1402        // wss without explicit port defaults to 443
1403        let (socket_addr, _) =
1404            SocketClientInner::parse_socket_url("wss://example.com", Mode::Tls).unwrap();
1405        assert_eq!(socket_addr, "example.com:443");
1406
1407        // ws without explicit port defaults to 80
1408        let (socket_addr, _) =
1409            SocketClientInner::parse_socket_url("ws://example.com", Mode::Plain).unwrap();
1410        assert_eq!(socket_addr, "example.com:80");
1411
1412        // https defaults to 443
1413        let (socket_addr, _) =
1414            SocketClientInner::parse_socket_url("https://example.com", Mode::Tls).unwrap();
1415        assert_eq!(socket_addr, "example.com:443");
1416
1417        // http defaults to 80
1418        let (socket_addr, _) =
1419            SocketClientInner::parse_socket_url("http://example.com", Mode::Plain).unwrap();
1420        assert_eq!(socket_addr, "example.com:80");
1421    }
1422
1423    #[rstest]
1424    fn test_parse_socket_url_unknown_scheme_uses_mode() {
1425        // Unknown scheme defaults to mode-based port
1426        let (socket_addr, _) =
1427            SocketClientInner::parse_socket_url("custom://example.com", Mode::Tls).unwrap();
1428        assert_eq!(socket_addr, "example.com:443");
1429
1430        let (socket_addr, _) =
1431            SocketClientInner::parse_socket_url("custom://example.com", Mode::Plain).unwrap();
1432        assert_eq!(socket_addr, "example.com:80");
1433    }
1434
1435    #[rstest]
1436    fn test_parse_socket_url_ipv6() {
1437        // IPv6 address with port
1438        let (socket_addr, request_url) =
1439            SocketClientInner::parse_socket_url("[::1]:8080", Mode::Plain).unwrap();
1440        assert_eq!(socket_addr, "[::1]:8080");
1441        assert_eq!(request_url, "ws://[::1]:8080");
1442
1443        // IPv6 in URL
1444        let (socket_addr, _) =
1445            SocketClientInner::parse_socket_url("ws://[::1]:8080", Mode::Plain).unwrap();
1446        assert_eq!(socket_addr, "[::1]:8080");
1447    }
1448
1449    #[rstest]
1450    #[tokio::test]
1451    async fn test_url_parsing_raw_socket_address() {
1452        // Test that raw socket addresses (host:port) work correctly
1453        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1454        let port = listener.local_addr().unwrap().port();
1455
1456        let server = task::spawn(async move {
1457            if let Ok((sock, _)) = listener.accept().await {
1458                drop(sock);
1459            }
1460            sleep(Duration::from_millis(50)).await;
1461        });
1462
1463        let config = SocketConfig {
1464            url: format!("127.0.0.1:{port}"), // Raw socket address format
1465            mode: Mode::Plain,
1466            suffix: b"\r\n".to_vec(),
1467            message_handler: None,
1468            heartbeat: None,
1469            reconnect_timeout_ms: Some(1_000),
1470            reconnect_delay_initial_ms: Some(50),
1471            reconnect_delay_max_ms: Some(100),
1472            reconnect_backoff_factor: Some(1.0),
1473            reconnect_jitter_ms: Some(0),
1474            certs_dir: None,
1475        };
1476
1477        // Should successfully connect with raw socket address
1478        let client = SocketClient::connect(config, None, None, None).await;
1479        assert!(
1480            client.is_ok(),
1481            "Client should connect with raw socket address format"
1482        );
1483
1484        if let Ok(client) = client {
1485            client.close().await;
1486        }
1487        server.abort();
1488    }
1489
1490    #[rstest]
1491    #[tokio::test]
1492    async fn test_url_parsing_with_scheme() {
1493        // Test that URLs with schemes also work
1494        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1495        let port = listener.local_addr().unwrap().port();
1496
1497        let server = task::spawn(async move {
1498            if let Ok((sock, _)) = listener.accept().await {
1499                drop(sock);
1500            }
1501            sleep(Duration::from_millis(50)).await;
1502        });
1503
1504        let config = SocketConfig {
1505            url: format!("ws://127.0.0.1:{port}"), // URL with scheme
1506            mode: Mode::Plain,
1507            suffix: b"\r\n".to_vec(),
1508            message_handler: None,
1509            heartbeat: None,
1510            reconnect_timeout_ms: Some(1_000),
1511            reconnect_delay_initial_ms: Some(50),
1512            reconnect_delay_max_ms: Some(100),
1513            reconnect_backoff_factor: Some(1.0),
1514            reconnect_jitter_ms: Some(0),
1515            certs_dir: None,
1516        };
1517
1518        // Should successfully connect with URL format
1519        let client = SocketClient::connect(config, None, None, None).await;
1520        assert!(
1521            client.is_ok(),
1522            "Client should connect with URL scheme format"
1523        );
1524
1525        if let Ok(client) = client {
1526            client.close().await;
1527        }
1528        server.abort();
1529    }
1530
1531    #[rstest]
1532    fn test_parse_socket_url_ipv6_with_zone() {
1533        // IPv6 with zone ID (link-local address)
1534        let (socket_addr, request_url) =
1535            SocketClientInner::parse_socket_url("[fe80::1%eth0]:8080", Mode::Plain).unwrap();
1536        assert_eq!(socket_addr, "[fe80::1%eth0]:8080");
1537        assert_eq!(request_url, "ws://[fe80::1%eth0]:8080");
1538
1539        // Verify zone is preserved in URL format too
1540        let (socket_addr, request_url) =
1541            SocketClientInner::parse_socket_url("ws://[fe80::1%lo]:9090", Mode::Plain).unwrap();
1542        assert_eq!(socket_addr, "[fe80::1%lo]:9090");
1543        assert_eq!(request_url, "ws://[fe80::1%lo]:9090");
1544    }
1545
1546    #[rstest]
1547    #[tokio::test]
1548    async fn test_ipv6_loopback_connection() {
1549        // Test IPv6 loopback address connection
1550        // Skip if IPv6 is not available on the system
1551        if TcpListener::bind("[::1]:0").await.is_err() {
1552            eprintln!("IPv6 not available, skipping test");
1553            return;
1554        }
1555
1556        let listener = TcpListener::bind("[::1]:0").await.unwrap();
1557        let port = listener.local_addr().unwrap().port();
1558
1559        let server = task::spawn(async move {
1560            if let Ok((mut sock, _)) = listener.accept().await {
1561                let mut buf = vec![0u8; 1024];
1562                if let Ok(n) = sock.read(&mut buf).await {
1563                    // Echo back
1564                    let _ = sock.write_all(&buf[..n]).await;
1565                }
1566            }
1567            sleep(Duration::from_millis(50)).await;
1568        });
1569
1570        let config = SocketConfig {
1571            url: format!("[::1]:{port}"), // IPv6 loopback
1572            mode: Mode::Plain,
1573            suffix: b"\r\n".to_vec(),
1574            message_handler: None,
1575            heartbeat: None,
1576            reconnect_timeout_ms: Some(1_000),
1577            reconnect_delay_initial_ms: Some(50),
1578            reconnect_delay_max_ms: Some(100),
1579            reconnect_backoff_factor: Some(1.0),
1580            reconnect_jitter_ms: Some(0),
1581            certs_dir: None,
1582        };
1583
1584        let client = SocketClient::connect(config, None, None, None).await;
1585        assert!(
1586            client.is_ok(),
1587            "Client should connect to IPv6 loopback address"
1588        );
1589
1590        if let Ok(client) = client {
1591            client.close().await;
1592        }
1593        server.abort();
1594    }
1595
1596    #[rstest]
1597    #[tokio::test]
1598    async fn test_send_waits_during_reconnection() {
1599        // Test that send operations wait for reconnection to complete (up to configured timeout)
1600        use nautilus_common::testing::wait_until_async;
1601
1602        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1603        let port = listener.local_addr().unwrap().port();
1604
1605        let server = task::spawn(async move {
1606            // First connection - accept and immediately close
1607            if let Ok((sock, _)) = listener.accept().await {
1608                drop(sock);
1609            }
1610
1611            // Wait before accepting second connection
1612            sleep(Duration::from_millis(500)).await;
1613
1614            // Second connection - accept and keep alive
1615            if let Ok((mut sock, _)) = listener.accept().await {
1616                // Echo messages
1617                let mut buf = vec![0u8; 1024];
1618                while let Ok(n) = sock.read(&mut buf).await {
1619                    if n == 0 {
1620                        break;
1621                    }
1622                    if sock.write_all(&buf[..n]).await.is_err() {
1623                        break;
1624                    }
1625                }
1626            }
1627        });
1628
1629        let config = SocketConfig {
1630            url: format!("127.0.0.1:{port}"),
1631            mode: Mode::Plain,
1632            suffix: b"\r\n".to_vec(),
1633            message_handler: None,
1634            heartbeat: None,
1635            reconnect_timeout_ms: Some(5_000), // 5s timeout - enough for reconnect
1636            reconnect_delay_initial_ms: Some(100),
1637            reconnect_delay_max_ms: Some(200),
1638            reconnect_backoff_factor: Some(1.0),
1639            reconnect_jitter_ms: Some(0),
1640            certs_dir: None,
1641        };
1642
1643        let client = SocketClient::connect(config, None, None, None)
1644            .await
1645            .unwrap();
1646
1647        // Wait for reconnection to trigger
1648        wait_until_async(
1649            || async { client.is_reconnecting() },
1650            Duration::from_secs(2),
1651        )
1652        .await;
1653
1654        // Try to send while reconnecting - should wait and succeed after reconnect
1655        let send_result = tokio::time::timeout(
1656            Duration::from_secs(3),
1657            client.send_bytes(b"test_message".to_vec()),
1658        )
1659        .await;
1660
1661        assert!(
1662            send_result.is_ok() && send_result.unwrap().is_ok(),
1663            "Send should succeed after waiting for reconnection"
1664        );
1665
1666        client.close().await;
1667        server.abort();
1668    }
1669
1670    #[rstest]
1671    #[tokio::test]
1672    async fn test_send_bytes_timeout_uses_configured_reconnect_timeout() {
1673        // Test that send_bytes operations respect the configured reconnect_timeout.
1674        // When a client is stuck in RECONNECT longer than the timeout, sends should fail with Timeout.
1675        use nautilus_common::testing::wait_until_async;
1676
1677        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1678        let port = listener.local_addr().unwrap().port();
1679
1680        let server = task::spawn(async move {
1681            // Accept first connection and immediately close it
1682            if let Ok((sock, _)) = listener.accept().await {
1683                drop(sock);
1684            }
1685            // Drop listener entirely so reconnection fails completely
1686            drop(listener);
1687            sleep(Duration::from_secs(60)).await;
1688        });
1689
1690        let config = SocketConfig {
1691            url: format!("127.0.0.1:{port}"),
1692            mode: Mode::Plain,
1693            suffix: b"\r\n".to_vec(),
1694            message_handler: None,
1695            heartbeat: None,
1696            reconnect_timeout_ms: Some(1_000), // 1s timeout for faster test
1697            reconnect_delay_initial_ms: Some(5_000), // Long backoff to keep client in RECONNECT
1698            reconnect_delay_max_ms: Some(5_000),
1699            reconnect_backoff_factor: Some(1.0),
1700            reconnect_jitter_ms: Some(0),
1701            certs_dir: None,
1702        };
1703
1704        let client = SocketClient::connect(config, None, None, None)
1705            .await
1706            .unwrap();
1707
1708        // Wait for client to enter RECONNECT state
1709        wait_until_async(
1710            || async { client.is_reconnecting() },
1711            Duration::from_secs(3),
1712        )
1713        .await;
1714
1715        // Attempt send while stuck in RECONNECT - should timeout after 1s (configured timeout)
1716        // The client will try to reconnect for 1s, fail, then wait 5s backoff before next attempt
1717        let start = std::time::Instant::now();
1718        let send_result = client.send_bytes(b"test".to_vec()).await;
1719        let elapsed = start.elapsed();
1720
1721        assert!(
1722            send_result.is_err(),
1723            "Send should fail when client stuck in RECONNECT, got: {:?}",
1724            send_result
1725        );
1726        assert!(
1727            matches!(send_result, Err(crate::error::SendError::Timeout)),
1728            "Send should return Timeout error, got: {:?}",
1729            send_result
1730        );
1731        // Verify timeout respects configured value (1s), but don't check upper bound
1732        // as CI scheduler jitter can cause legitimate delays beyond the timeout
1733        assert!(
1734            elapsed >= Duration::from_millis(900),
1735            "Send should timeout after at least 1s (configured timeout), took {:?}",
1736            elapsed
1737        );
1738
1739        client.close().await;
1740        server.abort();
1741    }
1742}