1use std::{
32 collections::VecDeque,
33 fmt::Debug,
34 path::Path,
35 sync::{
36 Arc,
37 atomic::{AtomicU8, Ordering},
38 },
39 time::Duration,
40};
41
42use bytes::Bytes;
43use nautilus_core::CleanDrop;
44use nautilus_cryptography::providers::install_cryptographic_provider;
45use tokio::io::{AsyncReadExt, AsyncWriteExt};
46use tokio_tungstenite::tungstenite::{Error, client::IntoClientRequest, stream::Mode};
47
48use super::{
49 SocketConfig, TcpMessageHandler, TcpReader, TcpWriter, WriterCommand, fix::process_fix_buffer,
50};
51use crate::{
52 backoff::ExponentialBackoff,
53 error::SendError,
54 logging::{log_task_aborted, log_task_started, log_task_stopped},
55 mode::ConnectionMode,
56 net::TcpStream,
57 tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
58};
59
60const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
62const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
63const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
64const SEND_OPERATION_CHECK_INTERVAL_MS: u64 = 1;
65
66#[cfg_attr(
82 feature = "python",
83 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
84)]
85struct SocketClientInner {
86 config: SocketConfig,
87 connector: Option<Connector>,
88 read_task: Arc<tokio::task::JoinHandle<()>>,
89 write_task: tokio::task::JoinHandle<()>,
90 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
91 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
92 connection_mode: Arc<AtomicU8>,
93 reconnect_timeout: Duration,
94 backoff: ExponentialBackoff,
95 handler: Option<TcpMessageHandler>,
96 reconnect_max_attempts: Option<u32>,
97 reconnect_attempt_count: u32,
98}
99
100impl SocketClientInner {
101 pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
107 install_cryptographic_provider();
108
109 let SocketConfig {
110 url,
111 mode,
112 heartbeat,
113 suffix,
114 message_handler,
115 reconnect_timeout_ms,
116 reconnect_delay_initial_ms,
117 reconnect_delay_max_ms,
118 reconnect_backoff_factor,
119 reconnect_jitter_ms,
120 connection_max_retries,
121 reconnect_max_attempts,
122 certs_dir,
123 } = &config.clone();
124 let connector = if let Some(dir) = certs_dir {
125 let config = create_tls_config_from_certs_dir(Path::new(dir), false)?;
126 Some(Connector::Rustls(Arc::new(config)))
127 } else {
128 None
129 };
130
131 const CONNECTION_TIMEOUT_SECS: u64 = 10;
133 let max_retries = connection_max_retries.unwrap_or(5);
134
135 let mut backoff = ExponentialBackoff::new(
136 Duration::from_millis(500),
137 Duration::from_millis(5000),
138 2.0,
139 250,
140 false,
141 )?;
142
143 #[allow(unused_assignments)]
144 let mut last_error = String::new();
145 let mut attempt = 0;
146 let (reader, writer) = loop {
147 attempt += 1;
148
149 match tokio::time::timeout(
150 Duration::from_secs(CONNECTION_TIMEOUT_SECS),
151 Self::tls_connect_with_server(url, *mode, connector.clone()),
152 )
153 .await
154 {
155 Ok(Ok(result)) => {
156 if attempt > 1 {
157 tracing::info!("Socket connection established after {attempt} attempts");
158 }
159 break result;
160 }
161 Ok(Err(e)) => {
162 last_error = e.to_string();
163 tracing::warn!(
164 attempt,
165 max_retries,
166 url = %url,
167 error = %last_error,
168 "Socket connection attempt failed"
169 );
170 }
171 Err(_) => {
172 last_error = format!(
173 "Connection timeout after {CONNECTION_TIMEOUT_SECS}s (possible DNS resolution failure)"
174 );
175 tracing::warn!(
176 attempt,
177 max_retries,
178 url = %url,
179 "Socket connection attempt timed out"
180 );
181 }
182 }
183
184 if attempt >= max_retries {
185 anyhow::bail!(
186 "Failed to connect to {} after {} attempts: {}. \
187 If this is a DNS error, check your network configuration and DNS settings.",
188 url,
189 max_retries,
190 if last_error.is_empty() {
191 "unknown error"
192 } else {
193 &last_error
194 }
195 );
196 }
197
198 let delay = backoff.next_duration();
199 tracing::debug!(
200 "Retrying in {delay:?} (attempt {}/{})",
201 attempt + 1,
202 max_retries
203 );
204 tokio::time::sleep(delay).await;
205 };
206
207 tracing::debug!("Connected");
208
209 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
210
211 let read_task = Arc::new(Self::spawn_read_task(
212 connection_mode.clone(),
213 reader,
214 message_handler.clone(),
215 suffix.clone(),
216 ));
217
218 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
219
220 let write_task =
221 Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
222
223 let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
225 Self::spawn_heartbeat_task(
226 connection_mode.clone(),
227 heartbeat.clone(),
228 writer_tx.clone(),
229 )
230 });
231
232 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
233 let backoff = ExponentialBackoff::new(
234 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
235 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
236 reconnect_backoff_factor.unwrap_or(1.5),
237 reconnect_jitter_ms.unwrap_or(100),
238 true, )?;
240
241 Ok(Self {
242 config,
243 connector,
244 read_task,
245 write_task,
246 writer_tx,
247 heartbeat_task,
248 connection_mode,
249 reconnect_timeout,
250 backoff,
251 handler: message_handler.clone(),
252 reconnect_max_attempts: *reconnect_max_attempts,
253 reconnect_attempt_count: 0,
254 })
255 }
256
257 fn parse_socket_url(url: &str, mode: Mode) -> Result<(String, String), Error> {
267 if url.contains("://") {
268 let parsed = url.parse::<http::Uri>().map_err(|e| {
270 Error::Io(std::io::Error::new(
271 std::io::ErrorKind::InvalidInput,
272 format!("Invalid URL: {e}"),
273 ))
274 })?;
275
276 let host = parsed.host().ok_or_else(|| {
277 Error::Io(std::io::Error::new(
278 std::io::ErrorKind::InvalidInput,
279 "URL missing host",
280 ))
281 })?;
282
283 let port = parsed
284 .port_u16()
285 .unwrap_or_else(|| match parsed.scheme_str() {
286 Some("wss" | "https") => 443,
287 Some("ws" | "http") => 80,
288 _ => match mode {
289 Mode::Tls => 443,
290 Mode::Plain => 80,
291 },
292 });
293
294 Ok((format!("{host}:{port}"), url.to_string()))
295 } else {
296 let scheme = match mode {
299 Mode::Tls => "wss",
300 Mode::Plain => "ws",
301 };
302 Ok((url.to_string(), format!("{scheme}://{url}")))
303 }
304 }
305
306 pub async fn tls_connect_with_server(
316 url: &str,
317 mode: Mode,
318 connector: Option<Connector>,
319 ) -> Result<(TcpReader, TcpWriter), Error> {
320 tracing::debug!("Connecting to {url}");
321
322 let (socket_addr, request_url) = Self::parse_socket_url(url, mode)?;
323 let tcp_result = TcpStream::connect(&socket_addr).await;
324
325 match tcp_result {
326 Ok(stream) => {
327 tracing::debug!("TCP connection established to {socket_addr}, proceeding with TLS");
328 if let Err(e) = stream.set_nodelay(true) {
329 tracing::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
330 }
331 let request = request_url.into_client_request()?;
332 tcp_tls(&request, mode, stream, connector)
333 .await
334 .map(tokio::io::split)
335 }
336 Err(e) => {
337 tracing::error!("TCP connection failed to {socket_addr}: {e:?}");
338 Err(Error::Io(e))
339 }
340 }
341 }
342
343 async fn reconnect(&mut self) -> Result<(), Error> {
348 tracing::debug!("Reconnecting");
349
350 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
351 tracing::debug!("Reconnect aborted due to disconnect state");
352 return Ok(());
353 }
354
355 tokio::time::timeout(self.reconnect_timeout, async {
356 let SocketConfig {
357 url,
358 mode,
359 heartbeat: _,
360 suffix,
361 message_handler: _,
362 reconnect_timeout_ms: _,
363 reconnect_delay_initial_ms: _,
364 reconnect_backoff_factor: _,
365 reconnect_delay_max_ms: _,
366 reconnect_jitter_ms: _,
367 connection_max_retries: _,
368 reconnect_max_attempts: _,
369 certs_dir: _,
370 } = &self.config;
371 let connector = self.connector.clone();
373 let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
375
376 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
377 tracing::debug!("Reconnect aborted mid-flight (after connect)");
378 return Ok(());
379 }
380 tracing::debug!("Connected");
381
382 let (tx, rx) = tokio::sync::oneshot::channel();
386 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
387 tracing::error!("{e}");
388 return Err(Error::Io(std::io::Error::new(
389 std::io::ErrorKind::BrokenPipe,
390 format!("Failed to send update command: {e}"),
391 )));
392 }
393
394 match rx.await {
396 Ok(true) => tracing::debug!("Writer confirmed buffer drain success"),
397 Ok(false) => {
398 tracing::warn!("Writer failed to drain buffer, aborting reconnect");
399 return Err(Error::Io(std::io::Error::other(
401 "Failed to drain reconnection buffer",
402 )));
403 }
404 Err(e) => {
405 tracing::error!("Writer dropped update channel: {e}");
406 return Err(Error::Io(std::io::Error::new(
407 std::io::ErrorKind::BrokenPipe,
408 "Writer task dropped response channel",
409 )));
410 }
411 }
412
413 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
415
416 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
417 tracing::debug!("Reconnect aborted mid-flight (after delay)");
418 return Ok(());
419 }
420
421 if !self.read_task.is_finished() {
422 self.read_task.abort();
423 log_task_aborted("read");
424 }
425
426 if self
429 .connection_mode
430 .compare_exchange(
431 ConnectionMode::Reconnect.as_u8(),
432 ConnectionMode::Active.as_u8(),
433 Ordering::SeqCst,
434 Ordering::SeqCst,
435 )
436 .is_err()
437 {
438 tracing::debug!("Reconnect aborted (state changed during reconnect)");
439 return Ok(());
440 }
441
442 self.read_task = Arc::new(Self::spawn_read_task(
444 self.connection_mode.clone(),
445 reader,
446 self.handler.clone(),
447 suffix.clone(),
448 ));
449
450 tracing::debug!("Reconnect succeeded");
451 Ok(())
452 })
453 .await
454 .map_err(|_| {
455 Error::Io(std::io::Error::new(
456 std::io::ErrorKind::TimedOut,
457 format!(
458 "reconnection timed out after {}s",
459 self.reconnect_timeout.as_secs_f64()
460 ),
461 ))
462 })?
463 }
464
465 #[inline]
472 #[must_use]
473 pub fn is_alive(&self) -> bool {
474 !self.read_task.is_finished()
475 }
476
477 #[must_use]
478 fn spawn_read_task(
479 connection_state: Arc<AtomicU8>,
480 mut reader: TcpReader,
481 handler: Option<TcpMessageHandler>,
482 suffix: Vec<u8>,
483 ) -> tokio::task::JoinHandle<()> {
484 log_task_started("read");
485
486 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
488
489 tokio::task::spawn(async move {
490 let mut buf = Vec::new();
491
492 loop {
493 if !ConnectionMode::from_atomic(&connection_state).is_active() {
494 break;
495 }
496
497 match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
498 Ok(Ok(0)) => {
500 tracing::debug!("Connection closed by server");
501 break;
502 }
503 Ok(Err(e)) => {
504 tracing::debug!("Connection ended: {e}");
505 break;
506 }
507 Ok(Ok(bytes)) => {
509 tracing::trace!("Received <binary> {bytes} bytes");
510
511 let is_fix = buf.len() >= 5 && buf.starts_with(b"8=FIX");
513
514 if is_fix && handler.is_some() {
515 if let Some(ref handler) = handler {
517 process_fix_buffer(&mut buf, handler);
518 }
519 } else {
520 while let Some((i, _)) = &buf
522 .windows(suffix.len())
523 .enumerate()
524 .find(|(_, pair)| pair.eq(&suffix))
525 {
526 let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
527 data.truncate(data.len() - suffix.len());
528
529 if let Some(ref handler) = handler {
530 handler(&data);
531 }
532 }
533 }
534 }
535 Err(_) => {
536 continue;
538 }
539 }
540 }
541
542 log_task_stopped("read");
543 })
544 }
545
546 async fn drain_reconnect_buffer(
556 buffer: &mut VecDeque<Bytes>,
557 writer: &mut TcpWriter,
558 suffix: &[u8],
559 ) -> bool {
560 if buffer.is_empty() {
561 return false;
562 }
563
564 let initial_buffer_len = buffer.len();
565 tracing::info!(
566 "Sending {} buffered messages after reconnection",
567 initial_buffer_len
568 );
569
570 let mut send_error_occurred = false;
571
572 while let Some(buffered_msg) = buffer.front() {
573 let mut combined_msg = Vec::with_capacity(buffered_msg.len() + suffix.len());
574 combined_msg.extend_from_slice(buffered_msg);
575 combined_msg.extend_from_slice(suffix);
576
577 if let Err(e) = writer.write_all(&combined_msg).await {
578 tracing::error!(
579 "Failed to send buffered message with suffix after reconnection: {e}, {} messages remain in buffer",
580 buffer.len()
581 );
582 send_error_occurred = true;
583 break;
584 }
585
586 buffer.pop_front();
587 }
588
589 if buffer.is_empty() {
590 tracing::info!(
591 "Successfully sent all {} buffered messages",
592 initial_buffer_len
593 );
594 }
595
596 send_error_occurred
597 }
598
599 fn spawn_write_task(
600 connection_state: Arc<AtomicU8>,
601 writer: TcpWriter,
602 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
603 suffix: Vec<u8>,
604 ) -> tokio::task::JoinHandle<()> {
605 log_task_started("write");
606
607 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
609
610 tokio::task::spawn(async move {
611 let mut active_writer = writer;
612 let mut reconnect_buffer: VecDeque<Bytes> = VecDeque::new();
613
614 loop {
615 if matches!(
616 ConnectionMode::from_atomic(&connection_state),
617 ConnectionMode::Disconnect | ConnectionMode::Closed
618 ) {
619 break;
620 }
621
622 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
623 Ok(Some(msg)) => {
624 let mode = ConnectionMode::from_atomic(&connection_state);
626 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
627 break;
628 }
629
630 match msg {
631 WriterCommand::Update(new_writer, tx) => {
632 tracing::debug!("Received new writer");
633
634 tokio::time::sleep(Duration::from_millis(100)).await;
636
637 _ = tokio::time::timeout(
640 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
641 active_writer.shutdown(),
642 )
643 .await;
644
645 active_writer = new_writer;
646 tracing::debug!("Updated writer");
647
648 let send_error = Self::drain_reconnect_buffer(
649 &mut reconnect_buffer,
650 &mut active_writer,
651 &suffix,
652 )
653 .await;
654
655 if let Err(e) = tx.send(!send_error) {
656 tracing::error!(
657 "Failed to report drain status to controller: {e:?}"
658 );
659 }
660 }
661 _ if mode.is_reconnect() => {
662 if let WriterCommand::Send(data) = msg {
663 tracing::debug!(
664 "Buffering message while reconnecting ({} bytes)",
665 data.len()
666 );
667 reconnect_buffer.push_back(data);
668 }
669 continue;
670 }
671 WriterCommand::Send(msg) => {
672 if let Err(e) = active_writer.write_all(&msg).await {
673 tracing::error!("Failed to send message: {e}");
674 tracing::warn!("Writer triggering reconnect");
675 reconnect_buffer.push_back(msg);
676 connection_state
677 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
678 continue;
679 }
680 if let Err(e) = active_writer.write_all(&suffix).await {
681 tracing::error!("Failed to send suffix: {e}");
682 tracing::warn!("Writer triggering reconnect");
683 reconnect_buffer.push_back(msg);
685 connection_state
686 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
687 continue;
688 }
689 }
690 }
691 }
692 Ok(None) => {
693 tracing::debug!("Writer channel closed, terminating writer task");
695 break;
696 }
697 Err(_) => {
698 continue;
700 }
701 }
702 }
703
704 _ = tokio::time::timeout(
707 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
708 active_writer.shutdown(),
709 )
710 .await;
711
712 log_task_stopped("write");
713 })
714 }
715
716 fn spawn_heartbeat_task(
717 connection_state: Arc<AtomicU8>,
718 heartbeat: (u64, Vec<u8>),
719 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
720 ) -> tokio::task::JoinHandle<()> {
721 log_task_started("heartbeat");
722 let (interval_secs, message) = heartbeat;
723
724 tokio::task::spawn(async move {
725 let interval = Duration::from_secs(interval_secs);
726
727 loop {
728 tokio::time::sleep(interval).await;
729
730 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
731 ConnectionMode::Active => {
732 let msg = WriterCommand::Send(message.clone().into());
733
734 match writer_tx.send(msg) {
735 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
736 Err(e) => {
737 tracing::error!("Failed to send heartbeat to writer task: {e}");
738 }
739 }
740 }
741 ConnectionMode::Reconnect => continue,
742 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
743 }
744 }
745
746 log_task_stopped("heartbeat");
747 })
748 }
749}
750
751impl Drop for SocketClientInner {
752 fn drop(&mut self) {
753 self.clean_drop();
755 }
756}
757
758impl CleanDrop for SocketClientInner {
760 fn clean_drop(&mut self) {
761 if !self.read_task.is_finished() {
762 self.read_task.abort();
763 log_task_aborted("read");
764 }
765
766 if !self.write_task.is_finished() {
767 self.write_task.abort();
768 log_task_aborted("write");
769 }
770
771 if let Some(ref handle) = self.heartbeat_task.take()
772 && !handle.is_finished()
773 {
774 handle.abort();
775 log_task_aborted("heartbeat");
776 }
777
778 #[cfg(feature = "python")]
779 {
780 self.config.message_handler = None;
782 }
783 }
784}
785
786#[cfg_attr(
787 feature = "python",
788 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
789)]
790pub struct SocketClient {
791 pub(crate) controller_task: tokio::task::JoinHandle<()>,
792 pub(crate) connection_mode: Arc<AtomicU8>,
793 pub(crate) reconnect_timeout: Duration,
794 pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
795}
796
797impl Debug for SocketClient {
798 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
799 f.debug_struct(stringify!(SocketClient)).finish()
800 }
801}
802
803impl SocketClient {
804 pub async fn connect(
810 config: SocketConfig,
811 post_connection: Option<Arc<dyn Fn() + Send + Sync>>,
812 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
813 post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
814 ) -> anyhow::Result<Self> {
815 let inner = SocketClientInner::connect_url(config).await?;
816 let writer_tx = inner.writer_tx.clone();
817 let connection_mode = inner.connection_mode.clone();
818 let reconnect_timeout = inner.reconnect_timeout;
819
820 let controller_task = Self::spawn_controller_task(
821 inner,
822 connection_mode.clone(),
823 post_reconnection,
824 post_disconnection,
825 );
826
827 if let Some(handler) = post_connection {
828 handler();
829 tracing::debug!("Called `post_connection` handler");
830 }
831
832 Ok(Self {
833 controller_task,
834 connection_mode,
835 reconnect_timeout,
836 writer_tx,
837 })
838 }
839
840 #[must_use]
842 pub fn connection_mode(&self) -> ConnectionMode {
843 ConnectionMode::from_atomic(&self.connection_mode)
844 }
845
846 #[inline]
851 #[must_use]
852 pub fn is_active(&self) -> bool {
853 self.connection_mode().is_active()
854 }
855
856 #[inline]
861 #[must_use]
862 pub fn is_reconnecting(&self) -> bool {
863 self.connection_mode().is_reconnect()
864 }
865
866 #[inline]
870 #[must_use]
871 pub fn is_disconnecting(&self) -> bool {
872 self.connection_mode().is_disconnect()
873 }
874
875 #[inline]
881 #[must_use]
882 pub fn is_closed(&self) -> bool {
883 self.connection_mode().is_closed()
884 }
885
886 pub async fn close(&self) {
891 self.connection_mode
892 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
893
894 if let Ok(()) =
895 tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
896 while !self.is_closed() {
897 tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS))
898 .await;
899 }
900
901 if !self.controller_task.is_finished() {
902 self.controller_task.abort();
903 log_task_aborted("controller");
904 }
905 })
906 .await
907 {
908 log_task_stopped("controller");
909 } else {
910 tracing::error!("Timeout waiting for controller task to finish");
911 if !self.controller_task.is_finished() {
912 self.controller_task.abort();
913 log_task_aborted("controller");
914 }
915 }
916 }
917
918 pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
924 if self.is_closed() || self.is_disconnecting() {
926 return Err(SendError::Closed);
927 }
928
929 let timeout = self.reconnect_timeout;
930 let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
931
932 if !self.is_active() {
933 tracing::debug!("Waiting for client to become ACTIVE before sending...");
934
935 let inner = tokio::time::timeout(timeout, async {
936 loop {
937 if self.is_active() {
938 return Ok(());
939 }
940 if matches!(
941 self.connection_mode(),
942 ConnectionMode::Disconnect | ConnectionMode::Closed
943 ) {
944 return Err(());
945 }
946 tokio::time::sleep(check_interval).await;
947 }
948 })
949 .await
950 .map_err(|_| SendError::Timeout)?;
951 inner.map_err(|()| SendError::Closed)?;
952 }
953
954 let msg = WriterCommand::Send(data.into());
955 self.writer_tx
956 .send(msg)
957 .map_err(|e| SendError::BrokenPipe(e.to_string()))
958 }
959
960 fn spawn_controller_task(
961 mut inner: SocketClientInner,
962 connection_mode: Arc<AtomicU8>,
963 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
964 post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
965 ) -> tokio::task::JoinHandle<()> {
966 tokio::task::spawn(async move {
967 log_task_started("controller");
968
969 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
970
971 loop {
972 tokio::time::sleep(check_interval).await;
973 let mut mode = ConnectionMode::from_atomic(&connection_mode);
974
975 if mode.is_disconnect() {
976 tracing::debug!("Disconnecting");
977
978 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
979 if tokio::time::timeout(timeout, async {
980 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
982
983 if !inner.read_task.is_finished() {
984 inner.read_task.abort();
985 log_task_aborted("read");
986 }
987
988 if let Some(task) = &inner.heartbeat_task
989 && !task.is_finished()
990 {
991 task.abort();
992 log_task_aborted("heartbeat");
993 }
994 })
995 .await
996 .is_err()
997 {
998 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
999 }
1000
1001 tracing::debug!("Closed");
1002
1003 if let Some(ref handler) = post_disconnection {
1004 handler();
1005 tracing::debug!("Called `post_disconnection` handler");
1006 }
1007 break; }
1009
1010 if mode.is_active() && !inner.is_alive() {
1011 if connection_mode
1012 .compare_exchange(
1013 ConnectionMode::Active.as_u8(),
1014 ConnectionMode::Reconnect.as_u8(),
1015 Ordering::SeqCst,
1016 Ordering::SeqCst,
1017 )
1018 .is_ok()
1019 {
1020 tracing::debug!("Detected dead read task, transitioning to RECONNECT");
1021 }
1022 mode = ConnectionMode::from_atomic(&connection_mode);
1023 }
1024
1025 if mode.is_reconnect() {
1026 if let Some(max_attempts) = inner.reconnect_max_attempts
1028 && inner.reconnect_attempt_count >= max_attempts
1029 {
1030 tracing::error!(
1031 "Max reconnection attempts ({}) exceeded, transitioning to CLOSED",
1032 max_attempts
1033 );
1034 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1035 break;
1036 }
1037
1038 inner.reconnect_attempt_count += 1;
1039 match inner.reconnect().await {
1040 Ok(()) => {
1041 tracing::debug!("Reconnected successfully");
1042 inner.backoff.reset();
1043 inner.reconnect_attempt_count = 0; if ConnectionMode::from_atomic(&connection_mode).is_active() {
1046 if let Some(ref handler) = post_reconnection {
1047 handler();
1048 tracing::debug!("Called `post_reconnection` handler");
1049 }
1050 } else {
1051 tracing::debug!(
1052 "Skipping post_reconnection handlers due to disconnect state"
1053 );
1054 }
1055 }
1056 Err(e) => {
1057 let duration = inner.backoff.next_duration();
1058 tracing::warn!(
1059 "Reconnect attempt {} failed: {e}",
1060 inner.reconnect_attempt_count
1061 );
1062 if !duration.is_zero() {
1063 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
1064 }
1065 tokio::time::sleep(duration).await;
1066 }
1067 }
1068 }
1069 }
1070 inner
1071 .connection_mode
1072 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1073
1074 log_task_stopped("controller");
1075 })
1076 }
1077}
1078
1079impl Drop for SocketClient {
1081 fn drop(&mut self) {
1082 if !self.controller_task.is_finished() {
1083 self.controller_task.abort();
1084 log_task_aborted("controller");
1085 }
1086 }
1087}
1088
1089#[cfg(test)]
1090#[cfg(feature = "python")]
1091#[cfg(target_os = "linux")] mod tests {
1093 use nautilus_common::testing::wait_until_async;
1094 use pyo3::Python;
1095 use tokio::{
1096 io::{AsyncReadExt, AsyncWriteExt},
1097 net::{TcpListener, TcpStream},
1098 sync::Mutex,
1099 task,
1100 time::{Duration, sleep},
1101 };
1102
1103 use super::*;
1104
1105 async fn bind_test_server() -> (u16, TcpListener) {
1106 let listener = TcpListener::bind("127.0.0.1:0")
1107 .await
1108 .expect("Failed to bind ephemeral port");
1109 let port = listener.local_addr().unwrap().port();
1110 (port, listener)
1111 }
1112
1113 async fn run_echo_server(mut socket: TcpStream) {
1114 let mut buf = Vec::new();
1115 loop {
1116 match socket.read_buf(&mut buf).await {
1117 Ok(0) => {
1118 break;
1119 }
1120 Ok(_n) => {
1121 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1122 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1123 line.truncate(line.len() - 2);
1125
1126 if line == b"close" {
1127 let _ = socket.shutdown().await;
1128 return;
1129 }
1130
1131 let mut echo_data = line;
1132 echo_data.extend_from_slice(b"\r\n");
1133 if socket.write_all(&echo_data).await.is_err() {
1134 break;
1135 }
1136 }
1137 }
1138 Err(e) => {
1139 eprintln!("Server read error: {e}");
1140 break;
1141 }
1142 }
1143 }
1144 }
1145
1146 #[tokio::test]
1147 async fn test_basic_send_receive() {
1148 Python::initialize();
1149
1150 let (port, listener) = bind_test_server().await;
1151 let server_task = task::spawn(async move {
1152 let (socket, _) = listener.accept().await.unwrap();
1153 run_echo_server(socket).await;
1154 });
1155
1156 let config = SocketConfig {
1157 url: format!("127.0.0.1:{port}"),
1158 mode: Mode::Plain,
1159 suffix: b"\r\n".to_vec(),
1160 message_handler: None,
1161 heartbeat: None,
1162 reconnect_timeout_ms: None,
1163 reconnect_delay_initial_ms: None,
1164 reconnect_backoff_factor: None,
1165 reconnect_delay_max_ms: None,
1166 reconnect_jitter_ms: None,
1167 reconnect_max_attempts: None,
1168 connection_max_retries: None,
1169 certs_dir: None,
1170 };
1171
1172 let client = SocketClient::connect(config, None, None, None)
1173 .await
1174 .expect("Client connect failed unexpectedly");
1175
1176 client.send_bytes(b"Hello".into()).await.unwrap();
1177 client.send_bytes(b"World".into()).await.unwrap();
1178
1179 sleep(Duration::from_millis(100)).await;
1181
1182 client.send_bytes(b"close".into()).await.unwrap();
1183 server_task.await.unwrap();
1184 assert!(!client.is_closed());
1185 }
1186
1187 #[tokio::test]
1188 async fn test_reconnect_fail_exhausted() {
1189 Python::initialize();
1190
1191 let (port, listener) = bind_test_server().await;
1192 drop(listener); wait_until_async(
1196 || async {
1197 TcpStream::connect(format!("127.0.0.1:{port}"))
1198 .await
1199 .is_err()
1200 },
1201 Duration::from_secs(2),
1202 )
1203 .await;
1204
1205 let config = SocketConfig {
1206 url: format!("127.0.0.1:{port}"),
1207 mode: Mode::Plain,
1208 suffix: b"\r\n".to_vec(),
1209 message_handler: None,
1210 heartbeat: None,
1211 reconnect_timeout_ms: Some(100),
1212 reconnect_delay_initial_ms: Some(50),
1213 reconnect_backoff_factor: Some(1.0),
1214 reconnect_delay_max_ms: Some(50),
1215 reconnect_jitter_ms: Some(0),
1216 connection_max_retries: Some(1),
1217 reconnect_max_attempts: None,
1218 certs_dir: None,
1219 };
1220
1221 let client_res = SocketClient::connect(config, None, None, None).await;
1222 assert!(
1223 client_res.is_err(),
1224 "Should fail quickly with no server listening"
1225 );
1226 }
1227
1228 #[tokio::test]
1229 async fn test_user_disconnect() {
1230 Python::initialize();
1231
1232 let (port, listener) = bind_test_server().await;
1233 let server_task = task::spawn(async move {
1234 let (socket, _) = listener.accept().await.unwrap();
1235 let mut buf = [0u8; 1024];
1236 let _ = socket.try_read(&mut buf);
1237
1238 loop {
1239 sleep(Duration::from_secs(1)).await;
1240 }
1241 });
1242
1243 let config = SocketConfig {
1244 url: format!("127.0.0.1:{port}"),
1245 mode: Mode::Plain,
1246 suffix: b"\r\n".to_vec(),
1247 message_handler: None,
1248 heartbeat: None,
1249 reconnect_timeout_ms: None,
1250 reconnect_delay_initial_ms: None,
1251 reconnect_backoff_factor: None,
1252 reconnect_delay_max_ms: None,
1253 reconnect_jitter_ms: None,
1254 reconnect_max_attempts: None,
1255 connection_max_retries: None,
1256 certs_dir: None,
1257 };
1258
1259 let client = SocketClient::connect(config, None, None, None)
1260 .await
1261 .unwrap();
1262
1263 client.close().await;
1264 assert!(client.is_closed());
1265 server_task.abort();
1266 }
1267
1268 #[tokio::test]
1269 async fn test_heartbeat() {
1270 Python::initialize();
1271
1272 let (port, listener) = bind_test_server().await;
1273 let received = Arc::new(Mutex::new(Vec::new()));
1274 let received2 = received.clone();
1275
1276 let server_task = task::spawn(async move {
1277 let (socket, _) = listener.accept().await.unwrap();
1278
1279 let mut buf = Vec::new();
1280 loop {
1281 match socket.try_read_buf(&mut buf) {
1282 Ok(0) => break,
1283 Ok(_) => {
1284 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1285 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1286 line.truncate(line.len() - 2);
1287 received2.lock().await.push(line);
1288 }
1289 }
1290 Err(_) => {
1291 tokio::time::sleep(Duration::from_millis(10)).await;
1292 }
1293 }
1294 }
1295 });
1296
1297 let heartbeat = Some((1, b"ping".to_vec()));
1299
1300 let config = SocketConfig {
1301 url: format!("127.0.0.1:{port}"),
1302 mode: Mode::Plain,
1303 suffix: b"\r\n".to_vec(),
1304 message_handler: None,
1305 heartbeat,
1306 reconnect_timeout_ms: None,
1307 reconnect_delay_initial_ms: None,
1308 reconnect_backoff_factor: None,
1309 reconnect_delay_max_ms: None,
1310 reconnect_jitter_ms: None,
1311 reconnect_max_attempts: None,
1312 connection_max_retries: None,
1313 certs_dir: None,
1314 };
1315
1316 let client = SocketClient::connect(config, None, None, None)
1317 .await
1318 .unwrap();
1319
1320 sleep(Duration::from_secs(3)).await;
1322
1323 {
1324 let lock = received.lock().await;
1325 let pings = lock
1326 .iter()
1327 .filter(|line| line == &&b"ping".to_vec())
1328 .count();
1329 assert!(
1330 pings >= 2,
1331 "Expected at least 2 heartbeat pings; got {pings}"
1332 );
1333 }
1334
1335 client.close().await;
1336 server_task.abort();
1337 }
1338
1339 #[tokio::test]
1340 async fn test_reconnect_success() {
1341 Python::initialize();
1342
1343 let (port, listener) = bind_test_server().await;
1344
1345 let server_task = task::spawn(async move {
1349 let (mut socket, _) = listener.accept().await.expect("First accept failed");
1351
1352 sleep(Duration::from_millis(500)).await;
1354 let _ = socket.shutdown().await;
1355
1356 sleep(Duration::from_millis(500)).await;
1358
1359 let (socket, _) = listener.accept().await.expect("Second accept failed");
1361 run_echo_server(socket).await;
1362 });
1363
1364 let config = SocketConfig {
1365 url: format!("127.0.0.1:{port}"),
1366 mode: Mode::Plain,
1367 suffix: b"\r\n".to_vec(),
1368 message_handler: None,
1369 heartbeat: None,
1370 reconnect_timeout_ms: Some(5_000),
1371 reconnect_delay_initial_ms: Some(500),
1372 reconnect_delay_max_ms: Some(5_000),
1373 reconnect_backoff_factor: Some(2.0),
1374 reconnect_jitter_ms: Some(50),
1375 reconnect_max_attempts: None,
1376 connection_max_retries: None,
1377 certs_dir: None,
1378 };
1379
1380 let client = SocketClient::connect(config, None, None, None)
1381 .await
1382 .expect("Client connect failed unexpectedly");
1383
1384 assert!(client.is_active(), "Client should start as active");
1386
1387 wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1390
1391 client
1392 .send_bytes(b"TestReconnect".into())
1393 .await
1394 .expect("Send failed");
1395
1396 client.close().await;
1397 server_task.abort();
1398 }
1399}
1400
1401#[cfg(test)]
1402#[cfg(not(feature = "turmoil"))]
1403mod rust_tests {
1404 use nautilus_common::testing::wait_until_async;
1405 use rstest::rstest;
1406 use tokio::{
1407 io::{AsyncReadExt, AsyncWriteExt},
1408 net::TcpListener,
1409 task,
1410 time::{Duration, sleep},
1411 };
1412
1413 use super::*;
1414
1415 #[rstest]
1416 #[tokio::test]
1417 async fn test_reconnect_then_close() {
1418 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1420 let port = listener.local_addr().unwrap().port();
1421
1422 let server = task::spawn(async move {
1424 if let Ok((mut sock, _)) = listener.accept().await {
1425 drop(sock.shutdown());
1426 }
1427 sleep(Duration::from_secs(1)).await;
1429 });
1430
1431 let config = SocketConfig {
1433 url: format!("127.0.0.1:{port}"),
1434 mode: Mode::Plain,
1435 suffix: b"\r\n".to_vec(),
1436 message_handler: None,
1437 heartbeat: None,
1438 reconnect_timeout_ms: Some(1_000),
1439 reconnect_delay_initial_ms: Some(50),
1440 reconnect_delay_max_ms: Some(100),
1441 reconnect_backoff_factor: Some(1.0),
1442 reconnect_jitter_ms: Some(0),
1443 connection_max_retries: Some(1),
1444 reconnect_max_attempts: None,
1445 certs_dir: None,
1446 };
1447
1448 let client = SocketClient::connect(config.clone(), None, None, None)
1450 .await
1451 .unwrap();
1452
1453 wait_until_async(
1455 || async { client.is_reconnecting() },
1456 Duration::from_secs(2),
1457 )
1458 .await;
1459
1460 client.close().await;
1462 assert!(client.is_closed());
1463 server.abort();
1464 }
1465
1466 #[rstest]
1467 #[tokio::test]
1468 async fn test_reconnect_state_flips_when_reader_stops() {
1469 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1471 let port = listener.local_addr().unwrap().port();
1472
1473 let server = task::spawn(async move {
1474 if let Ok((sock, _)) = listener.accept().await {
1475 drop(sock);
1476 }
1477 sleep(Duration::from_millis(50)).await;
1479 });
1480
1481 let config = SocketConfig {
1482 url: format!("127.0.0.1:{port}"),
1483 mode: Mode::Plain,
1484 suffix: b"\r\n".to_vec(),
1485 message_handler: None,
1486 heartbeat: None,
1487 reconnect_timeout_ms: Some(1_000),
1488 reconnect_delay_initial_ms: Some(50),
1489 reconnect_delay_max_ms: Some(100),
1490 reconnect_backoff_factor: Some(1.0),
1491 reconnect_jitter_ms: Some(0),
1492 connection_max_retries: Some(1),
1493 reconnect_max_attempts: None,
1494 certs_dir: None,
1495 };
1496
1497 let client = SocketClient::connect(config, None, None, None)
1498 .await
1499 .unwrap();
1500
1501 wait_until_async(
1502 || async { client.is_reconnecting() },
1503 Duration::from_secs(2),
1504 )
1505 .await;
1506
1507 client.close().await;
1508 server.abort();
1509 }
1510
1511 #[rstest]
1512 fn test_parse_socket_url_raw_address() {
1513 let (socket_addr, request_url) =
1515 SocketClientInner::parse_socket_url("example.com:6130", Mode::Tls).unwrap();
1516 assert_eq!(socket_addr, "example.com:6130");
1517 assert_eq!(request_url, "wss://example.com:6130");
1518
1519 let (socket_addr, request_url) =
1521 SocketClientInner::parse_socket_url("localhost:8080", Mode::Plain).unwrap();
1522 assert_eq!(socket_addr, "localhost:8080");
1523 assert_eq!(request_url, "ws://localhost:8080");
1524 }
1525
1526 #[rstest]
1527 fn test_parse_socket_url_with_scheme() {
1528 let (socket_addr, request_url) =
1530 SocketClientInner::parse_socket_url("wss://example.com:443/path", Mode::Tls).unwrap();
1531 assert_eq!(socket_addr, "example.com:443");
1532 assert_eq!(request_url, "wss://example.com:443/path");
1533
1534 let (socket_addr, request_url) =
1536 SocketClientInner::parse_socket_url("ws://localhost:8080", Mode::Plain).unwrap();
1537 assert_eq!(socket_addr, "localhost:8080");
1538 assert_eq!(request_url, "ws://localhost:8080");
1539 }
1540
1541 #[rstest]
1542 fn test_parse_socket_url_default_ports() {
1543 let (socket_addr, _) =
1545 SocketClientInner::parse_socket_url("wss://example.com", Mode::Tls).unwrap();
1546 assert_eq!(socket_addr, "example.com:443");
1547
1548 let (socket_addr, _) =
1550 SocketClientInner::parse_socket_url("ws://example.com", Mode::Plain).unwrap();
1551 assert_eq!(socket_addr, "example.com:80");
1552
1553 let (socket_addr, _) =
1555 SocketClientInner::parse_socket_url("https://example.com", Mode::Tls).unwrap();
1556 assert_eq!(socket_addr, "example.com:443");
1557
1558 let (socket_addr, _) =
1560 SocketClientInner::parse_socket_url("http://example.com", Mode::Plain).unwrap();
1561 assert_eq!(socket_addr, "example.com:80");
1562 }
1563
1564 #[rstest]
1565 fn test_parse_socket_url_unknown_scheme_uses_mode() {
1566 let (socket_addr, _) =
1568 SocketClientInner::parse_socket_url("custom://example.com", Mode::Tls).unwrap();
1569 assert_eq!(socket_addr, "example.com:443");
1570
1571 let (socket_addr, _) =
1572 SocketClientInner::parse_socket_url("custom://example.com", Mode::Plain).unwrap();
1573 assert_eq!(socket_addr, "example.com:80");
1574 }
1575
1576 #[rstest]
1577 fn test_parse_socket_url_ipv6() {
1578 let (socket_addr, request_url) =
1580 SocketClientInner::parse_socket_url("[::1]:8080", Mode::Plain).unwrap();
1581 assert_eq!(socket_addr, "[::1]:8080");
1582 assert_eq!(request_url, "ws://[::1]:8080");
1583
1584 let (socket_addr, _) =
1586 SocketClientInner::parse_socket_url("ws://[::1]:8080", Mode::Plain).unwrap();
1587 assert_eq!(socket_addr, "[::1]:8080");
1588 }
1589
1590 #[rstest]
1591 #[tokio::test]
1592 async fn test_url_parsing_raw_socket_address() {
1593 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1595 let port = listener.local_addr().unwrap().port();
1596
1597 let server = task::spawn(async move {
1598 if let Ok((sock, _)) = listener.accept().await {
1599 drop(sock);
1600 }
1601 sleep(Duration::from_millis(50)).await;
1602 });
1603
1604 let config = SocketConfig {
1605 url: format!("127.0.0.1:{port}"), mode: Mode::Plain,
1607 suffix: b"\r\n".to_vec(),
1608 message_handler: None,
1609 heartbeat: None,
1610 reconnect_timeout_ms: Some(1_000),
1611 reconnect_delay_initial_ms: Some(50),
1612 reconnect_delay_max_ms: Some(100),
1613 reconnect_backoff_factor: Some(1.0),
1614 reconnect_jitter_ms: Some(0),
1615 connection_max_retries: Some(1),
1616 reconnect_max_attempts: None,
1617 certs_dir: None,
1618 };
1619
1620 let client = SocketClient::connect(config, None, None, None).await;
1622 assert!(
1623 client.is_ok(),
1624 "Client should connect with raw socket address format"
1625 );
1626
1627 if let Ok(client) = client {
1628 client.close().await;
1629 }
1630 server.abort();
1631 }
1632
1633 #[rstest]
1634 #[tokio::test]
1635 async fn test_url_parsing_with_scheme() {
1636 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1638 let port = listener.local_addr().unwrap().port();
1639
1640 let server = task::spawn(async move {
1641 if let Ok((sock, _)) = listener.accept().await {
1642 drop(sock);
1643 }
1644 sleep(Duration::from_millis(50)).await;
1645 });
1646
1647 let config = SocketConfig {
1648 url: format!("ws://127.0.0.1:{port}"), mode: Mode::Plain,
1650 suffix: b"\r\n".to_vec(),
1651 message_handler: None,
1652 heartbeat: None,
1653 reconnect_timeout_ms: Some(1_000),
1654 reconnect_delay_initial_ms: Some(50),
1655 reconnect_delay_max_ms: Some(100),
1656 reconnect_backoff_factor: Some(1.0),
1657 reconnect_jitter_ms: Some(0),
1658 connection_max_retries: Some(1),
1659 reconnect_max_attempts: None,
1660 certs_dir: None,
1661 };
1662
1663 let client = SocketClient::connect(config, None, None, None).await;
1665 assert!(
1666 client.is_ok(),
1667 "Client should connect with URL scheme format"
1668 );
1669
1670 if let Ok(client) = client {
1671 client.close().await;
1672 }
1673 server.abort();
1674 }
1675
1676 #[rstest]
1677 fn test_parse_socket_url_ipv6_with_zone() {
1678 let (socket_addr, request_url) =
1680 SocketClientInner::parse_socket_url("[fe80::1%eth0]:8080", Mode::Plain).unwrap();
1681 assert_eq!(socket_addr, "[fe80::1%eth0]:8080");
1682 assert_eq!(request_url, "ws://[fe80::1%eth0]:8080");
1683
1684 let (socket_addr, request_url) =
1686 SocketClientInner::parse_socket_url("ws://[fe80::1%lo]:9090", Mode::Plain).unwrap();
1687 assert_eq!(socket_addr, "[fe80::1%lo]:9090");
1688 assert_eq!(request_url, "ws://[fe80::1%lo]:9090");
1689 }
1690
1691 #[rstest]
1692 #[tokio::test]
1693 async fn test_ipv6_loopback_connection() {
1694 if TcpListener::bind("[::1]:0").await.is_err() {
1697 eprintln!("IPv6 not available, skipping test");
1698 return;
1699 }
1700
1701 let listener = TcpListener::bind("[::1]:0").await.unwrap();
1702 let port = listener.local_addr().unwrap().port();
1703
1704 let server = task::spawn(async move {
1705 if let Ok((mut sock, _)) = listener.accept().await {
1706 let mut buf = vec![0u8; 1024];
1707 if let Ok(n) = sock.read(&mut buf).await {
1708 let _ = sock.write_all(&buf[..n]).await;
1710 }
1711 }
1712 sleep(Duration::from_millis(50)).await;
1713 });
1714
1715 let config = SocketConfig {
1716 url: format!("[::1]:{port}"), mode: Mode::Plain,
1718 suffix: b"\r\n".to_vec(),
1719 message_handler: None,
1720 heartbeat: None,
1721 reconnect_timeout_ms: Some(1_000),
1722 reconnect_delay_initial_ms: Some(50),
1723 reconnect_delay_max_ms: Some(100),
1724 reconnect_backoff_factor: Some(1.0),
1725 reconnect_jitter_ms: Some(0),
1726 connection_max_retries: Some(1),
1727 reconnect_max_attempts: None,
1728 certs_dir: None,
1729 };
1730
1731 let client = SocketClient::connect(config, None, None, None).await;
1732 assert!(
1733 client.is_ok(),
1734 "Client should connect to IPv6 loopback address"
1735 );
1736
1737 if let Ok(client) = client {
1738 client.close().await;
1739 }
1740 server.abort();
1741 }
1742
1743 #[rstest]
1744 #[tokio::test]
1745 async fn test_send_waits_during_reconnection() {
1746 use nautilus_common::testing::wait_until_async;
1748
1749 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1750 let port = listener.local_addr().unwrap().port();
1751
1752 let server = task::spawn(async move {
1753 if let Ok((sock, _)) = listener.accept().await {
1755 drop(sock);
1756 }
1757
1758 sleep(Duration::from_millis(500)).await;
1760
1761 if let Ok((mut sock, _)) = listener.accept().await {
1763 let mut buf = vec![0u8; 1024];
1765 while let Ok(n) = sock.read(&mut buf).await {
1766 if n == 0 {
1767 break;
1768 }
1769 if sock.write_all(&buf[..n]).await.is_err() {
1770 break;
1771 }
1772 }
1773 }
1774 });
1775
1776 let config = SocketConfig {
1777 url: format!("127.0.0.1:{port}"),
1778 mode: Mode::Plain,
1779 suffix: b"\r\n".to_vec(),
1780 message_handler: None,
1781 heartbeat: None,
1782 reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
1784 reconnect_delay_max_ms: Some(200),
1785 reconnect_backoff_factor: Some(1.0),
1786 reconnect_jitter_ms: Some(0),
1787 connection_max_retries: Some(1),
1788 reconnect_max_attempts: None,
1789 certs_dir: None,
1790 };
1791
1792 let client = SocketClient::connect(config, None, None, None)
1793 .await
1794 .unwrap();
1795
1796 wait_until_async(
1798 || async { client.is_reconnecting() },
1799 Duration::from_secs(2),
1800 )
1801 .await;
1802
1803 let send_result = tokio::time::timeout(
1805 Duration::from_secs(3),
1806 client.send_bytes(b"test_message".to_vec()),
1807 )
1808 .await;
1809
1810 assert!(
1811 send_result.is_ok() && send_result.unwrap().is_ok(),
1812 "Send should succeed after waiting for reconnection"
1813 );
1814
1815 client.close().await;
1816 server.abort();
1817 }
1818
1819 #[rstest]
1820 #[tokio::test]
1821 async fn test_send_bytes_timeout_uses_configured_reconnect_timeout() {
1822 use nautilus_common::testing::wait_until_async;
1825
1826 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1827 let port = listener.local_addr().unwrap().port();
1828
1829 let server = task::spawn(async move {
1830 if let Ok((sock, _)) = listener.accept().await {
1832 drop(sock);
1833 }
1834 drop(listener);
1836 sleep(Duration::from_secs(60)).await;
1837 });
1838
1839 let config = SocketConfig {
1840 url: format!("127.0.0.1:{port}"),
1841 mode: Mode::Plain,
1842 suffix: b"\r\n".to_vec(),
1843 message_handler: None,
1844 heartbeat: None,
1845 reconnect_timeout_ms: Some(1_000), reconnect_delay_initial_ms: Some(200), reconnect_delay_max_ms: Some(200),
1848 reconnect_backoff_factor: Some(1.0),
1849 reconnect_jitter_ms: Some(0),
1850 connection_max_retries: Some(1),
1851 reconnect_max_attempts: None,
1852 certs_dir: None,
1853 };
1854
1855 let client = SocketClient::connect(config, None, None, None)
1856 .await
1857 .unwrap();
1858
1859 wait_until_async(
1861 || async { client.is_reconnecting() },
1862 Duration::from_secs(3),
1863 )
1864 .await;
1865
1866 let start = std::time::Instant::now();
1869 let send_result = client.send_bytes(b"test".to_vec()).await;
1870 let elapsed = start.elapsed();
1871
1872 assert!(
1873 send_result.is_err(),
1874 "Send should fail when client stuck in RECONNECT, was: {send_result:?}"
1875 );
1876 assert!(
1877 matches!(send_result, Err(crate::error::SendError::Timeout)),
1878 "Send should return Timeout error, was: {send_result:?}"
1879 );
1880 assert!(
1883 elapsed >= Duration::from_millis(900),
1884 "Send should timeout after at least 1s (configured timeout), took {elapsed:?}"
1885 );
1886
1887 client.close().await;
1888 server.abort();
1889 }
1890}