1use std::{
20 path::Path,
21 sync::{
22 Arc,
23 atomic::{AtomicU8, Ordering},
24 },
25 time::Duration,
26};
27
28use nautilus_cryptography::providers::install_cryptographic_provider;
29use pyo3::prelude::*;
30use tokio::{
31 io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
32 net::TcpStream,
33 sync::Mutex,
34};
35use tokio_tungstenite::{
36 MaybeTlsStream,
37 tungstenite::{Error, client::IntoClientRequest, stream::Mode},
38};
39
40use crate::{
41 backoff::ExponentialBackoff,
42 mode::ConnectionMode,
43 tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
44};
45
46type TcpWriter = WriteHalf<MaybeTlsStream<TcpStream>>;
47type SharedTcpWriter = Arc<Mutex<WriteHalf<MaybeTlsStream<TcpStream>>>>;
48type TcpReader = ReadHalf<MaybeTlsStream<TcpStream>>;
49
50#[derive(Debug, Clone)]
52#[cfg_attr(
53 feature = "python",
54 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
55)]
56pub struct SocketConfig {
57 pub url: String,
59 pub mode: Mode,
61 pub suffix: Vec<u8>,
63 pub handler: Arc<PyObject>,
65 pub heartbeat: Option<(u64, Vec<u8>)>,
67 pub reconnect_timeout_ms: Option<u64>,
69 pub reconnect_delay_initial_ms: Option<u64>,
71 pub reconnect_delay_max_ms: Option<u64>,
73 pub reconnect_backoff_factor: Option<f64>,
75 pub reconnect_jitter_ms: Option<u64>,
77 pub certs_dir: Option<String>,
79}
80
81#[cfg_attr(
97 feature = "python",
98 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
99)]
100struct SocketClientInner {
101 config: SocketConfig,
102 connector: Option<Connector>,
103 read_task: Arc<tokio::task::JoinHandle<()>>,
104 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
105 writer: SharedTcpWriter,
106 connection_mode: Arc<AtomicU8>,
107 reconnect_timeout: Duration,
108 backoff: ExponentialBackoff,
109}
110
111impl SocketClientInner {
112 pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
113 install_cryptographic_provider();
114
115 let SocketConfig {
116 url,
117 mode,
118 heartbeat,
119 suffix,
120 handler,
121 reconnect_timeout_ms,
122 reconnect_delay_initial_ms,
123 reconnect_delay_max_ms,
124 reconnect_backoff_factor,
125 reconnect_jitter_ms,
126 certs_dir,
127 } = &config;
128 let connector = if let Some(dir) = certs_dir {
129 let config = create_tls_config_from_certs_dir(Path::new(dir))?;
130 Some(Connector::Rustls(Arc::new(config)))
131 } else {
132 None
133 };
134
135 let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector.clone()).await?;
136 let writer = Arc::new(Mutex::new(writer));
137
138 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
139
140 let handler = Python::with_gil(|py| handler.clone_ref(py));
141 let read_task = Arc::new(Self::spawn_read_task(reader, handler, suffix.clone()));
142
143 let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
145 Self::spawn_heartbeat_task(
146 connection_mode.clone(),
147 heartbeat.clone(),
148 writer.clone(),
149 suffix.clone(),
150 )
151 });
152
153 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
154 let backoff = ExponentialBackoff::new(
155 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
156 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
157 reconnect_backoff_factor.unwrap_or(1.5),
158 reconnect_jitter_ms.unwrap_or(100),
159 true, );
161
162 Ok(Self {
163 config,
164 connector,
165 read_task,
166 heartbeat_task,
167 writer,
168 connection_mode,
169 reconnect_timeout,
170 backoff,
171 })
172 }
173
174 pub async fn tls_connect_with_server(
175 url: &str,
176 mode: Mode,
177 connector: Option<Connector>,
178 ) -> Result<(TcpReader, TcpWriter), Error> {
179 tracing::debug!("Connecting to server");
180 let stream = TcpStream::connect(url).await?;
181 tracing::debug!("Making TLS connection");
182 let request = url.into_client_request()?;
183 tcp_tls(&request, mode, stream, connector)
184 .await
185 .map(tokio::io::split)
186 }
187
188 async fn reconnect(&mut self) -> Result<(), Error> {
193 tracing::debug!("Reconnecting");
194
195 tokio::time::timeout(self.reconnect_timeout, async {
196 shutdown(
198 self.read_task.clone(),
199 self.heartbeat_task.take(),
200 self.writer.clone(),
201 )
202 .await;
203
204 let SocketConfig {
205 url,
206 mode,
207 heartbeat,
208 suffix,
209 handler,
210 reconnect_timeout_ms: _,
211 reconnect_delay_initial_ms: _,
212 reconnect_backoff_factor: _,
213 reconnect_delay_max_ms: _,
214 reconnect_jitter_ms: _,
215 certs_dir: _,
216 } = &self.config;
217 let connector = self.connector.clone();
219 let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
220 let writer = Arc::new(Mutex::new(writer));
221 self.writer = writer.clone();
222
223 let handler_for_read = Python::with_gil(|py| handler.clone_ref(py));
225 self.read_task = Arc::new(Self::spawn_read_task(
226 reader,
227 handler_for_read,
228 suffix.clone(),
229 ));
230
231 self.heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
233 Self::spawn_heartbeat_task(
234 self.connection_mode.clone(),
235 heartbeat.clone(),
236 writer.clone(),
237 suffix.clone(),
238 )
239 });
240
241 self.connection_mode
242 .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
243
244 tracing::debug!("Reconnect succeeded");
245 Ok(())
246 })
247 .await
248 .map_err(|_| {
249 Error::Io(std::io::Error::new(
250 std::io::ErrorKind::TimedOut,
251 format!(
252 "reconnection timed out after {}s",
253 self.reconnect_timeout.as_secs_f64()
254 ),
255 ))
256 })?
257 }
258
259 #[inline]
266 #[must_use]
267 pub fn is_alive(&self) -> bool {
268 !self.read_task.is_finished()
269 }
270
271 #[must_use]
272 fn spawn_read_task(
273 mut reader: TcpReader,
274 handler: PyObject,
275 suffix: Vec<u8>,
276 ) -> tokio::task::JoinHandle<()> {
277 tracing::debug!("Started task 'read'");
278
279 tokio::task::spawn(async move {
280 let mut buf = Vec::new();
281
282 loop {
283 match reader.read_buf(&mut buf).await {
284 Ok(0) => {
286 tracing::debug!("Connection closed by server");
287 break;
288 }
289 Err(e) => {
290 tracing::debug!("Connection ended: {e}");
291 break;
292 }
293 Ok(bytes) => {
295 tracing::trace!("Received <binary> {bytes} bytes");
296
297 while let Some((i, _)) = &buf
300 .windows(suffix.len())
301 .enumerate()
302 .find(|(_, pair)| pair.eq(&suffix))
303 {
304 let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
305 data.truncate(data.len() - suffix.len());
306
307 if let Err(e) =
308 Python::with_gil(|py| handler.call1(py, (data.as_slice(),)))
309 {
310 tracing::error!("Call to handler failed: {e}");
311 break;
312 }
313 }
314 }
315 };
316 }
317
318 tracing::debug!("Completed task 'read'");
319 })
320 }
321
322 fn spawn_heartbeat_task(
323 connection_state: Arc<AtomicU8>,
324 heartbeat: (u64, Vec<u8>),
325 writer: SharedTcpWriter,
326 suffix: Vec<u8>,
327 ) -> tokio::task::JoinHandle<()> {
328 tracing::debug!("Started task 'heartbeat'");
329 let (interval_secs, mut message) = heartbeat;
330
331 tokio::task::spawn(async move {
332 let interval = Duration::from_secs(interval_secs);
333 message.extend(suffix);
334
335 loop {
336 tokio::time::sleep(interval).await;
337
338 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
339 ConnectionMode::Active => {
340 let mut guard = writer.lock().await;
341 match guard.write_all(&message).await {
342 Ok(()) => tracing::trace!("Sent heartbeat"),
343 Err(e) => tracing::error!("Failed to send heartbeat: {e}"),
344 }
345 }
346 ConnectionMode::Reconnect => continue,
347 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
348 }
349 }
350
351 tracing::debug!("Completed task 'heartbeat'");
352 })
353 }
354}
355
356async fn shutdown(
363 read_task: Arc<tokio::task::JoinHandle<()>>,
364 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
365 writer: SharedTcpWriter,
366) {
367 tracing::debug!("Shutting down inner client");
368
369 let timeout = Duration::from_secs(5);
370 if tokio::time::timeout(timeout, async {
371 let mut writer = writer.lock().await;
373 if let Err(e) = writer.shutdown().await {
374 tracing::error!("Error on shutdown: {e}");
375 }
376 drop(writer);
377
378 tokio::time::sleep(Duration::from_millis(100)).await;
379
380 if !read_task.is_finished() {
382 read_task.abort();
383 tracing::debug!("Aborted read task");
384 }
385 if let Some(task) = heartbeat_task {
386 if !task.is_finished() {
387 task.abort();
388 tracing::debug!("Aborted heartbeat task");
389 }
390 }
391 })
392 .await
393 .is_err()
394 {
395 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
396 }
397
398 tracing::debug!("Closed");
399}
400
401impl Drop for SocketClientInner {
402 fn drop(&mut self) {
403 if !self.read_task.is_finished() {
404 self.read_task.abort();
405 }
406
407 if let Some(ref handle) = self.heartbeat_task.take() {
409 if !handle.is_finished() {
410 handle.abort();
411 }
412 }
413 }
414}
415
416#[cfg_attr(
417 feature = "python",
418 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
419)]
420pub struct SocketClient {
421 pub(crate) writer: SharedTcpWriter,
422 pub(crate) controller_task: tokio::task::JoinHandle<()>,
423 pub(crate) connection_mode: Arc<AtomicU8>,
424 pub(crate) suffix: Vec<u8>,
425}
426
427impl SocketClient {
428 pub async fn connect(
434 config: SocketConfig,
435 post_connection: Option<PyObject>,
436 post_reconnection: Option<PyObject>,
437 post_disconnection: Option<PyObject>,
438 ) -> anyhow::Result<Self> {
439 let suffix = config.suffix.clone();
440 let inner = SocketClientInner::connect_url(config).await?;
441 let writer = inner.writer.clone();
442 let connection_mode = inner.connection_mode.clone();
443
444 let controller_task = Self::spawn_controller_task(
445 inner,
446 connection_mode.clone(),
447 post_reconnection,
448 post_disconnection,
449 );
450
451 if let Some(handler) = post_connection {
452 Python::with_gil(|py| match handler.call0(py) {
453 Ok(_) => tracing::debug!("Called `post_connection` handler"),
454 Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
455 });
456 }
457
458 Ok(Self {
459 writer,
460 controller_task,
461 connection_mode,
462 suffix,
463 })
464 }
465
466 #[must_use]
468 pub fn connection_mode(&self) -> ConnectionMode {
469 ConnectionMode::from_atomic(&self.connection_mode)
470 }
471
472 #[inline]
477 #[must_use]
478 pub fn is_active(&self) -> bool {
479 self.connection_mode().is_active()
480 }
481
482 #[inline]
487 #[must_use]
488 pub fn is_reconnecting(&self) -> bool {
489 self.connection_mode().is_reconnect()
490 }
491
492 #[inline]
496 #[must_use]
497 pub fn is_disconnecting(&self) -> bool {
498 self.connection_mode().is_disconnect()
499 }
500
501 #[inline]
507 #[must_use]
508 pub fn is_closed(&self) -> bool {
509 self.connection_mode().is_closed()
510 }
511
512 pub async fn close(&self) {
517 self.connection_mode
518 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
519
520 match tokio::time::timeout(Duration::from_secs(5), async {
521 while !self.is_closed() {
522 tokio::time::sleep(Duration::from_millis(10)).await;
523 }
524
525 if !self.controller_task.is_finished() {
526 self.controller_task.abort();
527 tracing::debug!("Aborted controller task");
528 }
529 })
530 .await
531 {
532 Ok(()) => {
533 tracing::debug!("Controller task finished");
534 }
535 Err(_) => {
536 tracing::error!("Timeout waiting for controller task to finish");
537 }
538 }
539 }
540
541 pub async fn send_bytes(&self, data: &[u8]) -> Result<(), std::io::Error> {
547 if self.is_closed() {
548 return Err(std::io::Error::new(
549 std::io::ErrorKind::NotConnected,
550 "Not connected",
551 ));
552 }
553
554 let timeout = Duration::from_secs(2);
555 let check_interval = Duration::from_millis(1);
556
557 if !self.is_active() {
558 tracing::debug!("Waiting for client to become ACTIVE before sending (2s)...");
559 match tokio::time::timeout(timeout, async {
560 while !self.is_active() {
561 if matches!(
562 self.connection_mode(),
563 ConnectionMode::Disconnect | ConnectionMode::Closed
564 ) {
565 return Err("Client disconnected waiting to send");
566 }
567
568 tokio::time::sleep(check_interval).await;
569 }
570
571 Ok(())
572 })
573 .await
574 {
575 Ok(Ok(())) => tracing::debug!("Client now active"),
576 Ok(Err(e)) => {
577 tracing::error!("Cannot send data ({}): {e}", String::from_utf8_lossy(data));
578 return Ok(());
579 }
580 Err(_) => {
581 tracing::error!(
582 "Cannot send data ({}): timeout waiting to become ACTIVE",
583 String::from_utf8_lossy(data)
584 );
585 return Ok(());
586 }
587 }
588 }
589
590 let mut writer = self.writer.lock().await;
591 writer.write_all(data).await?;
592 writer.write_all(&self.suffix).await
593 }
594
595 fn spawn_controller_task(
596 mut inner: SocketClientInner,
597 connection_mode: Arc<AtomicU8>,
598 post_reconnection: Option<PyObject>,
599 post_disconnection: Option<PyObject>,
600 ) -> tokio::task::JoinHandle<()> {
601 tokio::task::spawn(async move {
602 tracing::debug!("Started task 'controller'");
603
604 let check_interval = Duration::from_millis(10);
605
606 loop {
607 tokio::time::sleep(check_interval).await;
608 let mode = ConnectionMode::from_atomic(&connection_mode);
609
610 if mode.is_disconnect() {
611 tracing::debug!("Disconnecting");
612 shutdown(
613 inner.read_task.clone(),
614 inner.heartbeat_task.take(),
615 inner.writer.clone(),
616 )
617 .await;
618
619 if let Some(ref handler) = post_disconnection {
620 Python::with_gil(|py| match handler.call0(py) {
621 Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
622 Err(e) => {
623 tracing::error!("Error calling `post_disconnection` handler: {e}");
624 }
625 });
626 }
627 break; }
629
630 if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
631 match inner.reconnect().await {
632 Ok(()) => {
633 tracing::debug!("Reconnected successfully");
634 inner.backoff.reset();
635
636 if let Some(ref handler) = post_reconnection {
637 Python::with_gil(|py| match handler.call0(py) {
638 Ok(_) => tracing::debug!("Called `post_reconnection` handler"),
639 Err(e) => tracing::error!(
640 "Error calling `post_reconnection` handler: {e}"
641 ),
642 });
643 }
644 }
645 Err(e) => {
646 let duration = inner.backoff.next_duration();
647 tracing::warn!("Reconnect attempt failed: {e}",);
648 if !duration.is_zero() {
649 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
650 }
651 tokio::time::sleep(duration).await;
652 }
653 }
654 }
655 }
656 inner
657 .connection_mode
658 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
659 })
660 }
661}
662
663#[cfg(test)]
667#[cfg(target_os = "linux")] mod tests {
669 use std::{ffi::CString, net::TcpListener};
670
671 use nautilus_common::testing::wait_until_async;
672 use nautilus_core::python::IntoPyObjectNautilusExt;
673 use pyo3::prepare_freethreaded_python;
674 use tokio::{
675 io::{AsyncReadExt, AsyncWriteExt},
676 net::TcpStream,
677 task,
678 time::{Duration, sleep},
679 };
680
681 use super::*;
682
683 fn create_handler() -> PyObject {
684 let code_raw = r"
685class Counter:
686 def __init__(self):
687 self.count = 0
688 self.check = False
689
690 def handler(self, bytes):
691 msg = bytes.decode()
692 if msg == 'ping':
693 self.count += 1
694 elif msg == 'heartbeat message':
695 self.check = True
696
697 def get_check(self):
698 return self.check
699
700 def get_count(self):
701 return self.count
702
703counter = Counter()
704";
705 let code = CString::new(code_raw).unwrap();
706 let filename = CString::new("test".to_string()).unwrap();
707 let module = CString::new("test".to_string()).unwrap();
708 Python::with_gil(|py| {
709 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
710 let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
711
712 counter
713 .getattr(py, "handler")
714 .unwrap()
715 .into_py_any_unwrap(py)
716 })
717 }
718
719 fn bind_test_server() -> (u16, TcpListener) {
720 let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind ephemeral port");
721 let port = listener.local_addr().unwrap().port();
722 (port, listener)
723 }
724
725 async fn run_echo_server(mut socket: TcpStream) {
726 let mut buf = Vec::new();
727 loop {
728 match socket.read_buf(&mut buf).await {
729 Ok(0) => {
730 break;
731 }
732 Ok(_n) => {
733 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
734 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
735 line.truncate(line.len() - 2);
737
738 if line == b"close" {
739 let _ = socket.shutdown().await;
740 return;
741 }
742
743 let mut echo_data = line;
744 echo_data.extend_from_slice(b"\r\n");
745 if socket.write_all(&echo_data).await.is_err() {
746 break;
747 }
748 }
749 }
750 Err(e) => {
751 eprintln!("Server read error: {e}");
752 break;
753 }
754 }
755 }
756 }
757
758 #[tokio::test]
759 async fn test_basic_send_receive() {
760 prepare_freethreaded_python();
761
762 let (port, listener) = bind_test_server();
763 let server_task = task::spawn(async move {
764 let (socket, _) = tokio::net::TcpListener::from_std(listener)
765 .unwrap()
766 .accept()
767 .await
768 .unwrap();
769 run_echo_server(socket).await;
770 });
771
772 let config = SocketConfig {
773 url: format!("127.0.0.1:{port}"),
774 mode: Mode::Plain,
775 suffix: b"\r\n".to_vec(),
776 handler: Arc::new(create_handler()),
777 heartbeat: None,
778 reconnect_timeout_ms: None,
779 reconnect_delay_initial_ms: None,
780 reconnect_backoff_factor: None,
781 reconnect_delay_max_ms: None,
782 reconnect_jitter_ms: None,
783 certs_dir: None,
784 };
785
786 let client = SocketClient::connect(config, None, None, None)
787 .await
788 .expect("Client connect failed unexpectedly");
789
790 client.send_bytes(b"Hello").await.unwrap();
791 client.send_bytes(b"World").await.unwrap();
792
793 sleep(Duration::from_millis(100)).await;
795
796 client.send_bytes(b"close").await.unwrap();
797 server_task.await.unwrap();
798 assert!(!client.is_closed());
799 }
800
801 #[tokio::test]
802 async fn test_reconnect_fail_exhausted() {
803 prepare_freethreaded_python();
804
805 let (port, listener) = bind_test_server();
806 drop(listener); let config = SocketConfig {
809 url: format!("127.0.0.1:{port}"),
810 mode: Mode::Plain,
811 suffix: b"\r\n".to_vec(),
812 handler: Arc::new(create_handler()),
813 heartbeat: None,
814 reconnect_timeout_ms: None,
815 reconnect_delay_initial_ms: None,
816 reconnect_backoff_factor: None,
817 reconnect_delay_max_ms: None,
818 reconnect_jitter_ms: None,
819 certs_dir: None,
820 };
821
822 let client_res = SocketClient::connect(config, None, None, None).await;
823 assert!(
824 client_res.is_err(),
825 "Should fail quickly with no server listening"
826 );
827 }
828
829 #[tokio::test]
830 async fn test_user_disconnect() {
831 prepare_freethreaded_python();
832
833 let (port, listener) = bind_test_server();
834 let server_task = task::spawn(async move {
835 let (socket, _) = tokio::net::TcpListener::from_std(listener)
836 .unwrap()
837 .accept()
838 .await
839 .unwrap();
840 let mut buf = [0u8; 1024];
841 let _ = socket.try_read(&mut buf);
842
843 loop {
844 sleep(Duration::from_secs(1)).await;
845 }
846 });
847
848 let config = SocketConfig {
849 url: format!("127.0.0.1:{port}"),
850 mode: Mode::Plain,
851 suffix: b"\r\n".to_vec(),
852 handler: Arc::new(create_handler()),
853 heartbeat: None,
854 reconnect_timeout_ms: None,
855 reconnect_delay_initial_ms: None,
856 reconnect_backoff_factor: None,
857 reconnect_delay_max_ms: None,
858 reconnect_jitter_ms: None,
859 certs_dir: None,
860 };
861
862 let client = SocketClient::connect(config, None, None, None)
863 .await
864 .unwrap();
865
866 client.close().await;
867 assert!(client.is_closed());
868 server_task.abort();
869 }
870
871 #[tokio::test]
872 async fn test_heartbeat() {
873 prepare_freethreaded_python();
874
875 let (port, listener) = bind_test_server();
876 let received = Arc::new(Mutex::new(Vec::new()));
877 let received2 = received.clone();
878
879 let server_task = task::spawn(async move {
880 let (socket, _) = tokio::net::TcpListener::from_std(listener)
881 .unwrap()
882 .accept()
883 .await
884 .unwrap();
885
886 let mut buf = Vec::new();
887 loop {
888 match socket.try_read_buf(&mut buf) {
889 Ok(0) => break,
890 Ok(_) => {
891 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
892 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
893 line.truncate(line.len() - 2);
894 received2.lock().await.push(line);
895 }
896 }
897 Err(_) => {
898 tokio::time::sleep(Duration::from_millis(10)).await;
899 }
900 }
901 }
902 });
903
904 let heartbeat = Some((1, b"ping".to_vec()));
906
907 let config = SocketConfig {
908 url: format!("127.0.0.1:{port}"),
909 mode: Mode::Plain,
910 suffix: b"\r\n".to_vec(),
911 handler: Arc::new(create_handler()),
912 heartbeat,
913 reconnect_timeout_ms: None,
914 reconnect_delay_initial_ms: None,
915 reconnect_backoff_factor: None,
916 reconnect_delay_max_ms: None,
917 reconnect_jitter_ms: None,
918 certs_dir: None,
919 };
920
921 let client = SocketClient::connect(config, None, None, None)
922 .await
923 .unwrap();
924
925 sleep(Duration::from_secs(3)).await;
927
928 {
929 let lock = received.lock().await;
930 let pings = lock
931 .iter()
932 .filter(|line| line == &&b"ping".to_vec())
933 .count();
934 assert!(
935 pings >= 2,
936 "Expected at least 2 heartbeat pings; got {pings}"
937 );
938 }
939
940 client.close().await;
941 server_task.abort();
942 }
943
944 #[tokio::test]
945 async fn test_python_handler_error() {
946 prepare_freethreaded_python();
947
948 let (port, listener) = bind_test_server();
949 let server_task = task::spawn(async move {
950 let (socket, _) = tokio::net::TcpListener::from_std(listener)
951 .unwrap()
952 .accept()
953 .await
954 .unwrap();
955 run_echo_server(socket).await;
956 });
957
958 let code_raw = r#"
959def handler(bytes_data):
960 txt = bytes_data.decode()
961 if "ERR" in txt:
962 raise ValueError("Simulated error in handler")
963 return
964"#;
965 let code = CString::new(code_raw).unwrap();
966 let filename = CString::new("test".to_string()).unwrap();
967 let module = CString::new("test".to_string()).unwrap();
968
969 let handler = Python::with_gil(|py| {
970 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
971 let func = pymod.getattr("handler").unwrap();
972 Arc::new(func.into_py_any_unwrap(py))
973 });
974
975 let config = SocketConfig {
976 url: format!("127.0.0.1:{port}"),
977 mode: Mode::Plain,
978 suffix: b"\r\n".to_vec(),
979 handler,
980 heartbeat: None,
981 reconnect_timeout_ms: None,
982 reconnect_delay_initial_ms: None,
983 reconnect_backoff_factor: None,
984 reconnect_delay_max_ms: None,
985 reconnect_jitter_ms: None,
986 certs_dir: None,
987 };
988
989 let client = SocketClient::connect(config, None, None, None)
990 .await
991 .expect("Client connect failed unexpectedly");
992
993 client.send_bytes(b"hello").await.unwrap();
994 sleep(Duration::from_millis(100)).await;
995
996 client.send_bytes(b"ERR").await.unwrap();
997 sleep(Duration::from_secs(1)).await;
998
999 assert!(client.is_active());
1000
1001 client.close().await;
1002
1003 assert!(client.is_closed());
1004 server_task.abort();
1005 }
1006
1007 #[tokio::test]
1008 async fn test_reconnect_success() {
1009 prepare_freethreaded_python();
1010
1011 let (port, listener) = bind_test_server();
1012 let listener = tokio::net::TcpListener::from_std(listener).unwrap();
1013
1014 let server_task = task::spawn(async move {
1018 let (mut socket, _) = listener.accept().await.expect("First accept failed");
1020
1021 sleep(Duration::from_millis(500)).await;
1023 let _ = socket.shutdown().await;
1024
1025 sleep(Duration::from_millis(500)).await;
1027
1028 let (socket, _) = listener.accept().await.expect("Second accept failed");
1030 run_echo_server(socket).await;
1031 });
1032
1033 let config = SocketConfig {
1034 url: format!("127.0.0.1:{port}"),
1035 mode: Mode::Plain,
1036 suffix: b"\r\n".to_vec(),
1037 handler: Arc::new(create_handler()),
1038 heartbeat: None,
1039 reconnect_timeout_ms: Some(5_000),
1040 reconnect_delay_initial_ms: Some(500),
1041 reconnect_delay_max_ms: Some(5_000),
1042 reconnect_backoff_factor: Some(2.0),
1043 reconnect_jitter_ms: Some(50),
1044 certs_dir: None,
1045 };
1046
1047 let client = SocketClient::connect(config, None, None, None)
1048 .await
1049 .expect("Client connect failed unexpectedly");
1050
1051 assert!(client.is_active(), "Client should start as active");
1053
1054 wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1057
1058 client
1059 .send_bytes(b"TestReconnect")
1060 .await
1061 .expect("Send failed");
1062
1063 client.close().await;
1064 server_task.abort();
1065 }
1066}