1use std::{
32 collections::VecDeque,
33 fmt::Debug,
34 path::Path,
35 sync::{
36 Arc,
37 atomic::{AtomicU8, Ordering},
38 },
39 time::Duration,
40};
41
42use bytes::Bytes;
43use nautilus_core::CleanDrop;
44use nautilus_cryptography::providers::install_cryptographic_provider;
45use tokio::io::{AsyncReadExt, AsyncWriteExt};
46use tokio_tungstenite::tungstenite::{Error, client::IntoClientRequest, stream::Mode};
47
48use super::{
49 SocketConfig, TcpMessageHandler, TcpReader, TcpWriter, WriterCommand, fix::process_fix_buffer,
50};
51use crate::{
52 backoff::ExponentialBackoff,
53 error::SendError,
54 logging::{log_task_aborted, log_task_started, log_task_stopped},
55 mode::ConnectionMode,
56 net::TcpStream,
57 tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
58};
59
60const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
62const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
63const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
64const SEND_OPERATION_CHECK_INTERVAL_MS: u64 = 1;
65
66#[cfg_attr(
82 feature = "python",
83 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
84)]
85struct SocketClientInner {
86 config: SocketConfig,
87 connector: Option<Connector>,
88 read_task: Arc<tokio::task::JoinHandle<()>>,
89 write_task: tokio::task::JoinHandle<()>,
90 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
91 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
92 connection_mode: Arc<AtomicU8>,
93 reconnect_timeout: Duration,
94 backoff: ExponentialBackoff,
95 handler: Option<TcpMessageHandler>,
96 reconnect_max_attempts: Option<u32>,
97 reconnect_attempt_count: u32,
98}
99
100impl SocketClientInner {
101 pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
107 install_cryptographic_provider();
108
109 let SocketConfig {
110 url,
111 mode,
112 heartbeat,
113 suffix,
114 message_handler,
115 reconnect_timeout_ms,
116 reconnect_delay_initial_ms,
117 reconnect_delay_max_ms,
118 reconnect_backoff_factor,
119 reconnect_jitter_ms,
120 connection_max_retries,
121 reconnect_max_attempts,
122 certs_dir,
123 } = &config.clone();
124 let connector = if let Some(dir) = certs_dir {
125 let config = create_tls_config_from_certs_dir(Path::new(dir), false)?;
126 Some(Connector::Rustls(Arc::new(config)))
127 } else {
128 None
129 };
130
131 const CONNECTION_TIMEOUT_SECS: u64 = 10;
133 let max_retries = connection_max_retries.unwrap_or(5);
134
135 let mut backoff = ExponentialBackoff::new(
136 Duration::from_millis(500),
137 Duration::from_millis(5000),
138 2.0,
139 250,
140 false,
141 )?;
142
143 #[allow(unused_assignments)]
144 let mut last_error = String::new();
145 let mut attempt = 0;
146 let (reader, writer) = loop {
147 attempt += 1;
148
149 match tokio::time::timeout(
150 Duration::from_secs(CONNECTION_TIMEOUT_SECS),
151 Self::tls_connect_with_server(url, *mode, connector.clone()),
152 )
153 .await
154 {
155 Ok(Ok(result)) => {
156 if attempt > 1 {
157 tracing::info!("Socket connection established after {attempt} attempts");
158 }
159 break result;
160 }
161 Ok(Err(e)) => {
162 last_error = e.to_string();
163 tracing::warn!(
164 attempt,
165 max_retries,
166 url = %url,
167 error = %last_error,
168 "Socket connection attempt failed"
169 );
170 }
171 Err(_) => {
172 last_error = format!(
173 "Connection timeout after {CONNECTION_TIMEOUT_SECS}s (possible DNS resolution failure)"
174 );
175 tracing::warn!(
176 attempt,
177 max_retries,
178 url = %url,
179 "Socket connection attempt timed out"
180 );
181 }
182 }
183
184 if attempt >= max_retries {
185 anyhow::bail!(
186 "Failed to connect to {} after {} attempts: {}. \
187 If this is a DNS error, check your network configuration and DNS settings.",
188 url,
189 max_retries,
190 if last_error.is_empty() {
191 "unknown error"
192 } else {
193 &last_error
194 }
195 );
196 }
197
198 let delay = backoff.next_duration();
199 tracing::debug!(
200 "Retrying in {delay:?} (attempt {}/{})",
201 attempt + 1,
202 max_retries
203 );
204 tokio::time::sleep(delay).await;
205 };
206
207 tracing::debug!("Connected");
208
209 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
210
211 let read_task = Arc::new(Self::spawn_read_task(
212 connection_mode.clone(),
213 reader,
214 message_handler.clone(),
215 suffix.clone(),
216 ));
217
218 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
219
220 let write_task =
221 Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
222
223 let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
225 Self::spawn_heartbeat_task(
226 connection_mode.clone(),
227 heartbeat.clone(),
228 writer_tx.clone(),
229 )
230 });
231
232 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
233 let backoff = ExponentialBackoff::new(
234 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
235 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
236 reconnect_backoff_factor.unwrap_or(1.5),
237 reconnect_jitter_ms.unwrap_or(100),
238 true, )?;
240
241 Ok(Self {
242 config,
243 connector,
244 read_task,
245 write_task,
246 writer_tx,
247 heartbeat_task,
248 connection_mode,
249 reconnect_timeout,
250 backoff,
251 handler: message_handler.clone(),
252 reconnect_max_attempts: *reconnect_max_attempts,
253 reconnect_attempt_count: 0,
254 })
255 }
256
257 fn parse_socket_url(url: &str, mode: Mode) -> Result<(String, String), Error> {
267 if url.contains("://") {
268 let parsed = url.parse::<http::Uri>().map_err(|e| {
270 Error::Io(std::io::Error::new(
271 std::io::ErrorKind::InvalidInput,
272 format!("Invalid URL: {e}"),
273 ))
274 })?;
275
276 let host = parsed.host().ok_or_else(|| {
277 Error::Io(std::io::Error::new(
278 std::io::ErrorKind::InvalidInput,
279 "URL missing host",
280 ))
281 })?;
282
283 let port = parsed
284 .port_u16()
285 .unwrap_or_else(|| match parsed.scheme_str() {
286 Some("wss" | "https") => 443,
287 Some("ws" | "http") => 80,
288 _ => match mode {
289 Mode::Tls => 443,
290 Mode::Plain => 80,
291 },
292 });
293
294 Ok((format!("{host}:{port}"), url.to_string()))
295 } else {
296 let scheme = match mode {
299 Mode::Tls => "wss",
300 Mode::Plain => "ws",
301 };
302 Ok((url.to_string(), format!("{scheme}://{url}")))
303 }
304 }
305
306 pub async fn tls_connect_with_server(
316 url: &str,
317 mode: Mode,
318 connector: Option<Connector>,
319 ) -> Result<(TcpReader, TcpWriter), Error> {
320 tracing::debug!("Connecting to {url}");
321
322 let (socket_addr, request_url) = Self::parse_socket_url(url, mode)?;
323 let tcp_result = TcpStream::connect(&socket_addr).await;
324
325 match tcp_result {
326 Ok(stream) => {
327 tracing::debug!("TCP connection established to {socket_addr}, proceeding with TLS");
328 if let Err(e) = stream.set_nodelay(true) {
329 tracing::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
330 }
331 let request = request_url.into_client_request()?;
332 tcp_tls(&request, mode, stream, connector)
333 .await
334 .map(tokio::io::split)
335 }
336 Err(e) => {
337 tracing::error!("TCP connection failed to {socket_addr}: {e:?}");
338 Err(Error::Io(e))
339 }
340 }
341 }
342
343 async fn reconnect(&mut self) -> Result<(), Error> {
348 tracing::debug!("Reconnecting");
349
350 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
351 tracing::debug!("Reconnect aborted due to disconnect state");
352 return Ok(());
353 }
354
355 tokio::time::timeout(self.reconnect_timeout, async {
356 let SocketConfig {
357 url,
358 mode,
359 heartbeat: _,
360 suffix,
361 message_handler: _,
362 reconnect_timeout_ms: _,
363 reconnect_delay_initial_ms: _,
364 reconnect_backoff_factor: _,
365 reconnect_delay_max_ms: _,
366 reconnect_jitter_ms: _,
367 connection_max_retries: _,
368 reconnect_max_attempts: _,
369 certs_dir: _,
370 } = &self.config;
371 let connector = self.connector.clone();
373 let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
375
376 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
377 tracing::debug!("Reconnect aborted mid-flight (after connect)");
378 return Ok(());
379 }
380 tracing::debug!("Connected");
381
382 let (tx, rx) = tokio::sync::oneshot::channel();
386 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
387 tracing::error!("{e}");
388 return Err(Error::Io(std::io::Error::new(
389 std::io::ErrorKind::BrokenPipe,
390 format!("Failed to send update command: {e}"),
391 )));
392 }
393
394 match rx.await {
396 Ok(true) => tracing::debug!("Writer confirmed buffer drain success"),
397 Ok(false) => {
398 tracing::warn!("Writer failed to drain buffer, aborting reconnect");
399 return Err(Error::Io(std::io::Error::other(
401 "Failed to drain reconnection buffer",
402 )));
403 }
404 Err(e) => {
405 tracing::error!("Writer dropped update channel: {e}");
406 return Err(Error::Io(std::io::Error::new(
407 std::io::ErrorKind::BrokenPipe,
408 "Writer task dropped response channel",
409 )));
410 }
411 }
412
413 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
415
416 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
417 tracing::debug!("Reconnect aborted mid-flight (after delay)");
418 return Ok(());
419 }
420
421 if !self.read_task.is_finished() {
422 self.read_task.abort();
423 log_task_aborted("read");
424 }
425
426 if self
429 .connection_mode
430 .compare_exchange(
431 ConnectionMode::Reconnect.as_u8(),
432 ConnectionMode::Active.as_u8(),
433 Ordering::SeqCst,
434 Ordering::SeqCst,
435 )
436 .is_err()
437 {
438 tracing::debug!("Reconnect aborted (state changed during reconnect)");
439 return Ok(());
440 }
441
442 self.read_task = Arc::new(Self::spawn_read_task(
444 self.connection_mode.clone(),
445 reader,
446 self.handler.clone(),
447 suffix.clone(),
448 ));
449
450 tracing::debug!("Reconnect succeeded");
451 Ok(())
452 })
453 .await
454 .map_err(|_| {
455 Error::Io(std::io::Error::new(
456 std::io::ErrorKind::TimedOut,
457 format!(
458 "reconnection timed out after {}s",
459 self.reconnect_timeout.as_secs_f64()
460 ),
461 ))
462 })?
463 }
464
465 #[inline]
472 #[must_use]
473 pub fn is_alive(&self) -> bool {
474 !self.read_task.is_finished()
475 }
476
477 #[must_use]
478 fn spawn_read_task(
479 connection_state: Arc<AtomicU8>,
480 mut reader: TcpReader,
481 handler: Option<TcpMessageHandler>,
482 suffix: Vec<u8>,
483 ) -> tokio::task::JoinHandle<()> {
484 log_task_started("read");
485
486 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
488
489 tokio::task::spawn(async move {
490 let mut buf = Vec::new();
491
492 loop {
493 if !ConnectionMode::from_atomic(&connection_state).is_active() {
494 break;
495 }
496
497 match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
498 Ok(Ok(0)) => {
500 tracing::debug!("Connection closed by server");
501 break;
502 }
503 Ok(Err(e)) => {
504 tracing::debug!("Connection ended: {e}");
505 break;
506 }
507 Ok(Ok(bytes)) => {
509 tracing::trace!("Received <binary> {bytes} bytes");
510
511 let is_fix = buf.len() >= 5 && buf.starts_with(b"8=FIX");
513
514 if is_fix && handler.is_some() {
515 if let Some(ref handler) = handler {
517 process_fix_buffer(&mut buf, handler);
518 }
519 } else {
520 while let Some((i, _)) = &buf
522 .windows(suffix.len())
523 .enumerate()
524 .find(|(_, pair)| pair.eq(&suffix))
525 {
526 let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
527 data.truncate(data.len() - suffix.len());
528
529 if let Some(ref handler) = handler {
530 handler(&data);
531 }
532 }
533 }
534 }
535 Err(_) => {
536 continue;
538 }
539 }
540 }
541
542 log_task_stopped("read");
543 })
544 }
545
546 async fn drain_reconnect_buffer(
556 buffer: &mut VecDeque<Bytes>,
557 writer: &mut TcpWriter,
558 suffix: &[u8],
559 ) -> bool {
560 if buffer.is_empty() {
561 return false;
562 }
563
564 let initial_buffer_len = buffer.len();
565 tracing::info!(
566 "Sending {} buffered messages after reconnection",
567 initial_buffer_len
568 );
569
570 let mut send_error_occurred = false;
571
572 while let Some(buffered_msg) = buffer.front() {
573 let mut combined_msg = Vec::with_capacity(buffered_msg.len() + suffix.len());
574 combined_msg.extend_from_slice(buffered_msg);
575 combined_msg.extend_from_slice(suffix);
576
577 if let Err(e) = writer.write_all(&combined_msg).await {
578 tracing::error!(
579 "Failed to send buffered message with suffix after reconnection: {e}, {} messages remain in buffer",
580 buffer.len()
581 );
582 send_error_occurred = true;
583 break;
584 }
585
586 buffer.pop_front();
587 }
588
589 if buffer.is_empty() {
590 tracing::info!(
591 "Successfully sent all {} buffered messages",
592 initial_buffer_len
593 );
594 }
595
596 send_error_occurred
597 }
598
599 fn spawn_write_task(
600 connection_state: Arc<AtomicU8>,
601 writer: TcpWriter,
602 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
603 suffix: Vec<u8>,
604 ) -> tokio::task::JoinHandle<()> {
605 log_task_started("write");
606
607 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
609
610 tokio::task::spawn(async move {
611 let mut active_writer = writer;
612 let mut reconnect_buffer: VecDeque<Bytes> = VecDeque::new();
613
614 loop {
615 if matches!(
616 ConnectionMode::from_atomic(&connection_state),
617 ConnectionMode::Disconnect | ConnectionMode::Closed
618 ) {
619 break;
620 }
621
622 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
623 Ok(Some(msg)) => {
624 let mode = ConnectionMode::from_atomic(&connection_state);
626 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
627 break;
628 }
629
630 match msg {
631 WriterCommand::Update(new_writer, tx) => {
632 tracing::debug!("Received new writer");
633
634 tokio::time::sleep(Duration::from_millis(100)).await;
636
637 _ = tokio::time::timeout(
640 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
641 active_writer.shutdown(),
642 )
643 .await;
644
645 active_writer = new_writer;
646 tracing::debug!("Updated writer");
647
648 let send_error = Self::drain_reconnect_buffer(
649 &mut reconnect_buffer,
650 &mut active_writer,
651 &suffix,
652 )
653 .await;
654
655 if let Err(e) = tx.send(!send_error) {
656 tracing::error!(
657 "Failed to report drain status to controller: {e:?}"
658 );
659 }
660 }
661 _ if mode.is_reconnect() => {
662 if let WriterCommand::Send(data) = msg {
663 tracing::debug!(
664 "Buffering message while reconnecting ({} bytes)",
665 data.len()
666 );
667 reconnect_buffer.push_back(data);
668 }
669 continue;
670 }
671 WriterCommand::Send(msg) => {
672 if let Err(e) = active_writer.write_all(&msg).await {
673 tracing::error!("Failed to send message: {e}");
674 tracing::warn!("Writer triggering reconnect");
675 reconnect_buffer.push_back(msg);
676 connection_state
677 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
678 continue;
679 }
680 if let Err(e) = active_writer.write_all(&suffix).await {
681 tracing::error!("Failed to send suffix: {e}");
682 tracing::warn!("Writer triggering reconnect");
683 reconnect_buffer.push_back(msg);
685 connection_state
686 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
687 continue;
688 }
689 }
690 }
691 }
692 Ok(None) => {
693 tracing::debug!("Writer channel closed, terminating writer task");
695 break;
696 }
697 Err(_) => {
698 continue;
700 }
701 }
702 }
703
704 _ = tokio::time::timeout(
707 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
708 active_writer.shutdown(),
709 )
710 .await;
711
712 log_task_stopped("write");
713 })
714 }
715
716 fn spawn_heartbeat_task(
717 connection_state: Arc<AtomicU8>,
718 heartbeat: (u64, Vec<u8>),
719 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
720 ) -> tokio::task::JoinHandle<()> {
721 log_task_started("heartbeat");
722 let (interval_secs, message) = heartbeat;
723
724 tokio::task::spawn(async move {
725 let interval = Duration::from_secs(interval_secs);
726
727 loop {
728 tokio::time::sleep(interval).await;
729
730 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
731 ConnectionMode::Active => {
732 let msg = WriterCommand::Send(message.clone().into());
733
734 match writer_tx.send(msg) {
735 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
736 Err(e) => {
737 tracing::error!("Failed to send heartbeat to writer task: {e}");
738 }
739 }
740 }
741 ConnectionMode::Reconnect => continue,
742 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
743 }
744 }
745
746 log_task_stopped("heartbeat");
747 })
748 }
749}
750
751impl Drop for SocketClientInner {
752 fn drop(&mut self) {
753 self.clean_drop();
755 }
756}
757
758impl CleanDrop for SocketClientInner {
760 fn clean_drop(&mut self) {
761 if !self.read_task.is_finished() {
762 self.read_task.abort();
763 log_task_aborted("read");
764 }
765
766 if !self.write_task.is_finished() {
767 self.write_task.abort();
768 log_task_aborted("write");
769 }
770
771 if let Some(ref handle) = self.heartbeat_task.take()
772 && !handle.is_finished()
773 {
774 handle.abort();
775 log_task_aborted("heartbeat");
776 }
777
778 #[cfg(feature = "python")]
779 {
780 self.config.message_handler = None;
782 }
783 }
784}
785
786#[cfg_attr(
787 feature = "python",
788 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
789)]
790pub struct SocketClient {
791 pub(crate) controller_task: tokio::task::JoinHandle<()>,
792 pub(crate) connection_mode: Arc<AtomicU8>,
793 pub(crate) reconnect_timeout: Duration,
794 pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
795}
796
797impl Debug for SocketClient {
798 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
799 f.debug_struct(stringify!(SocketClient)).finish()
800 }
801}
802
803impl SocketClient {
804 pub async fn connect(
810 config: SocketConfig,
811 post_connection: Option<Arc<dyn Fn() + Send + Sync>>,
812 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
813 post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
814 ) -> anyhow::Result<Self> {
815 let inner = SocketClientInner::connect_url(config).await?;
816 let writer_tx = inner.writer_tx.clone();
817 let connection_mode = inner.connection_mode.clone();
818 let reconnect_timeout = inner.reconnect_timeout;
819
820 let controller_task = Self::spawn_controller_task(
821 inner,
822 connection_mode.clone(),
823 post_reconnection,
824 post_disconnection,
825 );
826
827 if let Some(handler) = post_connection {
828 handler();
829 tracing::debug!("Called `post_connection` handler");
830 }
831
832 Ok(Self {
833 controller_task,
834 connection_mode,
835 reconnect_timeout,
836 writer_tx,
837 })
838 }
839
840 #[must_use]
842 pub fn connection_mode(&self) -> ConnectionMode {
843 ConnectionMode::from_atomic(&self.connection_mode)
844 }
845
846 #[inline]
851 #[must_use]
852 pub fn is_active(&self) -> bool {
853 self.connection_mode().is_active()
854 }
855
856 #[inline]
861 #[must_use]
862 pub fn is_reconnecting(&self) -> bool {
863 self.connection_mode().is_reconnect()
864 }
865
866 #[inline]
870 #[must_use]
871 pub fn is_disconnecting(&self) -> bool {
872 self.connection_mode().is_disconnect()
873 }
874
875 #[inline]
881 #[must_use]
882 pub fn is_closed(&self) -> bool {
883 self.connection_mode().is_closed()
884 }
885
886 pub async fn close(&self) {
891 self.connection_mode
892 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
893
894 if let Ok(()) =
895 tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
896 while !self.is_closed() {
897 tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS))
898 .await;
899 }
900
901 if !self.controller_task.is_finished() {
902 self.controller_task.abort();
903 log_task_aborted("controller");
904 }
905 })
906 .await
907 {
908 log_task_stopped("controller");
909 } else {
910 tracing::error!("Timeout waiting for controller task to finish");
911 if !self.controller_task.is_finished() {
912 self.controller_task.abort();
913 log_task_aborted("controller");
914 }
915 }
916 }
917
918 pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
924 if self.is_closed() || self.is_disconnecting() {
926 return Err(SendError::Closed);
927 }
928
929 let timeout = self.reconnect_timeout;
930 let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
931
932 if !self.is_active() {
933 tracing::debug!("Waiting for client to become ACTIVE before sending...");
934
935 let inner = tokio::time::timeout(timeout, async {
936 loop {
937 if self.is_active() {
938 return Ok(());
939 }
940 if matches!(
941 self.connection_mode(),
942 ConnectionMode::Disconnect | ConnectionMode::Closed
943 ) {
944 return Err(());
945 }
946 tokio::time::sleep(check_interval).await;
947 }
948 })
949 .await
950 .map_err(|_| SendError::Timeout)?;
951 inner.map_err(|()| SendError::Closed)?;
952 }
953
954 let msg = WriterCommand::Send(data.into());
955 self.writer_tx
956 .send(msg)
957 .map_err(|e| SendError::BrokenPipe(e.to_string()))
958 }
959
960 fn spawn_controller_task(
961 mut inner: SocketClientInner,
962 connection_mode: Arc<AtomicU8>,
963 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
964 post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
965 ) -> tokio::task::JoinHandle<()> {
966 tokio::task::spawn(async move {
967 log_task_started("controller");
968
969 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
970
971 loop {
972 tokio::time::sleep(check_interval).await;
973 let mut mode = ConnectionMode::from_atomic(&connection_mode);
974
975 if mode.is_disconnect() {
976 tracing::debug!("Disconnecting");
977
978 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
979 if tokio::time::timeout(timeout, async {
980 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
982
983 if !inner.read_task.is_finished() {
984 inner.read_task.abort();
985 log_task_aborted("read");
986 }
987
988 if let Some(task) = &inner.heartbeat_task
989 && !task.is_finished()
990 {
991 task.abort();
992 log_task_aborted("heartbeat");
993 }
994 })
995 .await
996 .is_err()
997 {
998 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
999 }
1000
1001 tracing::debug!("Closed");
1002
1003 if let Some(ref handler) = post_disconnection {
1004 handler();
1005 tracing::debug!("Called `post_disconnection` handler");
1006 }
1007 break; }
1009
1010 if mode.is_active() && !inner.is_alive() {
1011 if connection_mode
1012 .compare_exchange(
1013 ConnectionMode::Active.as_u8(),
1014 ConnectionMode::Reconnect.as_u8(),
1015 Ordering::SeqCst,
1016 Ordering::SeqCst,
1017 )
1018 .is_ok()
1019 {
1020 tracing::debug!("Detected dead read task, transitioning to RECONNECT");
1021 }
1022 mode = ConnectionMode::from_atomic(&connection_mode);
1023 }
1024
1025 if mode.is_reconnect() {
1026 if let Some(max_attempts) = inner.reconnect_max_attempts
1028 && inner.reconnect_attempt_count >= max_attempts
1029 {
1030 tracing::error!(
1031 "Max reconnection attempts ({}) exceeded, transitioning to CLOSED",
1032 max_attempts
1033 );
1034 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1035 break;
1036 }
1037
1038 inner.reconnect_attempt_count += 1;
1039 match inner.reconnect().await {
1040 Ok(()) => {
1041 tracing::debug!("Reconnected successfully");
1042 inner.backoff.reset();
1043 inner.reconnect_attempt_count = 0; if ConnectionMode::from_atomic(&connection_mode).is_active() {
1046 if let Some(ref handler) = post_reconnection {
1047 handler();
1048 tracing::debug!("Called `post_reconnection` handler");
1049 }
1050 } else {
1051 tracing::debug!(
1052 "Skipping post_reconnection handlers due to disconnect state"
1053 );
1054 }
1055 }
1056 Err(e) => {
1057 let duration = inner.backoff.next_duration();
1058 tracing::warn!(
1059 "Reconnect attempt {} failed: {e}",
1060 inner.reconnect_attempt_count
1061 );
1062 if !duration.is_zero() {
1063 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
1064 }
1065 tokio::time::sleep(duration).await;
1066 }
1067 }
1068 }
1069 }
1070 inner
1071 .connection_mode
1072 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1073
1074 log_task_stopped("controller");
1075 })
1076 }
1077}
1078
1079impl Drop for SocketClient {
1081 fn drop(&mut self) {
1082 if !self.controller_task.is_finished() {
1083 self.controller_task.abort();
1084 log_task_aborted("controller");
1085 }
1086 }
1087}
1088
1089#[cfg(test)]
1094#[cfg(feature = "python")]
1095#[cfg(target_os = "linux")] mod tests {
1097 use nautilus_common::testing::wait_until_async;
1098 use pyo3::Python;
1099 use tokio::{
1100 io::{AsyncReadExt, AsyncWriteExt},
1101 net::{TcpListener, TcpStream},
1102 sync::Mutex,
1103 task,
1104 time::{Duration, sleep},
1105 };
1106
1107 use super::*;
1108
1109 async fn bind_test_server() -> (u16, TcpListener) {
1110 let listener = TcpListener::bind("127.0.0.1:0")
1111 .await
1112 .expect("Failed to bind ephemeral port");
1113 let port = listener.local_addr().unwrap().port();
1114 (port, listener)
1115 }
1116
1117 async fn run_echo_server(mut socket: TcpStream) {
1118 let mut buf = Vec::new();
1119 loop {
1120 match socket.read_buf(&mut buf).await {
1121 Ok(0) => {
1122 break;
1123 }
1124 Ok(_n) => {
1125 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1126 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1127 line.truncate(line.len() - 2);
1129
1130 if line == b"close" {
1131 let _ = socket.shutdown().await;
1132 return;
1133 }
1134
1135 let mut echo_data = line;
1136 echo_data.extend_from_slice(b"\r\n");
1137 if socket.write_all(&echo_data).await.is_err() {
1138 break;
1139 }
1140 }
1141 }
1142 Err(e) => {
1143 eprintln!("Server read error: {e}");
1144 break;
1145 }
1146 }
1147 }
1148 }
1149
1150 #[tokio::test]
1151 async fn test_basic_send_receive() {
1152 Python::initialize();
1153
1154 let (port, listener) = bind_test_server().await;
1155 let server_task = task::spawn(async move {
1156 let (socket, _) = listener.accept().await.unwrap();
1157 run_echo_server(socket).await;
1158 });
1159
1160 let config = SocketConfig {
1161 url: format!("127.0.0.1:{port}"),
1162 mode: Mode::Plain,
1163 suffix: b"\r\n".to_vec(),
1164 message_handler: None,
1165 heartbeat: None,
1166 reconnect_timeout_ms: None,
1167 reconnect_delay_initial_ms: None,
1168 reconnect_backoff_factor: None,
1169 reconnect_delay_max_ms: None,
1170 reconnect_jitter_ms: None,
1171 reconnect_max_attempts: None,
1172 connection_max_retries: None,
1173 certs_dir: None,
1174 };
1175
1176 let client = SocketClient::connect(config, None, None, None)
1177 .await
1178 .expect("Client connect failed unexpectedly");
1179
1180 client.send_bytes(b"Hello".into()).await.unwrap();
1181 client.send_bytes(b"World".into()).await.unwrap();
1182
1183 sleep(Duration::from_millis(100)).await;
1185
1186 client.send_bytes(b"close".into()).await.unwrap();
1187 server_task.await.unwrap();
1188 assert!(!client.is_closed());
1189 }
1190
1191 #[tokio::test]
1192 async fn test_reconnect_fail_exhausted() {
1193 Python::initialize();
1194
1195 let (port, listener) = bind_test_server().await;
1196 drop(listener); wait_until_async(
1200 || async {
1201 TcpStream::connect(format!("127.0.0.1:{port}"))
1202 .await
1203 .is_err()
1204 },
1205 Duration::from_secs(2),
1206 )
1207 .await;
1208
1209 let config = SocketConfig {
1210 url: format!("127.0.0.1:{port}"),
1211 mode: Mode::Plain,
1212 suffix: b"\r\n".to_vec(),
1213 message_handler: None,
1214 heartbeat: None,
1215 reconnect_timeout_ms: Some(100),
1216 reconnect_delay_initial_ms: Some(50),
1217 reconnect_backoff_factor: Some(1.0),
1218 reconnect_delay_max_ms: Some(50),
1219 reconnect_jitter_ms: Some(0),
1220 connection_max_retries: Some(1),
1221 reconnect_max_attempts: None,
1222 certs_dir: None,
1223 };
1224
1225 let client_res = SocketClient::connect(config, None, None, None).await;
1226 assert!(
1227 client_res.is_err(),
1228 "Should fail quickly with no server listening"
1229 );
1230 }
1231
1232 #[tokio::test]
1233 async fn test_user_disconnect() {
1234 Python::initialize();
1235
1236 let (port, listener) = bind_test_server().await;
1237 let server_task = task::spawn(async move {
1238 let (socket, _) = listener.accept().await.unwrap();
1239 let mut buf = [0u8; 1024];
1240 let _ = socket.try_read(&mut buf);
1241
1242 loop {
1243 sleep(Duration::from_secs(1)).await;
1244 }
1245 });
1246
1247 let config = SocketConfig {
1248 url: format!("127.0.0.1:{port}"),
1249 mode: Mode::Plain,
1250 suffix: b"\r\n".to_vec(),
1251 message_handler: None,
1252 heartbeat: None,
1253 reconnect_timeout_ms: None,
1254 reconnect_delay_initial_ms: None,
1255 reconnect_backoff_factor: None,
1256 reconnect_delay_max_ms: None,
1257 reconnect_jitter_ms: None,
1258 reconnect_max_attempts: None,
1259 connection_max_retries: None,
1260 certs_dir: None,
1261 };
1262
1263 let client = SocketClient::connect(config, None, None, None)
1264 .await
1265 .unwrap();
1266
1267 client.close().await;
1268 assert!(client.is_closed());
1269 server_task.abort();
1270 }
1271
1272 #[tokio::test]
1273 async fn test_heartbeat() {
1274 Python::initialize();
1275
1276 let (port, listener) = bind_test_server().await;
1277 let received = Arc::new(Mutex::new(Vec::new()));
1278 let received2 = received.clone();
1279
1280 let server_task = task::spawn(async move {
1281 let (socket, _) = listener.accept().await.unwrap();
1282
1283 let mut buf = Vec::new();
1284 loop {
1285 match socket.try_read_buf(&mut buf) {
1286 Ok(0) => break,
1287 Ok(_) => {
1288 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1289 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1290 line.truncate(line.len() - 2);
1291 received2.lock().await.push(line);
1292 }
1293 }
1294 Err(_) => {
1295 tokio::time::sleep(Duration::from_millis(10)).await;
1296 }
1297 }
1298 }
1299 });
1300
1301 let heartbeat = Some((1, b"ping".to_vec()));
1303
1304 let config = SocketConfig {
1305 url: format!("127.0.0.1:{port}"),
1306 mode: Mode::Plain,
1307 suffix: b"\r\n".to_vec(),
1308 message_handler: None,
1309 heartbeat,
1310 reconnect_timeout_ms: None,
1311 reconnect_delay_initial_ms: None,
1312 reconnect_backoff_factor: None,
1313 reconnect_delay_max_ms: None,
1314 reconnect_jitter_ms: None,
1315 reconnect_max_attempts: None,
1316 connection_max_retries: None,
1317 certs_dir: None,
1318 };
1319
1320 let client = SocketClient::connect(config, None, None, None)
1321 .await
1322 .unwrap();
1323
1324 sleep(Duration::from_secs(3)).await;
1326
1327 {
1328 let lock = received.lock().await;
1329 let pings = lock
1330 .iter()
1331 .filter(|line| line == &&b"ping".to_vec())
1332 .count();
1333 assert!(
1334 pings >= 2,
1335 "Expected at least 2 heartbeat pings; got {pings}"
1336 );
1337 }
1338
1339 client.close().await;
1340 server_task.abort();
1341 }
1342
1343 #[tokio::test]
1344 async fn test_reconnect_success() {
1345 Python::initialize();
1346
1347 let (port, listener) = bind_test_server().await;
1348
1349 let server_task = task::spawn(async move {
1353 let (mut socket, _) = listener.accept().await.expect("First accept failed");
1355
1356 sleep(Duration::from_millis(500)).await;
1358 let _ = socket.shutdown().await;
1359
1360 sleep(Duration::from_millis(500)).await;
1362
1363 let (socket, _) = listener.accept().await.expect("Second accept failed");
1365 run_echo_server(socket).await;
1366 });
1367
1368 let config = SocketConfig {
1369 url: format!("127.0.0.1:{port}"),
1370 mode: Mode::Plain,
1371 suffix: b"\r\n".to_vec(),
1372 message_handler: None,
1373 heartbeat: None,
1374 reconnect_timeout_ms: Some(5_000),
1375 reconnect_delay_initial_ms: Some(500),
1376 reconnect_delay_max_ms: Some(5_000),
1377 reconnect_backoff_factor: Some(2.0),
1378 reconnect_jitter_ms: Some(50),
1379 reconnect_max_attempts: None,
1380 connection_max_retries: None,
1381 certs_dir: None,
1382 };
1383
1384 let client = SocketClient::connect(config, None, None, None)
1385 .await
1386 .expect("Client connect failed unexpectedly");
1387
1388 assert!(client.is_active(), "Client should start as active");
1390
1391 wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1394
1395 client
1396 .send_bytes(b"TestReconnect".into())
1397 .await
1398 .expect("Send failed");
1399
1400 client.close().await;
1401 server_task.abort();
1402 }
1403}
1404
1405#[cfg(test)]
1406#[cfg(not(feature = "turmoil"))]
1407mod rust_tests {
1408 use nautilus_common::testing::wait_until_async;
1409 use rstest::rstest;
1410 use tokio::{
1411 io::{AsyncReadExt, AsyncWriteExt},
1412 net::TcpListener,
1413 task,
1414 time::{Duration, sleep},
1415 };
1416
1417 use super::*;
1418
1419 #[rstest]
1420 #[tokio::test]
1421 async fn test_reconnect_then_close() {
1422 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1424 let port = listener.local_addr().unwrap().port();
1425
1426 let server = task::spawn(async move {
1428 if let Ok((mut sock, _)) = listener.accept().await {
1429 drop(sock.shutdown());
1430 }
1431 sleep(Duration::from_secs(1)).await;
1433 });
1434
1435 let config = SocketConfig {
1437 url: format!("127.0.0.1:{port}"),
1438 mode: Mode::Plain,
1439 suffix: b"\r\n".to_vec(),
1440 message_handler: None,
1441 heartbeat: None,
1442 reconnect_timeout_ms: Some(1_000),
1443 reconnect_delay_initial_ms: Some(50),
1444 reconnect_delay_max_ms: Some(100),
1445 reconnect_backoff_factor: Some(1.0),
1446 reconnect_jitter_ms: Some(0),
1447 connection_max_retries: Some(1),
1448 reconnect_max_attempts: None,
1449 certs_dir: None,
1450 };
1451
1452 let client = SocketClient::connect(config.clone(), None, None, None)
1454 .await
1455 .unwrap();
1456
1457 wait_until_async(
1459 || async { client.is_reconnecting() },
1460 Duration::from_secs(2),
1461 )
1462 .await;
1463
1464 client.close().await;
1466 assert!(client.is_closed());
1467 server.abort();
1468 }
1469
1470 #[rstest]
1471 #[tokio::test]
1472 async fn test_reconnect_state_flips_when_reader_stops() {
1473 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1475 let port = listener.local_addr().unwrap().port();
1476
1477 let server = task::spawn(async move {
1478 if let Ok((sock, _)) = listener.accept().await {
1479 drop(sock);
1480 }
1481 sleep(Duration::from_millis(50)).await;
1483 });
1484
1485 let config = SocketConfig {
1486 url: format!("127.0.0.1:{port}"),
1487 mode: Mode::Plain,
1488 suffix: b"\r\n".to_vec(),
1489 message_handler: None,
1490 heartbeat: None,
1491 reconnect_timeout_ms: Some(1_000),
1492 reconnect_delay_initial_ms: Some(50),
1493 reconnect_delay_max_ms: Some(100),
1494 reconnect_backoff_factor: Some(1.0),
1495 reconnect_jitter_ms: Some(0),
1496 connection_max_retries: Some(1),
1497 reconnect_max_attempts: None,
1498 certs_dir: None,
1499 };
1500
1501 let client = SocketClient::connect(config, None, None, None)
1502 .await
1503 .unwrap();
1504
1505 wait_until_async(
1506 || async { client.is_reconnecting() },
1507 Duration::from_secs(2),
1508 )
1509 .await;
1510
1511 client.close().await;
1512 server.abort();
1513 }
1514
1515 #[rstest]
1516 fn test_parse_socket_url_raw_address() {
1517 let (socket_addr, request_url) =
1519 SocketClientInner::parse_socket_url("example.com:6130", Mode::Tls).unwrap();
1520 assert_eq!(socket_addr, "example.com:6130");
1521 assert_eq!(request_url, "wss://example.com:6130");
1522
1523 let (socket_addr, request_url) =
1525 SocketClientInner::parse_socket_url("localhost:8080", Mode::Plain).unwrap();
1526 assert_eq!(socket_addr, "localhost:8080");
1527 assert_eq!(request_url, "ws://localhost:8080");
1528 }
1529
1530 #[rstest]
1531 fn test_parse_socket_url_with_scheme() {
1532 let (socket_addr, request_url) =
1534 SocketClientInner::parse_socket_url("wss://example.com:443/path", Mode::Tls).unwrap();
1535 assert_eq!(socket_addr, "example.com:443");
1536 assert_eq!(request_url, "wss://example.com:443/path");
1537
1538 let (socket_addr, request_url) =
1540 SocketClientInner::parse_socket_url("ws://localhost:8080", Mode::Plain).unwrap();
1541 assert_eq!(socket_addr, "localhost:8080");
1542 assert_eq!(request_url, "ws://localhost:8080");
1543 }
1544
1545 #[rstest]
1546 fn test_parse_socket_url_default_ports() {
1547 let (socket_addr, _) =
1549 SocketClientInner::parse_socket_url("wss://example.com", Mode::Tls).unwrap();
1550 assert_eq!(socket_addr, "example.com:443");
1551
1552 let (socket_addr, _) =
1554 SocketClientInner::parse_socket_url("ws://example.com", Mode::Plain).unwrap();
1555 assert_eq!(socket_addr, "example.com:80");
1556
1557 let (socket_addr, _) =
1559 SocketClientInner::parse_socket_url("https://example.com", Mode::Tls).unwrap();
1560 assert_eq!(socket_addr, "example.com:443");
1561
1562 let (socket_addr, _) =
1564 SocketClientInner::parse_socket_url("http://example.com", Mode::Plain).unwrap();
1565 assert_eq!(socket_addr, "example.com:80");
1566 }
1567
1568 #[rstest]
1569 fn test_parse_socket_url_unknown_scheme_uses_mode() {
1570 let (socket_addr, _) =
1572 SocketClientInner::parse_socket_url("custom://example.com", Mode::Tls).unwrap();
1573 assert_eq!(socket_addr, "example.com:443");
1574
1575 let (socket_addr, _) =
1576 SocketClientInner::parse_socket_url("custom://example.com", Mode::Plain).unwrap();
1577 assert_eq!(socket_addr, "example.com:80");
1578 }
1579
1580 #[rstest]
1581 fn test_parse_socket_url_ipv6() {
1582 let (socket_addr, request_url) =
1584 SocketClientInner::parse_socket_url("[::1]:8080", Mode::Plain).unwrap();
1585 assert_eq!(socket_addr, "[::1]:8080");
1586 assert_eq!(request_url, "ws://[::1]:8080");
1587
1588 let (socket_addr, _) =
1590 SocketClientInner::parse_socket_url("ws://[::1]:8080", Mode::Plain).unwrap();
1591 assert_eq!(socket_addr, "[::1]:8080");
1592 }
1593
1594 #[rstest]
1595 #[tokio::test]
1596 async fn test_url_parsing_raw_socket_address() {
1597 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1599 let port = listener.local_addr().unwrap().port();
1600
1601 let server = task::spawn(async move {
1602 if let Ok((sock, _)) = listener.accept().await {
1603 drop(sock);
1604 }
1605 sleep(Duration::from_millis(50)).await;
1606 });
1607
1608 let config = SocketConfig {
1609 url: format!("127.0.0.1:{port}"), mode: Mode::Plain,
1611 suffix: b"\r\n".to_vec(),
1612 message_handler: None,
1613 heartbeat: None,
1614 reconnect_timeout_ms: Some(1_000),
1615 reconnect_delay_initial_ms: Some(50),
1616 reconnect_delay_max_ms: Some(100),
1617 reconnect_backoff_factor: Some(1.0),
1618 reconnect_jitter_ms: Some(0),
1619 connection_max_retries: Some(1),
1620 reconnect_max_attempts: None,
1621 certs_dir: None,
1622 };
1623
1624 let client = SocketClient::connect(config, None, None, None).await;
1626 assert!(
1627 client.is_ok(),
1628 "Client should connect with raw socket address format"
1629 );
1630
1631 if let Ok(client) = client {
1632 client.close().await;
1633 }
1634 server.abort();
1635 }
1636
1637 #[rstest]
1638 #[tokio::test]
1639 async fn test_url_parsing_with_scheme() {
1640 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1642 let port = listener.local_addr().unwrap().port();
1643
1644 let server = task::spawn(async move {
1645 if let Ok((sock, _)) = listener.accept().await {
1646 drop(sock);
1647 }
1648 sleep(Duration::from_millis(50)).await;
1649 });
1650
1651 let config = SocketConfig {
1652 url: format!("ws://127.0.0.1:{port}"), mode: Mode::Plain,
1654 suffix: b"\r\n".to_vec(),
1655 message_handler: None,
1656 heartbeat: None,
1657 reconnect_timeout_ms: Some(1_000),
1658 reconnect_delay_initial_ms: Some(50),
1659 reconnect_delay_max_ms: Some(100),
1660 reconnect_backoff_factor: Some(1.0),
1661 reconnect_jitter_ms: Some(0),
1662 connection_max_retries: Some(1),
1663 reconnect_max_attempts: None,
1664 certs_dir: None,
1665 };
1666
1667 let client = SocketClient::connect(config, None, None, None).await;
1669 assert!(
1670 client.is_ok(),
1671 "Client should connect with URL scheme format"
1672 );
1673
1674 if let Ok(client) = client {
1675 client.close().await;
1676 }
1677 server.abort();
1678 }
1679
1680 #[rstest]
1681 fn test_parse_socket_url_ipv6_with_zone() {
1682 let (socket_addr, request_url) =
1684 SocketClientInner::parse_socket_url("[fe80::1%eth0]:8080", Mode::Plain).unwrap();
1685 assert_eq!(socket_addr, "[fe80::1%eth0]:8080");
1686 assert_eq!(request_url, "ws://[fe80::1%eth0]:8080");
1687
1688 let (socket_addr, request_url) =
1690 SocketClientInner::parse_socket_url("ws://[fe80::1%lo]:9090", Mode::Plain).unwrap();
1691 assert_eq!(socket_addr, "[fe80::1%lo]:9090");
1692 assert_eq!(request_url, "ws://[fe80::1%lo]:9090");
1693 }
1694
1695 #[rstest]
1696 #[tokio::test]
1697 async fn test_ipv6_loopback_connection() {
1698 if TcpListener::bind("[::1]:0").await.is_err() {
1701 eprintln!("IPv6 not available, skipping test");
1702 return;
1703 }
1704
1705 let listener = TcpListener::bind("[::1]:0").await.unwrap();
1706 let port = listener.local_addr().unwrap().port();
1707
1708 let server = task::spawn(async move {
1709 if let Ok((mut sock, _)) = listener.accept().await {
1710 let mut buf = vec![0u8; 1024];
1711 if let Ok(n) = sock.read(&mut buf).await {
1712 let _ = sock.write_all(&buf[..n]).await;
1714 }
1715 }
1716 sleep(Duration::from_millis(50)).await;
1717 });
1718
1719 let config = SocketConfig {
1720 url: format!("[::1]:{port}"), mode: Mode::Plain,
1722 suffix: b"\r\n".to_vec(),
1723 message_handler: None,
1724 heartbeat: None,
1725 reconnect_timeout_ms: Some(1_000),
1726 reconnect_delay_initial_ms: Some(50),
1727 reconnect_delay_max_ms: Some(100),
1728 reconnect_backoff_factor: Some(1.0),
1729 reconnect_jitter_ms: Some(0),
1730 connection_max_retries: Some(1),
1731 reconnect_max_attempts: None,
1732 certs_dir: None,
1733 };
1734
1735 let client = SocketClient::connect(config, None, None, None).await;
1736 assert!(
1737 client.is_ok(),
1738 "Client should connect to IPv6 loopback address"
1739 );
1740
1741 if let Ok(client) = client {
1742 client.close().await;
1743 }
1744 server.abort();
1745 }
1746
1747 #[rstest]
1748 #[tokio::test]
1749 async fn test_send_waits_during_reconnection() {
1750 use nautilus_common::testing::wait_until_async;
1752
1753 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1754 let port = listener.local_addr().unwrap().port();
1755
1756 let server = task::spawn(async move {
1757 if let Ok((sock, _)) = listener.accept().await {
1759 drop(sock);
1760 }
1761
1762 sleep(Duration::from_millis(500)).await;
1764
1765 if let Ok((mut sock, _)) = listener.accept().await {
1767 let mut buf = vec![0u8; 1024];
1769 while let Ok(n) = sock.read(&mut buf).await {
1770 if n == 0 {
1771 break;
1772 }
1773 if sock.write_all(&buf[..n]).await.is_err() {
1774 break;
1775 }
1776 }
1777 }
1778 });
1779
1780 let config = SocketConfig {
1781 url: format!("127.0.0.1:{port}"),
1782 mode: Mode::Plain,
1783 suffix: b"\r\n".to_vec(),
1784 message_handler: None,
1785 heartbeat: None,
1786 reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
1788 reconnect_delay_max_ms: Some(200),
1789 reconnect_backoff_factor: Some(1.0),
1790 reconnect_jitter_ms: Some(0),
1791 connection_max_retries: Some(1),
1792 reconnect_max_attempts: None,
1793 certs_dir: None,
1794 };
1795
1796 let client = SocketClient::connect(config, None, None, None)
1797 .await
1798 .unwrap();
1799
1800 wait_until_async(
1802 || async { client.is_reconnecting() },
1803 Duration::from_secs(2),
1804 )
1805 .await;
1806
1807 let send_result = tokio::time::timeout(
1809 Duration::from_secs(3),
1810 client.send_bytes(b"test_message".to_vec()),
1811 )
1812 .await;
1813
1814 assert!(
1815 send_result.is_ok() && send_result.unwrap().is_ok(),
1816 "Send should succeed after waiting for reconnection"
1817 );
1818
1819 client.close().await;
1820 server.abort();
1821 }
1822
1823 #[rstest]
1824 #[tokio::test]
1825 async fn test_send_bytes_timeout_uses_configured_reconnect_timeout() {
1826 use nautilus_common::testing::wait_until_async;
1829
1830 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1831 let port = listener.local_addr().unwrap().port();
1832
1833 let server = task::spawn(async move {
1834 if let Ok((sock, _)) = listener.accept().await {
1836 drop(sock);
1837 }
1838 drop(listener);
1840 sleep(Duration::from_secs(60)).await;
1841 });
1842
1843 let config = SocketConfig {
1844 url: format!("127.0.0.1:{port}"),
1845 mode: Mode::Plain,
1846 suffix: b"\r\n".to_vec(),
1847 message_handler: None,
1848 heartbeat: None,
1849 reconnect_timeout_ms: Some(1_000), reconnect_delay_initial_ms: Some(200), reconnect_delay_max_ms: Some(200),
1852 reconnect_backoff_factor: Some(1.0),
1853 reconnect_jitter_ms: Some(0),
1854 connection_max_retries: Some(1),
1855 reconnect_max_attempts: None,
1856 certs_dir: None,
1857 };
1858
1859 let client = SocketClient::connect(config, None, None, None)
1860 .await
1861 .unwrap();
1862
1863 wait_until_async(
1865 || async { client.is_reconnecting() },
1866 Duration::from_secs(3),
1867 )
1868 .await;
1869
1870 let start = std::time::Instant::now();
1873 let send_result = client.send_bytes(b"test".to_vec()).await;
1874 let elapsed = start.elapsed();
1875
1876 assert!(
1877 send_result.is_err(),
1878 "Send should fail when client stuck in RECONNECT, was: {:?}",
1879 send_result
1880 );
1881 assert!(
1882 matches!(send_result, Err(crate::error::SendError::Timeout)),
1883 "Send should return Timeout error, was: {:?}",
1884 send_result
1885 );
1886 assert!(
1889 elapsed >= Duration::from_millis(900),
1890 "Send should timeout after at least 1s (configured timeout), took {:?}",
1891 elapsed
1892 );
1893
1894 client.close().await;
1895 server.abort();
1896 }
1897}