1use std::{
32 collections::VecDeque,
33 fmt::Debug,
34 path::Path,
35 sync::{
36 Arc,
37 atomic::{AtomicU8, Ordering},
38 },
39 time::Duration,
40};
41
42use bytes::Bytes;
43use nautilus_core::CleanDrop;
44use nautilus_cryptography::providers::install_cryptographic_provider;
45use tokio::io::{AsyncReadExt, AsyncWriteExt};
46use tokio_tungstenite::tungstenite::{Error, client::IntoClientRequest, stream::Mode};
47
48use super::{
49 SocketConfig, TcpMessageHandler, TcpReader, TcpWriter, WriterCommand, fix::process_fix_buffer,
50};
51use crate::{
52 backoff::ExponentialBackoff,
53 error::SendError,
54 logging::{log_task_aborted, log_task_started, log_task_stopped},
55 mode::ConnectionMode,
56 net::TcpStream,
57 tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
58};
59
60const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
62const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
63const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
64const SEND_OPERATION_CHECK_INTERVAL_MS: u64 = 1;
65
66const MAX_READ_BUFFER_BYTES: usize = 10 * 1024 * 1024;
68
69#[cfg_attr(
85 feature = "python",
86 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
87)]
88struct SocketClientInner {
89 config: SocketConfig,
90 connector: Option<Connector>,
91 read_task: Arc<tokio::task::JoinHandle<()>>,
92 write_task: tokio::task::JoinHandle<()>,
93 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
94 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
95 connection_mode: Arc<AtomicU8>,
96 reconnect_timeout: Duration,
97 backoff: ExponentialBackoff,
98 handler: Option<TcpMessageHandler>,
99 reconnect_max_attempts: Option<u32>,
100 reconnect_attempt_count: u32,
101}
102
103impl SocketClientInner {
104 pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
110 install_cryptographic_provider();
111
112 let SocketConfig {
113 url,
114 mode,
115 heartbeat,
116 suffix,
117 message_handler,
118 reconnect_timeout_ms,
119 reconnect_delay_initial_ms,
120 reconnect_delay_max_ms,
121 reconnect_backoff_factor,
122 reconnect_jitter_ms,
123 connection_max_retries,
124 reconnect_max_attempts,
125 certs_dir,
126 } = &config.clone();
127 let connector = if let Some(dir) = certs_dir {
128 let config = create_tls_config_from_certs_dir(Path::new(dir), false)?;
129 Some(Connector::Rustls(Arc::new(config)))
130 } else {
131 None
132 };
133
134 const CONNECTION_TIMEOUT_SECS: u64 = 10;
136 let max_retries = connection_max_retries.unwrap_or(5);
137
138 let mut backoff = ExponentialBackoff::new(
139 Duration::from_millis(500),
140 Duration::from_millis(5000),
141 2.0,
142 250,
143 false,
144 )?;
145
146 #[allow(unused_assignments)]
147 let mut last_error = String::new();
148 let mut attempt = 0;
149 let (reader, writer) = loop {
150 attempt += 1;
151
152 match tokio::time::timeout(
153 Duration::from_secs(CONNECTION_TIMEOUT_SECS),
154 Self::tls_connect_with_server(url, *mode, connector.clone()),
155 )
156 .await
157 {
158 Ok(Ok(result)) => {
159 if attempt > 1 {
160 tracing::info!("Socket connection established after {attempt} attempts");
161 }
162 break result;
163 }
164 Ok(Err(e)) => {
165 last_error = e.to_string();
166 tracing::warn!(
167 attempt,
168 max_retries,
169 url = %url,
170 error = %last_error,
171 "Socket connection attempt failed"
172 );
173 }
174 Err(_) => {
175 last_error = format!(
176 "Connection timeout after {CONNECTION_TIMEOUT_SECS}s (possible DNS resolution failure)"
177 );
178 tracing::warn!(
179 attempt,
180 max_retries,
181 url = %url,
182 "Socket connection attempt timed out"
183 );
184 }
185 }
186
187 if attempt >= max_retries {
188 anyhow::bail!(
189 "Failed to connect to {} after {} attempts: {}. \
190 If this is a DNS error, check your network configuration and DNS settings.",
191 url,
192 max_retries,
193 if last_error.is_empty() {
194 "unknown error"
195 } else {
196 &last_error
197 }
198 );
199 }
200
201 let delay = backoff.next_duration();
202 tracing::debug!(
203 "Retrying in {delay:?} (attempt {}/{})",
204 attempt + 1,
205 max_retries
206 );
207 tokio::time::sleep(delay).await;
208 };
209
210 tracing::debug!("Connected");
211
212 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
213
214 let read_task = Arc::new(Self::spawn_read_task(
215 connection_mode.clone(),
216 reader,
217 message_handler.clone(),
218 suffix.clone(),
219 ));
220
221 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
222
223 let write_task =
224 Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
225
226 let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
228 Self::spawn_heartbeat_task(
229 connection_mode.clone(),
230 heartbeat.clone(),
231 writer_tx.clone(),
232 )
233 });
234
235 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
236 let backoff = ExponentialBackoff::new(
237 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
238 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
239 reconnect_backoff_factor.unwrap_or(1.5),
240 reconnect_jitter_ms.unwrap_or(100),
241 true, )?;
243
244 Ok(Self {
245 config,
246 connector,
247 read_task,
248 write_task,
249 writer_tx,
250 heartbeat_task,
251 connection_mode,
252 reconnect_timeout,
253 backoff,
254 handler: message_handler.clone(),
255 reconnect_max_attempts: *reconnect_max_attempts,
256 reconnect_attempt_count: 0,
257 })
258 }
259
260 fn parse_socket_url(url: &str, mode: Mode) -> Result<(String, String), Error> {
270 if url.contains("://") {
271 let parsed = url.parse::<http::Uri>().map_err(|e| {
273 Error::Io(std::io::Error::new(
274 std::io::ErrorKind::InvalidInput,
275 format!("Invalid URL: {e}"),
276 ))
277 })?;
278
279 let host = parsed.host().ok_or_else(|| {
280 Error::Io(std::io::Error::new(
281 std::io::ErrorKind::InvalidInput,
282 "URL missing host",
283 ))
284 })?;
285
286 let port = parsed
287 .port_u16()
288 .unwrap_or_else(|| match parsed.scheme_str() {
289 Some("wss" | "https") => 443,
290 Some("ws" | "http") => 80,
291 _ => match mode {
292 Mode::Tls => 443,
293 Mode::Plain => 80,
294 },
295 });
296
297 Ok((format!("{host}:{port}"), url.to_string()))
298 } else {
299 let scheme = match mode {
302 Mode::Tls => "wss",
303 Mode::Plain => "ws",
304 };
305 Ok((url.to_string(), format!("{scheme}://{url}")))
306 }
307 }
308
309 pub async fn tls_connect_with_server(
319 url: &str,
320 mode: Mode,
321 connector: Option<Connector>,
322 ) -> Result<(TcpReader, TcpWriter), Error> {
323 tracing::debug!("Connecting to {url}");
324
325 let (socket_addr, request_url) = Self::parse_socket_url(url, mode)?;
326 let tcp_result = TcpStream::connect(&socket_addr).await;
327
328 match tcp_result {
329 Ok(stream) => {
330 tracing::debug!("TCP connection established to {socket_addr}, proceeding with TLS");
331 if let Err(e) = stream.set_nodelay(true) {
332 tracing::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
333 }
334 let request = request_url.into_client_request()?;
335 tcp_tls(&request, mode, stream, connector)
336 .await
337 .map(tokio::io::split)
338 }
339 Err(e) => {
340 tracing::error!("TCP connection failed to {socket_addr}: {e:?}");
341 Err(Error::Io(e))
342 }
343 }
344 }
345
346 async fn reconnect(&mut self) -> Result<(), Error> {
351 tracing::debug!("Reconnecting");
352
353 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
354 tracing::debug!("Reconnect aborted due to disconnect state");
355 return Ok(());
356 }
357
358 tokio::time::timeout(self.reconnect_timeout, async {
359 let SocketConfig {
360 url,
361 mode,
362 heartbeat: _,
363 suffix,
364 message_handler: _,
365 reconnect_timeout_ms: _,
366 reconnect_delay_initial_ms: _,
367 reconnect_backoff_factor: _,
368 reconnect_delay_max_ms: _,
369 reconnect_jitter_ms: _,
370 connection_max_retries: _,
371 reconnect_max_attempts: _,
372 certs_dir: _,
373 } = &self.config;
374 let connector = self.connector.clone();
376 let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
378
379 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
380 tracing::debug!("Reconnect aborted mid-flight (after connect)");
381 return Ok(());
382 }
383 tracing::debug!("Connected");
384
385 let (tx, rx) = tokio::sync::oneshot::channel();
389 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
390 tracing::error!("{e}");
391 return Err(Error::Io(std::io::Error::new(
392 std::io::ErrorKind::BrokenPipe,
393 format!("Failed to send update command: {e}"),
394 )));
395 }
396
397 match rx.await {
399 Ok(true) => tracing::debug!("Writer confirmed buffer drain success"),
400 Ok(false) => {
401 tracing::warn!("Writer failed to drain buffer, aborting reconnect");
402 return Err(Error::Io(std::io::Error::other(
404 "Failed to drain reconnection buffer",
405 )));
406 }
407 Err(e) => {
408 tracing::error!("Writer dropped update channel: {e}");
409 return Err(Error::Io(std::io::Error::new(
410 std::io::ErrorKind::BrokenPipe,
411 "Writer task dropped response channel",
412 )));
413 }
414 }
415
416 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
418
419 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
420 tracing::debug!("Reconnect aborted mid-flight (after delay)");
421 return Ok(());
422 }
423
424 if !self.read_task.is_finished() {
425 self.read_task.abort();
426 log_task_aborted("read");
427 }
428
429 if self
432 .connection_mode
433 .compare_exchange(
434 ConnectionMode::Reconnect.as_u8(),
435 ConnectionMode::Active.as_u8(),
436 Ordering::SeqCst,
437 Ordering::SeqCst,
438 )
439 .is_err()
440 {
441 tracing::debug!("Reconnect aborted (state changed during reconnect)");
442 return Ok(());
443 }
444
445 self.read_task = Arc::new(Self::spawn_read_task(
447 self.connection_mode.clone(),
448 reader,
449 self.handler.clone(),
450 suffix.clone(),
451 ));
452
453 tracing::debug!("Reconnect succeeded");
454 Ok(())
455 })
456 .await
457 .map_err(|_| {
458 Error::Io(std::io::Error::new(
459 std::io::ErrorKind::TimedOut,
460 format!(
461 "reconnection timed out after {}s",
462 self.reconnect_timeout.as_secs_f64()
463 ),
464 ))
465 })?
466 }
467
468 #[inline]
475 #[must_use]
476 pub fn is_alive(&self) -> bool {
477 !self.read_task.is_finished()
478 }
479
480 #[must_use]
481 fn spawn_read_task(
482 connection_state: Arc<AtomicU8>,
483 mut reader: TcpReader,
484 handler: Option<TcpMessageHandler>,
485 suffix: Vec<u8>,
486 ) -> tokio::task::JoinHandle<()> {
487 log_task_started("read");
488
489 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
491
492 tokio::task::spawn(async move {
493 let mut buf = Vec::new();
494
495 loop {
496 if !ConnectionMode::from_atomic(&connection_state).is_active() {
497 break;
498 }
499
500 match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
501 Ok(Ok(0)) => {
503 tracing::debug!("Connection closed by server");
504 break;
505 }
506 Ok(Err(e)) => {
507 tracing::debug!("Connection ended: {e}");
508 break;
509 }
510 Ok(Ok(bytes)) => {
512 tracing::trace!("Received <binary> {bytes} bytes");
513
514 let is_fix = buf.len() >= 5 && buf.starts_with(b"8=FIX");
516
517 if is_fix && handler.is_some() {
518 if let Some(ref handler) = handler {
520 process_fix_buffer(&mut buf, handler);
521 }
522 } else {
523 while let Some((i, _)) = &buf
525 .windows(suffix.len())
526 .enumerate()
527 .find(|(_, pair)| pair.eq(&suffix))
528 {
529 let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
530 data.truncate(data.len() - suffix.len());
531
532 if let Some(ref handler) = handler {
533 handler(&data);
534 }
535 }
536 }
537
538 if buf.len() > MAX_READ_BUFFER_BYTES {
539 tracing::error!(
540 "Read buffer exceeded maximum size ({} bytes), closing connection",
541 MAX_READ_BUFFER_BYTES
542 );
543 break;
544 }
545 }
546 Err(_) => {
547 continue;
549 }
550 }
551 }
552
553 log_task_stopped("read");
554 })
555 }
556
557 async fn drain_reconnect_buffer(
567 buffer: &mut VecDeque<Bytes>,
568 writer: &mut TcpWriter,
569 suffix: &[u8],
570 ) -> bool {
571 if buffer.is_empty() {
572 return false;
573 }
574
575 let initial_buffer_len = buffer.len();
576 tracing::info!(
577 "Sending {} buffered messages after reconnection",
578 initial_buffer_len
579 );
580
581 let mut send_error_occurred = false;
582
583 while let Some(buffered_msg) = buffer.front() {
584 let mut combined_msg = Vec::with_capacity(buffered_msg.len() + suffix.len());
585 combined_msg.extend_from_slice(buffered_msg);
586 combined_msg.extend_from_slice(suffix);
587
588 if let Err(e) = writer.write_all(&combined_msg).await {
589 tracing::error!(
590 "Failed to send buffered message with suffix after reconnection: {e}, {} messages remain in buffer",
591 buffer.len()
592 );
593 send_error_occurred = true;
594 break;
595 }
596
597 buffer.pop_front();
598 }
599
600 if buffer.is_empty() {
601 tracing::info!(
602 "Successfully sent all {} buffered messages",
603 initial_buffer_len
604 );
605 }
606
607 send_error_occurred
608 }
609
610 fn spawn_write_task(
611 connection_state: Arc<AtomicU8>,
612 writer: TcpWriter,
613 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
614 suffix: Vec<u8>,
615 ) -> tokio::task::JoinHandle<()> {
616 log_task_started("write");
617
618 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
620
621 tokio::task::spawn(async move {
622 let mut active_writer = writer;
623 let mut reconnect_buffer: VecDeque<Bytes> = VecDeque::new();
624
625 loop {
626 if matches!(
627 ConnectionMode::from_atomic(&connection_state),
628 ConnectionMode::Disconnect | ConnectionMode::Closed
629 ) {
630 break;
631 }
632
633 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
634 Ok(Some(msg)) => {
635 let mode = ConnectionMode::from_atomic(&connection_state);
637 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
638 break;
639 }
640
641 match msg {
642 WriterCommand::Update(new_writer, tx) => {
643 tracing::debug!("Received new writer");
644
645 tokio::time::sleep(Duration::from_millis(100)).await;
647
648 _ = tokio::time::timeout(
651 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
652 active_writer.shutdown(),
653 )
654 .await;
655
656 active_writer = new_writer;
657 tracing::debug!("Updated writer");
658
659 let send_error = Self::drain_reconnect_buffer(
660 &mut reconnect_buffer,
661 &mut active_writer,
662 &suffix,
663 )
664 .await;
665
666 if let Err(e) = tx.send(!send_error) {
667 tracing::error!(
668 "Failed to report drain status to controller: {e:?}"
669 );
670 }
671 }
672 _ if mode.is_reconnect() => {
673 if let WriterCommand::Send(data) = msg {
674 tracing::debug!(
675 "Buffering message while reconnecting ({} bytes)",
676 data.len()
677 );
678 reconnect_buffer.push_back(data);
679 }
680 continue;
681 }
682 WriterCommand::Send(msg) => {
683 if let Err(e) = active_writer.write_all(&msg).await {
684 tracing::error!("Failed to send message: {e}");
685 tracing::warn!("Writer triggering reconnect");
686 reconnect_buffer.push_back(msg);
687 connection_state
688 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
689 continue;
690 }
691 if let Err(e) = active_writer.write_all(&suffix).await {
692 tracing::error!("Failed to send suffix: {e}");
693 tracing::warn!("Writer triggering reconnect");
694 reconnect_buffer.push_back(msg);
696 connection_state
697 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
698 continue;
699 }
700 }
701 }
702 }
703 Ok(None) => {
704 tracing::debug!("Writer channel closed, terminating writer task");
706 break;
707 }
708 Err(_) => {
709 continue;
711 }
712 }
713 }
714
715 _ = tokio::time::timeout(
718 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
719 active_writer.shutdown(),
720 )
721 .await;
722
723 log_task_stopped("write");
724 })
725 }
726
727 fn spawn_heartbeat_task(
728 connection_state: Arc<AtomicU8>,
729 heartbeat: (u64, Vec<u8>),
730 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
731 ) -> tokio::task::JoinHandle<()> {
732 log_task_started("heartbeat");
733 let (interval_secs, message) = heartbeat;
734
735 tokio::task::spawn(async move {
736 let interval = Duration::from_secs(interval_secs);
737
738 loop {
739 tokio::time::sleep(interval).await;
740
741 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
742 ConnectionMode::Active => {
743 let msg = WriterCommand::Send(message.clone().into());
744
745 match writer_tx.send(msg) {
746 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
747 Err(e) => {
748 tracing::error!("Failed to send heartbeat to writer task: {e}");
749 }
750 }
751 }
752 ConnectionMode::Reconnect => continue,
753 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
754 }
755 }
756
757 log_task_stopped("heartbeat");
758 })
759 }
760}
761
762impl Drop for SocketClientInner {
763 fn drop(&mut self) {
764 self.clean_drop();
766 }
767}
768
769impl CleanDrop for SocketClientInner {
771 fn clean_drop(&mut self) {
772 if !self.read_task.is_finished() {
773 self.read_task.abort();
774 log_task_aborted("read");
775 }
776
777 if !self.write_task.is_finished() {
778 self.write_task.abort();
779 log_task_aborted("write");
780 }
781
782 if let Some(ref handle) = self.heartbeat_task.take()
783 && !handle.is_finished()
784 {
785 handle.abort();
786 log_task_aborted("heartbeat");
787 }
788
789 #[cfg(feature = "python")]
790 {
791 self.config.message_handler = None;
793 }
794 }
795}
796
797#[cfg_attr(
798 feature = "python",
799 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
800)]
801pub struct SocketClient {
802 pub(crate) controller_task: tokio::task::JoinHandle<()>,
803 pub(crate) connection_mode: Arc<AtomicU8>,
804 pub(crate) reconnect_timeout: Duration,
805 pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
806}
807
808impl Debug for SocketClient {
809 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
810 f.debug_struct(stringify!(SocketClient)).finish()
811 }
812}
813
814impl SocketClient {
815 pub async fn connect(
821 config: SocketConfig,
822 post_connection: Option<Arc<dyn Fn() + Send + Sync>>,
823 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
824 post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
825 ) -> anyhow::Result<Self> {
826 let inner = SocketClientInner::connect_url(config).await?;
827 let writer_tx = inner.writer_tx.clone();
828 let connection_mode = inner.connection_mode.clone();
829 let reconnect_timeout = inner.reconnect_timeout;
830
831 let controller_task = Self::spawn_controller_task(
832 inner,
833 connection_mode.clone(),
834 post_reconnection,
835 post_disconnection,
836 );
837
838 if let Some(handler) = post_connection {
839 handler();
840 tracing::debug!("Called `post_connection` handler");
841 }
842
843 Ok(Self {
844 controller_task,
845 connection_mode,
846 reconnect_timeout,
847 writer_tx,
848 })
849 }
850
851 #[must_use]
853 pub fn connection_mode(&self) -> ConnectionMode {
854 ConnectionMode::from_atomic(&self.connection_mode)
855 }
856
857 #[inline]
862 #[must_use]
863 pub fn is_active(&self) -> bool {
864 self.connection_mode().is_active()
865 }
866
867 #[inline]
872 #[must_use]
873 pub fn is_reconnecting(&self) -> bool {
874 self.connection_mode().is_reconnect()
875 }
876
877 #[inline]
881 #[must_use]
882 pub fn is_disconnecting(&self) -> bool {
883 self.connection_mode().is_disconnect()
884 }
885
886 #[inline]
892 #[must_use]
893 pub fn is_closed(&self) -> bool {
894 self.connection_mode().is_closed()
895 }
896
897 pub async fn close(&self) {
902 self.connection_mode
903 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
904
905 if let Ok(()) =
906 tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
907 while !self.is_closed() {
908 tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS))
909 .await;
910 }
911
912 if !self.controller_task.is_finished() {
913 self.controller_task.abort();
914 log_task_aborted("controller");
915 }
916 })
917 .await
918 {
919 log_task_stopped("controller");
920 } else {
921 tracing::error!("Timeout waiting for controller task to finish");
922 if !self.controller_task.is_finished() {
923 self.controller_task.abort();
924 log_task_aborted("controller");
925 }
926 }
927 }
928
929 pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
935 if self.is_closed() || self.is_disconnecting() {
937 return Err(SendError::Closed);
938 }
939
940 let timeout = self.reconnect_timeout;
941 let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
942
943 if !self.is_active() {
944 tracing::debug!("Waiting for client to become ACTIVE before sending...");
945
946 let inner = tokio::time::timeout(timeout, async {
947 loop {
948 if self.is_active() {
949 return Ok(());
950 }
951 if matches!(
952 self.connection_mode(),
953 ConnectionMode::Disconnect | ConnectionMode::Closed
954 ) {
955 return Err(());
956 }
957 tokio::time::sleep(check_interval).await;
958 }
959 })
960 .await
961 .map_err(|_| SendError::Timeout)?;
962 inner.map_err(|()| SendError::Closed)?;
963 }
964
965 let msg = WriterCommand::Send(data.into());
966 self.writer_tx
967 .send(msg)
968 .map_err(|e| SendError::BrokenPipe(e.to_string()))
969 }
970
971 fn spawn_controller_task(
972 mut inner: SocketClientInner,
973 connection_mode: Arc<AtomicU8>,
974 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
975 post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
976 ) -> tokio::task::JoinHandle<()> {
977 tokio::task::spawn(async move {
978 log_task_started("controller");
979
980 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
981
982 loop {
983 tokio::time::sleep(check_interval).await;
984 let mut mode = ConnectionMode::from_atomic(&connection_mode);
985
986 if mode.is_disconnect() {
987 tracing::debug!("Disconnecting");
988
989 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
990 if tokio::time::timeout(timeout, async {
991 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
993
994 if !inner.read_task.is_finished() {
995 inner.read_task.abort();
996 log_task_aborted("read");
997 }
998
999 if let Some(task) = &inner.heartbeat_task
1000 && !task.is_finished()
1001 {
1002 task.abort();
1003 log_task_aborted("heartbeat");
1004 }
1005 })
1006 .await
1007 .is_err()
1008 {
1009 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
1010 }
1011
1012 tracing::debug!("Closed");
1013
1014 if let Some(ref handler) = post_disconnection {
1015 handler();
1016 tracing::debug!("Called `post_disconnection` handler");
1017 }
1018 break; }
1020
1021 if mode.is_closed() {
1022 tracing::debug!("Connection closed");
1023 break;
1024 }
1025
1026 if mode.is_active() && !inner.is_alive() {
1027 if connection_mode
1028 .compare_exchange(
1029 ConnectionMode::Active.as_u8(),
1030 ConnectionMode::Reconnect.as_u8(),
1031 Ordering::SeqCst,
1032 Ordering::SeqCst,
1033 )
1034 .is_ok()
1035 {
1036 tracing::debug!("Detected dead read task, transitioning to RECONNECT");
1037 }
1038 mode = ConnectionMode::from_atomic(&connection_mode);
1039 }
1040
1041 if mode.is_reconnect() {
1042 if let Some(max_attempts) = inner.reconnect_max_attempts
1044 && inner.reconnect_attempt_count >= max_attempts
1045 {
1046 tracing::error!(
1047 "Max reconnection attempts ({}) exceeded, transitioning to CLOSED",
1048 max_attempts
1049 );
1050 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1051 break;
1052 }
1053
1054 inner.reconnect_attempt_count += 1;
1055 match inner.reconnect().await {
1056 Ok(()) => {
1057 tracing::debug!("Reconnected successfully");
1058 inner.backoff.reset();
1059 inner.reconnect_attempt_count = 0; if ConnectionMode::from_atomic(&connection_mode).is_active() {
1062 if let Some(ref handler) = post_reconnection {
1063 handler();
1064 tracing::debug!("Called `post_reconnection` handler");
1065 }
1066 } else {
1067 tracing::debug!(
1068 "Skipping post_reconnection handlers due to disconnect state"
1069 );
1070 }
1071 }
1072 Err(e) => {
1073 let duration = inner.backoff.next_duration();
1074 tracing::warn!(
1075 "Reconnect attempt {} failed: {e}",
1076 inner.reconnect_attempt_count
1077 );
1078 if !duration.is_zero() {
1079 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
1080 }
1081 tokio::time::sleep(duration).await;
1082 }
1083 }
1084 }
1085 }
1086 inner
1087 .connection_mode
1088 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1089
1090 log_task_stopped("controller");
1091 })
1092 }
1093}
1094
1095impl Drop for SocketClient {
1097 fn drop(&mut self) {
1098 if !self.controller_task.is_finished() {
1099 self.controller_task.abort();
1100 log_task_aborted("controller");
1101 }
1102 }
1103}
1104
1105#[cfg(test)]
1106#[cfg(feature = "python")]
1107#[cfg(target_os = "linux")] mod tests {
1109 use nautilus_common::testing::wait_until_async;
1110 use pyo3::Python;
1111 use tokio::{
1112 io::{AsyncReadExt, AsyncWriteExt},
1113 net::{TcpListener, TcpStream},
1114 sync::Mutex,
1115 task,
1116 time::{Duration, sleep},
1117 };
1118
1119 use super::*;
1120
1121 async fn bind_test_server() -> (u16, TcpListener) {
1122 let listener = TcpListener::bind("127.0.0.1:0")
1123 .await
1124 .expect("Failed to bind ephemeral port");
1125 let port = listener.local_addr().unwrap().port();
1126 (port, listener)
1127 }
1128
1129 async fn run_echo_server(mut socket: TcpStream) {
1130 let mut buf = Vec::new();
1131 loop {
1132 match socket.read_buf(&mut buf).await {
1133 Ok(0) => {
1134 break;
1135 }
1136 Ok(_n) => {
1137 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1138 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1139 line.truncate(line.len() - 2);
1141
1142 if line == b"close" {
1143 let _ = socket.shutdown().await;
1144 return;
1145 }
1146
1147 let mut echo_data = line;
1148 echo_data.extend_from_slice(b"\r\n");
1149 if socket.write_all(&echo_data).await.is_err() {
1150 break;
1151 }
1152 }
1153 }
1154 Err(e) => {
1155 eprintln!("Server read error: {e}");
1156 break;
1157 }
1158 }
1159 }
1160 }
1161
1162 #[tokio::test]
1163 async fn test_basic_send_receive() {
1164 Python::initialize();
1165
1166 let (port, listener) = bind_test_server().await;
1167 let server_task = task::spawn(async move {
1168 let (socket, _) = listener.accept().await.unwrap();
1169 run_echo_server(socket).await;
1170 });
1171
1172 let config = SocketConfig {
1173 url: format!("127.0.0.1:{port}"),
1174 mode: Mode::Plain,
1175 suffix: b"\r\n".to_vec(),
1176 message_handler: None,
1177 heartbeat: None,
1178 reconnect_timeout_ms: None,
1179 reconnect_delay_initial_ms: None,
1180 reconnect_backoff_factor: None,
1181 reconnect_delay_max_ms: None,
1182 reconnect_jitter_ms: None,
1183 reconnect_max_attempts: None,
1184 connection_max_retries: None,
1185 certs_dir: None,
1186 };
1187
1188 let client = SocketClient::connect(config, None, None, None)
1189 .await
1190 .expect("Client connect failed unexpectedly");
1191
1192 client.send_bytes(b"Hello".into()).await.unwrap();
1193 client.send_bytes(b"World".into()).await.unwrap();
1194
1195 sleep(Duration::from_millis(100)).await;
1197
1198 client.send_bytes(b"close".into()).await.unwrap();
1199 server_task.await.unwrap();
1200 assert!(!client.is_closed());
1201 }
1202
1203 #[tokio::test]
1204 async fn test_reconnect_fail_exhausted() {
1205 Python::initialize();
1206
1207 let (port, listener) = bind_test_server().await;
1208 drop(listener); wait_until_async(
1212 || async {
1213 TcpStream::connect(format!("127.0.0.1:{port}"))
1214 .await
1215 .is_err()
1216 },
1217 Duration::from_secs(2),
1218 )
1219 .await;
1220
1221 let config = SocketConfig {
1222 url: format!("127.0.0.1:{port}"),
1223 mode: Mode::Plain,
1224 suffix: b"\r\n".to_vec(),
1225 message_handler: None,
1226 heartbeat: None,
1227 reconnect_timeout_ms: Some(100),
1228 reconnect_delay_initial_ms: Some(50),
1229 reconnect_backoff_factor: Some(1.0),
1230 reconnect_delay_max_ms: Some(50),
1231 reconnect_jitter_ms: Some(0),
1232 connection_max_retries: Some(1),
1233 reconnect_max_attempts: None,
1234 certs_dir: None,
1235 };
1236
1237 let client_res = SocketClient::connect(config, None, None, None).await;
1238 assert!(
1239 client_res.is_err(),
1240 "Should fail quickly with no server listening"
1241 );
1242 }
1243
1244 #[tokio::test]
1245 async fn test_user_disconnect() {
1246 Python::initialize();
1247
1248 let (port, listener) = bind_test_server().await;
1249 let server_task = task::spawn(async move {
1250 let (socket, _) = listener.accept().await.unwrap();
1251 let mut buf = [0u8; 1024];
1252 let _ = socket.try_read(&mut buf);
1253
1254 loop {
1255 sleep(Duration::from_secs(1)).await;
1256 }
1257 });
1258
1259 let config = SocketConfig {
1260 url: format!("127.0.0.1:{port}"),
1261 mode: Mode::Plain,
1262 suffix: b"\r\n".to_vec(),
1263 message_handler: None,
1264 heartbeat: None,
1265 reconnect_timeout_ms: None,
1266 reconnect_delay_initial_ms: None,
1267 reconnect_backoff_factor: None,
1268 reconnect_delay_max_ms: None,
1269 reconnect_jitter_ms: None,
1270 reconnect_max_attempts: None,
1271 connection_max_retries: None,
1272 certs_dir: None,
1273 };
1274
1275 let client = SocketClient::connect(config, None, None, None)
1276 .await
1277 .unwrap();
1278
1279 client.close().await;
1280 assert!(client.is_closed());
1281 server_task.abort();
1282 }
1283
1284 #[tokio::test]
1285 async fn test_heartbeat() {
1286 Python::initialize();
1287
1288 let (port, listener) = bind_test_server().await;
1289 let received = Arc::new(Mutex::new(Vec::new()));
1290 let received2 = received.clone();
1291
1292 let server_task = task::spawn(async move {
1293 let (socket, _) = listener.accept().await.unwrap();
1294
1295 let mut buf = Vec::new();
1296 loop {
1297 match socket.try_read_buf(&mut buf) {
1298 Ok(0) => break,
1299 Ok(_) => {
1300 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1301 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1302 line.truncate(line.len() - 2);
1303 received2.lock().await.push(line);
1304 }
1305 }
1306 Err(_) => {
1307 tokio::time::sleep(Duration::from_millis(10)).await;
1308 }
1309 }
1310 }
1311 });
1312
1313 let heartbeat = Some((1, b"ping".to_vec()));
1315
1316 let config = SocketConfig {
1317 url: format!("127.0.0.1:{port}"),
1318 mode: Mode::Plain,
1319 suffix: b"\r\n".to_vec(),
1320 message_handler: None,
1321 heartbeat,
1322 reconnect_timeout_ms: None,
1323 reconnect_delay_initial_ms: None,
1324 reconnect_backoff_factor: None,
1325 reconnect_delay_max_ms: None,
1326 reconnect_jitter_ms: None,
1327 reconnect_max_attempts: None,
1328 connection_max_retries: None,
1329 certs_dir: None,
1330 };
1331
1332 let client = SocketClient::connect(config, None, None, None)
1333 .await
1334 .unwrap();
1335
1336 sleep(Duration::from_secs(3)).await;
1338
1339 {
1340 let lock = received.lock().await;
1341 let pings = lock
1342 .iter()
1343 .filter(|line| line == &&b"ping".to_vec())
1344 .count();
1345 assert!(
1346 pings >= 2,
1347 "Expected at least 2 heartbeat pings; got {pings}"
1348 );
1349 }
1350
1351 client.close().await;
1352 server_task.abort();
1353 }
1354
1355 #[tokio::test]
1356 async fn test_reconnect_success() {
1357 Python::initialize();
1358
1359 let (port, listener) = bind_test_server().await;
1360
1361 let server_task = task::spawn(async move {
1365 let (mut socket, _) = listener.accept().await.expect("First accept failed");
1367
1368 sleep(Duration::from_millis(500)).await;
1370 let _ = socket.shutdown().await;
1371
1372 sleep(Duration::from_millis(500)).await;
1374
1375 let (socket, _) = listener.accept().await.expect("Second accept failed");
1377 run_echo_server(socket).await;
1378 });
1379
1380 let config = SocketConfig {
1381 url: format!("127.0.0.1:{port}"),
1382 mode: Mode::Plain,
1383 suffix: b"\r\n".to_vec(),
1384 message_handler: None,
1385 heartbeat: None,
1386 reconnect_timeout_ms: Some(5_000),
1387 reconnect_delay_initial_ms: Some(500),
1388 reconnect_delay_max_ms: Some(5_000),
1389 reconnect_backoff_factor: Some(2.0),
1390 reconnect_jitter_ms: Some(50),
1391 reconnect_max_attempts: None,
1392 connection_max_retries: None,
1393 certs_dir: None,
1394 };
1395
1396 let client = SocketClient::connect(config, None, None, None)
1397 .await
1398 .expect("Client connect failed unexpectedly");
1399
1400 assert!(client.is_active(), "Client should start as active");
1402
1403 wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1406
1407 client
1408 .send_bytes(b"TestReconnect".into())
1409 .await
1410 .expect("Send failed");
1411
1412 client.close().await;
1413 server_task.abort();
1414 }
1415}
1416
1417#[cfg(test)]
1418#[cfg(not(feature = "turmoil"))]
1419mod rust_tests {
1420 use nautilus_common::testing::wait_until_async;
1421 use rstest::rstest;
1422 use tokio::{
1423 io::{AsyncReadExt, AsyncWriteExt},
1424 net::TcpListener,
1425 task,
1426 time::{Duration, sleep},
1427 };
1428
1429 use super::*;
1430
1431 #[rstest]
1432 #[tokio::test]
1433 async fn test_reconnect_then_close() {
1434 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1436 let port = listener.local_addr().unwrap().port();
1437
1438 let server = task::spawn(async move {
1440 if let Ok((mut sock, _)) = listener.accept().await {
1441 drop(sock.shutdown());
1442 }
1443 sleep(Duration::from_secs(1)).await;
1445 });
1446
1447 let config = SocketConfig {
1449 url: format!("127.0.0.1:{port}"),
1450 mode: Mode::Plain,
1451 suffix: b"\r\n".to_vec(),
1452 message_handler: None,
1453 heartbeat: None,
1454 reconnect_timeout_ms: Some(1_000),
1455 reconnect_delay_initial_ms: Some(50),
1456 reconnect_delay_max_ms: Some(100),
1457 reconnect_backoff_factor: Some(1.0),
1458 reconnect_jitter_ms: Some(0),
1459 connection_max_retries: Some(1),
1460 reconnect_max_attempts: None,
1461 certs_dir: None,
1462 };
1463
1464 let client = SocketClient::connect(config.clone(), None, None, None)
1466 .await
1467 .unwrap();
1468
1469 wait_until_async(
1471 || async { client.is_reconnecting() },
1472 Duration::from_secs(2),
1473 )
1474 .await;
1475
1476 client.close().await;
1478 assert!(client.is_closed());
1479 server.abort();
1480 }
1481
1482 #[rstest]
1483 #[tokio::test]
1484 async fn test_reconnect_state_flips_when_reader_stops() {
1485 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1487 let port = listener.local_addr().unwrap().port();
1488
1489 let server = task::spawn(async move {
1490 if let Ok((sock, _)) = listener.accept().await {
1491 drop(sock);
1492 }
1493 sleep(Duration::from_millis(50)).await;
1495 });
1496
1497 let config = SocketConfig {
1498 url: format!("127.0.0.1:{port}"),
1499 mode: Mode::Plain,
1500 suffix: b"\r\n".to_vec(),
1501 message_handler: None,
1502 heartbeat: None,
1503 reconnect_timeout_ms: Some(1_000),
1504 reconnect_delay_initial_ms: Some(50),
1505 reconnect_delay_max_ms: Some(100),
1506 reconnect_backoff_factor: Some(1.0),
1507 reconnect_jitter_ms: Some(0),
1508 connection_max_retries: Some(1),
1509 reconnect_max_attempts: None,
1510 certs_dir: None,
1511 };
1512
1513 let client = SocketClient::connect(config, None, None, None)
1514 .await
1515 .unwrap();
1516
1517 wait_until_async(
1518 || async { client.is_reconnecting() },
1519 Duration::from_secs(2),
1520 )
1521 .await;
1522
1523 client.close().await;
1524 server.abort();
1525 }
1526
1527 #[rstest]
1528 fn test_parse_socket_url_raw_address() {
1529 let (socket_addr, request_url) =
1531 SocketClientInner::parse_socket_url("example.com:6130", Mode::Tls).unwrap();
1532 assert_eq!(socket_addr, "example.com:6130");
1533 assert_eq!(request_url, "wss://example.com:6130");
1534
1535 let (socket_addr, request_url) =
1537 SocketClientInner::parse_socket_url("localhost:8080", Mode::Plain).unwrap();
1538 assert_eq!(socket_addr, "localhost:8080");
1539 assert_eq!(request_url, "ws://localhost:8080");
1540 }
1541
1542 #[rstest]
1543 fn test_parse_socket_url_with_scheme() {
1544 let (socket_addr, request_url) =
1546 SocketClientInner::parse_socket_url("wss://example.com:443/path", Mode::Tls).unwrap();
1547 assert_eq!(socket_addr, "example.com:443");
1548 assert_eq!(request_url, "wss://example.com:443/path");
1549
1550 let (socket_addr, request_url) =
1552 SocketClientInner::parse_socket_url("ws://localhost:8080", Mode::Plain).unwrap();
1553 assert_eq!(socket_addr, "localhost:8080");
1554 assert_eq!(request_url, "ws://localhost:8080");
1555 }
1556
1557 #[rstest]
1558 fn test_parse_socket_url_default_ports() {
1559 let (socket_addr, _) =
1561 SocketClientInner::parse_socket_url("wss://example.com", Mode::Tls).unwrap();
1562 assert_eq!(socket_addr, "example.com:443");
1563
1564 let (socket_addr, _) =
1566 SocketClientInner::parse_socket_url("ws://example.com", Mode::Plain).unwrap();
1567 assert_eq!(socket_addr, "example.com:80");
1568
1569 let (socket_addr, _) =
1571 SocketClientInner::parse_socket_url("https://example.com", Mode::Tls).unwrap();
1572 assert_eq!(socket_addr, "example.com:443");
1573
1574 let (socket_addr, _) =
1576 SocketClientInner::parse_socket_url("http://example.com", Mode::Plain).unwrap();
1577 assert_eq!(socket_addr, "example.com:80");
1578 }
1579
1580 #[rstest]
1581 fn test_parse_socket_url_unknown_scheme_uses_mode() {
1582 let (socket_addr, _) =
1584 SocketClientInner::parse_socket_url("custom://example.com", Mode::Tls).unwrap();
1585 assert_eq!(socket_addr, "example.com:443");
1586
1587 let (socket_addr, _) =
1588 SocketClientInner::parse_socket_url("custom://example.com", Mode::Plain).unwrap();
1589 assert_eq!(socket_addr, "example.com:80");
1590 }
1591
1592 #[rstest]
1593 fn test_parse_socket_url_ipv6() {
1594 let (socket_addr, request_url) =
1596 SocketClientInner::parse_socket_url("[::1]:8080", Mode::Plain).unwrap();
1597 assert_eq!(socket_addr, "[::1]:8080");
1598 assert_eq!(request_url, "ws://[::1]:8080");
1599
1600 let (socket_addr, _) =
1602 SocketClientInner::parse_socket_url("ws://[::1]:8080", Mode::Plain).unwrap();
1603 assert_eq!(socket_addr, "[::1]:8080");
1604 }
1605
1606 #[rstest]
1607 #[tokio::test]
1608 async fn test_url_parsing_raw_socket_address() {
1609 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1611 let port = listener.local_addr().unwrap().port();
1612
1613 let server = task::spawn(async move {
1614 if let Ok((sock, _)) = listener.accept().await {
1615 drop(sock);
1616 }
1617 sleep(Duration::from_millis(50)).await;
1618 });
1619
1620 let config = SocketConfig {
1621 url: format!("127.0.0.1:{port}"), mode: Mode::Plain,
1623 suffix: b"\r\n".to_vec(),
1624 message_handler: None,
1625 heartbeat: None,
1626 reconnect_timeout_ms: Some(1_000),
1627 reconnect_delay_initial_ms: Some(50),
1628 reconnect_delay_max_ms: Some(100),
1629 reconnect_backoff_factor: Some(1.0),
1630 reconnect_jitter_ms: Some(0),
1631 connection_max_retries: Some(1),
1632 reconnect_max_attempts: None,
1633 certs_dir: None,
1634 };
1635
1636 let client = SocketClient::connect(config, None, None, None).await;
1638 assert!(
1639 client.is_ok(),
1640 "Client should connect with raw socket address format"
1641 );
1642
1643 if let Ok(client) = client {
1644 client.close().await;
1645 }
1646 server.abort();
1647 }
1648
1649 #[rstest]
1650 #[tokio::test]
1651 async fn test_url_parsing_with_scheme() {
1652 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1654 let port = listener.local_addr().unwrap().port();
1655
1656 let server = task::spawn(async move {
1657 if let Ok((sock, _)) = listener.accept().await {
1658 drop(sock);
1659 }
1660 sleep(Duration::from_millis(50)).await;
1661 });
1662
1663 let config = SocketConfig {
1664 url: format!("ws://127.0.0.1:{port}"), mode: Mode::Plain,
1666 suffix: b"\r\n".to_vec(),
1667 message_handler: None,
1668 heartbeat: None,
1669 reconnect_timeout_ms: Some(1_000),
1670 reconnect_delay_initial_ms: Some(50),
1671 reconnect_delay_max_ms: Some(100),
1672 reconnect_backoff_factor: Some(1.0),
1673 reconnect_jitter_ms: Some(0),
1674 connection_max_retries: Some(1),
1675 reconnect_max_attempts: None,
1676 certs_dir: None,
1677 };
1678
1679 let client = SocketClient::connect(config, None, None, None).await;
1681 assert!(
1682 client.is_ok(),
1683 "Client should connect with URL scheme format"
1684 );
1685
1686 if let Ok(client) = client {
1687 client.close().await;
1688 }
1689 server.abort();
1690 }
1691
1692 #[rstest]
1693 fn test_parse_socket_url_ipv6_with_zone() {
1694 let (socket_addr, request_url) =
1696 SocketClientInner::parse_socket_url("[fe80::1%eth0]:8080", Mode::Plain).unwrap();
1697 assert_eq!(socket_addr, "[fe80::1%eth0]:8080");
1698 assert_eq!(request_url, "ws://[fe80::1%eth0]:8080");
1699
1700 let (socket_addr, request_url) =
1702 SocketClientInner::parse_socket_url("ws://[fe80::1%lo]:9090", Mode::Plain).unwrap();
1703 assert_eq!(socket_addr, "[fe80::1%lo]:9090");
1704 assert_eq!(request_url, "ws://[fe80::1%lo]:9090");
1705 }
1706
1707 #[rstest]
1708 #[tokio::test]
1709 async fn test_ipv6_loopback_connection() {
1710 if TcpListener::bind("[::1]:0").await.is_err() {
1713 eprintln!("IPv6 not available, skipping test");
1714 return;
1715 }
1716
1717 let listener = TcpListener::bind("[::1]:0").await.unwrap();
1718 let port = listener.local_addr().unwrap().port();
1719
1720 let server = task::spawn(async move {
1721 if let Ok((mut sock, _)) = listener.accept().await {
1722 let mut buf = vec![0u8; 1024];
1723 if let Ok(n) = sock.read(&mut buf).await {
1724 let _ = sock.write_all(&buf[..n]).await;
1726 }
1727 }
1728 sleep(Duration::from_millis(50)).await;
1729 });
1730
1731 let config = SocketConfig {
1732 url: format!("[::1]:{port}"), mode: Mode::Plain,
1734 suffix: b"\r\n".to_vec(),
1735 message_handler: None,
1736 heartbeat: None,
1737 reconnect_timeout_ms: Some(1_000),
1738 reconnect_delay_initial_ms: Some(50),
1739 reconnect_delay_max_ms: Some(100),
1740 reconnect_backoff_factor: Some(1.0),
1741 reconnect_jitter_ms: Some(0),
1742 connection_max_retries: Some(1),
1743 reconnect_max_attempts: None,
1744 certs_dir: None,
1745 };
1746
1747 let client = SocketClient::connect(config, None, None, None).await;
1748 assert!(
1749 client.is_ok(),
1750 "Client should connect to IPv6 loopback address"
1751 );
1752
1753 if let Ok(client) = client {
1754 client.close().await;
1755 }
1756 server.abort();
1757 }
1758
1759 #[rstest]
1760 #[tokio::test]
1761 async fn test_send_waits_during_reconnection() {
1762 use nautilus_common::testing::wait_until_async;
1764
1765 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1766 let port = listener.local_addr().unwrap().port();
1767
1768 let server = task::spawn(async move {
1769 if let Ok((sock, _)) = listener.accept().await {
1771 drop(sock);
1772 }
1773
1774 sleep(Duration::from_millis(500)).await;
1776
1777 if let Ok((mut sock, _)) = listener.accept().await {
1779 let mut buf = vec![0u8; 1024];
1781 while let Ok(n) = sock.read(&mut buf).await {
1782 if n == 0 {
1783 break;
1784 }
1785 if sock.write_all(&buf[..n]).await.is_err() {
1786 break;
1787 }
1788 }
1789 }
1790 });
1791
1792 let config = SocketConfig {
1793 url: format!("127.0.0.1:{port}"),
1794 mode: Mode::Plain,
1795 suffix: b"\r\n".to_vec(),
1796 message_handler: None,
1797 heartbeat: None,
1798 reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
1800 reconnect_delay_max_ms: Some(200),
1801 reconnect_backoff_factor: Some(1.0),
1802 reconnect_jitter_ms: Some(0),
1803 connection_max_retries: Some(1),
1804 reconnect_max_attempts: None,
1805 certs_dir: None,
1806 };
1807
1808 let client = SocketClient::connect(config, None, None, None)
1809 .await
1810 .unwrap();
1811
1812 wait_until_async(
1814 || async { client.is_reconnecting() },
1815 Duration::from_secs(2),
1816 )
1817 .await;
1818
1819 let send_result = tokio::time::timeout(
1821 Duration::from_secs(3),
1822 client.send_bytes(b"test_message".to_vec()),
1823 )
1824 .await;
1825
1826 assert!(
1827 send_result.is_ok() && send_result.unwrap().is_ok(),
1828 "Send should succeed after waiting for reconnection"
1829 );
1830
1831 client.close().await;
1832 server.abort();
1833 }
1834
1835 #[rstest]
1836 #[tokio::test]
1837 async fn test_send_bytes_timeout_uses_configured_reconnect_timeout() {
1838 use nautilus_common::testing::wait_until_async;
1841
1842 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1843 let port = listener.local_addr().unwrap().port();
1844
1845 let server = task::spawn(async move {
1846 if let Ok((sock, _)) = listener.accept().await {
1848 drop(sock);
1849 }
1850 drop(listener);
1852 sleep(Duration::from_secs(60)).await;
1853 });
1854
1855 let config = SocketConfig {
1856 url: format!("127.0.0.1:{port}"),
1857 mode: Mode::Plain,
1858 suffix: b"\r\n".to_vec(),
1859 message_handler: None,
1860 heartbeat: None,
1861 reconnect_timeout_ms: Some(1_000), reconnect_delay_initial_ms: Some(200), reconnect_delay_max_ms: Some(200),
1864 reconnect_backoff_factor: Some(1.0),
1865 reconnect_jitter_ms: Some(0),
1866 connection_max_retries: Some(1),
1867 reconnect_max_attempts: None,
1868 certs_dir: None,
1869 };
1870
1871 let client = SocketClient::connect(config, None, None, None)
1872 .await
1873 .unwrap();
1874
1875 wait_until_async(
1877 || async { client.is_reconnecting() },
1878 Duration::from_secs(3),
1879 )
1880 .await;
1881
1882 let start = std::time::Instant::now();
1885 let send_result = client.send_bytes(b"test".to_vec()).await;
1886 let elapsed = start.elapsed();
1887
1888 assert!(
1889 send_result.is_err(),
1890 "Send should fail when client stuck in RECONNECT, was: {send_result:?}"
1891 );
1892 assert!(
1893 matches!(send_result, Err(crate::error::SendError::Timeout)),
1894 "Send should return Timeout error, was: {send_result:?}"
1895 );
1896 assert!(
1899 elapsed >= Duration::from_millis(900),
1900 "Send should timeout after at least 1s (configured timeout), took {elapsed:?}"
1901 );
1902
1903 client.close().await;
1904 server.abort();
1905 }
1906}