1use std::{
32 collections::VecDeque,
33 fmt::Debug,
34 path::Path,
35 sync::{
36 Arc,
37 atomic::{AtomicU8, Ordering},
38 },
39 time::Duration,
40};
41
42use bytes::Bytes;
43use nautilus_core::CleanDrop;
44use nautilus_cryptography::providers::install_cryptographic_provider;
45use tokio::io::{AsyncReadExt, AsyncWriteExt};
46use tokio_tungstenite::tungstenite::{Error, client::IntoClientRequest, stream::Mode};
47
48use super::{
49 SocketConfig, TcpMessageHandler, TcpReader, TcpWriter, WriterCommand, fix::process_fix_buffer,
50};
51use crate::{
52 backoff::ExponentialBackoff,
53 error::SendError,
54 logging::{log_task_aborted, log_task_started, log_task_stopped},
55 mode::ConnectionMode,
56 net::TcpStream,
57 tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
58};
59
60const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
62const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
63const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
64const SEND_OPERATION_CHECK_INTERVAL_MS: u64 = 1;
65
66const MAX_READ_BUFFER_BYTES: usize = 10 * 1024 * 1024;
68
69#[cfg_attr(
85 feature = "python",
86 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
87)]
88struct SocketClientInner {
89 config: SocketConfig,
90 connector: Option<Connector>,
91 read_task: Arc<tokio::task::JoinHandle<()>>,
92 write_task: tokio::task::JoinHandle<()>,
93 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
94 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
95 connection_mode: Arc<AtomicU8>,
96 reconnect_timeout: Duration,
97 backoff: ExponentialBackoff,
98 handler: Option<TcpMessageHandler>,
99 reconnect_max_attempts: Option<u32>,
100 reconnect_attempt_count: u32,
101}
102
103impl SocketClientInner {
104 pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
110 const CONNECTION_TIMEOUT_SECS: u64 = 10;
111
112 install_cryptographic_provider();
113
114 if config.suffix.is_empty() {
116 anyhow::bail!("Socket suffix cannot be empty: suffix is required for message framing");
117 }
118
119 let SocketConfig {
120 url,
121 mode,
122 heartbeat,
123 suffix,
124 message_handler,
125 reconnect_timeout_ms,
126 reconnect_delay_initial_ms,
127 reconnect_delay_max_ms,
128 reconnect_backoff_factor,
129 reconnect_jitter_ms,
130 connection_max_retries,
131 reconnect_max_attempts,
132 certs_dir,
133 } = &config.clone();
134 let connector = if let Some(dir) = certs_dir {
135 let config = create_tls_config_from_certs_dir(Path::new(dir), false)?;
136 Some(Connector::Rustls(Arc::new(config)))
137 } else {
138 None
139 };
140
141 let max_retries = connection_max_retries.unwrap_or(5);
143
144 let mut backoff = ExponentialBackoff::new(
145 Duration::from_millis(500),
146 Duration::from_millis(5000),
147 2.0,
148 250,
149 false,
150 )?;
151
152 #[allow(unused_assignments)]
153 let mut last_error = String::new();
154 let mut attempt = 0;
155 let (reader, writer) = loop {
156 attempt += 1;
157
158 match tokio::time::timeout(
159 Duration::from_secs(CONNECTION_TIMEOUT_SECS),
160 Self::tls_connect_with_server(url, *mode, connector.clone()),
161 )
162 .await
163 {
164 Ok(Ok(result)) => {
165 if attempt > 1 {
166 log::info!("Socket connection established after {attempt} attempts");
167 }
168 break result;
169 }
170 Ok(Err(e)) => {
171 last_error = e.to_string();
172 log::warn!(
173 "Socket connection attempt {attempt}/{max_retries} to {url} failed: {last_error}"
174 );
175 }
176 Err(_) => {
177 last_error = format!(
178 "Connection timeout after {CONNECTION_TIMEOUT_SECS}s (possible DNS resolution failure)"
179 );
180 log::warn!(
181 "Socket connection attempt {attempt}/{max_retries} to {url} timed out"
182 );
183 }
184 }
185
186 if attempt >= max_retries {
187 anyhow::bail!(
188 "Failed to connect to {} after {} attempts: {}. \
189 If this is a DNS error, check your network configuration and DNS settings.",
190 url,
191 max_retries,
192 if last_error.is_empty() {
193 "unknown error"
194 } else {
195 &last_error
196 }
197 );
198 }
199
200 let delay = backoff.next_duration();
201 log::debug!(
202 "Retrying in {delay:?} (attempt {}/{})",
203 attempt + 1,
204 max_retries
205 );
206 tokio::time::sleep(delay).await;
207 };
208
209 log::debug!("Connected");
210
211 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
212
213 let read_task = Arc::new(Self::spawn_read_task(
214 connection_mode.clone(),
215 reader,
216 message_handler.clone(),
217 suffix.clone(),
218 ));
219
220 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
221
222 let write_task =
223 Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
224
225 let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
227 Self::spawn_heartbeat_task(
228 connection_mode.clone(),
229 heartbeat.clone(),
230 writer_tx.clone(),
231 )
232 });
233
234 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
235 let backoff = ExponentialBackoff::new(
236 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
237 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
238 reconnect_backoff_factor.unwrap_or(1.5),
239 reconnect_jitter_ms.unwrap_or(100),
240 true, )?;
242
243 Ok(Self {
244 config,
245 connector,
246 read_task,
247 write_task,
248 writer_tx,
249 heartbeat_task,
250 connection_mode,
251 reconnect_timeout,
252 backoff,
253 handler: message_handler.clone(),
254 reconnect_max_attempts: *reconnect_max_attempts,
255 reconnect_attempt_count: 0,
256 })
257 }
258
259 fn parse_socket_url(url: &str, mode: Mode) -> Result<(String, String), Error> {
269 if url.contains("://") {
270 let parsed = url.parse::<http::Uri>().map_err(|e| {
272 Error::Io(std::io::Error::new(
273 std::io::ErrorKind::InvalidInput,
274 format!("Invalid URL: {e}"),
275 ))
276 })?;
277
278 let host = parsed.host().ok_or_else(|| {
279 Error::Io(std::io::Error::new(
280 std::io::ErrorKind::InvalidInput,
281 "URL missing host",
282 ))
283 })?;
284
285 let port = parsed
286 .port_u16()
287 .unwrap_or_else(|| match parsed.scheme_str() {
288 Some("wss" | "https") => 443,
289 Some("ws" | "http") => 80,
290 _ => match mode {
291 Mode::Tls => 443,
292 Mode::Plain => 80,
293 },
294 });
295
296 Ok((format!("{host}:{port}"), url.to_string()))
297 } else {
298 let scheme = match mode {
301 Mode::Tls => "wss",
302 Mode::Plain => "ws",
303 };
304 Ok((url.to_string(), format!("{scheme}://{url}")))
305 }
306 }
307
308 pub async fn tls_connect_with_server(
318 url: &str,
319 mode: Mode,
320 connector: Option<Connector>,
321 ) -> Result<(TcpReader, TcpWriter), Error> {
322 log::debug!("Connecting to {url}");
323
324 let (socket_addr, request_url) = Self::parse_socket_url(url, mode)?;
325 let tcp_result = TcpStream::connect(&socket_addr).await;
326
327 match tcp_result {
328 Ok(stream) => {
329 log::debug!("TCP connection established to {socket_addr}, proceeding with TLS");
330 if let Err(e) = stream.set_nodelay(true) {
331 log::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
332 }
333 let request = request_url.into_client_request()?;
334 tcp_tls(&request, mode, stream, connector)
335 .await
336 .map(tokio::io::split)
337 }
338 Err(e) => {
339 log::error!("TCP connection failed to {socket_addr}: {e:?}");
340 Err(Error::Io(e))
341 }
342 }
343 }
344
345 async fn reconnect(&mut self) -> Result<(), Error> {
350 log::debug!("Reconnecting");
351
352 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
353 log::debug!("Reconnect aborted due to disconnect state");
354 return Ok(());
355 }
356
357 tokio::time::timeout(self.reconnect_timeout, async {
358 let SocketConfig {
359 url,
360 mode,
361 heartbeat: _,
362 suffix,
363 message_handler: _,
364 reconnect_timeout_ms: _,
365 reconnect_delay_initial_ms: _,
366 reconnect_backoff_factor: _,
367 reconnect_delay_max_ms: _,
368 reconnect_jitter_ms: _,
369 connection_max_retries: _,
370 reconnect_max_attempts: _,
371 certs_dir: _,
372 } = &self.config;
373 let connector = self.connector.clone();
375 let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
377
378 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
379 log::debug!("Reconnect aborted mid-flight (after connect)");
380 return Ok(());
381 }
382 log::debug!("Connected");
383
384 let (tx, rx) = tokio::sync::oneshot::channel();
388 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
389 log::error!("{e}");
390 return Err(Error::Io(std::io::Error::new(
391 std::io::ErrorKind::BrokenPipe,
392 format!("Failed to send update command: {e}"),
393 )));
394 }
395
396 match rx.await {
398 Ok(true) => log::debug!("Writer confirmed buffer drain success"),
399 Ok(false) => {
400 log::warn!("Writer failed to drain buffer, aborting reconnect");
401 return Err(Error::Io(std::io::Error::other(
403 "Failed to drain reconnection buffer",
404 )));
405 }
406 Err(e) => {
407 log::error!("Writer dropped update channel: {e}");
408 return Err(Error::Io(std::io::Error::new(
409 std::io::ErrorKind::BrokenPipe,
410 "Writer task dropped response channel",
411 )));
412 }
413 }
414
415 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
417
418 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
419 log::debug!("Reconnect aborted mid-flight (after delay)");
420 return Ok(());
421 }
422
423 if !self.read_task.is_finished() {
424 self.read_task.abort();
425 log_task_aborted("read");
426 }
427
428 if self
431 .connection_mode
432 .compare_exchange(
433 ConnectionMode::Reconnect.as_u8(),
434 ConnectionMode::Active.as_u8(),
435 Ordering::SeqCst,
436 Ordering::SeqCst,
437 )
438 .is_err()
439 {
440 log::debug!("Reconnect aborted (state changed during reconnect)");
441 return Ok(());
442 }
443
444 self.read_task = Arc::new(Self::spawn_read_task(
446 self.connection_mode.clone(),
447 reader,
448 self.handler.clone(),
449 suffix.clone(),
450 ));
451
452 log::debug!("Reconnect succeeded");
453 Ok(())
454 })
455 .await
456 .map_err(|_| {
457 Error::Io(std::io::Error::new(
458 std::io::ErrorKind::TimedOut,
459 format!(
460 "reconnection timed out after {}s",
461 self.reconnect_timeout.as_secs_f64()
462 ),
463 ))
464 })?
465 }
466
467 #[inline]
474 #[must_use]
475 pub fn is_alive(&self) -> bool {
476 !self.read_task.is_finished()
477 }
478
479 #[must_use]
480 fn spawn_read_task(
481 connection_state: Arc<AtomicU8>,
482 mut reader: TcpReader,
483 handler: Option<TcpMessageHandler>,
484 suffix: Vec<u8>,
485 ) -> tokio::task::JoinHandle<()> {
486 log_task_started("read");
487
488 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
490
491 tokio::task::spawn(async move {
492 let mut buf = Vec::new();
493
494 loop {
495 if !ConnectionMode::from_atomic(&connection_state).is_active() {
496 break;
497 }
498
499 match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
500 Ok(Ok(0)) => {
502 log::debug!("Connection closed by server");
503 break;
504 }
505 Ok(Err(e)) => {
506 log::debug!("Connection ended: {e}");
507 break;
508 }
509 Ok(Ok(bytes)) => {
511 log::trace!("Received <binary> {bytes} bytes");
512
513 let is_fix = buf.len() >= 5 && buf.starts_with(b"8=FIX");
515
516 if is_fix && handler.is_some() {
517 if let Some(ref handler) = handler {
519 process_fix_buffer(&mut buf, handler);
520 }
521 } else {
522 while let Some((i, _)) = &buf
524 .windows(suffix.len())
525 .enumerate()
526 .find(|(_, pair)| pair.eq(&suffix))
527 {
528 let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
529 data.truncate(data.len() - suffix.len());
530
531 if let Some(ref handler) = handler {
532 handler(&data);
533 }
534 }
535 }
536
537 if buf.len() > MAX_READ_BUFFER_BYTES {
538 log::error!(
539 "Read buffer exceeded maximum size ({MAX_READ_BUFFER_BYTES} bytes), closing connection"
540 );
541 break;
542 }
543 }
544 Err(_) => {
545 continue;
547 }
548 }
549 }
550
551 log_task_stopped("read");
552 })
553 }
554
555 async fn drain_reconnect_buffer(
565 buffer: &mut VecDeque<Bytes>,
566 writer: &mut TcpWriter,
567 suffix: &[u8],
568 ) -> bool {
569 if buffer.is_empty() {
570 return false;
571 }
572
573 let initial_buffer_len = buffer.len();
574 log::info!("Sending {initial_buffer_len} buffered messages after reconnection");
575
576 let mut send_error_occurred = false;
577
578 while let Some(buffered_msg) = buffer.front() {
579 let mut combined_msg = Vec::with_capacity(buffered_msg.len() + suffix.len());
580 combined_msg.extend_from_slice(buffered_msg);
581 combined_msg.extend_from_slice(suffix);
582
583 if let Err(e) = writer.write_all(&combined_msg).await {
584 log::error!(
585 "Failed to send buffered message with suffix after reconnection: {e}, {} messages remain in buffer",
586 buffer.len()
587 );
588 send_error_occurred = true;
589 break;
590 }
591
592 buffer.pop_front();
593 }
594
595 if buffer.is_empty() {
596 log::info!("Successfully sent all {initial_buffer_len} buffered messages");
597 }
598
599 send_error_occurred
600 }
601
602 fn spawn_write_task(
603 connection_state: Arc<AtomicU8>,
604 writer: TcpWriter,
605 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
606 suffix: Vec<u8>,
607 ) -> tokio::task::JoinHandle<()> {
608 log_task_started("write");
609
610 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
612
613 tokio::task::spawn(async move {
614 let mut active_writer = writer;
615 let mut reconnect_buffer: VecDeque<Bytes> = VecDeque::new();
616
617 loop {
618 if matches!(
619 ConnectionMode::from_atomic(&connection_state),
620 ConnectionMode::Disconnect | ConnectionMode::Closed
621 ) {
622 break;
623 }
624
625 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
626 Ok(Some(msg)) => {
627 let mode = ConnectionMode::from_atomic(&connection_state);
629 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
630 break;
631 }
632
633 match msg {
634 WriterCommand::Update(new_writer, tx) => {
635 log::debug!("Received new writer");
636
637 tokio::time::sleep(Duration::from_millis(100)).await;
639
640 _ = tokio::time::timeout(
643 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
644 active_writer.shutdown(),
645 )
646 .await;
647
648 active_writer = new_writer;
649 log::debug!("Updated writer");
650
651 let send_error = Self::drain_reconnect_buffer(
652 &mut reconnect_buffer,
653 &mut active_writer,
654 &suffix,
655 )
656 .await;
657
658 if let Err(e) = tx.send(!send_error) {
659 log::error!(
660 "Failed to report drain status to controller: {e:?}"
661 );
662 }
663 }
664 _ if mode.is_reconnect() => {
665 if let WriterCommand::Send(data) = msg {
666 log::debug!(
667 "Buffering message while reconnecting ({} bytes)",
668 data.len()
669 );
670 reconnect_buffer.push_back(data);
671 }
672 continue;
673 }
674 WriterCommand::Send(msg) => {
675 if let Err(e) = active_writer.write_all(&msg).await {
676 log::error!("Failed to send message: {e}");
677 log::warn!("Writer triggering reconnect");
678 reconnect_buffer.push_back(msg);
679 connection_state
680 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
681 continue;
682 }
683 if let Err(e) = active_writer.write_all(&suffix).await {
684 log::error!("Failed to send suffix: {e}");
685 log::warn!("Writer triggering reconnect");
686 reconnect_buffer.push_back(msg);
688 connection_state
689 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
690 continue;
691 }
692 }
693 }
694 }
695 Ok(None) => {
696 log::debug!("Writer channel closed, terminating writer task");
698 break;
699 }
700 Err(_) => {
701 continue;
703 }
704 }
705 }
706
707 _ = tokio::time::timeout(
710 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
711 active_writer.shutdown(),
712 )
713 .await;
714
715 log_task_stopped("write");
716 })
717 }
718
719 fn spawn_heartbeat_task(
720 connection_state: Arc<AtomicU8>,
721 heartbeat: (u64, Vec<u8>),
722 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
723 ) -> tokio::task::JoinHandle<()> {
724 log_task_started("heartbeat");
725 let (interval_secs, message) = heartbeat;
726
727 tokio::task::spawn(async move {
728 let interval = Duration::from_secs(interval_secs);
729
730 loop {
731 tokio::time::sleep(interval).await;
732
733 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
734 ConnectionMode::Active => {
735 let msg = WriterCommand::Send(message.clone().into());
736
737 match writer_tx.send(msg) {
738 Ok(()) => log::trace!("Sent heartbeat to writer task"),
739 Err(e) => {
740 log::error!("Failed to send heartbeat to writer task: {e}");
741 }
742 }
743 }
744 ConnectionMode::Reconnect => continue,
745 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
746 }
747 }
748
749 log_task_stopped("heartbeat");
750 })
751 }
752}
753
754impl Drop for SocketClientInner {
755 fn drop(&mut self) {
756 self.clean_drop();
758 }
759}
760
761impl CleanDrop for SocketClientInner {
763 fn clean_drop(&mut self) {
764 if !self.read_task.is_finished() {
765 self.read_task.abort();
766 log_task_aborted("read");
767 }
768
769 if !self.write_task.is_finished() {
770 self.write_task.abort();
771 log_task_aborted("write");
772 }
773
774 if let Some(ref handle) = self.heartbeat_task.take()
775 && !handle.is_finished()
776 {
777 handle.abort();
778 log_task_aborted("heartbeat");
779 }
780
781 #[cfg(feature = "python")]
782 {
783 self.config.message_handler = None;
785 }
786 }
787}
788
789#[cfg_attr(
790 feature = "python",
791 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
792)]
793pub struct SocketClient {
794 pub(crate) controller_task: tokio::task::JoinHandle<()>,
795 pub(crate) connection_mode: Arc<AtomicU8>,
796 pub(crate) reconnect_timeout: Duration,
797 pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
798}
799
800impl Debug for SocketClient {
801 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
802 f.debug_struct(stringify!(SocketClient)).finish()
803 }
804}
805
806impl SocketClient {
807 pub async fn connect(
813 config: SocketConfig,
814 post_connection: Option<Arc<dyn Fn() + Send + Sync>>,
815 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
816 post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
817 ) -> anyhow::Result<Self> {
818 let inner = SocketClientInner::connect_url(config).await?;
819 let writer_tx = inner.writer_tx.clone();
820 let connection_mode = inner.connection_mode.clone();
821 let reconnect_timeout = inner.reconnect_timeout;
822
823 let controller_task = Self::spawn_controller_task(
824 inner,
825 connection_mode.clone(),
826 post_reconnection,
827 post_disconnection,
828 );
829
830 if let Some(handler) = post_connection {
831 handler();
832 log::debug!("Called `post_connection` handler");
833 }
834
835 Ok(Self {
836 controller_task,
837 connection_mode,
838 reconnect_timeout,
839 writer_tx,
840 })
841 }
842
843 #[must_use]
845 pub fn connection_mode(&self) -> ConnectionMode {
846 ConnectionMode::from_atomic(&self.connection_mode)
847 }
848
849 #[inline]
854 #[must_use]
855 pub fn is_active(&self) -> bool {
856 self.connection_mode().is_active()
857 }
858
859 #[inline]
864 #[must_use]
865 pub fn is_reconnecting(&self) -> bool {
866 self.connection_mode().is_reconnect()
867 }
868
869 #[inline]
873 #[must_use]
874 pub fn is_disconnecting(&self) -> bool {
875 self.connection_mode().is_disconnect()
876 }
877
878 #[inline]
884 #[must_use]
885 pub fn is_closed(&self) -> bool {
886 self.connection_mode().is_closed()
887 }
888
889 pub async fn close(&self) {
894 self.connection_mode
895 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
896
897 if tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
898 while !self.is_closed() {
899 tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
900 }
901
902 if !self.controller_task.is_finished() {
903 self.controller_task.abort();
904 log_task_aborted("controller");
905 }
906 })
907 .await
908 == Ok(())
909 {
910 log_task_stopped("controller");
911 } else {
912 log::error!("Timeout waiting for controller task to finish");
913 if !self.controller_task.is_finished() {
914 self.controller_task.abort();
915 log_task_aborted("controller");
916 }
917 }
918 }
919
920 pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
926 if self.is_closed() || self.is_disconnecting() {
928 return Err(SendError::Closed);
929 }
930
931 let timeout = self.reconnect_timeout;
932 let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
933
934 if !self.is_active() {
935 log::debug!("Waiting for client to become ACTIVE before sending...");
936
937 let inner = tokio::time::timeout(timeout, async {
938 loop {
939 if self.is_active() {
940 return Ok(());
941 }
942 if matches!(
943 self.connection_mode(),
944 ConnectionMode::Disconnect | ConnectionMode::Closed
945 ) {
946 return Err(());
947 }
948 tokio::time::sleep(check_interval).await;
949 }
950 })
951 .await
952 .map_err(|_| SendError::Timeout)?;
953 inner.map_err(|()| SendError::Closed)?;
954 }
955
956 let msg = WriterCommand::Send(data.into());
957 self.writer_tx
958 .send(msg)
959 .map_err(|e| SendError::BrokenPipe(e.to_string()))
960 }
961
962 fn spawn_controller_task(
963 mut inner: SocketClientInner,
964 connection_mode: Arc<AtomicU8>,
965 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
966 post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
967 ) -> tokio::task::JoinHandle<()> {
968 tokio::task::spawn(async move {
969 log_task_started("controller");
970
971 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
972
973 loop {
974 tokio::time::sleep(check_interval).await;
975 let mut mode = ConnectionMode::from_atomic(&connection_mode);
976
977 if mode.is_disconnect() {
978 log::debug!("Disconnecting");
979
980 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
981 if tokio::time::timeout(timeout, async {
982 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
984
985 if !inner.read_task.is_finished() {
986 inner.read_task.abort();
987 log_task_aborted("read");
988 }
989
990 if let Some(task) = &inner.heartbeat_task
991 && !task.is_finished()
992 {
993 task.abort();
994 log_task_aborted("heartbeat");
995 }
996 })
997 .await
998 .is_err()
999 {
1000 log::error!("Shutdown timed out after {}s", timeout.as_secs());
1001 }
1002
1003 log::debug!("Closed");
1004
1005 if let Some(ref handler) = post_disconnection {
1006 handler();
1007 log::debug!("Called `post_disconnection` handler");
1008 }
1009 break; }
1011
1012 if mode.is_closed() {
1013 log::debug!("Connection closed");
1014 break;
1015 }
1016
1017 if mode.is_active() && !inner.is_alive() {
1018 if connection_mode
1019 .compare_exchange(
1020 ConnectionMode::Active.as_u8(),
1021 ConnectionMode::Reconnect.as_u8(),
1022 Ordering::SeqCst,
1023 Ordering::SeqCst,
1024 )
1025 .is_ok()
1026 {
1027 log::debug!("Detected dead read task, transitioning to RECONNECT");
1028 }
1029 mode = ConnectionMode::from_atomic(&connection_mode);
1030 }
1031
1032 if mode.is_reconnect() {
1033 if let Some(max_attempts) = inner.reconnect_max_attempts
1035 && inner.reconnect_attempt_count >= max_attempts
1036 {
1037 log::error!(
1038 "Max reconnection attempts ({max_attempts}) exceeded, transitioning to CLOSED"
1039 );
1040 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1041 break;
1042 }
1043
1044 inner.reconnect_attempt_count += 1;
1045 match inner.reconnect().await {
1046 Ok(()) => {
1047 log::debug!("Reconnected successfully");
1048 inner.backoff.reset();
1049 inner.reconnect_attempt_count = 0; if ConnectionMode::from_atomic(&connection_mode).is_active() {
1052 if let Some(ref handler) = post_reconnection {
1053 handler();
1054 log::debug!("Called `post_reconnection` handler");
1055 }
1056 } else {
1057 log::debug!(
1058 "Skipping post_reconnection handlers due to disconnect state"
1059 );
1060 }
1061 }
1062 Err(e) => {
1063 let duration = inner.backoff.next_duration();
1064 log::warn!(
1065 "Reconnect attempt {} failed: {e}",
1066 inner.reconnect_attempt_count
1067 );
1068 if !duration.is_zero() {
1069 log::warn!("Backing off for {}s...", duration.as_secs_f64());
1070 }
1071 tokio::time::sleep(duration).await;
1072 }
1073 }
1074 }
1075 }
1076 inner
1077 .connection_mode
1078 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1079
1080 log_task_stopped("controller");
1081 })
1082 }
1083}
1084
1085impl Drop for SocketClient {
1087 fn drop(&mut self) {
1088 if !self.controller_task.is_finished() {
1089 self.controller_task.abort();
1090 log_task_aborted("controller");
1091 }
1092 }
1093}
1094
1095#[cfg(test)]
1096#[cfg(feature = "python")]
1097#[cfg(target_os = "linux")] mod tests {
1099 use nautilus_common::testing::wait_until_async;
1100 use pyo3::Python;
1101 use tokio::{
1102 io::{AsyncReadExt, AsyncWriteExt},
1103 net::{TcpListener, TcpStream},
1104 sync::Mutex,
1105 task,
1106 time::{Duration, sleep},
1107 };
1108
1109 use super::*;
1110
1111 async fn bind_test_server() -> (u16, TcpListener) {
1112 let listener = TcpListener::bind("127.0.0.1:0")
1113 .await
1114 .expect("Failed to bind ephemeral port");
1115 let port = listener.local_addr().unwrap().port();
1116 (port, listener)
1117 }
1118
1119 async fn run_echo_server(mut socket: TcpStream) {
1120 let mut buf = Vec::new();
1121 loop {
1122 match socket.read_buf(&mut buf).await {
1123 Ok(0) => {
1124 break;
1125 }
1126 Ok(_n) => {
1127 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1128 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1129 line.truncate(line.len() - 2);
1131
1132 if line == b"close" {
1133 let _ = socket.shutdown().await;
1134 return;
1135 }
1136
1137 let mut echo_data = line;
1138 echo_data.extend_from_slice(b"\r\n");
1139 if socket.write_all(&echo_data).await.is_err() {
1140 break;
1141 }
1142 }
1143 }
1144 Err(e) => {
1145 eprintln!("Server read error: {e}");
1146 break;
1147 }
1148 }
1149 }
1150 }
1151
1152 #[tokio::test]
1153 async fn test_basic_send_receive() {
1154 Python::initialize();
1155
1156 let (port, listener) = bind_test_server().await;
1157 let server_task = task::spawn(async move {
1158 let (socket, _) = listener.accept().await.unwrap();
1159 run_echo_server(socket).await;
1160 });
1161
1162 let config = SocketConfig {
1163 url: format!("127.0.0.1:{port}"),
1164 mode: Mode::Plain,
1165 suffix: b"\r\n".to_vec(),
1166 message_handler: None,
1167 heartbeat: None,
1168 reconnect_timeout_ms: None,
1169 reconnect_delay_initial_ms: None,
1170 reconnect_backoff_factor: None,
1171 reconnect_delay_max_ms: None,
1172 reconnect_jitter_ms: None,
1173 reconnect_max_attempts: None,
1174 connection_max_retries: None,
1175 certs_dir: None,
1176 };
1177
1178 let client = SocketClient::connect(config, None, None, None)
1179 .await
1180 .expect("Client connect failed unexpectedly");
1181
1182 client.send_bytes(b"Hello".into()).await.unwrap();
1183 client.send_bytes(b"World".into()).await.unwrap();
1184
1185 sleep(Duration::from_millis(100)).await;
1187
1188 client.send_bytes(b"close".into()).await.unwrap();
1189 server_task.await.unwrap();
1190 assert!(!client.is_closed());
1191 }
1192
1193 #[tokio::test]
1194 async fn test_reconnect_fail_exhausted() {
1195 Python::initialize();
1196
1197 let (port, listener) = bind_test_server().await;
1198 drop(listener); wait_until_async(
1202 || async {
1203 TcpStream::connect(format!("127.0.0.1:{port}"))
1204 .await
1205 .is_err()
1206 },
1207 Duration::from_secs(2),
1208 )
1209 .await;
1210
1211 let config = SocketConfig {
1212 url: format!("127.0.0.1:{port}"),
1213 mode: Mode::Plain,
1214 suffix: b"\r\n".to_vec(),
1215 message_handler: None,
1216 heartbeat: None,
1217 reconnect_timeout_ms: Some(100),
1218 reconnect_delay_initial_ms: Some(50),
1219 reconnect_backoff_factor: Some(1.0),
1220 reconnect_delay_max_ms: Some(50),
1221 reconnect_jitter_ms: Some(0),
1222 connection_max_retries: Some(1),
1223 reconnect_max_attempts: None,
1224 certs_dir: None,
1225 };
1226
1227 let client_res = SocketClient::connect(config, None, None, None).await;
1228 assert!(
1229 client_res.is_err(),
1230 "Should fail quickly with no server listening"
1231 );
1232 }
1233
1234 #[tokio::test]
1235 async fn test_user_disconnect() {
1236 Python::initialize();
1237
1238 let (port, listener) = bind_test_server().await;
1239 let server_task = task::spawn(async move {
1240 let (socket, _) = listener.accept().await.unwrap();
1241 let mut buf = [0u8; 1024];
1242 let _ = socket.try_read(&mut buf);
1243
1244 loop {
1245 sleep(Duration::from_secs(1)).await;
1246 }
1247 });
1248
1249 let config = SocketConfig {
1250 url: format!("127.0.0.1:{port}"),
1251 mode: Mode::Plain,
1252 suffix: b"\r\n".to_vec(),
1253 message_handler: None,
1254 heartbeat: None,
1255 reconnect_timeout_ms: None,
1256 reconnect_delay_initial_ms: None,
1257 reconnect_backoff_factor: None,
1258 reconnect_delay_max_ms: None,
1259 reconnect_jitter_ms: None,
1260 reconnect_max_attempts: None,
1261 connection_max_retries: None,
1262 certs_dir: None,
1263 };
1264
1265 let client = SocketClient::connect(config, None, None, None)
1266 .await
1267 .unwrap();
1268
1269 client.close().await;
1270 assert!(client.is_closed());
1271 server_task.abort();
1272 }
1273
1274 #[tokio::test]
1275 async fn test_heartbeat() {
1276 Python::initialize();
1277
1278 let (port, listener) = bind_test_server().await;
1279 let received = Arc::new(Mutex::new(Vec::new()));
1280 let received2 = received.clone();
1281
1282 let server_task = task::spawn(async move {
1283 let (socket, _) = listener.accept().await.unwrap();
1284
1285 let mut buf = Vec::new();
1286 loop {
1287 match socket.try_read_buf(&mut buf) {
1288 Ok(0) => break,
1289 Ok(_) => {
1290 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1291 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1292 line.truncate(line.len() - 2);
1293 received2.lock().await.push(line);
1294 }
1295 }
1296 Err(_) => {
1297 tokio::time::sleep(Duration::from_millis(10)).await;
1298 }
1299 }
1300 }
1301 });
1302
1303 let heartbeat = Some((1, b"ping".to_vec()));
1305
1306 let config = SocketConfig {
1307 url: format!("127.0.0.1:{port}"),
1308 mode: Mode::Plain,
1309 suffix: b"\r\n".to_vec(),
1310 message_handler: None,
1311 heartbeat,
1312 reconnect_timeout_ms: None,
1313 reconnect_delay_initial_ms: None,
1314 reconnect_backoff_factor: None,
1315 reconnect_delay_max_ms: None,
1316 reconnect_jitter_ms: None,
1317 reconnect_max_attempts: None,
1318 connection_max_retries: None,
1319 certs_dir: None,
1320 };
1321
1322 let client = SocketClient::connect(config, None, None, None)
1323 .await
1324 .unwrap();
1325
1326 sleep(Duration::from_secs(3)).await;
1328
1329 {
1330 let lock = received.lock().await;
1331 let pings = lock
1332 .iter()
1333 .filter(|line| line == &&b"ping".to_vec())
1334 .count();
1335 assert!(
1336 pings >= 2,
1337 "Expected at least 2 heartbeat pings; got {pings}"
1338 );
1339 }
1340
1341 client.close().await;
1342 server_task.abort();
1343 }
1344
1345 #[tokio::test]
1346 async fn test_reconnect_success() {
1347 Python::initialize();
1348
1349 let (port, listener) = bind_test_server().await;
1350
1351 let server_task = task::spawn(async move {
1355 let (mut socket, _) = listener.accept().await.expect("First accept failed");
1357
1358 sleep(Duration::from_millis(500)).await;
1360 let _ = socket.shutdown().await;
1361
1362 sleep(Duration::from_millis(500)).await;
1364
1365 let (socket, _) = listener.accept().await.expect("Second accept failed");
1367 run_echo_server(socket).await;
1368 });
1369
1370 let config = SocketConfig {
1371 url: format!("127.0.0.1:{port}"),
1372 mode: Mode::Plain,
1373 suffix: b"\r\n".to_vec(),
1374 message_handler: None,
1375 heartbeat: None,
1376 reconnect_timeout_ms: Some(5_000),
1377 reconnect_delay_initial_ms: Some(500),
1378 reconnect_delay_max_ms: Some(5_000),
1379 reconnect_backoff_factor: Some(2.0),
1380 reconnect_jitter_ms: Some(50),
1381 reconnect_max_attempts: None,
1382 connection_max_retries: None,
1383 certs_dir: None,
1384 };
1385
1386 let client = SocketClient::connect(config, None, None, None)
1387 .await
1388 .expect("Client connect failed unexpectedly");
1389
1390 assert!(client.is_active(), "Client should start as active");
1392
1393 wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1396
1397 client
1398 .send_bytes(b"TestReconnect".into())
1399 .await
1400 .expect("Send failed");
1401
1402 client.close().await;
1403 server_task.abort();
1404 }
1405}
1406
1407#[cfg(test)]
1408#[cfg(not(feature = "turmoil"))]
1409mod rust_tests {
1410 use nautilus_common::testing::wait_until_async;
1411 use rstest::rstest;
1412 use tokio::{
1413 io::{AsyncReadExt, AsyncWriteExt},
1414 net::TcpListener,
1415 task,
1416 time::{Duration, sleep},
1417 };
1418
1419 use super::*;
1420
1421 #[rstest]
1422 #[tokio::test]
1423 async fn test_reconnect_then_close() {
1424 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1426 let port = listener.local_addr().unwrap().port();
1427
1428 let server = task::spawn(async move {
1430 if let Ok((mut sock, _)) = listener.accept().await {
1431 drop(sock.shutdown());
1432 }
1433 sleep(Duration::from_secs(1)).await;
1435 });
1436
1437 let config = SocketConfig {
1439 url: format!("127.0.0.1:{port}"),
1440 mode: Mode::Plain,
1441 suffix: b"\r\n".to_vec(),
1442 message_handler: None,
1443 heartbeat: None,
1444 reconnect_timeout_ms: Some(1_000),
1445 reconnect_delay_initial_ms: Some(50),
1446 reconnect_delay_max_ms: Some(100),
1447 reconnect_backoff_factor: Some(1.0),
1448 reconnect_jitter_ms: Some(0),
1449 connection_max_retries: Some(1),
1450 reconnect_max_attempts: None,
1451 certs_dir: None,
1452 };
1453
1454 let client = SocketClient::connect(config.clone(), None, None, None)
1456 .await
1457 .unwrap();
1458
1459 wait_until_async(
1461 || async { client.is_reconnecting() },
1462 Duration::from_secs(2),
1463 )
1464 .await;
1465
1466 client.close().await;
1468 assert!(client.is_closed());
1469 server.abort();
1470 }
1471
1472 #[rstest]
1473 #[tokio::test]
1474 async fn test_reconnect_state_flips_when_reader_stops() {
1475 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1477 let port = listener.local_addr().unwrap().port();
1478
1479 let server = task::spawn(async move {
1480 if let Ok((sock, _)) = listener.accept().await {
1481 drop(sock);
1482 }
1483 sleep(Duration::from_millis(50)).await;
1485 });
1486
1487 let config = SocketConfig {
1488 url: format!("127.0.0.1:{port}"),
1489 mode: Mode::Plain,
1490 suffix: b"\r\n".to_vec(),
1491 message_handler: None,
1492 heartbeat: None,
1493 reconnect_timeout_ms: Some(1_000),
1494 reconnect_delay_initial_ms: Some(50),
1495 reconnect_delay_max_ms: Some(100),
1496 reconnect_backoff_factor: Some(1.0),
1497 reconnect_jitter_ms: Some(0),
1498 connection_max_retries: Some(1),
1499 reconnect_max_attempts: None,
1500 certs_dir: None,
1501 };
1502
1503 let client = SocketClient::connect(config, None, None, None)
1504 .await
1505 .unwrap();
1506
1507 wait_until_async(
1508 || async { client.is_reconnecting() },
1509 Duration::from_secs(2),
1510 )
1511 .await;
1512
1513 client.close().await;
1514 server.abort();
1515 }
1516
1517 #[rstest]
1518 fn test_parse_socket_url_raw_address() {
1519 let (socket_addr, request_url) =
1521 SocketClientInner::parse_socket_url("example.com:6130", Mode::Tls).unwrap();
1522 assert_eq!(socket_addr, "example.com:6130");
1523 assert_eq!(request_url, "wss://example.com:6130");
1524
1525 let (socket_addr, request_url) =
1527 SocketClientInner::parse_socket_url("localhost:8080", Mode::Plain).unwrap();
1528 assert_eq!(socket_addr, "localhost:8080");
1529 assert_eq!(request_url, "ws://localhost:8080");
1530 }
1531
1532 #[rstest]
1533 fn test_parse_socket_url_with_scheme() {
1534 let (socket_addr, request_url) =
1536 SocketClientInner::parse_socket_url("wss://example.com:443/path", Mode::Tls).unwrap();
1537 assert_eq!(socket_addr, "example.com:443");
1538 assert_eq!(request_url, "wss://example.com:443/path");
1539
1540 let (socket_addr, request_url) =
1542 SocketClientInner::parse_socket_url("ws://localhost:8080", Mode::Plain).unwrap();
1543 assert_eq!(socket_addr, "localhost:8080");
1544 assert_eq!(request_url, "ws://localhost:8080");
1545 }
1546
1547 #[rstest]
1548 fn test_parse_socket_url_default_ports() {
1549 let (socket_addr, _) =
1551 SocketClientInner::parse_socket_url("wss://example.com", Mode::Tls).unwrap();
1552 assert_eq!(socket_addr, "example.com:443");
1553
1554 let (socket_addr, _) =
1556 SocketClientInner::parse_socket_url("ws://example.com", Mode::Plain).unwrap();
1557 assert_eq!(socket_addr, "example.com:80");
1558
1559 let (socket_addr, _) =
1561 SocketClientInner::parse_socket_url("https://example.com", Mode::Tls).unwrap();
1562 assert_eq!(socket_addr, "example.com:443");
1563
1564 let (socket_addr, _) =
1566 SocketClientInner::parse_socket_url("http://example.com", Mode::Plain).unwrap();
1567 assert_eq!(socket_addr, "example.com:80");
1568 }
1569
1570 #[rstest]
1571 fn test_parse_socket_url_unknown_scheme_uses_mode() {
1572 let (socket_addr, _) =
1574 SocketClientInner::parse_socket_url("custom://example.com", Mode::Tls).unwrap();
1575 assert_eq!(socket_addr, "example.com:443");
1576
1577 let (socket_addr, _) =
1578 SocketClientInner::parse_socket_url("custom://example.com", Mode::Plain).unwrap();
1579 assert_eq!(socket_addr, "example.com:80");
1580 }
1581
1582 #[rstest]
1583 fn test_parse_socket_url_ipv6() {
1584 let (socket_addr, request_url) =
1586 SocketClientInner::parse_socket_url("[::1]:8080", Mode::Plain).unwrap();
1587 assert_eq!(socket_addr, "[::1]:8080");
1588 assert_eq!(request_url, "ws://[::1]:8080");
1589
1590 let (socket_addr, _) =
1592 SocketClientInner::parse_socket_url("ws://[::1]:8080", Mode::Plain).unwrap();
1593 assert_eq!(socket_addr, "[::1]:8080");
1594 }
1595
1596 #[rstest]
1597 #[tokio::test]
1598 async fn test_url_parsing_raw_socket_address() {
1599 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1601 let port = listener.local_addr().unwrap().port();
1602
1603 let server = task::spawn(async move {
1604 if let Ok((sock, _)) = listener.accept().await {
1605 drop(sock);
1606 }
1607 sleep(Duration::from_millis(50)).await;
1608 });
1609
1610 let config = SocketConfig {
1611 url: format!("127.0.0.1:{port}"), mode: Mode::Plain,
1613 suffix: b"\r\n".to_vec(),
1614 message_handler: None,
1615 heartbeat: None,
1616 reconnect_timeout_ms: Some(1_000),
1617 reconnect_delay_initial_ms: Some(50),
1618 reconnect_delay_max_ms: Some(100),
1619 reconnect_backoff_factor: Some(1.0),
1620 reconnect_jitter_ms: Some(0),
1621 connection_max_retries: Some(1),
1622 reconnect_max_attempts: None,
1623 certs_dir: None,
1624 };
1625
1626 let client = SocketClient::connect(config, None, None, None).await;
1628 assert!(
1629 client.is_ok(),
1630 "Client should connect with raw socket address format"
1631 );
1632
1633 if let Ok(client) = client {
1634 client.close().await;
1635 }
1636 server.abort();
1637 }
1638
1639 #[rstest]
1640 #[tokio::test]
1641 async fn test_url_parsing_with_scheme() {
1642 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1644 let port = listener.local_addr().unwrap().port();
1645
1646 let server = task::spawn(async move {
1647 if let Ok((sock, _)) = listener.accept().await {
1648 drop(sock);
1649 }
1650 sleep(Duration::from_millis(50)).await;
1651 });
1652
1653 let config = SocketConfig {
1654 url: format!("ws://127.0.0.1:{port}"), mode: Mode::Plain,
1656 suffix: b"\r\n".to_vec(),
1657 message_handler: None,
1658 heartbeat: None,
1659 reconnect_timeout_ms: Some(1_000),
1660 reconnect_delay_initial_ms: Some(50),
1661 reconnect_delay_max_ms: Some(100),
1662 reconnect_backoff_factor: Some(1.0),
1663 reconnect_jitter_ms: Some(0),
1664 connection_max_retries: Some(1),
1665 reconnect_max_attempts: None,
1666 certs_dir: None,
1667 };
1668
1669 let client = SocketClient::connect(config, None, None, None).await;
1671 assert!(
1672 client.is_ok(),
1673 "Client should connect with URL scheme format"
1674 );
1675
1676 if let Ok(client) = client {
1677 client.close().await;
1678 }
1679 server.abort();
1680 }
1681
1682 #[rstest]
1683 fn test_parse_socket_url_ipv6_with_zone() {
1684 let (socket_addr, request_url) =
1686 SocketClientInner::parse_socket_url("[fe80::1%eth0]:8080", Mode::Plain).unwrap();
1687 assert_eq!(socket_addr, "[fe80::1%eth0]:8080");
1688 assert_eq!(request_url, "ws://[fe80::1%eth0]:8080");
1689
1690 let (socket_addr, request_url) =
1692 SocketClientInner::parse_socket_url("ws://[fe80::1%lo]:9090", Mode::Plain).unwrap();
1693 assert_eq!(socket_addr, "[fe80::1%lo]:9090");
1694 assert_eq!(request_url, "ws://[fe80::1%lo]:9090");
1695 }
1696
1697 #[rstest]
1698 #[tokio::test]
1699 async fn test_ipv6_loopback_connection() {
1700 if TcpListener::bind("[::1]:0").await.is_err() {
1703 eprintln!("IPv6 not available, skipping test");
1704 return;
1705 }
1706
1707 let listener = TcpListener::bind("[::1]:0").await.unwrap();
1708 let port = listener.local_addr().unwrap().port();
1709
1710 let server = task::spawn(async move {
1711 if let Ok((mut sock, _)) = listener.accept().await {
1712 let mut buf = vec![0u8; 1024];
1713 if let Ok(n) = sock.read(&mut buf).await {
1714 let _ = sock.write_all(&buf[..n]).await;
1716 }
1717 }
1718 sleep(Duration::from_millis(50)).await;
1719 });
1720
1721 let config = SocketConfig {
1722 url: format!("[::1]:{port}"), mode: Mode::Plain,
1724 suffix: b"\r\n".to_vec(),
1725 message_handler: None,
1726 heartbeat: None,
1727 reconnect_timeout_ms: Some(1_000),
1728 reconnect_delay_initial_ms: Some(50),
1729 reconnect_delay_max_ms: Some(100),
1730 reconnect_backoff_factor: Some(1.0),
1731 reconnect_jitter_ms: Some(0),
1732 connection_max_retries: Some(1),
1733 reconnect_max_attempts: None,
1734 certs_dir: None,
1735 };
1736
1737 let client = SocketClient::connect(config, None, None, None).await;
1738 assert!(
1739 client.is_ok(),
1740 "Client should connect to IPv6 loopback address"
1741 );
1742
1743 if let Ok(client) = client {
1744 client.close().await;
1745 }
1746 server.abort();
1747 }
1748
1749 #[rstest]
1750 #[tokio::test]
1751 async fn test_send_waits_during_reconnection() {
1752 use nautilus_common::testing::wait_until_async;
1754
1755 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1756 let port = listener.local_addr().unwrap().port();
1757
1758 let server = task::spawn(async move {
1759 if let Ok((sock, _)) = listener.accept().await {
1761 drop(sock);
1762 }
1763
1764 sleep(Duration::from_millis(500)).await;
1766
1767 if let Ok((mut sock, _)) = listener.accept().await {
1769 let mut buf = vec![0u8; 1024];
1771 while let Ok(n) = sock.read(&mut buf).await {
1772 if n == 0 {
1773 break;
1774 }
1775 if sock.write_all(&buf[..n]).await.is_err() {
1776 break;
1777 }
1778 }
1779 }
1780 });
1781
1782 let config = SocketConfig {
1783 url: format!("127.0.0.1:{port}"),
1784 mode: Mode::Plain,
1785 suffix: b"\r\n".to_vec(),
1786 message_handler: None,
1787 heartbeat: None,
1788 reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
1790 reconnect_delay_max_ms: Some(200),
1791 reconnect_backoff_factor: Some(1.0),
1792 reconnect_jitter_ms: Some(0),
1793 connection_max_retries: Some(1),
1794 reconnect_max_attempts: None,
1795 certs_dir: None,
1796 };
1797
1798 let client = SocketClient::connect(config, None, None, None)
1799 .await
1800 .unwrap();
1801
1802 wait_until_async(
1804 || async { client.is_reconnecting() },
1805 Duration::from_secs(2),
1806 )
1807 .await;
1808
1809 let send_result = tokio::time::timeout(
1811 Duration::from_secs(3),
1812 client.send_bytes(b"test_message".to_vec()),
1813 )
1814 .await;
1815
1816 assert!(
1817 send_result.is_ok() && send_result.unwrap().is_ok(),
1818 "Send should succeed after waiting for reconnection"
1819 );
1820
1821 client.close().await;
1822 server.abort();
1823 }
1824
1825 #[rstest]
1826 #[tokio::test]
1827 async fn test_send_bytes_timeout_uses_configured_reconnect_timeout() {
1828 use nautilus_common::testing::wait_until_async;
1831
1832 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1833 let port = listener.local_addr().unwrap().port();
1834
1835 let server = task::spawn(async move {
1836 if let Ok((sock, _)) = listener.accept().await {
1838 drop(sock);
1839 }
1840 drop(listener);
1842 sleep(Duration::from_secs(60)).await;
1843 });
1844
1845 let config = SocketConfig {
1846 url: format!("127.0.0.1:{port}"),
1847 mode: Mode::Plain,
1848 suffix: b"\r\n".to_vec(),
1849 message_handler: None,
1850 heartbeat: None,
1851 reconnect_timeout_ms: Some(1_000), reconnect_delay_initial_ms: Some(200), reconnect_delay_max_ms: Some(200),
1854 reconnect_backoff_factor: Some(1.0),
1855 reconnect_jitter_ms: Some(0),
1856 connection_max_retries: Some(1),
1857 reconnect_max_attempts: None,
1858 certs_dir: None,
1859 };
1860
1861 let client = SocketClient::connect(config, None, None, None)
1862 .await
1863 .unwrap();
1864
1865 wait_until_async(
1867 || async { client.is_reconnecting() },
1868 Duration::from_secs(3),
1869 )
1870 .await;
1871
1872 let start = std::time::Instant::now();
1875 let send_result = client.send_bytes(b"test".to_vec()).await;
1876 let elapsed = start.elapsed();
1877
1878 assert!(
1879 send_result.is_err(),
1880 "Send should fail when client stuck in RECONNECT, was: {send_result:?}"
1881 );
1882 assert!(
1883 matches!(send_result, Err(crate::error::SendError::Timeout)),
1884 "Send should return Timeout error, was: {send_result:?}"
1885 );
1886 assert!(
1889 elapsed >= Duration::from_millis(900),
1890 "Send should timeout after at least 1s (configured timeout), took {elapsed:?}"
1891 );
1892
1893 client.close().await;
1894 server.abort();
1895 }
1896
1897 #[rstest]
1898 #[tokio::test]
1899 async fn test_empty_suffix_rejected() {
1900 let config = SocketConfig {
1901 url: "127.0.0.1:9999".to_string(),
1902 mode: Mode::Plain,
1903 suffix: vec![],
1904 message_handler: None,
1905 heartbeat: None,
1906 reconnect_timeout_ms: None,
1907 reconnect_delay_initial_ms: None,
1908 reconnect_delay_max_ms: None,
1909 reconnect_backoff_factor: None,
1910 reconnect_jitter_ms: None,
1911 reconnect_max_attempts: None,
1912 connection_max_retries: Some(1),
1913 certs_dir: None,
1914 };
1915
1916 let result = SocketClient::connect(config, None, None, None).await;
1917
1918 assert!(
1919 result.is_err(),
1920 "Empty suffix should cause connection to fail"
1921 );
1922 let err_msg = result.unwrap_err().to_string();
1923 assert!(
1924 err_msg.contains("suffix cannot be empty"),
1925 "Error should mention empty suffix, got: {err_msg}"
1926 );
1927 }
1928}