nautilus_network/
socket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance raw TCP client implementation with TLS capability, automatic reconnection
17//! with exponential backoff and state management.
18
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED).
21//! - Synchronized reconnection with backoff.
22//! - Split read/write architecture.
23//! - Python callback integration.
24//!
25//! **Design**:
26//! - Single reader, multiple writer model.
27//! - Read half runs in dedicated task.
28//! - Write half runs in dedicated task connected with channel.
29//! - Controller task manages lifecycle.
30
31use std::{
32    fmt::Debug,
33    path::Path,
34    sync::{
35        Arc,
36        atomic::{AtomicU8, Ordering},
37    },
38    time::Duration,
39};
40
41use bytes::Bytes;
42use nautilus_core::CleanDrop;
43use nautilus_cryptography::providers::install_cryptographic_provider;
44use tokio::{
45    io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
46    net::TcpStream,
47};
48use tokio_tungstenite::{
49    MaybeTlsStream,
50    tungstenite::{Error, client::IntoClientRequest, stream::Mode},
51};
52
53use crate::{
54    backoff::ExponentialBackoff,
55    error::SendError,
56    fix::process_fix_buffer,
57    logging::{log_task_aborted, log_task_started, log_task_stopped},
58    mode::ConnectionMode,
59    tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
60};
61
62// Connection timing constants
63const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
64const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
65const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
66const SEND_OPERATION_CHECK_INTERVAL_MS: u64 = 1;
67
68type TcpWriter = WriteHalf<MaybeTlsStream<TcpStream>>;
69type TcpReader = ReadHalf<MaybeTlsStream<TcpStream>>;
70pub type TcpMessageHandler = Arc<dyn Fn(&[u8]) + Send + Sync>;
71
72/// Configuration for TCP socket connection.
73#[cfg_attr(
74    feature = "python",
75    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
76)]
77pub struct SocketConfig {
78    /// The URL to connect to.
79    pub url: String,
80    /// The connection mode {Plain, TLS}.
81    pub mode: Mode,
82    /// The sequence of bytes which separates lines.
83    pub suffix: Vec<u8>,
84    /// The optional function to handle incoming messages.
85    pub message_handler: Option<TcpMessageHandler>,
86    /// The optional heartbeat with period and beat message.
87    pub heartbeat: Option<(u64, Vec<u8>)>,
88    /// The timeout (milliseconds) for reconnection attempts.
89    pub reconnect_timeout_ms: Option<u64>,
90    /// The initial reconnection delay (milliseconds) for reconnects.
91    pub reconnect_delay_initial_ms: Option<u64>,
92    /// The maximum reconnect delay (milliseconds) for exponential backoff.
93    pub reconnect_delay_max_ms: Option<u64>,
94    /// The exponential backoff factor for reconnection delays.
95    pub reconnect_backoff_factor: Option<f64>,
96    /// The maximum jitter (milliseconds) added to reconnection delays.
97    pub reconnect_jitter_ms: Option<u64>,
98    /// The path to the certificates directory.
99    pub certs_dir: Option<String>,
100}
101
102impl Debug for SocketConfig {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        f.debug_struct(stringify!(SocketConfig))
105            .field("url", &self.url)
106            .field("mode", &self.mode)
107            .field("suffix", &self.suffix)
108            .field(
109                "message_handler",
110                &self.message_handler.as_ref().map(|_| "<function>"),
111            )
112            .field("heartbeat", &self.heartbeat)
113            .field("reconnect_timeout_ms", &self.reconnect_timeout_ms)
114            .field(
115                "reconnect_delay_initial_ms",
116                &self.reconnect_delay_initial_ms,
117            )
118            .field("reconnect_delay_max_ms", &self.reconnect_delay_max_ms)
119            .field("reconnect_backoff_factor", &self.reconnect_backoff_factor)
120            .field("reconnect_jitter_ms", &self.reconnect_jitter_ms)
121            .field("certs_dir", &self.certs_dir)
122            .finish()
123    }
124}
125
126impl Clone for SocketConfig {
127    fn clone(&self) -> Self {
128        Self {
129            url: self.url.clone(),
130            mode: self.mode,
131            suffix: self.suffix.clone(),
132            message_handler: self.message_handler.clone(),
133            heartbeat: self.heartbeat.clone(),
134            reconnect_timeout_ms: self.reconnect_timeout_ms,
135            reconnect_delay_initial_ms: self.reconnect_delay_initial_ms,
136            reconnect_delay_max_ms: self.reconnect_delay_max_ms,
137            reconnect_backoff_factor: self.reconnect_backoff_factor,
138            reconnect_jitter_ms: self.reconnect_jitter_ms,
139            certs_dir: self.certs_dir.clone(),
140        }
141    }
142}
143
144/// Represents a command for the writer task.
145#[derive(Debug)]
146pub enum WriterCommand {
147    /// Update the writer reference with a new one after reconnection.
148    Update(TcpWriter),
149    /// Send data to the server.
150    Send(Bytes),
151}
152
153/// Creates a `TcpStream` with the server.
154///
155/// The stream can be encrypted with TLS or Plain. The stream is split into
156/// read and write ends:
157/// - The read end is passed to the task that keeps receiving
158///   messages from the server and passing them to a handler.
159/// - The write end is passed to a task which receives messages over a channel
160///   to send to the server.
161///
162/// The heartbeat is optional and can be configured with an interval and data to
163/// send.
164///
165/// The client uses a suffix to separate messages on the byte stream. It is
166/// appended to all sent messages and heartbeats. It is also used to split
167/// the received byte stream.
168#[cfg_attr(
169    feature = "python",
170    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
171)]
172struct SocketClientInner {
173    config: SocketConfig,
174    connector: Option<Connector>,
175    read_task: Arc<tokio::task::JoinHandle<()>>,
176    write_task: tokio::task::JoinHandle<()>,
177    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
178    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
179    connection_mode: Arc<AtomicU8>,
180    reconnect_timeout: Duration,
181    backoff: ExponentialBackoff,
182    handler: Option<TcpMessageHandler>,
183}
184
185impl SocketClientInner {
186    /// Connect to a URL with the specified configuration.
187    ///
188    /// # Errors
189    ///
190    /// Returns an error if connection fails or configuration is invalid.
191    pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
192        install_cryptographic_provider();
193
194        let SocketConfig {
195            url,
196            mode,
197            heartbeat,
198            suffix,
199            message_handler,
200            reconnect_timeout_ms,
201            reconnect_delay_initial_ms,
202            reconnect_delay_max_ms,
203            reconnect_backoff_factor,
204            reconnect_jitter_ms,
205            certs_dir,
206        } = &config.clone();
207        let connector = if let Some(dir) = certs_dir {
208            let config = create_tls_config_from_certs_dir(Path::new(dir), false)?;
209            Some(Connector::Rustls(Arc::new(config)))
210        } else {
211            None
212        };
213
214        let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector.clone()).await?;
215        tracing::debug!("Connected");
216
217        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
218
219        let read_task = Arc::new(Self::spawn_read_task(
220            connection_mode.clone(),
221            reader,
222            message_handler.clone(),
223            suffix.clone(),
224        ));
225
226        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
227
228        let write_task =
229            Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
230
231        // Optionally spawn a heartbeat task to periodically ping server
232        let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
233            Self::spawn_heartbeat_task(
234                connection_mode.clone(),
235                heartbeat.clone(),
236                writer_tx.clone(),
237            )
238        });
239
240        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
241        let backoff = ExponentialBackoff::new(
242            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
243            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
244            reconnect_backoff_factor.unwrap_or(1.5),
245            reconnect_jitter_ms.unwrap_or(100),
246            true, // immediate-first
247        )?;
248
249        Ok(Self {
250            config,
251            connector,
252            read_task,
253            write_task,
254            writer_tx,
255            heartbeat_task,
256            connection_mode,
257            reconnect_timeout,
258            backoff,
259            handler: message_handler.clone(),
260        })
261    }
262
263    /// Parse URL and extract socket address and request URL.
264    ///
265    /// Accepts either:
266    /// - Raw socket address: "host:port" → returns ("host:port", "scheme://host:port")
267    /// - Full URL: "scheme://host:port/path" → returns ("host:port", original URL)
268    ///
269    /// # Errors
270    ///
271    /// Returns an error if the URL is invalid or missing required components.
272    fn parse_socket_url(url: &str, mode: Mode) -> Result<(String, String), Error> {
273        if url.contains("://") {
274            // URL with scheme (e.g., "wss://host:port/path")
275            let parsed = url.parse::<http::Uri>().map_err(|e| {
276                Error::Io(std::io::Error::new(
277                    std::io::ErrorKind::InvalidInput,
278                    format!("Invalid URL: {e}"),
279                ))
280            })?;
281
282            let host = parsed.host().ok_or_else(|| {
283                Error::Io(std::io::Error::new(
284                    std::io::ErrorKind::InvalidInput,
285                    "URL missing host",
286                ))
287            })?;
288
289            let port = parsed
290                .port_u16()
291                .unwrap_or_else(|| match parsed.scheme_str() {
292                    Some("wss") | Some("https") => 443,
293                    Some("ws") | Some("http") => 80,
294                    _ => match mode {
295                        Mode::Tls => 443,
296                        Mode::Plain => 80,
297                    },
298                });
299
300            Ok((format!("{host}:{port}"), url.to_string()))
301        } else {
302            // Raw socket address (e.g., "host:port")
303            // Construct a proper URL for the request based on mode
304            let scheme = match mode {
305                Mode::Tls => "wss",
306                Mode::Plain => "ws",
307            };
308            Ok((url.to_string(), format!("{scheme}://{url}")))
309        }
310    }
311
312    /// Establish a TLS or plain TCP connection with the server.
313    ///
314    /// Accepts either a raw socket address (e.g., "host:port") or a full URL with scheme
315    /// (e.g., "wss://host:port"). For FIX/raw socket connections, use the host:port format.
316    /// For WebSocket-style connections, include the scheme.
317    ///
318    /// # Errors
319    ///
320    /// Returns an error if the connection cannot be established.
321    pub async fn tls_connect_with_server(
322        url: &str,
323        mode: Mode,
324        connector: Option<Connector>,
325    ) -> Result<(TcpReader, TcpWriter), Error> {
326        tracing::debug!("Connecting to {url}");
327
328        let (socket_addr, request_url) = Self::parse_socket_url(url, mode)?;
329        let tcp_result = TcpStream::connect(&socket_addr).await;
330
331        match tcp_result {
332            Ok(stream) => {
333                tracing::debug!("TCP connection established to {socket_addr}, proceeding with TLS");
334                let request = request_url.into_client_request()?;
335                tcp_tls(&request, mode, stream, connector)
336                    .await
337                    .map(tokio::io::split)
338            }
339            Err(e) => {
340                tracing::error!("TCP connection failed to {socket_addr}: {e:?}");
341                Err(Error::Io(e))
342            }
343        }
344    }
345
346    /// Reconnect with server.
347    ///
348    /// Makes a new connection with server, uses the new read and write halves
349    /// to update the reader and writer.
350    async fn reconnect(&mut self) -> Result<(), Error> {
351        tracing::debug!("Reconnecting");
352
353        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
354            tracing::debug!("Reconnect aborted due to disconnect state");
355            return Ok(());
356        }
357
358        tokio::time::timeout(self.reconnect_timeout, async {
359            let SocketConfig {
360                url,
361                mode,
362                heartbeat: _,
363                suffix,
364                message_handler: _,
365                reconnect_timeout_ms: _,
366                reconnect_delay_initial_ms: _,
367                reconnect_backoff_factor: _,
368                reconnect_delay_max_ms: _,
369                reconnect_jitter_ms: _,
370                certs_dir: _,
371            } = &self.config;
372            // Create a fresh connection
373            let connector = self.connector.clone();
374            // Attempt to connect; abort early if a disconnect was requested
375            let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
376
377            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
378                tracing::debug!("Reconnect aborted mid-flight (after connect)");
379                return Ok(());
380            }
381            tracing::debug!("Connected");
382
383            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
384                tracing::error!("{e}");
385            }
386
387            // Delay before closing connection
388            tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
389
390            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
391                tracing::debug!("Reconnect aborted mid-flight (after delay)");
392                return Ok(());
393            }
394
395            if !self.read_task.is_finished() {
396                self.read_task.abort();
397                log_task_aborted("read");
398            }
399
400            // Atomically transition from Reconnect to Active
401            // This prevents race condition where disconnect could be requested between check and store
402            if self
403                .connection_mode
404                .compare_exchange(
405                    ConnectionMode::Reconnect.as_u8(),
406                    ConnectionMode::Active.as_u8(),
407                    Ordering::SeqCst,
408                    Ordering::SeqCst,
409                )
410                .is_err()
411            {
412                tracing::debug!("Reconnect aborted (state changed during reconnect)");
413                return Ok(());
414            }
415
416            // Spawn new read task
417            self.read_task = Arc::new(Self::spawn_read_task(
418                self.connection_mode.clone(),
419                reader,
420                self.handler.clone(),
421                suffix.clone(),
422            ));
423
424            tracing::debug!("Reconnect succeeded");
425            Ok(())
426        })
427        .await
428        .map_err(|_| {
429            Error::Io(std::io::Error::new(
430                std::io::ErrorKind::TimedOut,
431                format!(
432                    "reconnection timed out after {}s",
433                    self.reconnect_timeout.as_secs_f64()
434                ),
435            ))
436        })?
437    }
438
439    /// Check if the client is still alive.
440    ///
441    /// The client is connected if the read task has not finished. It is expected
442    /// that in case of any failure client or server side. The read task will be
443    /// shutdown. There might be some delay between the connection being closed
444    /// and the client detecting it.
445    #[inline]
446    #[must_use]
447    pub fn is_alive(&self) -> bool {
448        !self.read_task.is_finished()
449    }
450
451    #[must_use]
452    fn spawn_read_task(
453        connection_state: Arc<AtomicU8>,
454        mut reader: TcpReader,
455        handler: Option<TcpMessageHandler>,
456        suffix: Vec<u8>,
457    ) -> tokio::task::JoinHandle<()> {
458        log_task_started("read");
459
460        // Interval between checking the connection mode
461        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
462
463        tokio::task::spawn(async move {
464            let mut buf = Vec::new();
465
466            loop {
467                if !ConnectionMode::from_atomic(&connection_state).is_active() {
468                    break;
469                }
470
471                match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
472                    // Connection has been terminated or vector buffer is complete
473                    Ok(Ok(0)) => {
474                        tracing::debug!("Connection closed by server");
475                        break;
476                    }
477                    Ok(Err(e)) => {
478                        tracing::debug!("Connection ended: {e}");
479                        break;
480                    }
481                    // Received bytes of data
482                    Ok(Ok(bytes)) => {
483                        tracing::trace!("Received <binary> {bytes} bytes");
484
485                        // Check if buffer contains FIX protocol messages (starts with "8=FIX")
486                        let is_fix = buf.len() >= 5 && buf.starts_with(b"8=FIX");
487
488                        if is_fix && handler.is_some() {
489                            // FIX protocol processing
490                            if let Some(ref handler) = handler {
491                                process_fix_buffer(&mut buf, handler);
492                            }
493                        } else {
494                            // Regular suffix-based message processing
495                            while let Some((i, _)) = &buf
496                                .windows(suffix.len())
497                                .enumerate()
498                                .find(|(_, pair)| pair.eq(&suffix))
499                            {
500                                let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
501                                data.truncate(data.len() - suffix.len());
502
503                                if let Some(ref handler) = handler {
504                                    handler(&data);
505                                }
506                            }
507                        }
508                    }
509                    Err(_) => {
510                        // Timeout - continue loop and check connection mode
511                        continue;
512                    }
513                }
514            }
515
516            log_task_stopped("read");
517        })
518    }
519
520    fn spawn_write_task(
521        connection_state: Arc<AtomicU8>,
522        writer: TcpWriter,
523        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
524        suffix: Vec<u8>,
525    ) -> tokio::task::JoinHandle<()> {
526        log_task_started("write");
527
528        // Interval between checking the connection mode
529        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
530
531        tokio::task::spawn(async move {
532            let mut active_writer = writer;
533
534            loop {
535                if matches!(
536                    ConnectionMode::from_atomic(&connection_state),
537                    ConnectionMode::Disconnect | ConnectionMode::Closed
538                ) {
539                    break;
540                }
541
542                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
543                    Ok(Some(msg)) => {
544                        // Re-check connection mode after receiving a message
545                        let mode = ConnectionMode::from_atomic(&connection_state);
546                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
547                            break;
548                        }
549
550                        match msg {
551                            WriterCommand::Update(new_writer) => {
552                                tracing::debug!("Received new writer");
553
554                                // Delay before closing connection
555                                tokio::time::sleep(Duration::from_millis(100)).await;
556
557                                // Attempt to shutdown the writer gracefully before updating,
558                                // we ignore any error as the writer may already be closed.
559                                _ = tokio::time::timeout(
560                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
561                                    active_writer.shutdown(),
562                                )
563                                .await;
564
565                                active_writer = new_writer;
566                                tracing::debug!("Updated writer");
567                            }
568                            _ if mode.is_reconnect() => {
569                                tracing::warn!("Skipping message while reconnecting, {msg:?}");
570                                continue;
571                            }
572                            WriterCommand::Send(msg) => {
573                                if let Err(e) = active_writer.write_all(&msg).await {
574                                    tracing::error!("Failed to send message: {e}");
575                                    // Mode is active so trigger reconnection
576                                    tracing::warn!("Writer triggering reconnect");
577                                    connection_state
578                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
579                                    continue;
580                                }
581                                if let Err(e) = active_writer.write_all(&suffix).await {
582                                    tracing::error!("Failed to send suffix: {e}");
583                                    // Mode is active so trigger reconnection
584                                    tracing::warn!("Writer triggering reconnect");
585                                    connection_state
586                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
587                                    continue;
588                                }
589                            }
590                        }
591                    }
592                    Ok(None) => {
593                        // Channel closed - writer task should terminate
594                        tracing::debug!("Writer channel closed, terminating writer task");
595                        break;
596                    }
597                    Err(_) => {
598                        // Timeout - just continue the loop
599                        continue;
600                    }
601                }
602            }
603
604            // Attempt to shutdown the writer gracefully before exiting,
605            // we ignore any error as the writer may already be closed.
606            _ = tokio::time::timeout(
607                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
608                active_writer.shutdown(),
609            )
610            .await;
611
612            log_task_stopped("write");
613        })
614    }
615
616    fn spawn_heartbeat_task(
617        connection_state: Arc<AtomicU8>,
618        heartbeat: (u64, Vec<u8>),
619        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
620    ) -> tokio::task::JoinHandle<()> {
621        log_task_started("heartbeat");
622        let (interval_secs, message) = heartbeat;
623
624        tokio::task::spawn(async move {
625            let interval = Duration::from_secs(interval_secs);
626
627            loop {
628                tokio::time::sleep(interval).await;
629
630                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
631                    ConnectionMode::Active => {
632                        let msg = WriterCommand::Send(message.clone().into());
633
634                        match writer_tx.send(msg) {
635                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
636                            Err(e) => {
637                                tracing::error!("Failed to send heartbeat to writer task: {e}");
638                            }
639                        }
640                    }
641                    ConnectionMode::Reconnect => continue,
642                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
643                }
644            }
645
646            log_task_stopped("heartbeat");
647        })
648    }
649}
650
651impl Drop for SocketClientInner {
652    fn drop(&mut self) {
653        // Delegate to explicit cleanup handler
654        self.clean_drop();
655    }
656}
657
658/// Cleanup on drop: aborts background tasks and clears handlers to break reference cycles.
659impl CleanDrop for SocketClientInner {
660    fn clean_drop(&mut self) {
661        if !self.read_task.is_finished() {
662            self.read_task.abort();
663            log_task_aborted("read");
664        }
665
666        if !self.write_task.is_finished() {
667            self.write_task.abort();
668            log_task_aborted("write");
669        }
670
671        if let Some(ref handle) = self.heartbeat_task.take()
672            && !handle.is_finished()
673        {
674            handle.abort();
675            log_task_aborted("heartbeat");
676        }
677
678        #[cfg(feature = "python")]
679        {
680            // Remove stored handler to break ref cycle
681            self.config.message_handler = None;
682        }
683    }
684}
685
686#[cfg_attr(
687    feature = "python",
688    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
689)]
690pub struct SocketClient {
691    pub(crate) controller_task: tokio::task::JoinHandle<()>,
692    pub(crate) connection_mode: Arc<AtomicU8>,
693    pub(crate) reconnect_timeout: Duration,
694    pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
695}
696
697impl Debug for SocketClient {
698    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
699        f.debug_struct(stringify!(SocketClient)).finish()
700    }
701}
702
703impl SocketClient {
704    /// Connect to the server.
705    ///
706    /// # Errors
707    ///
708    /// Returns any error connecting to the server.
709    pub async fn connect(
710        config: SocketConfig,
711        post_connection: Option<Arc<dyn Fn() + Send + Sync>>,
712        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
713        post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
714    ) -> anyhow::Result<Self> {
715        let inner = SocketClientInner::connect_url(config).await?;
716        let writer_tx = inner.writer_tx.clone();
717        let connection_mode = inner.connection_mode.clone();
718        let reconnect_timeout = inner.reconnect_timeout;
719
720        let controller_task = Self::spawn_controller_task(
721            inner,
722            connection_mode.clone(),
723            post_reconnection,
724            post_disconnection,
725        );
726
727        if let Some(handler) = post_connection {
728            handler();
729            tracing::debug!("Called `post_connection` handler");
730        }
731
732        Ok(Self {
733            controller_task,
734            connection_mode,
735            reconnect_timeout,
736            writer_tx,
737        })
738    }
739
740    /// Returns the current connection mode.
741    #[must_use]
742    pub fn connection_mode(&self) -> ConnectionMode {
743        ConnectionMode::from_atomic(&self.connection_mode)
744    }
745
746    /// Check if the client connection is active.
747    ///
748    /// Returns `true` if the client is connected and has not been signalled to disconnect.
749    /// The client will automatically retry connection based on its configuration.
750    #[inline]
751    #[must_use]
752    pub fn is_active(&self) -> bool {
753        self.connection_mode().is_active()
754    }
755
756    /// Check if the client is reconnecting.
757    ///
758    /// Returns `true` if the client lost connection and is attempting to reestablish it.
759    /// The client will automatically retry connection based on its configuration.
760    #[inline]
761    #[must_use]
762    pub fn is_reconnecting(&self) -> bool {
763        self.connection_mode().is_reconnect()
764    }
765
766    /// Check if the client is disconnecting.
767    ///
768    /// Returns `true` if the client is in disconnect mode.
769    #[inline]
770    #[must_use]
771    pub fn is_disconnecting(&self) -> bool {
772        self.connection_mode().is_disconnect()
773    }
774
775    /// Check if the client is closed.
776    ///
777    /// Returns `true` if the client has been explicitly disconnected or reached
778    /// maximum reconnection attempts. In this state, the client cannot be reused
779    /// and a new client must be created for further connections.
780    #[inline]
781    #[must_use]
782    pub fn is_closed(&self) -> bool {
783        self.connection_mode().is_closed()
784    }
785
786    /// Close the client.
787    ///
788    /// Controller task will periodically check the disconnect mode
789    /// and shutdown the client if it is not alive.
790    pub async fn close(&self) {
791        self.connection_mode
792            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
793
794        match tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
795            while !self.is_closed() {
796                tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
797            }
798
799            if !self.controller_task.is_finished() {
800                self.controller_task.abort();
801                log_task_aborted("controller");
802            }
803        })
804        .await
805        {
806            Ok(()) => {
807                log_task_stopped("controller");
808            }
809            Err(_) => {
810                tracing::error!("Timeout waiting for controller task to finish");
811                if !self.controller_task.is_finished() {
812                    self.controller_task.abort();
813                    log_task_aborted("controller");
814                }
815            }
816        }
817    }
818
819    /// Sends a message of the given `data`.
820    ///
821    /// # Errors
822    ///
823    /// Returns an error if sending fails.
824    pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
825        if self.is_closed() {
826            return Err(SendError::Closed);
827        }
828
829        let timeout = self.reconnect_timeout;
830        let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
831
832        if !self.is_active() {
833            tracing::debug!("Waiting for client to become ACTIVE before sending...");
834
835            let inner = tokio::time::timeout(timeout, async {
836                loop {
837                    if self.is_active() {
838                        return Ok(());
839                    }
840                    if matches!(
841                        self.connection_mode(),
842                        ConnectionMode::Disconnect | ConnectionMode::Closed
843                    ) {
844                        return Err(());
845                    }
846                    tokio::time::sleep(check_interval).await;
847                }
848            })
849            .await
850            .map_err(|_| SendError::Timeout)?;
851            inner.map_err(|()| SendError::Closed)?;
852        }
853
854        let msg = WriterCommand::Send(data.into());
855        self.writer_tx
856            .send(msg)
857            .map_err(|e| SendError::BrokenPipe(e.to_string()))
858    }
859
860    fn spawn_controller_task(
861        mut inner: SocketClientInner,
862        connection_mode: Arc<AtomicU8>,
863        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
864        post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
865    ) -> tokio::task::JoinHandle<()> {
866        tokio::task::spawn(async move {
867            log_task_started("controller");
868
869            let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
870
871            loop {
872                tokio::time::sleep(check_interval).await;
873                let mut mode = ConnectionMode::from_atomic(&connection_mode);
874
875                if mode.is_disconnect() {
876                    tracing::debug!("Disconnecting");
877
878                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
879                    if tokio::time::timeout(timeout, async {
880                        // Delay awaiting graceful shutdown
881                        tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
882
883                        if !inner.read_task.is_finished() {
884                            inner.read_task.abort();
885                            log_task_aborted("read");
886                        }
887
888                        if let Some(task) = &inner.heartbeat_task
889                            && !task.is_finished()
890                        {
891                            task.abort();
892                            log_task_aborted("heartbeat");
893                        }
894                    })
895                    .await
896                    .is_err()
897                    {
898                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
899                    }
900
901                    tracing::debug!("Closed");
902
903                    if let Some(ref handler) = post_disconnection {
904                        handler();
905                        tracing::debug!("Called `post_disconnection` handler");
906                    }
907                    break; // Controller finished
908                }
909
910                if mode.is_active() && !inner.is_alive() {
911                    if connection_mode
912                        .compare_exchange(
913                            ConnectionMode::Active.as_u8(),
914                            ConnectionMode::Reconnect.as_u8(),
915                            Ordering::SeqCst,
916                            Ordering::SeqCst,
917                        )
918                        .is_ok()
919                    {
920                        tracing::debug!("Detected dead read task, transitioning to RECONNECT");
921                    }
922                    mode = ConnectionMode::from_atomic(&connection_mode);
923                }
924
925                if mode.is_reconnect() {
926                    match inner.reconnect().await {
927                        Ok(()) => {
928                            tracing::debug!("Reconnected successfully");
929                            inner.backoff.reset();
930                            // Only invoke reconnect handler if still active
931                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
932                                if let Some(ref handler) = post_reconnection {
933                                    handler();
934                                    tracing::debug!("Called `post_reconnection` handler");
935                                }
936                            } else {
937                                tracing::debug!(
938                                    "Skipping post_reconnection handlers due to disconnect state"
939                                );
940                            }
941                        }
942                        Err(e) => {
943                            let duration = inner.backoff.next_duration();
944                            tracing::warn!("Reconnect attempt failed: {e}");
945                            if !duration.is_zero() {
946                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
947                            }
948                            tokio::time::sleep(duration).await;
949                        }
950                    }
951                }
952            }
953            inner
954                .connection_mode
955                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
956
957            log_task_stopped("controller");
958        })
959    }
960}
961
962// Abort controller task on drop to clean up background tasks
963impl Drop for SocketClient {
964    fn drop(&mut self) {
965        if !self.controller_task.is_finished() {
966            self.controller_task.abort();
967            log_task_aborted("controller");
968        }
969    }
970}
971
972////////////////////////////////////////////////////////////////////////////////
973// Tests
974////////////////////////////////////////////////////////////////////////////////
975
976#[cfg(test)]
977#[cfg(feature = "python")]
978#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
979mod tests {
980    use nautilus_common::testing::wait_until_async;
981    use pyo3::Python;
982    use tokio::{
983        io::{AsyncReadExt, AsyncWriteExt},
984        net::{TcpListener, TcpStream},
985        sync::Mutex,
986        task,
987        time::{Duration, sleep},
988    };
989
990    use super::*;
991
992    async fn bind_test_server() -> (u16, TcpListener) {
993        let listener = TcpListener::bind("127.0.0.1:0")
994            .await
995            .expect("Failed to bind ephemeral port");
996        let port = listener.local_addr().unwrap().port();
997        (port, listener)
998    }
999
1000    async fn run_echo_server(mut socket: TcpStream) {
1001        let mut buf = Vec::new();
1002        loop {
1003            match socket.read_buf(&mut buf).await {
1004                Ok(0) => {
1005                    break;
1006                }
1007                Ok(_n) => {
1008                    while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1009                        let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1010                        // Remove trailing \r\n
1011                        line.truncate(line.len() - 2);
1012
1013                        if line == b"close" {
1014                            let _ = socket.shutdown().await;
1015                            return;
1016                        }
1017
1018                        let mut echo_data = line;
1019                        echo_data.extend_from_slice(b"\r\n");
1020                        if socket.write_all(&echo_data).await.is_err() {
1021                            break;
1022                        }
1023                    }
1024                }
1025                Err(e) => {
1026                    eprintln!("Server read error: {e}");
1027                    break;
1028                }
1029            }
1030        }
1031    }
1032
1033    #[tokio::test]
1034    async fn test_basic_send_receive() {
1035        Python::initialize();
1036
1037        let (port, listener) = bind_test_server().await;
1038        let server_task = task::spawn(async move {
1039            let (socket, _) = listener.accept().await.unwrap();
1040            run_echo_server(socket).await;
1041        });
1042
1043        let config = SocketConfig {
1044            url: format!("127.0.0.1:{port}"),
1045            mode: Mode::Plain,
1046            suffix: b"\r\n".to_vec(),
1047            message_handler: None,
1048            heartbeat: None,
1049            reconnect_timeout_ms: None,
1050            reconnect_delay_initial_ms: None,
1051            reconnect_backoff_factor: None,
1052            reconnect_delay_max_ms: None,
1053            reconnect_jitter_ms: None,
1054            certs_dir: None,
1055        };
1056
1057        let client = SocketClient::connect(config, None, None, None)
1058            .await
1059            .expect("Client connect failed unexpectedly");
1060
1061        client.send_bytes(b"Hello".into()).await.unwrap();
1062        client.send_bytes(b"World".into()).await.unwrap();
1063
1064        // Wait a bit for the server to echo them back
1065        sleep(Duration::from_millis(100)).await;
1066
1067        client.send_bytes(b"close".into()).await.unwrap();
1068        server_task.await.unwrap();
1069        assert!(!client.is_closed());
1070    }
1071
1072    #[tokio::test]
1073    async fn test_reconnect_fail_exhausted() {
1074        Python::initialize();
1075
1076        let (port, listener) = bind_test_server().await;
1077        drop(listener); // We drop it immediately -> no server is listening
1078
1079        let config = SocketConfig {
1080            url: format!("127.0.0.1:{port}"),
1081            mode: Mode::Plain,
1082            suffix: b"\r\n".to_vec(),
1083            message_handler: None,
1084            heartbeat: None,
1085            reconnect_timeout_ms: None,
1086            reconnect_delay_initial_ms: None,
1087            reconnect_backoff_factor: None,
1088            reconnect_delay_max_ms: None,
1089            reconnect_jitter_ms: None,
1090            certs_dir: None,
1091        };
1092
1093        let client_res = SocketClient::connect(config, None, None, None).await;
1094        assert!(
1095            client_res.is_err(),
1096            "Should fail quickly with no server listening"
1097        );
1098    }
1099
1100    #[tokio::test]
1101    async fn test_user_disconnect() {
1102        Python::initialize();
1103
1104        let (port, listener) = bind_test_server().await;
1105        let server_task = task::spawn(async move {
1106            let (socket, _) = listener.accept().await.unwrap();
1107            let mut buf = [0u8; 1024];
1108            let _ = socket.try_read(&mut buf);
1109
1110            loop {
1111                sleep(Duration::from_secs(1)).await;
1112            }
1113        });
1114
1115        let config = SocketConfig {
1116            url: format!("127.0.0.1:{port}"),
1117            mode: Mode::Plain,
1118            suffix: b"\r\n".to_vec(),
1119            message_handler: None,
1120            heartbeat: None,
1121            reconnect_timeout_ms: None,
1122            reconnect_delay_initial_ms: None,
1123            reconnect_backoff_factor: None,
1124            reconnect_delay_max_ms: None,
1125            reconnect_jitter_ms: None,
1126            certs_dir: None,
1127        };
1128
1129        let client = SocketClient::connect(config, None, None, None)
1130            .await
1131            .unwrap();
1132
1133        client.close().await;
1134        assert!(client.is_closed());
1135        server_task.abort();
1136    }
1137
1138    #[tokio::test]
1139    async fn test_heartbeat() {
1140        Python::initialize();
1141
1142        let (port, listener) = bind_test_server().await;
1143        let received = Arc::new(Mutex::new(Vec::new()));
1144        let received2 = received.clone();
1145
1146        let server_task = task::spawn(async move {
1147            let (socket, _) = listener.accept().await.unwrap();
1148
1149            let mut buf = Vec::new();
1150            loop {
1151                match socket.try_read_buf(&mut buf) {
1152                    Ok(0) => break,
1153                    Ok(_) => {
1154                        while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1155                            let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1156                            line.truncate(line.len() - 2);
1157                            received2.lock().await.push(line);
1158                        }
1159                    }
1160                    Err(_) => {
1161                        tokio::time::sleep(Duration::from_millis(10)).await;
1162                    }
1163                }
1164            }
1165        });
1166
1167        // Heartbeat every 1 second
1168        let heartbeat = Some((1, b"ping".to_vec()));
1169
1170        let config = SocketConfig {
1171            url: format!("127.0.0.1:{port}"),
1172            mode: Mode::Plain,
1173            suffix: b"\r\n".to_vec(),
1174            message_handler: None,
1175            heartbeat,
1176            reconnect_timeout_ms: None,
1177            reconnect_delay_initial_ms: None,
1178            reconnect_backoff_factor: None,
1179            reconnect_delay_max_ms: None,
1180            reconnect_jitter_ms: None,
1181            certs_dir: None,
1182        };
1183
1184        let client = SocketClient::connect(config, None, None, None)
1185            .await
1186            .unwrap();
1187
1188        // Wait ~3 seconds to collect some heartbeats
1189        sleep(Duration::from_secs(3)).await;
1190
1191        {
1192            let lock = received.lock().await;
1193            let pings = lock
1194                .iter()
1195                .filter(|line| line == &&b"ping".to_vec())
1196                .count();
1197            assert!(
1198                pings >= 2,
1199                "Expected at least 2 heartbeat pings; got {pings}"
1200            );
1201        }
1202
1203        client.close().await;
1204        server_task.abort();
1205    }
1206
1207    #[tokio::test]
1208    async fn test_reconnect_success() {
1209        Python::initialize();
1210
1211        let (port, listener) = bind_test_server().await;
1212
1213        // Spawn a server task that:
1214        // 1. Accepts the first connection and then drops it after a short delay (simulate disconnect)
1215        // 2. Waits a bit and then accepts a new connection and runs the echo server
1216        let server_task = task::spawn(async move {
1217            // Accept first connection
1218            let (mut socket, _) = listener.accept().await.expect("First accept failed");
1219
1220            // Wait briefly and then force-close the connection
1221            sleep(Duration::from_millis(500)).await;
1222            let _ = socket.shutdown().await;
1223
1224            // Wait for the client's reconnect attempt
1225            sleep(Duration::from_millis(500)).await;
1226
1227            // Run the echo server on the new connection
1228            let (socket, _) = listener.accept().await.expect("Second accept failed");
1229            run_echo_server(socket).await;
1230        });
1231
1232        let config = SocketConfig {
1233            url: format!("127.0.0.1:{port}"),
1234            mode: Mode::Plain,
1235            suffix: b"\r\n".to_vec(),
1236            message_handler: None,
1237            heartbeat: None,
1238            reconnect_timeout_ms: Some(5_000),
1239            reconnect_delay_initial_ms: Some(500),
1240            reconnect_delay_max_ms: Some(5_000),
1241            reconnect_backoff_factor: Some(2.0),
1242            reconnect_jitter_ms: Some(50),
1243            certs_dir: None,
1244        };
1245
1246        let client = SocketClient::connect(config, None, None, None)
1247            .await
1248            .expect("Client connect failed unexpectedly");
1249
1250        // Initially, the client should be active
1251        assert!(client.is_active(), "Client should start as active");
1252
1253        // Wait until the client loses connection (i.e. not active),
1254        // then wait until it reconnects (active again).
1255        wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1256
1257        client
1258            .send_bytes(b"TestReconnect".into())
1259            .await
1260            .expect("Send failed");
1261
1262        client.close().await;
1263        server_task.abort();
1264    }
1265}
1266
1267#[cfg(test)]
1268mod rust_tests {
1269    use rstest::rstest;
1270    use tokio::{
1271        io::{AsyncReadExt, AsyncWriteExt},
1272        net::TcpListener,
1273        task,
1274        time::{Duration, sleep},
1275    };
1276
1277    use super::*;
1278
1279    #[rstest]
1280    #[tokio::test]
1281    async fn test_reconnect_then_close() {
1282        // Bind an ephemeral port
1283        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1284        let port = listener.local_addr().unwrap().port();
1285
1286        // Server task: accept one connection and then drop it
1287        let server = task::spawn(async move {
1288            if let Ok((mut sock, _)) = listener.accept().await {
1289                drop(sock.shutdown());
1290            }
1291            // Keep listener alive briefly to avoid premature exit
1292            sleep(Duration::from_secs(1)).await;
1293        });
1294
1295        // Configure client with a short reconnect backoff
1296        let config = SocketConfig {
1297            url: format!("127.0.0.1:{port}"),
1298            mode: Mode::Plain,
1299            suffix: b"\r\n".to_vec(),
1300            message_handler: None,
1301            heartbeat: None,
1302            reconnect_timeout_ms: Some(1_000),
1303            reconnect_delay_initial_ms: Some(50),
1304            reconnect_delay_max_ms: Some(100),
1305            reconnect_backoff_factor: Some(1.0),
1306            reconnect_jitter_ms: Some(0),
1307            certs_dir: None,
1308        };
1309
1310        // Connect client (handler=None)
1311        let client = SocketClient::connect(config.clone(), None, None, None)
1312            .await
1313            .unwrap();
1314
1315        // Allow server to drop connection and client to notice
1316        sleep(Duration::from_millis(100)).await;
1317
1318        // Now close the client
1319        client.close().await;
1320        assert!(client.is_closed());
1321        server.abort();
1322    }
1323
1324    #[rstest]
1325    #[tokio::test]
1326    async fn test_reconnect_state_flips_when_reader_stops() {
1327        // Bind an ephemeral port and accept a single connection which we immediately close.
1328        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1329        let port = listener.local_addr().unwrap().port();
1330
1331        let server = task::spawn(async move {
1332            if let Ok((sock, _)) = listener.accept().await {
1333                drop(sock);
1334            }
1335            // Give the client a moment to observe the closed connection.
1336            sleep(Duration::from_millis(50)).await;
1337        });
1338
1339        let config = SocketConfig {
1340            url: format!("127.0.0.1:{port}"),
1341            mode: Mode::Plain,
1342            suffix: b"\r\n".to_vec(),
1343            message_handler: None,
1344            heartbeat: None,
1345            reconnect_timeout_ms: Some(1_000),
1346            reconnect_delay_initial_ms: Some(50),
1347            reconnect_delay_max_ms: Some(100),
1348            reconnect_backoff_factor: Some(1.0),
1349            reconnect_jitter_ms: Some(0),
1350            certs_dir: None,
1351        };
1352
1353        let client = SocketClient::connect(config, None, None, None)
1354            .await
1355            .unwrap();
1356
1357        tokio::time::timeout(Duration::from_secs(2), async {
1358            loop {
1359                if client.is_reconnecting() {
1360                    break;
1361                }
1362                tokio::time::sleep(Duration::from_millis(10)).await;
1363            }
1364        })
1365        .await
1366        .expect("client did not enter RECONNECT state");
1367
1368        client.close().await;
1369        server.abort();
1370    }
1371
1372    #[rstest]
1373    fn test_parse_socket_url_raw_address() {
1374        // Raw socket address with TLS mode
1375        let (socket_addr, request_url) =
1376            SocketClientInner::parse_socket_url("example.com:6130", Mode::Tls).unwrap();
1377        assert_eq!(socket_addr, "example.com:6130");
1378        assert_eq!(request_url, "wss://example.com:6130");
1379
1380        // Raw socket address with Plain mode
1381        let (socket_addr, request_url) =
1382            SocketClientInner::parse_socket_url("localhost:8080", Mode::Plain).unwrap();
1383        assert_eq!(socket_addr, "localhost:8080");
1384        assert_eq!(request_url, "ws://localhost:8080");
1385    }
1386
1387    #[rstest]
1388    fn test_parse_socket_url_with_scheme() {
1389        // Full URL with wss scheme
1390        let (socket_addr, request_url) =
1391            SocketClientInner::parse_socket_url("wss://example.com:443/path", Mode::Tls).unwrap();
1392        assert_eq!(socket_addr, "example.com:443");
1393        assert_eq!(request_url, "wss://example.com:443/path");
1394
1395        // Full URL with ws scheme
1396        let (socket_addr, request_url) =
1397            SocketClientInner::parse_socket_url("ws://localhost:8080", Mode::Plain).unwrap();
1398        assert_eq!(socket_addr, "localhost:8080");
1399        assert_eq!(request_url, "ws://localhost:8080");
1400    }
1401
1402    #[rstest]
1403    fn test_parse_socket_url_default_ports() {
1404        // wss without explicit port defaults to 443
1405        let (socket_addr, _) =
1406            SocketClientInner::parse_socket_url("wss://example.com", Mode::Tls).unwrap();
1407        assert_eq!(socket_addr, "example.com:443");
1408
1409        // ws without explicit port defaults to 80
1410        let (socket_addr, _) =
1411            SocketClientInner::parse_socket_url("ws://example.com", Mode::Plain).unwrap();
1412        assert_eq!(socket_addr, "example.com:80");
1413
1414        // https defaults to 443
1415        let (socket_addr, _) =
1416            SocketClientInner::parse_socket_url("https://example.com", Mode::Tls).unwrap();
1417        assert_eq!(socket_addr, "example.com:443");
1418
1419        // http defaults to 80
1420        let (socket_addr, _) =
1421            SocketClientInner::parse_socket_url("http://example.com", Mode::Plain).unwrap();
1422        assert_eq!(socket_addr, "example.com:80");
1423    }
1424
1425    #[rstest]
1426    fn test_parse_socket_url_unknown_scheme_uses_mode() {
1427        // Unknown scheme defaults to mode-based port
1428        let (socket_addr, _) =
1429            SocketClientInner::parse_socket_url("custom://example.com", Mode::Tls).unwrap();
1430        assert_eq!(socket_addr, "example.com:443");
1431
1432        let (socket_addr, _) =
1433            SocketClientInner::parse_socket_url("custom://example.com", Mode::Plain).unwrap();
1434        assert_eq!(socket_addr, "example.com:80");
1435    }
1436
1437    #[rstest]
1438    fn test_parse_socket_url_ipv6() {
1439        // IPv6 address with port
1440        let (socket_addr, request_url) =
1441            SocketClientInner::parse_socket_url("[::1]:8080", Mode::Plain).unwrap();
1442        assert_eq!(socket_addr, "[::1]:8080");
1443        assert_eq!(request_url, "ws://[::1]:8080");
1444
1445        // IPv6 in URL
1446        let (socket_addr, _) =
1447            SocketClientInner::parse_socket_url("ws://[::1]:8080", Mode::Plain).unwrap();
1448        assert_eq!(socket_addr, "[::1]:8080");
1449    }
1450
1451    #[rstest]
1452    #[tokio::test]
1453    async fn test_url_parsing_raw_socket_address() {
1454        // Test that raw socket addresses (host:port) work correctly
1455        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1456        let port = listener.local_addr().unwrap().port();
1457
1458        let server = task::spawn(async move {
1459            if let Ok((sock, _)) = listener.accept().await {
1460                drop(sock);
1461            }
1462            sleep(Duration::from_millis(50)).await;
1463        });
1464
1465        let config = SocketConfig {
1466            url: format!("127.0.0.1:{port}"), // Raw socket address format
1467            mode: Mode::Plain,
1468            suffix: b"\r\n".to_vec(),
1469            message_handler: None,
1470            heartbeat: None,
1471            reconnect_timeout_ms: Some(1_000),
1472            reconnect_delay_initial_ms: Some(50),
1473            reconnect_delay_max_ms: Some(100),
1474            reconnect_backoff_factor: Some(1.0),
1475            reconnect_jitter_ms: Some(0),
1476            certs_dir: None,
1477        };
1478
1479        // Should successfully connect with raw socket address
1480        let client = SocketClient::connect(config, None, None, None).await;
1481        assert!(
1482            client.is_ok(),
1483            "Client should connect with raw socket address format"
1484        );
1485
1486        if let Ok(client) = client {
1487            client.close().await;
1488        }
1489        server.abort();
1490    }
1491
1492    #[rstest]
1493    #[tokio::test]
1494    async fn test_url_parsing_with_scheme() {
1495        // Test that URLs with schemes also work
1496        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1497        let port = listener.local_addr().unwrap().port();
1498
1499        let server = task::spawn(async move {
1500            if let Ok((sock, _)) = listener.accept().await {
1501                drop(sock);
1502            }
1503            sleep(Duration::from_millis(50)).await;
1504        });
1505
1506        let config = SocketConfig {
1507            url: format!("ws://127.0.0.1:{port}"), // URL with scheme
1508            mode: Mode::Plain,
1509            suffix: b"\r\n".to_vec(),
1510            message_handler: None,
1511            heartbeat: None,
1512            reconnect_timeout_ms: Some(1_000),
1513            reconnect_delay_initial_ms: Some(50),
1514            reconnect_delay_max_ms: Some(100),
1515            reconnect_backoff_factor: Some(1.0),
1516            reconnect_jitter_ms: Some(0),
1517            certs_dir: None,
1518        };
1519
1520        // Should successfully connect with URL format
1521        let client = SocketClient::connect(config, None, None, None).await;
1522        assert!(
1523            client.is_ok(),
1524            "Client should connect with URL scheme format"
1525        );
1526
1527        if let Ok(client) = client {
1528            client.close().await;
1529        }
1530        server.abort();
1531    }
1532
1533    #[rstest]
1534    fn test_parse_socket_url_ipv6_with_zone() {
1535        // IPv6 with zone ID (link-local address)
1536        let (socket_addr, request_url) =
1537            SocketClientInner::parse_socket_url("[fe80::1%eth0]:8080", Mode::Plain).unwrap();
1538        assert_eq!(socket_addr, "[fe80::1%eth0]:8080");
1539        assert_eq!(request_url, "ws://[fe80::1%eth0]:8080");
1540
1541        // Verify zone is preserved in URL format too
1542        let (socket_addr, request_url) =
1543            SocketClientInner::parse_socket_url("ws://[fe80::1%lo]:9090", Mode::Plain).unwrap();
1544        assert_eq!(socket_addr, "[fe80::1%lo]:9090");
1545        assert_eq!(request_url, "ws://[fe80::1%lo]:9090");
1546    }
1547
1548    #[rstest]
1549    #[tokio::test]
1550    async fn test_ipv6_loopback_connection() {
1551        // Test IPv6 loopback address connection
1552        // Skip if IPv6 is not available on the system
1553        if TcpListener::bind("[::1]:0").await.is_err() {
1554            eprintln!("IPv6 not available, skipping test");
1555            return;
1556        }
1557
1558        let listener = TcpListener::bind("[::1]:0").await.unwrap();
1559        let port = listener.local_addr().unwrap().port();
1560
1561        let server = task::spawn(async move {
1562            if let Ok((mut sock, _)) = listener.accept().await {
1563                let mut buf = vec![0u8; 1024];
1564                if let Ok(n) = sock.read(&mut buf).await {
1565                    // Echo back
1566                    let _ = sock.write_all(&buf[..n]).await;
1567                }
1568            }
1569            sleep(Duration::from_millis(50)).await;
1570        });
1571
1572        let config = SocketConfig {
1573            url: format!("[::1]:{port}"), // IPv6 loopback
1574            mode: Mode::Plain,
1575            suffix: b"\r\n".to_vec(),
1576            message_handler: None,
1577            heartbeat: None,
1578            reconnect_timeout_ms: Some(1_000),
1579            reconnect_delay_initial_ms: Some(50),
1580            reconnect_delay_max_ms: Some(100),
1581            reconnect_backoff_factor: Some(1.0),
1582            reconnect_jitter_ms: Some(0),
1583            certs_dir: None,
1584        };
1585
1586        let client = SocketClient::connect(config, None, None, None).await;
1587        assert!(
1588            client.is_ok(),
1589            "Client should connect to IPv6 loopback address"
1590        );
1591
1592        if let Ok(client) = client {
1593            client.close().await;
1594        }
1595        server.abort();
1596    }
1597
1598    #[rstest]
1599    #[tokio::test]
1600    async fn test_send_waits_during_reconnection() {
1601        // Test that send operations wait for reconnection to complete (up to configured timeout)
1602        use nautilus_common::testing::wait_until_async;
1603
1604        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1605        let port = listener.local_addr().unwrap().port();
1606
1607        let server = task::spawn(async move {
1608            // First connection - accept and immediately close
1609            if let Ok((sock, _)) = listener.accept().await {
1610                drop(sock);
1611            }
1612
1613            // Wait before accepting second connection
1614            sleep(Duration::from_millis(500)).await;
1615
1616            // Second connection - accept and keep alive
1617            if let Ok((mut sock, _)) = listener.accept().await {
1618                // Echo messages
1619                let mut buf = vec![0u8; 1024];
1620                while let Ok(n) = sock.read(&mut buf).await {
1621                    if n == 0 {
1622                        break;
1623                    }
1624                    if sock.write_all(&buf[..n]).await.is_err() {
1625                        break;
1626                    }
1627                }
1628            }
1629        });
1630
1631        let config = SocketConfig {
1632            url: format!("127.0.0.1:{port}"),
1633            mode: Mode::Plain,
1634            suffix: b"\r\n".to_vec(),
1635            message_handler: None,
1636            heartbeat: None,
1637            reconnect_timeout_ms: Some(5_000), // 5s timeout - enough for reconnect
1638            reconnect_delay_initial_ms: Some(100),
1639            reconnect_delay_max_ms: Some(200),
1640            reconnect_backoff_factor: Some(1.0),
1641            reconnect_jitter_ms: Some(0),
1642            certs_dir: None,
1643        };
1644
1645        let client = SocketClient::connect(config, None, None, None)
1646            .await
1647            .unwrap();
1648
1649        // Wait for reconnection to trigger
1650        wait_until_async(
1651            || async { client.is_reconnecting() },
1652            Duration::from_secs(2),
1653        )
1654        .await;
1655
1656        // Try to send while reconnecting - should wait and succeed after reconnect
1657        let send_result = tokio::time::timeout(
1658            Duration::from_secs(3),
1659            client.send_bytes(b"test_message".to_vec()),
1660        )
1661        .await;
1662
1663        assert!(
1664            send_result.is_ok() && send_result.unwrap().is_ok(),
1665            "Send should succeed after waiting for reconnection"
1666        );
1667
1668        client.close().await;
1669        server.abort();
1670    }
1671
1672    #[rstest]
1673    #[tokio::test]
1674    async fn test_send_bytes_timeout_uses_configured_reconnect_timeout() {
1675        // Test that send_bytes operations respect the configured reconnect_timeout.
1676        // When a client is stuck in RECONNECT longer than the timeout, sends should fail with Timeout.
1677        use nautilus_common::testing::wait_until_async;
1678
1679        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1680        let port = listener.local_addr().unwrap().port();
1681
1682        let server = task::spawn(async move {
1683            // Accept first connection and immediately close it
1684            if let Ok((sock, _)) = listener.accept().await {
1685                drop(sock);
1686            }
1687            // Drop listener entirely so reconnection fails completely
1688            drop(listener);
1689            sleep(Duration::from_secs(60)).await;
1690        });
1691
1692        let config = SocketConfig {
1693            url: format!("127.0.0.1:{port}"),
1694            mode: Mode::Plain,
1695            suffix: b"\r\n".to_vec(),
1696            message_handler: None,
1697            heartbeat: None,
1698            reconnect_timeout_ms: Some(1_000), // 1s timeout for faster test
1699            reconnect_delay_initial_ms: Some(5_000), // Long backoff to keep client in RECONNECT
1700            reconnect_delay_max_ms: Some(5_000),
1701            reconnect_backoff_factor: Some(1.0),
1702            reconnect_jitter_ms: Some(0),
1703            certs_dir: None,
1704        };
1705
1706        let client = SocketClient::connect(config, None, None, None)
1707            .await
1708            .unwrap();
1709
1710        // Wait for client to enter RECONNECT state
1711        wait_until_async(
1712            || async { client.is_reconnecting() },
1713            Duration::from_secs(3),
1714        )
1715        .await;
1716
1717        // Attempt send while stuck in RECONNECT - should timeout after 1s (configured timeout)
1718        // The client will try to reconnect for 1s, fail, then wait 5s backoff before next attempt
1719        let start = std::time::Instant::now();
1720        let send_result = client.send_bytes(b"test".to_vec()).await;
1721        let elapsed = start.elapsed();
1722
1723        assert!(
1724            send_result.is_err(),
1725            "Send should fail when client stuck in RECONNECT, got: {:?}",
1726            send_result
1727        );
1728        assert!(
1729            matches!(send_result, Err(crate::error::SendError::Timeout)),
1730            "Send should return Timeout error, got: {:?}",
1731            send_result
1732        );
1733        // Verify timeout respects configured value (1s), but don't check upper bound
1734        // as CI scheduler jitter can cause legitimate delays beyond the timeout
1735        assert!(
1736            elapsed >= Duration::from_millis(900),
1737            "Send should timeout after at least 1s (configured timeout), took {:?}",
1738            elapsed
1739        );
1740
1741        client.close().await;
1742        server.abort();
1743    }
1744}