1use std::{
32 fmt::Debug,
33 path::Path,
34 sync::{
35 Arc,
36 atomic::{AtomicU8, Ordering},
37 },
38 time::Duration,
39};
40
41use bytes::Bytes;
42use nautilus_core::CleanDrop;
43use nautilus_cryptography::providers::install_cryptographic_provider;
44use tokio::{
45 io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
46 net::TcpStream,
47};
48use tokio_tungstenite::{
49 MaybeTlsStream,
50 tungstenite::{Error, client::IntoClientRequest, stream::Mode},
51};
52
53use crate::{
54 backoff::ExponentialBackoff,
55 error::SendError,
56 fix::process_fix_buffer,
57 logging::{log_task_aborted, log_task_started, log_task_stopped},
58 mode::ConnectionMode,
59 tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
60};
61
62const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
64const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
65const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
66const SEND_OPERATION_TIMEOUT_SECS: u64 = 2;
67const SEND_OPERATION_CHECK_INTERVAL_MS: u64 = 1;
68
69type TcpWriter = WriteHalf<MaybeTlsStream<TcpStream>>;
70type TcpReader = ReadHalf<MaybeTlsStream<TcpStream>>;
71pub type TcpMessageHandler = Arc<dyn Fn(&[u8]) + Send + Sync>;
72
73#[cfg_attr(
75 feature = "python",
76 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
77)]
78pub struct SocketConfig {
79 pub url: String,
81 pub mode: Mode,
83 pub suffix: Vec<u8>,
85 pub message_handler: Option<TcpMessageHandler>,
87 pub heartbeat: Option<(u64, Vec<u8>)>,
89 pub reconnect_timeout_ms: Option<u64>,
91 pub reconnect_delay_initial_ms: Option<u64>,
93 pub reconnect_delay_max_ms: Option<u64>,
95 pub reconnect_backoff_factor: Option<f64>,
97 pub reconnect_jitter_ms: Option<u64>,
99 pub certs_dir: Option<String>,
101}
102
103impl Debug for SocketConfig {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 f.debug_struct("SocketConfig")
106 .field("url", &self.url)
107 .field("mode", &self.mode)
108 .field("suffix", &self.suffix)
109 .field(
110 "message_handler",
111 &self.message_handler.as_ref().map(|_| "<function>"),
112 )
113 .field("heartbeat", &self.heartbeat)
114 .field("reconnect_timeout_ms", &self.reconnect_timeout_ms)
115 .field(
116 "reconnect_delay_initial_ms",
117 &self.reconnect_delay_initial_ms,
118 )
119 .field("reconnect_delay_max_ms", &self.reconnect_delay_max_ms)
120 .field("reconnect_backoff_factor", &self.reconnect_backoff_factor)
121 .field("reconnect_jitter_ms", &self.reconnect_jitter_ms)
122 .field("certs_dir", &self.certs_dir)
123 .finish()
124 }
125}
126
127impl Clone for SocketConfig {
128 fn clone(&self) -> Self {
129 Self {
130 url: self.url.clone(),
131 mode: self.mode,
132 suffix: self.suffix.clone(),
133 message_handler: self.message_handler.clone(),
134 heartbeat: self.heartbeat.clone(),
135 reconnect_timeout_ms: self.reconnect_timeout_ms,
136 reconnect_delay_initial_ms: self.reconnect_delay_initial_ms,
137 reconnect_delay_max_ms: self.reconnect_delay_max_ms,
138 reconnect_backoff_factor: self.reconnect_backoff_factor,
139 reconnect_jitter_ms: self.reconnect_jitter_ms,
140 certs_dir: self.certs_dir.clone(),
141 }
142 }
143}
144
145#[derive(Debug)]
147pub enum WriterCommand {
148 Update(TcpWriter),
150 Send(Bytes),
152}
153
154#[cfg_attr(
170 feature = "python",
171 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
172)]
173struct SocketClientInner {
174 config: SocketConfig,
175 connector: Option<Connector>,
176 read_task: Arc<tokio::task::JoinHandle<()>>,
177 write_task: tokio::task::JoinHandle<()>,
178 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
179 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
180 connection_mode: Arc<AtomicU8>,
181 reconnect_timeout: Duration,
182 backoff: ExponentialBackoff,
183 handler: Option<TcpMessageHandler>,
184}
185
186impl SocketClientInner {
187 pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
193 install_cryptographic_provider();
194
195 let SocketConfig {
196 url,
197 mode,
198 heartbeat,
199 suffix,
200 message_handler,
201 reconnect_timeout_ms,
202 reconnect_delay_initial_ms,
203 reconnect_delay_max_ms,
204 reconnect_backoff_factor,
205 reconnect_jitter_ms,
206 certs_dir,
207 } = &config.clone();
208 let connector = if let Some(dir) = certs_dir {
209 let config = create_tls_config_from_certs_dir(Path::new(dir))?;
210 Some(Connector::Rustls(Arc::new(config)))
211 } else {
212 None
213 };
214
215 let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector.clone()).await?;
216 tracing::debug!("Connected");
217
218 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
219
220 let read_task = Arc::new(Self::spawn_read_task(
221 connection_mode.clone(),
222 reader,
223 message_handler.clone(),
224 suffix.clone(),
225 ));
226
227 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
228
229 let write_task =
230 Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
231
232 let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
234 Self::spawn_heartbeat_task(
235 connection_mode.clone(),
236 heartbeat.clone(),
237 writer_tx.clone(),
238 )
239 });
240
241 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
242 let backoff = ExponentialBackoff::new(
243 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
244 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
245 reconnect_backoff_factor.unwrap_or(1.5),
246 reconnect_jitter_ms.unwrap_or(100),
247 true, )?;
249
250 Ok(Self {
251 config,
252 connector,
253 read_task,
254 write_task,
255 writer_tx,
256 heartbeat_task,
257 connection_mode,
258 reconnect_timeout,
259 backoff,
260 handler: message_handler.clone(),
261 })
262 }
263
264 pub async fn tls_connect_with_server(
270 url: &str,
271 mode: Mode,
272 connector: Option<Connector>,
273 ) -> Result<(TcpReader, TcpWriter), Error> {
274 tracing::debug!("Connecting to {url}");
275 let tcp_result = TcpStream::connect(url).await;
276
277 match tcp_result {
278 Ok(stream) => {
279 tracing::debug!("TCP connection established, proceeding with TLS");
280 let request = url.into_client_request()?;
281 tcp_tls(&request, mode, stream, connector)
282 .await
283 .map(tokio::io::split)
284 }
285 Err(e) => {
286 tracing::error!("TCP connection failed: {e:?}");
287 Err(Error::Io(e))
288 }
289 }
290 }
291
292 async fn reconnect(&mut self) -> Result<(), Error> {
297 tracing::debug!("Reconnecting");
298
299 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
300 tracing::debug!("Reconnect aborted due to disconnect state");
301 return Ok(());
302 }
303
304 tokio::time::timeout(self.reconnect_timeout, async {
305 let SocketConfig {
306 url,
307 mode,
308 heartbeat: _,
309 suffix,
310 message_handler: _,
311 reconnect_timeout_ms: _,
312 reconnect_delay_initial_ms: _,
313 reconnect_backoff_factor: _,
314 reconnect_delay_max_ms: _,
315 reconnect_jitter_ms: _,
316 certs_dir: _,
317 } = &self.config;
318 let connector = self.connector.clone();
320 let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
322
323 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
324 tracing::debug!("Reconnect aborted mid-flight (after connect)");
325 return Ok(());
326 }
327 tracing::debug!("Connected");
328
329 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
330 tracing::error!("{e}");
331 }
332
333 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
335
336 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
337 tracing::debug!("Reconnect aborted mid-flight (after delay)");
338 return Ok(());
339 }
340
341 if !self.read_task.is_finished() {
342 self.read_task.abort();
343 log_task_aborted("read");
344 }
345
346 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
348 tracing::debug!("Reconnect aborted mid-flight (before spawn read)");
349 return Ok(());
350 }
351
352 self.connection_mode
354 .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
355
356 self.read_task = Arc::new(Self::spawn_read_task(
358 self.connection_mode.clone(),
359 reader,
360 self.handler.clone(),
361 suffix.clone(),
362 ));
363
364 tracing::debug!("Reconnect succeeded");
365 Ok(())
366 })
367 .await
368 .map_err(|_| {
369 Error::Io(std::io::Error::new(
370 std::io::ErrorKind::TimedOut,
371 format!(
372 "reconnection timed out after {}s",
373 self.reconnect_timeout.as_secs_f64()
374 ),
375 ))
376 })?
377 }
378
379 #[inline]
386 #[must_use]
387 pub fn is_alive(&self) -> bool {
388 !self.read_task.is_finished()
389 }
390
391 #[must_use]
392 fn spawn_read_task(
393 connection_state: Arc<AtomicU8>,
394 mut reader: TcpReader,
395 handler: Option<TcpMessageHandler>,
396 suffix: Vec<u8>,
397 ) -> tokio::task::JoinHandle<()> {
398 log_task_started("read");
399
400 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
402
403 tokio::task::spawn(async move {
404 let mut buf = Vec::new();
405
406 loop {
407 if !ConnectionMode::from_atomic(&connection_state).is_active() {
408 break;
409 }
410
411 match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
412 Ok(Ok(0)) => {
414 tracing::debug!("Connection closed by server");
415 break;
416 }
417 Ok(Err(e)) => {
418 tracing::debug!("Connection ended: {e}");
419 break;
420 }
421 Ok(Ok(bytes)) => {
423 tracing::trace!("Received <binary> {bytes} bytes");
424
425 let is_fix = buf.len() >= 5 && buf.starts_with(b"8=FIX");
427
428 if is_fix && handler.is_some() {
429 if let Some(ref handler) = handler {
431 process_fix_buffer(&mut buf, handler);
432 }
433 } else {
434 while let Some((i, _)) = &buf
436 .windows(suffix.len())
437 .enumerate()
438 .find(|(_, pair)| pair.eq(&suffix))
439 {
440 let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
441 data.truncate(data.len() - suffix.len());
442
443 if let Some(ref handler) = handler {
444 handler(&data);
445 }
446 }
447 }
448 }
449 Err(_) => {
450 continue;
452 }
453 }
454 }
455
456 log_task_stopped("read");
457 })
458 }
459
460 fn spawn_write_task(
461 connection_state: Arc<AtomicU8>,
462 writer: TcpWriter,
463 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
464 suffix: Vec<u8>,
465 ) -> tokio::task::JoinHandle<()> {
466 log_task_started("write");
467
468 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
470
471 tokio::task::spawn(async move {
472 let mut active_writer = writer;
473
474 loop {
475 if matches!(
476 ConnectionMode::from_atomic(&connection_state),
477 ConnectionMode::Disconnect | ConnectionMode::Closed
478 ) {
479 break;
480 }
481
482 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
483 Ok(Some(msg)) => {
484 let mode = ConnectionMode::from_atomic(&connection_state);
486 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
487 break;
488 }
489
490 match msg {
491 WriterCommand::Update(new_writer) => {
492 tracing::debug!("Received new writer");
493
494 tokio::time::sleep(Duration::from_millis(100)).await;
496
497 _ = tokio::time::timeout(
500 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
501 active_writer.shutdown(),
502 )
503 .await;
504
505 active_writer = new_writer;
506 tracing::debug!("Updated writer");
507 }
508 _ if mode.is_reconnect() => {
509 tracing::warn!("Skipping message while reconnecting, {msg:?}");
510 continue;
511 }
512 WriterCommand::Send(msg) => {
513 if let Err(e) = active_writer.write_all(&msg).await {
514 tracing::error!("Failed to send message: {e}");
515 tracing::warn!("Writer triggering reconnect");
517 connection_state
518 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
519 continue;
520 }
521 if let Err(e) = active_writer.write_all(&suffix).await {
522 tracing::error!("Failed to send message: {e}");
523 }
524 }
525 }
526 }
527 Ok(None) => {
528 tracing::debug!("Writer channel closed, terminating writer task");
530 break;
531 }
532 Err(_) => {
533 continue;
535 }
536 }
537 }
538
539 _ = tokio::time::timeout(
542 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
543 active_writer.shutdown(),
544 )
545 .await;
546
547 log_task_stopped("write");
548 })
549 }
550
551 fn spawn_heartbeat_task(
552 connection_state: Arc<AtomicU8>,
553 heartbeat: (u64, Vec<u8>),
554 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
555 ) -> tokio::task::JoinHandle<()> {
556 log_task_started("heartbeat");
557 let (interval_secs, message) = heartbeat;
558
559 tokio::task::spawn(async move {
560 let interval = Duration::from_secs(interval_secs);
561
562 loop {
563 tokio::time::sleep(interval).await;
564
565 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
566 ConnectionMode::Active => {
567 let msg = WriterCommand::Send(message.clone().into());
568
569 match writer_tx.send(msg) {
570 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
571 Err(e) => {
572 tracing::error!("Failed to send heartbeat to writer task: {e}");
573 }
574 }
575 }
576 ConnectionMode::Reconnect => continue,
577 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
578 }
579 }
580
581 log_task_stopped("heartbeat");
582 })
583 }
584}
585
586impl Drop for SocketClientInner {
587 fn drop(&mut self) {
588 self.clean_drop();
590 }
591}
592
593impl CleanDrop for SocketClientInner {
594 fn clean_drop(&mut self) {
595 if !self.read_task.is_finished() {
596 self.read_task.abort();
597 log_task_aborted("read");
598 }
599
600 if !self.write_task.is_finished() {
601 self.write_task.abort();
602 log_task_aborted("write");
603 }
604
605 if let Some(ref handle) = self.heartbeat_task.take()
606 && !handle.is_finished()
607 {
608 handle.abort();
609 log_task_aborted("heartbeat");
610 }
611
612 #[cfg(feature = "python")]
613 {
614 self.config.message_handler = None;
616 }
617 }
618}
619
620#[cfg_attr(
621 feature = "python",
622 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
623)]
624pub struct SocketClient {
625 pub(crate) controller_task: tokio::task::JoinHandle<()>,
626 pub(crate) connection_mode: Arc<AtomicU8>,
627 pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
628}
629
630impl Debug for SocketClient {
631 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
632 f.debug_struct(stringify!(SocketClient)).finish()
633 }
634}
635
636impl SocketClient {
637 pub async fn connect(
643 config: SocketConfig,
644 post_connection: Option<Arc<dyn Fn() + Send + Sync>>,
645 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
646 post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
647 ) -> anyhow::Result<Self> {
648 let inner = SocketClientInner::connect_url(config).await?;
649 let writer_tx = inner.writer_tx.clone();
650 let connection_mode = inner.connection_mode.clone();
651
652 let controller_task = Self::spawn_controller_task(
653 inner,
654 connection_mode.clone(),
655 post_reconnection,
656 post_disconnection,
657 );
658
659 if let Some(handler) = post_connection {
660 handler();
661 tracing::debug!("Called `post_connection` handler");
662 }
663
664 Ok(Self {
665 controller_task,
666 connection_mode,
667 writer_tx,
668 })
669 }
670
671 #[must_use]
673 pub fn connection_mode(&self) -> ConnectionMode {
674 ConnectionMode::from_atomic(&self.connection_mode)
675 }
676
677 #[inline]
682 #[must_use]
683 pub fn is_active(&self) -> bool {
684 self.connection_mode().is_active()
685 }
686
687 #[inline]
692 #[must_use]
693 pub fn is_reconnecting(&self) -> bool {
694 self.connection_mode().is_reconnect()
695 }
696
697 #[inline]
701 #[must_use]
702 pub fn is_disconnecting(&self) -> bool {
703 self.connection_mode().is_disconnect()
704 }
705
706 #[inline]
712 #[must_use]
713 pub fn is_closed(&self) -> bool {
714 self.connection_mode().is_closed()
715 }
716
717 pub async fn close(&self) {
722 self.connection_mode
723 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
724
725 match tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
726 while !self.is_closed() {
727 tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
728 }
729
730 if !self.controller_task.is_finished() {
731 self.controller_task.abort();
732 log_task_aborted("controller");
733 }
734 })
735 .await
736 {
737 Ok(()) => {
738 log_task_stopped("controller");
739 }
740 Err(_) => {
741 tracing::error!("Timeout waiting for controller task to finish");
742 }
743 }
744 }
745
746 pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
752 if self.is_closed() {
753 return Err(SendError::Closed);
754 }
755
756 let timeout = Duration::from_secs(SEND_OPERATION_TIMEOUT_SECS);
757 let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
758
759 if !self.is_active() {
760 tracing::debug!("Waiting for client to become ACTIVE before sending...");
761
762 let inner = tokio::time::timeout(timeout, async {
763 loop {
764 if self.is_active() {
765 return Ok(());
766 }
767 if matches!(
768 self.connection_mode(),
769 ConnectionMode::Disconnect | ConnectionMode::Closed
770 ) {
771 return Err(());
772 }
773 tokio::time::sleep(check_interval).await;
774 }
775 })
776 .await
777 .map_err(|_| SendError::Timeout)?;
778 inner.map_err(|()| SendError::Closed)?;
779 }
780
781 let msg = WriterCommand::Send(data.into());
782 self.writer_tx
783 .send(msg)
784 .map_err(|e| SendError::BrokenPipe(e.to_string()))
785 }
786
787 fn spawn_controller_task(
788 mut inner: SocketClientInner,
789 connection_mode: Arc<AtomicU8>,
790 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
791 post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
792 ) -> tokio::task::JoinHandle<()> {
793 tokio::task::spawn(async move {
794 log_task_started("controller");
795
796 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
797
798 loop {
799 tokio::time::sleep(check_interval).await;
800 let mode = ConnectionMode::from_atomic(&connection_mode);
801
802 if mode.is_disconnect() {
803 tracing::debug!("Disconnecting");
804
805 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
806 if tokio::time::timeout(timeout, async {
807 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
809
810 if !inner.read_task.is_finished() {
811 inner.read_task.abort();
812 log_task_aborted("read");
813 }
814
815 if let Some(task) = &inner.heartbeat_task
816 && !task.is_finished()
817 {
818 task.abort();
819 log_task_aborted("heartbeat");
820 }
821 })
822 .await
823 .is_err()
824 {
825 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
826 }
827
828 tracing::debug!("Closed");
829
830 if let Some(ref handler) = post_disconnection {
831 handler();
832 tracing::debug!("Called `post_disconnection` handler");
833 }
834 break; }
836
837 if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
838 match inner.reconnect().await {
839 Ok(()) => {
840 tracing::debug!("Reconnected successfully");
841 inner.backoff.reset();
842 if ConnectionMode::from_atomic(&connection_mode).is_active() {
844 if let Some(ref handler) = post_reconnection {
845 handler();
846 tracing::debug!("Called `post_reconnection` handler");
847 }
848 } else {
849 tracing::debug!(
850 "Skipping post_reconnection handlers due to disconnect state"
851 );
852 }
853 }
854 Err(e) => {
855 let duration = inner.backoff.next_duration();
856 tracing::warn!("Reconnect attempt failed: {e}");
857 if !duration.is_zero() {
858 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
859 }
860 tokio::time::sleep(duration).await;
861 }
862 }
863 }
864 }
865 inner
866 .connection_mode
867 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
868
869 log_task_stopped("controller");
870 })
871 }
872}
873
874impl Drop for SocketClient {
876 fn drop(&mut self) {
877 if !self.controller_task.is_finished() {
878 self.controller_task.abort();
879 log_task_aborted("controller");
880 }
881 }
882}
883
884#[cfg(test)]
888#[cfg(feature = "python")]
889#[cfg(target_os = "linux")] mod tests {
891 use nautilus_common::testing::wait_until_async;
892 use pyo3::prepare_freethreaded_python;
893 use tokio::{
894 io::{AsyncReadExt, AsyncWriteExt},
895 net::{TcpListener, TcpStream},
896 sync::Mutex,
897 task,
898 time::{Duration, sleep},
899 };
900
901 use super::*;
902
903 async fn bind_test_server() -> (u16, TcpListener) {
904 let listener = TcpListener::bind("127.0.0.1:0")
905 .await
906 .expect("Failed to bind ephemeral port");
907 let port = listener.local_addr().unwrap().port();
908 (port, listener)
909 }
910
911 async fn run_echo_server(mut socket: TcpStream) {
912 let mut buf = Vec::new();
913 loop {
914 match socket.read_buf(&mut buf).await {
915 Ok(0) => {
916 break;
917 }
918 Ok(_n) => {
919 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
920 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
921 line.truncate(line.len() - 2);
923
924 if line == b"close" {
925 let _ = socket.shutdown().await;
926 return;
927 }
928
929 let mut echo_data = line;
930 echo_data.extend_from_slice(b"\r\n");
931 if socket.write_all(&echo_data).await.is_err() {
932 break;
933 }
934 }
935 }
936 Err(e) => {
937 eprintln!("Server read error: {e}");
938 break;
939 }
940 }
941 }
942 }
943
944 #[tokio::test]
945 async fn test_basic_send_receive() {
946 prepare_freethreaded_python();
947
948 let (port, listener) = bind_test_server().await;
949 let server_task = task::spawn(async move {
950 let (socket, _) = listener.accept().await.unwrap();
951 run_echo_server(socket).await;
952 });
953
954 let config = SocketConfig {
955 url: format!("127.0.0.1:{port}"),
956 mode: Mode::Plain,
957 suffix: b"\r\n".to_vec(),
958 message_handler: None,
959 heartbeat: None,
960 reconnect_timeout_ms: None,
961 reconnect_delay_initial_ms: None,
962 reconnect_backoff_factor: None,
963 reconnect_delay_max_ms: None,
964 reconnect_jitter_ms: None,
965 certs_dir: None,
966 };
967
968 let client = SocketClient::connect(config, None, None, None)
969 .await
970 .expect("Client connect failed unexpectedly");
971
972 client.send_bytes(b"Hello".into()).await.unwrap();
973 client.send_bytes(b"World".into()).await.unwrap();
974
975 sleep(Duration::from_millis(100)).await;
977
978 client.send_bytes(b"close".into()).await.unwrap();
979 server_task.await.unwrap();
980 assert!(!client.is_closed());
981 }
982
983 #[tokio::test]
984 async fn test_reconnect_fail_exhausted() {
985 prepare_freethreaded_python();
986
987 let (port, listener) = bind_test_server().await;
988 drop(listener); let config = SocketConfig {
991 url: format!("127.0.0.1:{port}"),
992 mode: Mode::Plain,
993 suffix: b"\r\n".to_vec(),
994 message_handler: None,
995 heartbeat: None,
996 reconnect_timeout_ms: None,
997 reconnect_delay_initial_ms: None,
998 reconnect_backoff_factor: None,
999 reconnect_delay_max_ms: None,
1000 reconnect_jitter_ms: None,
1001 certs_dir: None,
1002 };
1003
1004 let client_res = SocketClient::connect(config, None, None, None).await;
1005 assert!(
1006 client_res.is_err(),
1007 "Should fail quickly with no server listening"
1008 );
1009 }
1010
1011 #[tokio::test]
1012 async fn test_user_disconnect() {
1013 prepare_freethreaded_python();
1014
1015 let (port, listener) = bind_test_server().await;
1016 let server_task = task::spawn(async move {
1017 let (socket, _) = listener.accept().await.unwrap();
1018 let mut buf = [0u8; 1024];
1019 let _ = socket.try_read(&mut buf);
1020
1021 loop {
1022 sleep(Duration::from_secs(1)).await;
1023 }
1024 });
1025
1026 let config = SocketConfig {
1027 url: format!("127.0.0.1:{port}"),
1028 mode: Mode::Plain,
1029 suffix: b"\r\n".to_vec(),
1030 message_handler: None,
1031 heartbeat: None,
1032 reconnect_timeout_ms: None,
1033 reconnect_delay_initial_ms: None,
1034 reconnect_backoff_factor: None,
1035 reconnect_delay_max_ms: None,
1036 reconnect_jitter_ms: None,
1037 certs_dir: None,
1038 };
1039
1040 let client = SocketClient::connect(config, None, None, None)
1041 .await
1042 .unwrap();
1043
1044 client.close().await;
1045 assert!(client.is_closed());
1046 server_task.abort();
1047 }
1048
1049 #[tokio::test]
1050 async fn test_heartbeat() {
1051 prepare_freethreaded_python();
1052
1053 let (port, listener) = bind_test_server().await;
1054 let received = Arc::new(Mutex::new(Vec::new()));
1055 let received2 = received.clone();
1056
1057 let server_task = task::spawn(async move {
1058 let (socket, _) = listener.accept().await.unwrap();
1059
1060 let mut buf = Vec::new();
1061 loop {
1062 match socket.try_read_buf(&mut buf) {
1063 Ok(0) => break,
1064 Ok(_) => {
1065 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1066 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1067 line.truncate(line.len() - 2);
1068 received2.lock().await.push(line);
1069 }
1070 }
1071 Err(_) => {
1072 tokio::time::sleep(Duration::from_millis(10)).await;
1073 }
1074 }
1075 }
1076 });
1077
1078 let heartbeat = Some((1, b"ping".to_vec()));
1080
1081 let config = SocketConfig {
1082 url: format!("127.0.0.1:{port}"),
1083 mode: Mode::Plain,
1084 suffix: b"\r\n".to_vec(),
1085 message_handler: None,
1086 heartbeat,
1087 reconnect_timeout_ms: None,
1088 reconnect_delay_initial_ms: None,
1089 reconnect_backoff_factor: None,
1090 reconnect_delay_max_ms: None,
1091 reconnect_jitter_ms: None,
1092 certs_dir: None,
1093 };
1094
1095 let client = SocketClient::connect(config, None, None, None)
1096 .await
1097 .unwrap();
1098
1099 sleep(Duration::from_secs(3)).await;
1101
1102 {
1103 let lock = received.lock().await;
1104 let pings = lock
1105 .iter()
1106 .filter(|line| line == &&b"ping".to_vec())
1107 .count();
1108 assert!(
1109 pings >= 2,
1110 "Expected at least 2 heartbeat pings; got {pings}"
1111 );
1112 }
1113
1114 client.close().await;
1115 server_task.abort();
1116 }
1117
1118 #[tokio::test]
1119 async fn test_reconnect_success() {
1120 prepare_freethreaded_python();
1121
1122 let (port, listener) = bind_test_server().await;
1123
1124 let server_task = task::spawn(async move {
1128 let (mut socket, _) = listener.accept().await.expect("First accept failed");
1130
1131 sleep(Duration::from_millis(500)).await;
1133 let _ = socket.shutdown().await;
1134
1135 sleep(Duration::from_millis(500)).await;
1137
1138 let (socket, _) = listener.accept().await.expect("Second accept failed");
1140 run_echo_server(socket).await;
1141 });
1142
1143 let config = SocketConfig {
1144 url: format!("127.0.0.1:{port}"),
1145 mode: Mode::Plain,
1146 suffix: b"\r\n".to_vec(),
1147 message_handler: None,
1148 heartbeat: None,
1149 reconnect_timeout_ms: Some(5_000),
1150 reconnect_delay_initial_ms: Some(500),
1151 reconnect_delay_max_ms: Some(5_000),
1152 reconnect_backoff_factor: Some(2.0),
1153 reconnect_jitter_ms: Some(50),
1154 certs_dir: None,
1155 };
1156
1157 let client = SocketClient::connect(config, None, None, None)
1158 .await
1159 .expect("Client connect failed unexpectedly");
1160
1161 assert!(client.is_active(), "Client should start as active");
1163
1164 wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1167
1168 client
1169 .send_bytes(b"TestReconnect".into())
1170 .await
1171 .expect("Send failed");
1172
1173 client.close().await;
1174 server_task.abort();
1175 }
1176}
1177
1178#[cfg(test)]
1179mod rust_tests {
1180 use tokio::{
1181 net::TcpListener,
1182 task,
1183 time::{Duration, sleep},
1184 };
1185
1186 use super::*;
1187
1188 #[tokio::test]
1189 async fn test_reconnect_then_close() {
1190 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1192 let port = listener.local_addr().unwrap().port();
1193
1194 let server = task::spawn(async move {
1196 if let Ok((mut sock, _)) = listener.accept().await {
1197 drop(sock.shutdown());
1198 }
1199 sleep(Duration::from_secs(1)).await;
1201 });
1202
1203 let config = SocketConfig {
1205 url: format!("127.0.0.1:{port}"),
1206 mode: Mode::Plain,
1207 suffix: b"\r\n".to_vec(),
1208 message_handler: None,
1209 heartbeat: None,
1210 reconnect_timeout_ms: Some(1_000),
1211 reconnect_delay_initial_ms: Some(50),
1212 reconnect_delay_max_ms: Some(100),
1213 reconnect_backoff_factor: Some(1.0),
1214 reconnect_jitter_ms: Some(0),
1215 certs_dir: None,
1216 };
1217
1218 let client = {
1220 #[cfg(feature = "python")]
1221 {
1222 SocketClient::connect(config.clone(), None, None, None)
1223 .await
1224 .unwrap()
1225 }
1226 #[cfg(not(feature = "python"))]
1227 {
1228 SocketClient::connect(config.clone(), None, None, None)
1229 .await
1230 .unwrap()
1231 }
1232 };
1233
1234 sleep(Duration::from_millis(100)).await;
1236
1237 client.close().await;
1239 assert!(client.is_closed());
1240 server.abort();
1241 }
1242}