1use std::{
32 fmt::Debug,
33 path::Path,
34 sync::{
35 Arc,
36 atomic::{AtomicU8, Ordering},
37 },
38 time::Duration,
39};
40
41use bytes::Bytes;
42use nautilus_core::CleanDrop;
43use nautilus_cryptography::providers::install_cryptographic_provider;
44use tokio::{
45 io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
46 net::TcpStream,
47};
48use tokio_tungstenite::{
49 MaybeTlsStream,
50 tungstenite::{Error, client::IntoClientRequest, stream::Mode},
51};
52
53use crate::{
54 backoff::ExponentialBackoff,
55 error::SendError,
56 fix::process_fix_buffer,
57 logging::{log_task_aborted, log_task_started, log_task_stopped},
58 mode::ConnectionMode,
59 tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
60};
61
62const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
64const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
65const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
66const SEND_OPERATION_CHECK_INTERVAL_MS: u64 = 1;
67
68type TcpWriter = WriteHalf<MaybeTlsStream<TcpStream>>;
69type TcpReader = ReadHalf<MaybeTlsStream<TcpStream>>;
70pub type TcpMessageHandler = Arc<dyn Fn(&[u8]) + Send + Sync>;
71
72#[cfg_attr(
74 feature = "python",
75 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
76)]
77pub struct SocketConfig {
78 pub url: String,
80 pub mode: Mode,
82 pub suffix: Vec<u8>,
84 pub message_handler: Option<TcpMessageHandler>,
86 pub heartbeat: Option<(u64, Vec<u8>)>,
88 pub reconnect_timeout_ms: Option<u64>,
90 pub reconnect_delay_initial_ms: Option<u64>,
92 pub reconnect_delay_max_ms: Option<u64>,
94 pub reconnect_backoff_factor: Option<f64>,
96 pub reconnect_jitter_ms: Option<u64>,
98 pub certs_dir: Option<String>,
100}
101
102impl Debug for SocketConfig {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct(stringify!(SocketConfig))
105 .field("url", &self.url)
106 .field("mode", &self.mode)
107 .field("suffix", &self.suffix)
108 .field(
109 "message_handler",
110 &self.message_handler.as_ref().map(|_| "<function>"),
111 )
112 .field("heartbeat", &self.heartbeat)
113 .field("reconnect_timeout_ms", &self.reconnect_timeout_ms)
114 .field(
115 "reconnect_delay_initial_ms",
116 &self.reconnect_delay_initial_ms,
117 )
118 .field("reconnect_delay_max_ms", &self.reconnect_delay_max_ms)
119 .field("reconnect_backoff_factor", &self.reconnect_backoff_factor)
120 .field("reconnect_jitter_ms", &self.reconnect_jitter_ms)
121 .field("certs_dir", &self.certs_dir)
122 .finish()
123 }
124}
125
126impl Clone for SocketConfig {
127 fn clone(&self) -> Self {
128 Self {
129 url: self.url.clone(),
130 mode: self.mode,
131 suffix: self.suffix.clone(),
132 message_handler: self.message_handler.clone(),
133 heartbeat: self.heartbeat.clone(),
134 reconnect_timeout_ms: self.reconnect_timeout_ms,
135 reconnect_delay_initial_ms: self.reconnect_delay_initial_ms,
136 reconnect_delay_max_ms: self.reconnect_delay_max_ms,
137 reconnect_backoff_factor: self.reconnect_backoff_factor,
138 reconnect_jitter_ms: self.reconnect_jitter_ms,
139 certs_dir: self.certs_dir.clone(),
140 }
141 }
142}
143
144#[derive(Debug)]
146pub enum WriterCommand {
147 Update(TcpWriter),
149 Send(Bytes),
151}
152
153#[cfg_attr(
169 feature = "python",
170 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
171)]
172struct SocketClientInner {
173 config: SocketConfig,
174 connector: Option<Connector>,
175 read_task: Arc<tokio::task::JoinHandle<()>>,
176 write_task: tokio::task::JoinHandle<()>,
177 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
178 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
179 connection_mode: Arc<AtomicU8>,
180 reconnect_timeout: Duration,
181 backoff: ExponentialBackoff,
182 handler: Option<TcpMessageHandler>,
183}
184
185impl SocketClientInner {
186 pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
192 install_cryptographic_provider();
193
194 let SocketConfig {
195 url,
196 mode,
197 heartbeat,
198 suffix,
199 message_handler,
200 reconnect_timeout_ms,
201 reconnect_delay_initial_ms,
202 reconnect_delay_max_ms,
203 reconnect_backoff_factor,
204 reconnect_jitter_ms,
205 certs_dir,
206 } = &config.clone();
207 let connector = if let Some(dir) = certs_dir {
208 let config = create_tls_config_from_certs_dir(Path::new(dir), false)?;
209 Some(Connector::Rustls(Arc::new(config)))
210 } else {
211 None
212 };
213
214 let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector.clone()).await?;
215 tracing::debug!("Connected");
216
217 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
218
219 let read_task = Arc::new(Self::spawn_read_task(
220 connection_mode.clone(),
221 reader,
222 message_handler.clone(),
223 suffix.clone(),
224 ));
225
226 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
227
228 let write_task =
229 Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
230
231 let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
233 Self::spawn_heartbeat_task(
234 connection_mode.clone(),
235 heartbeat.clone(),
236 writer_tx.clone(),
237 )
238 });
239
240 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
241 let backoff = ExponentialBackoff::new(
242 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
243 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
244 reconnect_backoff_factor.unwrap_or(1.5),
245 reconnect_jitter_ms.unwrap_or(100),
246 true, )?;
248
249 Ok(Self {
250 config,
251 connector,
252 read_task,
253 write_task,
254 writer_tx,
255 heartbeat_task,
256 connection_mode,
257 reconnect_timeout,
258 backoff,
259 handler: message_handler.clone(),
260 })
261 }
262
263 fn parse_socket_url(url: &str, mode: Mode) -> Result<(String, String), Error> {
273 if url.contains("://") {
274 let parsed = url.parse::<http::Uri>().map_err(|e| {
276 Error::Io(std::io::Error::new(
277 std::io::ErrorKind::InvalidInput,
278 format!("Invalid URL: {e}"),
279 ))
280 })?;
281
282 let host = parsed.host().ok_or_else(|| {
283 Error::Io(std::io::Error::new(
284 std::io::ErrorKind::InvalidInput,
285 "URL missing host",
286 ))
287 })?;
288
289 let port = parsed
290 .port_u16()
291 .unwrap_or_else(|| match parsed.scheme_str() {
292 Some("wss") | Some("https") => 443,
293 Some("ws") | Some("http") => 80,
294 _ => match mode {
295 Mode::Tls => 443,
296 Mode::Plain => 80,
297 },
298 });
299
300 Ok((format!("{host}:{port}"), url.to_string()))
301 } else {
302 let scheme = match mode {
305 Mode::Tls => "wss",
306 Mode::Plain => "ws",
307 };
308 Ok((url.to_string(), format!("{scheme}://{url}")))
309 }
310 }
311
312 pub async fn tls_connect_with_server(
322 url: &str,
323 mode: Mode,
324 connector: Option<Connector>,
325 ) -> Result<(TcpReader, TcpWriter), Error> {
326 tracing::debug!("Connecting to {url}");
327
328 let (socket_addr, request_url) = Self::parse_socket_url(url, mode)?;
329 let tcp_result = TcpStream::connect(&socket_addr).await;
330
331 match tcp_result {
332 Ok(stream) => {
333 tracing::debug!("TCP connection established to {socket_addr}, proceeding with TLS");
334 let request = request_url.into_client_request()?;
335 tcp_tls(&request, mode, stream, connector)
336 .await
337 .map(tokio::io::split)
338 }
339 Err(e) => {
340 tracing::error!("TCP connection failed to {socket_addr}: {e:?}");
341 Err(Error::Io(e))
342 }
343 }
344 }
345
346 async fn reconnect(&mut self) -> Result<(), Error> {
351 tracing::debug!("Reconnecting");
352
353 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
354 tracing::debug!("Reconnect aborted due to disconnect state");
355 return Ok(());
356 }
357
358 tokio::time::timeout(self.reconnect_timeout, async {
359 let SocketConfig {
360 url,
361 mode,
362 heartbeat: _,
363 suffix,
364 message_handler: _,
365 reconnect_timeout_ms: _,
366 reconnect_delay_initial_ms: _,
367 reconnect_backoff_factor: _,
368 reconnect_delay_max_ms: _,
369 reconnect_jitter_ms: _,
370 certs_dir: _,
371 } = &self.config;
372 let connector = self.connector.clone();
374 let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
376
377 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
378 tracing::debug!("Reconnect aborted mid-flight (after connect)");
379 return Ok(());
380 }
381 tracing::debug!("Connected");
382
383 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
384 tracing::error!("{e}");
385 }
386
387 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
389
390 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
391 tracing::debug!("Reconnect aborted mid-flight (after delay)");
392 return Ok(());
393 }
394
395 if !self.read_task.is_finished() {
396 self.read_task.abort();
397 log_task_aborted("read");
398 }
399
400 if self
403 .connection_mode
404 .compare_exchange(
405 ConnectionMode::Reconnect.as_u8(),
406 ConnectionMode::Active.as_u8(),
407 Ordering::SeqCst,
408 Ordering::SeqCst,
409 )
410 .is_err()
411 {
412 tracing::debug!("Reconnect aborted (state changed during reconnect)");
413 return Ok(());
414 }
415
416 self.read_task = Arc::new(Self::spawn_read_task(
418 self.connection_mode.clone(),
419 reader,
420 self.handler.clone(),
421 suffix.clone(),
422 ));
423
424 tracing::debug!("Reconnect succeeded");
425 Ok(())
426 })
427 .await
428 .map_err(|_| {
429 Error::Io(std::io::Error::new(
430 std::io::ErrorKind::TimedOut,
431 format!(
432 "reconnection timed out after {}s",
433 self.reconnect_timeout.as_secs_f64()
434 ),
435 ))
436 })?
437 }
438
439 #[inline]
446 #[must_use]
447 pub fn is_alive(&self) -> bool {
448 !self.read_task.is_finished()
449 }
450
451 #[must_use]
452 fn spawn_read_task(
453 connection_state: Arc<AtomicU8>,
454 mut reader: TcpReader,
455 handler: Option<TcpMessageHandler>,
456 suffix: Vec<u8>,
457 ) -> tokio::task::JoinHandle<()> {
458 log_task_started("read");
459
460 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
462
463 tokio::task::spawn(async move {
464 let mut buf = Vec::new();
465
466 loop {
467 if !ConnectionMode::from_atomic(&connection_state).is_active() {
468 break;
469 }
470
471 match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
472 Ok(Ok(0)) => {
474 tracing::debug!("Connection closed by server");
475 break;
476 }
477 Ok(Err(e)) => {
478 tracing::debug!("Connection ended: {e}");
479 break;
480 }
481 Ok(Ok(bytes)) => {
483 tracing::trace!("Received <binary> {bytes} bytes");
484
485 let is_fix = buf.len() >= 5 && buf.starts_with(b"8=FIX");
487
488 if is_fix && handler.is_some() {
489 if let Some(ref handler) = handler {
491 process_fix_buffer(&mut buf, handler);
492 }
493 } else {
494 while let Some((i, _)) = &buf
496 .windows(suffix.len())
497 .enumerate()
498 .find(|(_, pair)| pair.eq(&suffix))
499 {
500 let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
501 data.truncate(data.len() - suffix.len());
502
503 if let Some(ref handler) = handler {
504 handler(&data);
505 }
506 }
507 }
508 }
509 Err(_) => {
510 continue;
512 }
513 }
514 }
515
516 log_task_stopped("read");
517 })
518 }
519
520 fn spawn_write_task(
521 connection_state: Arc<AtomicU8>,
522 writer: TcpWriter,
523 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
524 suffix: Vec<u8>,
525 ) -> tokio::task::JoinHandle<()> {
526 log_task_started("write");
527
528 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
530
531 tokio::task::spawn(async move {
532 let mut active_writer = writer;
533
534 loop {
535 if matches!(
536 ConnectionMode::from_atomic(&connection_state),
537 ConnectionMode::Disconnect | ConnectionMode::Closed
538 ) {
539 break;
540 }
541
542 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
543 Ok(Some(msg)) => {
544 let mode = ConnectionMode::from_atomic(&connection_state);
546 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
547 break;
548 }
549
550 match msg {
551 WriterCommand::Update(new_writer) => {
552 tracing::debug!("Received new writer");
553
554 tokio::time::sleep(Duration::from_millis(100)).await;
556
557 _ = tokio::time::timeout(
560 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
561 active_writer.shutdown(),
562 )
563 .await;
564
565 active_writer = new_writer;
566 tracing::debug!("Updated writer");
567 }
568 _ if mode.is_reconnect() => {
569 tracing::warn!("Skipping message while reconnecting, {msg:?}");
570 continue;
571 }
572 WriterCommand::Send(msg) => {
573 if let Err(e) = active_writer.write_all(&msg).await {
574 tracing::error!("Failed to send message: {e}");
575 tracing::warn!("Writer triggering reconnect");
577 connection_state
578 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
579 continue;
580 }
581 if let Err(e) = active_writer.write_all(&suffix).await {
582 tracing::error!("Failed to send suffix: {e}");
583 tracing::warn!("Writer triggering reconnect");
585 connection_state
586 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
587 continue;
588 }
589 }
590 }
591 }
592 Ok(None) => {
593 tracing::debug!("Writer channel closed, terminating writer task");
595 break;
596 }
597 Err(_) => {
598 continue;
600 }
601 }
602 }
603
604 _ = tokio::time::timeout(
607 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
608 active_writer.shutdown(),
609 )
610 .await;
611
612 log_task_stopped("write");
613 })
614 }
615
616 fn spawn_heartbeat_task(
617 connection_state: Arc<AtomicU8>,
618 heartbeat: (u64, Vec<u8>),
619 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
620 ) -> tokio::task::JoinHandle<()> {
621 log_task_started("heartbeat");
622 let (interval_secs, message) = heartbeat;
623
624 tokio::task::spawn(async move {
625 let interval = Duration::from_secs(interval_secs);
626
627 loop {
628 tokio::time::sleep(interval).await;
629
630 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
631 ConnectionMode::Active => {
632 let msg = WriterCommand::Send(message.clone().into());
633
634 match writer_tx.send(msg) {
635 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
636 Err(e) => {
637 tracing::error!("Failed to send heartbeat to writer task: {e}");
638 }
639 }
640 }
641 ConnectionMode::Reconnect => continue,
642 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
643 }
644 }
645
646 log_task_stopped("heartbeat");
647 })
648 }
649}
650
651impl Drop for SocketClientInner {
652 fn drop(&mut self) {
653 self.clean_drop();
655 }
656}
657
658impl CleanDrop for SocketClientInner {
660 fn clean_drop(&mut self) {
661 if !self.read_task.is_finished() {
662 self.read_task.abort();
663 log_task_aborted("read");
664 }
665
666 if !self.write_task.is_finished() {
667 self.write_task.abort();
668 log_task_aborted("write");
669 }
670
671 if let Some(ref handle) = self.heartbeat_task.take()
672 && !handle.is_finished()
673 {
674 handle.abort();
675 log_task_aborted("heartbeat");
676 }
677
678 #[cfg(feature = "python")]
679 {
680 self.config.message_handler = None;
682 }
683 }
684}
685
686#[cfg_attr(
687 feature = "python",
688 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
689)]
690pub struct SocketClient {
691 pub(crate) controller_task: tokio::task::JoinHandle<()>,
692 pub(crate) connection_mode: Arc<AtomicU8>,
693 pub(crate) reconnect_timeout: Duration,
694 pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
695}
696
697impl Debug for SocketClient {
698 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
699 f.debug_struct(stringify!(SocketClient)).finish()
700 }
701}
702
703impl SocketClient {
704 pub async fn connect(
710 config: SocketConfig,
711 post_connection: Option<Arc<dyn Fn() + Send + Sync>>,
712 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
713 post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
714 ) -> anyhow::Result<Self> {
715 let inner = SocketClientInner::connect_url(config).await?;
716 let writer_tx = inner.writer_tx.clone();
717 let connection_mode = inner.connection_mode.clone();
718 let reconnect_timeout = inner.reconnect_timeout;
719
720 let controller_task = Self::spawn_controller_task(
721 inner,
722 connection_mode.clone(),
723 post_reconnection,
724 post_disconnection,
725 );
726
727 if let Some(handler) = post_connection {
728 handler();
729 tracing::debug!("Called `post_connection` handler");
730 }
731
732 Ok(Self {
733 controller_task,
734 connection_mode,
735 reconnect_timeout,
736 writer_tx,
737 })
738 }
739
740 #[must_use]
742 pub fn connection_mode(&self) -> ConnectionMode {
743 ConnectionMode::from_atomic(&self.connection_mode)
744 }
745
746 #[inline]
751 #[must_use]
752 pub fn is_active(&self) -> bool {
753 self.connection_mode().is_active()
754 }
755
756 #[inline]
761 #[must_use]
762 pub fn is_reconnecting(&self) -> bool {
763 self.connection_mode().is_reconnect()
764 }
765
766 #[inline]
770 #[must_use]
771 pub fn is_disconnecting(&self) -> bool {
772 self.connection_mode().is_disconnect()
773 }
774
775 #[inline]
781 #[must_use]
782 pub fn is_closed(&self) -> bool {
783 self.connection_mode().is_closed()
784 }
785
786 pub async fn close(&self) {
791 self.connection_mode
792 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
793
794 match tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
795 while !self.is_closed() {
796 tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
797 }
798
799 if !self.controller_task.is_finished() {
800 self.controller_task.abort();
801 log_task_aborted("controller");
802 }
803 })
804 .await
805 {
806 Ok(()) => {
807 log_task_stopped("controller");
808 }
809 Err(_) => {
810 tracing::error!("Timeout waiting for controller task to finish");
811 if !self.controller_task.is_finished() {
812 self.controller_task.abort();
813 log_task_aborted("controller");
814 }
815 }
816 }
817 }
818
819 pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
825 if self.is_closed() {
826 return Err(SendError::Closed);
827 }
828
829 let timeout = self.reconnect_timeout;
830 let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
831
832 if !self.is_active() {
833 tracing::debug!("Waiting for client to become ACTIVE before sending...");
834
835 let inner = tokio::time::timeout(timeout, async {
836 loop {
837 if self.is_active() {
838 return Ok(());
839 }
840 if matches!(
841 self.connection_mode(),
842 ConnectionMode::Disconnect | ConnectionMode::Closed
843 ) {
844 return Err(());
845 }
846 tokio::time::sleep(check_interval).await;
847 }
848 })
849 .await
850 .map_err(|_| SendError::Timeout)?;
851 inner.map_err(|()| SendError::Closed)?;
852 }
853
854 let msg = WriterCommand::Send(data.into());
855 self.writer_tx
856 .send(msg)
857 .map_err(|e| SendError::BrokenPipe(e.to_string()))
858 }
859
860 fn spawn_controller_task(
861 mut inner: SocketClientInner,
862 connection_mode: Arc<AtomicU8>,
863 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
864 post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
865 ) -> tokio::task::JoinHandle<()> {
866 tokio::task::spawn(async move {
867 log_task_started("controller");
868
869 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
870
871 loop {
872 tokio::time::sleep(check_interval).await;
873 let mut mode = ConnectionMode::from_atomic(&connection_mode);
874
875 if mode.is_disconnect() {
876 tracing::debug!("Disconnecting");
877
878 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
879 if tokio::time::timeout(timeout, async {
880 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
882
883 if !inner.read_task.is_finished() {
884 inner.read_task.abort();
885 log_task_aborted("read");
886 }
887
888 if let Some(task) = &inner.heartbeat_task
889 && !task.is_finished()
890 {
891 task.abort();
892 log_task_aborted("heartbeat");
893 }
894 })
895 .await
896 .is_err()
897 {
898 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
899 }
900
901 tracing::debug!("Closed");
902
903 if let Some(ref handler) = post_disconnection {
904 handler();
905 tracing::debug!("Called `post_disconnection` handler");
906 }
907 break; }
909
910 if mode.is_active() && !inner.is_alive() {
911 if connection_mode
912 .compare_exchange(
913 ConnectionMode::Active.as_u8(),
914 ConnectionMode::Reconnect.as_u8(),
915 Ordering::SeqCst,
916 Ordering::SeqCst,
917 )
918 .is_ok()
919 {
920 tracing::debug!("Detected dead read task, transitioning to RECONNECT");
921 }
922 mode = ConnectionMode::from_atomic(&connection_mode);
923 }
924
925 if mode.is_reconnect() {
926 match inner.reconnect().await {
927 Ok(()) => {
928 tracing::debug!("Reconnected successfully");
929 inner.backoff.reset();
930 if ConnectionMode::from_atomic(&connection_mode).is_active() {
932 if let Some(ref handler) = post_reconnection {
933 handler();
934 tracing::debug!("Called `post_reconnection` handler");
935 }
936 } else {
937 tracing::debug!(
938 "Skipping post_reconnection handlers due to disconnect state"
939 );
940 }
941 }
942 Err(e) => {
943 let duration = inner.backoff.next_duration();
944 tracing::warn!("Reconnect attempt failed: {e}");
945 if !duration.is_zero() {
946 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
947 }
948 tokio::time::sleep(duration).await;
949 }
950 }
951 }
952 }
953 inner
954 .connection_mode
955 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
956
957 log_task_stopped("controller");
958 })
959 }
960}
961
962impl Drop for SocketClient {
964 fn drop(&mut self) {
965 if !self.controller_task.is_finished() {
966 self.controller_task.abort();
967 log_task_aborted("controller");
968 }
969 }
970}
971
972#[cfg(test)]
977#[cfg(feature = "python")]
978#[cfg(target_os = "linux")] mod tests {
980 use nautilus_common::testing::wait_until_async;
981 use pyo3::Python;
982 use tokio::{
983 io::{AsyncReadExt, AsyncWriteExt},
984 net::{TcpListener, TcpStream},
985 sync::Mutex,
986 task,
987 time::{Duration, sleep},
988 };
989
990 use super::*;
991
992 async fn bind_test_server() -> (u16, TcpListener) {
993 let listener = TcpListener::bind("127.0.0.1:0")
994 .await
995 .expect("Failed to bind ephemeral port");
996 let port = listener.local_addr().unwrap().port();
997 (port, listener)
998 }
999
1000 async fn run_echo_server(mut socket: TcpStream) {
1001 let mut buf = Vec::new();
1002 loop {
1003 match socket.read_buf(&mut buf).await {
1004 Ok(0) => {
1005 break;
1006 }
1007 Ok(_n) => {
1008 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1009 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1010 line.truncate(line.len() - 2);
1012
1013 if line == b"close" {
1014 let _ = socket.shutdown().await;
1015 return;
1016 }
1017
1018 let mut echo_data = line;
1019 echo_data.extend_from_slice(b"\r\n");
1020 if socket.write_all(&echo_data).await.is_err() {
1021 break;
1022 }
1023 }
1024 }
1025 Err(e) => {
1026 eprintln!("Server read error: {e}");
1027 break;
1028 }
1029 }
1030 }
1031 }
1032
1033 #[tokio::test]
1034 async fn test_basic_send_receive() {
1035 Python::initialize();
1036
1037 let (port, listener) = bind_test_server().await;
1038 let server_task = task::spawn(async move {
1039 let (socket, _) = listener.accept().await.unwrap();
1040 run_echo_server(socket).await;
1041 });
1042
1043 let config = SocketConfig {
1044 url: format!("127.0.0.1:{port}"),
1045 mode: Mode::Plain,
1046 suffix: b"\r\n".to_vec(),
1047 message_handler: None,
1048 heartbeat: None,
1049 reconnect_timeout_ms: None,
1050 reconnect_delay_initial_ms: None,
1051 reconnect_backoff_factor: None,
1052 reconnect_delay_max_ms: None,
1053 reconnect_jitter_ms: None,
1054 certs_dir: None,
1055 };
1056
1057 let client = SocketClient::connect(config, None, None, None)
1058 .await
1059 .expect("Client connect failed unexpectedly");
1060
1061 client.send_bytes(b"Hello".into()).await.unwrap();
1062 client.send_bytes(b"World".into()).await.unwrap();
1063
1064 sleep(Duration::from_millis(100)).await;
1066
1067 client.send_bytes(b"close".into()).await.unwrap();
1068 server_task.await.unwrap();
1069 assert!(!client.is_closed());
1070 }
1071
1072 #[tokio::test]
1073 async fn test_reconnect_fail_exhausted() {
1074 Python::initialize();
1075
1076 let (port, listener) = bind_test_server().await;
1077 drop(listener); let config = SocketConfig {
1080 url: format!("127.0.0.1:{port}"),
1081 mode: Mode::Plain,
1082 suffix: b"\r\n".to_vec(),
1083 message_handler: None,
1084 heartbeat: None,
1085 reconnect_timeout_ms: None,
1086 reconnect_delay_initial_ms: None,
1087 reconnect_backoff_factor: None,
1088 reconnect_delay_max_ms: None,
1089 reconnect_jitter_ms: None,
1090 certs_dir: None,
1091 };
1092
1093 let client_res = SocketClient::connect(config, None, None, None).await;
1094 assert!(
1095 client_res.is_err(),
1096 "Should fail quickly with no server listening"
1097 );
1098 }
1099
1100 #[tokio::test]
1101 async fn test_user_disconnect() {
1102 Python::initialize();
1103
1104 let (port, listener) = bind_test_server().await;
1105 let server_task = task::spawn(async move {
1106 let (socket, _) = listener.accept().await.unwrap();
1107 let mut buf = [0u8; 1024];
1108 let _ = socket.try_read(&mut buf);
1109
1110 loop {
1111 sleep(Duration::from_secs(1)).await;
1112 }
1113 });
1114
1115 let config = SocketConfig {
1116 url: format!("127.0.0.1:{port}"),
1117 mode: Mode::Plain,
1118 suffix: b"\r\n".to_vec(),
1119 message_handler: None,
1120 heartbeat: None,
1121 reconnect_timeout_ms: None,
1122 reconnect_delay_initial_ms: None,
1123 reconnect_backoff_factor: None,
1124 reconnect_delay_max_ms: None,
1125 reconnect_jitter_ms: None,
1126 certs_dir: None,
1127 };
1128
1129 let client = SocketClient::connect(config, None, None, None)
1130 .await
1131 .unwrap();
1132
1133 client.close().await;
1134 assert!(client.is_closed());
1135 server_task.abort();
1136 }
1137
1138 #[tokio::test]
1139 async fn test_heartbeat() {
1140 Python::initialize();
1141
1142 let (port, listener) = bind_test_server().await;
1143 let received = Arc::new(Mutex::new(Vec::new()));
1144 let received2 = received.clone();
1145
1146 let server_task = task::spawn(async move {
1147 let (socket, _) = listener.accept().await.unwrap();
1148
1149 let mut buf = Vec::new();
1150 loop {
1151 match socket.try_read_buf(&mut buf) {
1152 Ok(0) => break,
1153 Ok(_) => {
1154 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1155 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1156 line.truncate(line.len() - 2);
1157 received2.lock().await.push(line);
1158 }
1159 }
1160 Err(_) => {
1161 tokio::time::sleep(Duration::from_millis(10)).await;
1162 }
1163 }
1164 }
1165 });
1166
1167 let heartbeat = Some((1, b"ping".to_vec()));
1169
1170 let config = SocketConfig {
1171 url: format!("127.0.0.1:{port}"),
1172 mode: Mode::Plain,
1173 suffix: b"\r\n".to_vec(),
1174 message_handler: None,
1175 heartbeat,
1176 reconnect_timeout_ms: None,
1177 reconnect_delay_initial_ms: None,
1178 reconnect_backoff_factor: None,
1179 reconnect_delay_max_ms: None,
1180 reconnect_jitter_ms: None,
1181 certs_dir: None,
1182 };
1183
1184 let client = SocketClient::connect(config, None, None, None)
1185 .await
1186 .unwrap();
1187
1188 sleep(Duration::from_secs(3)).await;
1190
1191 {
1192 let lock = received.lock().await;
1193 let pings = lock
1194 .iter()
1195 .filter(|line| line == &&b"ping".to_vec())
1196 .count();
1197 assert!(
1198 pings >= 2,
1199 "Expected at least 2 heartbeat pings; got {pings}"
1200 );
1201 }
1202
1203 client.close().await;
1204 server_task.abort();
1205 }
1206
1207 #[tokio::test]
1208 async fn test_reconnect_success() {
1209 Python::initialize();
1210
1211 let (port, listener) = bind_test_server().await;
1212
1213 let server_task = task::spawn(async move {
1217 let (mut socket, _) = listener.accept().await.expect("First accept failed");
1219
1220 sleep(Duration::from_millis(500)).await;
1222 let _ = socket.shutdown().await;
1223
1224 sleep(Duration::from_millis(500)).await;
1226
1227 let (socket, _) = listener.accept().await.expect("Second accept failed");
1229 run_echo_server(socket).await;
1230 });
1231
1232 let config = SocketConfig {
1233 url: format!("127.0.0.1:{port}"),
1234 mode: Mode::Plain,
1235 suffix: b"\r\n".to_vec(),
1236 message_handler: None,
1237 heartbeat: None,
1238 reconnect_timeout_ms: Some(5_000),
1239 reconnect_delay_initial_ms: Some(500),
1240 reconnect_delay_max_ms: Some(5_000),
1241 reconnect_backoff_factor: Some(2.0),
1242 reconnect_jitter_ms: Some(50),
1243 certs_dir: None,
1244 };
1245
1246 let client = SocketClient::connect(config, None, None, None)
1247 .await
1248 .expect("Client connect failed unexpectedly");
1249
1250 assert!(client.is_active(), "Client should start as active");
1252
1253 wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1256
1257 client
1258 .send_bytes(b"TestReconnect".into())
1259 .await
1260 .expect("Send failed");
1261
1262 client.close().await;
1263 server_task.abort();
1264 }
1265}
1266
1267#[cfg(test)]
1268mod rust_tests {
1269 use rstest::rstest;
1270 use tokio::{
1271 io::{AsyncReadExt, AsyncWriteExt},
1272 net::TcpListener,
1273 task,
1274 time::{Duration, sleep},
1275 };
1276
1277 use super::*;
1278
1279 #[rstest]
1280 #[tokio::test]
1281 async fn test_reconnect_then_close() {
1282 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1284 let port = listener.local_addr().unwrap().port();
1285
1286 let server = task::spawn(async move {
1288 if let Ok((mut sock, _)) = listener.accept().await {
1289 drop(sock.shutdown());
1290 }
1291 sleep(Duration::from_secs(1)).await;
1293 });
1294
1295 let config = SocketConfig {
1297 url: format!("127.0.0.1:{port}"),
1298 mode: Mode::Plain,
1299 suffix: b"\r\n".to_vec(),
1300 message_handler: None,
1301 heartbeat: None,
1302 reconnect_timeout_ms: Some(1_000),
1303 reconnect_delay_initial_ms: Some(50),
1304 reconnect_delay_max_ms: Some(100),
1305 reconnect_backoff_factor: Some(1.0),
1306 reconnect_jitter_ms: Some(0),
1307 certs_dir: None,
1308 };
1309
1310 let client = SocketClient::connect(config.clone(), None, None, None)
1312 .await
1313 .unwrap();
1314
1315 sleep(Duration::from_millis(100)).await;
1317
1318 client.close().await;
1320 assert!(client.is_closed());
1321 server.abort();
1322 }
1323
1324 #[rstest]
1325 #[tokio::test]
1326 async fn test_reconnect_state_flips_when_reader_stops() {
1327 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1329 let port = listener.local_addr().unwrap().port();
1330
1331 let server = task::spawn(async move {
1332 if let Ok((sock, _)) = listener.accept().await {
1333 drop(sock);
1334 }
1335 sleep(Duration::from_millis(50)).await;
1337 });
1338
1339 let config = SocketConfig {
1340 url: format!("127.0.0.1:{port}"),
1341 mode: Mode::Plain,
1342 suffix: b"\r\n".to_vec(),
1343 message_handler: None,
1344 heartbeat: None,
1345 reconnect_timeout_ms: Some(1_000),
1346 reconnect_delay_initial_ms: Some(50),
1347 reconnect_delay_max_ms: Some(100),
1348 reconnect_backoff_factor: Some(1.0),
1349 reconnect_jitter_ms: Some(0),
1350 certs_dir: None,
1351 };
1352
1353 let client = SocketClient::connect(config, None, None, None)
1354 .await
1355 .unwrap();
1356
1357 tokio::time::timeout(Duration::from_secs(2), async {
1358 loop {
1359 if client.is_reconnecting() {
1360 break;
1361 }
1362 tokio::time::sleep(Duration::from_millis(10)).await;
1363 }
1364 })
1365 .await
1366 .expect("client did not enter RECONNECT state");
1367
1368 client.close().await;
1369 server.abort();
1370 }
1371
1372 #[rstest]
1373 fn test_parse_socket_url_raw_address() {
1374 let (socket_addr, request_url) =
1376 SocketClientInner::parse_socket_url("example.com:6130", Mode::Tls).unwrap();
1377 assert_eq!(socket_addr, "example.com:6130");
1378 assert_eq!(request_url, "wss://example.com:6130");
1379
1380 let (socket_addr, request_url) =
1382 SocketClientInner::parse_socket_url("localhost:8080", Mode::Plain).unwrap();
1383 assert_eq!(socket_addr, "localhost:8080");
1384 assert_eq!(request_url, "ws://localhost:8080");
1385 }
1386
1387 #[rstest]
1388 fn test_parse_socket_url_with_scheme() {
1389 let (socket_addr, request_url) =
1391 SocketClientInner::parse_socket_url("wss://example.com:443/path", Mode::Tls).unwrap();
1392 assert_eq!(socket_addr, "example.com:443");
1393 assert_eq!(request_url, "wss://example.com:443/path");
1394
1395 let (socket_addr, request_url) =
1397 SocketClientInner::parse_socket_url("ws://localhost:8080", Mode::Plain).unwrap();
1398 assert_eq!(socket_addr, "localhost:8080");
1399 assert_eq!(request_url, "ws://localhost:8080");
1400 }
1401
1402 #[rstest]
1403 fn test_parse_socket_url_default_ports() {
1404 let (socket_addr, _) =
1406 SocketClientInner::parse_socket_url("wss://example.com", Mode::Tls).unwrap();
1407 assert_eq!(socket_addr, "example.com:443");
1408
1409 let (socket_addr, _) =
1411 SocketClientInner::parse_socket_url("ws://example.com", Mode::Plain).unwrap();
1412 assert_eq!(socket_addr, "example.com:80");
1413
1414 let (socket_addr, _) =
1416 SocketClientInner::parse_socket_url("https://example.com", Mode::Tls).unwrap();
1417 assert_eq!(socket_addr, "example.com:443");
1418
1419 let (socket_addr, _) =
1421 SocketClientInner::parse_socket_url("http://example.com", Mode::Plain).unwrap();
1422 assert_eq!(socket_addr, "example.com:80");
1423 }
1424
1425 #[rstest]
1426 fn test_parse_socket_url_unknown_scheme_uses_mode() {
1427 let (socket_addr, _) =
1429 SocketClientInner::parse_socket_url("custom://example.com", Mode::Tls).unwrap();
1430 assert_eq!(socket_addr, "example.com:443");
1431
1432 let (socket_addr, _) =
1433 SocketClientInner::parse_socket_url("custom://example.com", Mode::Plain).unwrap();
1434 assert_eq!(socket_addr, "example.com:80");
1435 }
1436
1437 #[rstest]
1438 fn test_parse_socket_url_ipv6() {
1439 let (socket_addr, request_url) =
1441 SocketClientInner::parse_socket_url("[::1]:8080", Mode::Plain).unwrap();
1442 assert_eq!(socket_addr, "[::1]:8080");
1443 assert_eq!(request_url, "ws://[::1]:8080");
1444
1445 let (socket_addr, _) =
1447 SocketClientInner::parse_socket_url("ws://[::1]:8080", Mode::Plain).unwrap();
1448 assert_eq!(socket_addr, "[::1]:8080");
1449 }
1450
1451 #[rstest]
1452 #[tokio::test]
1453 async fn test_url_parsing_raw_socket_address() {
1454 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1456 let port = listener.local_addr().unwrap().port();
1457
1458 let server = task::spawn(async move {
1459 if let Ok((sock, _)) = listener.accept().await {
1460 drop(sock);
1461 }
1462 sleep(Duration::from_millis(50)).await;
1463 });
1464
1465 let config = SocketConfig {
1466 url: format!("127.0.0.1:{port}"), mode: Mode::Plain,
1468 suffix: b"\r\n".to_vec(),
1469 message_handler: None,
1470 heartbeat: None,
1471 reconnect_timeout_ms: Some(1_000),
1472 reconnect_delay_initial_ms: Some(50),
1473 reconnect_delay_max_ms: Some(100),
1474 reconnect_backoff_factor: Some(1.0),
1475 reconnect_jitter_ms: Some(0),
1476 certs_dir: None,
1477 };
1478
1479 let client = SocketClient::connect(config, None, None, None).await;
1481 assert!(
1482 client.is_ok(),
1483 "Client should connect with raw socket address format"
1484 );
1485
1486 if let Ok(client) = client {
1487 client.close().await;
1488 }
1489 server.abort();
1490 }
1491
1492 #[rstest]
1493 #[tokio::test]
1494 async fn test_url_parsing_with_scheme() {
1495 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1497 let port = listener.local_addr().unwrap().port();
1498
1499 let server = task::spawn(async move {
1500 if let Ok((sock, _)) = listener.accept().await {
1501 drop(sock);
1502 }
1503 sleep(Duration::from_millis(50)).await;
1504 });
1505
1506 let config = SocketConfig {
1507 url: format!("ws://127.0.0.1:{port}"), mode: Mode::Plain,
1509 suffix: b"\r\n".to_vec(),
1510 message_handler: None,
1511 heartbeat: None,
1512 reconnect_timeout_ms: Some(1_000),
1513 reconnect_delay_initial_ms: Some(50),
1514 reconnect_delay_max_ms: Some(100),
1515 reconnect_backoff_factor: Some(1.0),
1516 reconnect_jitter_ms: Some(0),
1517 certs_dir: None,
1518 };
1519
1520 let client = SocketClient::connect(config, None, None, None).await;
1522 assert!(
1523 client.is_ok(),
1524 "Client should connect with URL scheme format"
1525 );
1526
1527 if let Ok(client) = client {
1528 client.close().await;
1529 }
1530 server.abort();
1531 }
1532
1533 #[rstest]
1534 fn test_parse_socket_url_ipv6_with_zone() {
1535 let (socket_addr, request_url) =
1537 SocketClientInner::parse_socket_url("[fe80::1%eth0]:8080", Mode::Plain).unwrap();
1538 assert_eq!(socket_addr, "[fe80::1%eth0]:8080");
1539 assert_eq!(request_url, "ws://[fe80::1%eth0]:8080");
1540
1541 let (socket_addr, request_url) =
1543 SocketClientInner::parse_socket_url("ws://[fe80::1%lo]:9090", Mode::Plain).unwrap();
1544 assert_eq!(socket_addr, "[fe80::1%lo]:9090");
1545 assert_eq!(request_url, "ws://[fe80::1%lo]:9090");
1546 }
1547
1548 #[rstest]
1549 #[tokio::test]
1550 async fn test_ipv6_loopback_connection() {
1551 if TcpListener::bind("[::1]:0").await.is_err() {
1554 eprintln!("IPv6 not available, skipping test");
1555 return;
1556 }
1557
1558 let listener = TcpListener::bind("[::1]:0").await.unwrap();
1559 let port = listener.local_addr().unwrap().port();
1560
1561 let server = task::spawn(async move {
1562 if let Ok((mut sock, _)) = listener.accept().await {
1563 let mut buf = vec![0u8; 1024];
1564 if let Ok(n) = sock.read(&mut buf).await {
1565 let _ = sock.write_all(&buf[..n]).await;
1567 }
1568 }
1569 sleep(Duration::from_millis(50)).await;
1570 });
1571
1572 let config = SocketConfig {
1573 url: format!("[::1]:{port}"), mode: Mode::Plain,
1575 suffix: b"\r\n".to_vec(),
1576 message_handler: None,
1577 heartbeat: None,
1578 reconnect_timeout_ms: Some(1_000),
1579 reconnect_delay_initial_ms: Some(50),
1580 reconnect_delay_max_ms: Some(100),
1581 reconnect_backoff_factor: Some(1.0),
1582 reconnect_jitter_ms: Some(0),
1583 certs_dir: None,
1584 };
1585
1586 let client = SocketClient::connect(config, None, None, None).await;
1587 assert!(
1588 client.is_ok(),
1589 "Client should connect to IPv6 loopback address"
1590 );
1591
1592 if let Ok(client) = client {
1593 client.close().await;
1594 }
1595 server.abort();
1596 }
1597
1598 #[rstest]
1599 #[tokio::test]
1600 async fn test_send_waits_during_reconnection() {
1601 use nautilus_common::testing::wait_until_async;
1603
1604 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1605 let port = listener.local_addr().unwrap().port();
1606
1607 let server = task::spawn(async move {
1608 if let Ok((sock, _)) = listener.accept().await {
1610 drop(sock);
1611 }
1612
1613 sleep(Duration::from_millis(500)).await;
1615
1616 if let Ok((mut sock, _)) = listener.accept().await {
1618 let mut buf = vec![0u8; 1024];
1620 while let Ok(n) = sock.read(&mut buf).await {
1621 if n == 0 {
1622 break;
1623 }
1624 if sock.write_all(&buf[..n]).await.is_err() {
1625 break;
1626 }
1627 }
1628 }
1629 });
1630
1631 let config = SocketConfig {
1632 url: format!("127.0.0.1:{port}"),
1633 mode: Mode::Plain,
1634 suffix: b"\r\n".to_vec(),
1635 message_handler: None,
1636 heartbeat: None,
1637 reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
1639 reconnect_delay_max_ms: Some(200),
1640 reconnect_backoff_factor: Some(1.0),
1641 reconnect_jitter_ms: Some(0),
1642 certs_dir: None,
1643 };
1644
1645 let client = SocketClient::connect(config, None, None, None)
1646 .await
1647 .unwrap();
1648
1649 wait_until_async(
1651 || async { client.is_reconnecting() },
1652 Duration::from_secs(2),
1653 )
1654 .await;
1655
1656 let send_result = tokio::time::timeout(
1658 Duration::from_secs(3),
1659 client.send_bytes(b"test_message".to_vec()),
1660 )
1661 .await;
1662
1663 assert!(
1664 send_result.is_ok() && send_result.unwrap().is_ok(),
1665 "Send should succeed after waiting for reconnection"
1666 );
1667
1668 client.close().await;
1669 server.abort();
1670 }
1671
1672 #[rstest]
1673 #[tokio::test]
1674 async fn test_send_bytes_timeout_uses_configured_reconnect_timeout() {
1675 use nautilus_common::testing::wait_until_async;
1678
1679 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1680 let port = listener.local_addr().unwrap().port();
1681
1682 let server = task::spawn(async move {
1683 if let Ok((sock, _)) = listener.accept().await {
1685 drop(sock);
1686 }
1687 drop(listener);
1689 sleep(Duration::from_secs(60)).await;
1690 });
1691
1692 let config = SocketConfig {
1693 url: format!("127.0.0.1:{port}"),
1694 mode: Mode::Plain,
1695 suffix: b"\r\n".to_vec(),
1696 message_handler: None,
1697 heartbeat: None,
1698 reconnect_timeout_ms: Some(1_000), reconnect_delay_initial_ms: Some(5_000), reconnect_delay_max_ms: Some(5_000),
1701 reconnect_backoff_factor: Some(1.0),
1702 reconnect_jitter_ms: Some(0),
1703 certs_dir: None,
1704 };
1705
1706 let client = SocketClient::connect(config, None, None, None)
1707 .await
1708 .unwrap();
1709
1710 wait_until_async(
1712 || async { client.is_reconnecting() },
1713 Duration::from_secs(3),
1714 )
1715 .await;
1716
1717 let start = std::time::Instant::now();
1720 let send_result = client.send_bytes(b"test".to_vec()).await;
1721 let elapsed = start.elapsed();
1722
1723 assert!(
1724 send_result.is_err(),
1725 "Send should fail when client stuck in RECONNECT, got: {:?}",
1726 send_result
1727 );
1728 assert!(
1729 matches!(send_result, Err(crate::error::SendError::Timeout)),
1730 "Send should return Timeout error, got: {:?}",
1731 send_result
1732 );
1733 assert!(
1736 elapsed >= Duration::from_millis(900),
1737 "Send should timeout after at least 1s (configured timeout), took {:?}",
1738 elapsed
1739 );
1740
1741 client.close().await;
1742 server.abort();
1743 }
1744}