1use std::{
20 path::Path,
21 sync::{
22 atomic::{AtomicU8, Ordering},
23 Arc,
24 },
25 time::Duration,
26};
27
28use nautilus_cryptography::providers::install_cryptographic_provider;
29use pyo3::prelude::*;
30use tokio::{
31 io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
32 net::TcpStream,
33 sync::Mutex,
34};
35use tokio_tungstenite::{
36 tungstenite::{client::IntoClientRequest, stream::Mode, Error},
37 MaybeTlsStream,
38};
39
40use crate::{
41 backoff::ExponentialBackoff,
42 mode::ConnectionMode,
43 tls::{create_tls_config_from_certs_dir, tcp_tls, Connector},
44};
45
46type TcpWriter = WriteHalf<MaybeTlsStream<TcpStream>>;
47type SharedTcpWriter = Arc<Mutex<WriteHalf<MaybeTlsStream<TcpStream>>>>;
48type TcpReader = ReadHalf<MaybeTlsStream<TcpStream>>;
49
50#[derive(Debug, Clone)]
52#[cfg_attr(
53 feature = "python",
54 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
55)]
56pub struct SocketConfig {
57 pub url: String,
59 pub mode: Mode,
61 pub suffix: Vec<u8>,
63 pub handler: Arc<PyObject>,
65 pub heartbeat: Option<(u64, Vec<u8>)>,
67 pub reconnect_timeout_ms: Option<u64>,
69 pub reconnect_delay_initial_ms: Option<u64>,
71 pub reconnect_delay_max_ms: Option<u64>,
73 pub reconnect_backoff_factor: Option<f64>,
75 pub reconnect_jitter_ms: Option<u64>,
77 pub certs_dir: Option<String>,
79}
80
81#[cfg_attr(
97 feature = "python",
98 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
99)]
100struct SocketClientInner {
101 config: SocketConfig,
102 connector: Option<Connector>,
103 read_task: Arc<tokio::task::JoinHandle<()>>,
104 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
105 writer: SharedTcpWriter,
106 connection_mode: Arc<AtomicU8>,
107 reconnect_timeout: Duration,
108 backoff: ExponentialBackoff,
109}
110
111impl SocketClientInner {
112 pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
113 install_cryptographic_provider();
114
115 let SocketConfig {
116 url,
117 mode,
118 heartbeat,
119 suffix,
120 handler,
121 reconnect_timeout_ms,
122 reconnect_delay_initial_ms,
123 reconnect_delay_max_ms,
124 reconnect_backoff_factor,
125 reconnect_jitter_ms,
126 certs_dir,
127 } = &config;
128 let connector = if let Some(dir) = certs_dir {
129 let config = create_tls_config_from_certs_dir(Path::new(dir))?;
130 Some(Connector::Rustls(Arc::new(config)))
131 } else {
132 None
133 };
134
135 let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector.clone()).await?;
136 let writer = Arc::new(Mutex::new(writer));
137
138 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
139
140 let handler = Python::with_gil(|py| handler.clone_ref(py));
141 let read_task = Arc::new(Self::spawn_read_task(reader, handler, suffix.clone()));
142
143 let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
145 Self::spawn_heartbeat_task(
146 connection_mode.clone(),
147 heartbeat.clone(),
148 writer.clone(),
149 suffix.clone(),
150 )
151 });
152
153 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
154 let backoff = ExponentialBackoff::new(
155 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
156 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
157 reconnect_backoff_factor.unwrap_or(1.5),
158 reconnect_jitter_ms.unwrap_or(100),
159 true, );
161
162 Ok(Self {
163 config,
164 connector,
165 read_task,
166 heartbeat_task,
167 writer,
168 connection_mode,
169 reconnect_timeout,
170 backoff,
171 })
172 }
173
174 pub async fn tls_connect_with_server(
175 url: &str,
176 mode: Mode,
177 connector: Option<Connector>,
178 ) -> Result<(TcpReader, TcpWriter), Error> {
179 tracing::debug!("Connecting to server");
180 let stream = TcpStream::connect(url).await?;
181 tracing::debug!("Making TLS connection");
182 let request = url.into_client_request()?;
183 tcp_tls(&request, mode, stream, connector)
184 .await
185 .map(tokio::io::split)
186 }
187
188 async fn reconnect(&mut self) -> Result<(), Error> {
193 tracing::debug!("Reconnecting");
194
195 tokio::time::timeout(self.reconnect_timeout, async {
196 shutdown(
198 self.read_task.clone(),
199 self.heartbeat_task.take(),
200 self.writer.clone(),
201 )
202 .await;
203
204 let SocketConfig {
205 url,
206 mode,
207 heartbeat,
208 suffix,
209 handler,
210 reconnect_timeout_ms: _,
211 reconnect_delay_initial_ms: _,
212 reconnect_backoff_factor: _,
213 reconnect_delay_max_ms: _,
214 reconnect_jitter_ms: _,
215 certs_dir: _,
216 } = &self.config;
217 let connector = self.connector.clone();
219 let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
220 let writer = Arc::new(Mutex::new(writer));
221 self.writer = writer.clone();
222
223 let handler_for_read = Python::with_gil(|py| handler.clone_ref(py));
225 self.read_task = Arc::new(Self::spawn_read_task(
226 reader,
227 handler_for_read,
228 suffix.clone(),
229 ));
230
231 self.heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
233 Self::spawn_heartbeat_task(
234 self.connection_mode.clone(),
235 heartbeat.clone(),
236 writer.clone(),
237 suffix.clone(),
238 )
239 });
240
241 self.connection_mode
242 .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
243
244 tracing::debug!("Reconnect succeeded");
245 Ok(())
246 })
247 .await
248 .map_err(|_| {
249 Error::Io(std::io::Error::new(
250 std::io::ErrorKind::TimedOut,
251 format!(
252 "reconnection timed out after {}s",
253 self.reconnect_timeout.as_secs_f64()
254 ),
255 ))
256 })?
257 }
258
259 #[inline]
266 #[must_use]
267 pub fn is_alive(&self) -> bool {
268 !self.read_task.is_finished()
269 }
270
271 #[must_use]
272 fn spawn_read_task(
273 mut reader: TcpReader,
274 handler: PyObject,
275 suffix: Vec<u8>,
276 ) -> tokio::task::JoinHandle<()> {
277 tracing::debug!("Started task 'read'");
278
279 tokio::task::spawn(async move {
280 let mut buf = Vec::new();
281
282 loop {
283 match reader.read_buf(&mut buf).await {
284 Ok(0) => {
286 tracing::debug!("Connection closed by server");
287 break;
288 }
289 Err(e) => {
290 tracing::debug!("Connection ended: {e}");
291 break;
292 }
293 Ok(bytes) => {
295 tracing::trace!("Received <binary> {bytes} bytes");
296
297 while let Some((i, _)) = &buf
300 .windows(suffix.len())
301 .enumerate()
302 .find(|(_, pair)| pair.eq(&suffix))
303 {
304 let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
305 data.truncate(data.len() - suffix.len());
306
307 if let Err(e) =
308 Python::with_gil(|py| handler.call1(py, (data.as_slice(),)))
309 {
310 tracing::error!("Call to handler failed: {e}");
311 break;
312 }
313 }
314 }
315 };
316 }
317
318 tracing::debug!("Completed task 'read'");
319 })
320 }
321
322 fn spawn_heartbeat_task(
323 connection_state: Arc<AtomicU8>,
324 heartbeat: (u64, Vec<u8>),
325 writer: SharedTcpWriter,
326 suffix: Vec<u8>,
327 ) -> tokio::task::JoinHandle<()> {
328 tracing::debug!("Started task 'heartbeat'");
329 let (interval_secs, mut message) = heartbeat;
330
331 tokio::task::spawn(async move {
332 let interval = Duration::from_secs(interval_secs);
333 message.extend(suffix);
334
335 loop {
336 tokio::time::sleep(interval).await;
337
338 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
339 ConnectionMode::Active => {
340 let mut guard = writer.lock().await;
341 match guard.write_all(&message).await {
342 Ok(()) => tracing::trace!("Sent heartbeat"),
343 Err(e) => tracing::error!("Failed to send heartbeat: {e}"),
344 }
345 }
346 ConnectionMode::Reconnect => continue,
347 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
348 }
349 }
350
351 tracing::debug!("Completed task 'heartbeat'");
352 })
353 }
354}
355
356async fn shutdown(
363 read_task: Arc<tokio::task::JoinHandle<()>>,
364 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
365 writer: SharedTcpWriter,
366) {
367 tracing::debug!("Shutting down inner client");
368
369 let timeout = Duration::from_secs(5);
370 if tokio::time::timeout(timeout, async {
371 let mut writer = writer.lock().await;
373 if let Err(e) = writer.shutdown().await {
374 tracing::error!("Error on shutdown: {e}");
375 }
376 drop(writer);
377
378 tokio::time::sleep(Duration::from_millis(100)).await;
379
380 if !read_task.is_finished() {
382 read_task.abort();
383 tracing::debug!("Aborted read task");
384 }
385 if let Some(task) = heartbeat_task {
386 if !task.is_finished() {
387 task.abort();
388 tracing::debug!("Aborted heartbeat task");
389 }
390 }
391 })
392 .await
393 .is_err()
394 {
395 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
396 }
397
398 tracing::debug!("Closed");
399}
400
401impl Drop for SocketClientInner {
402 fn drop(&mut self) {
403 if !self.read_task.is_finished() {
404 self.read_task.abort();
405 }
406
407 if let Some(ref handle) = self.heartbeat_task.take() {
409 if !handle.is_finished() {
410 handle.abort();
411 }
412 }
413 }
414}
415
416#[cfg_attr(
417 feature = "python",
418 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
419)]
420pub struct SocketClient {
421 pub(crate) writer: SharedTcpWriter,
422 pub(crate) controller_task: tokio::task::JoinHandle<()>,
423 pub(crate) connection_mode: Arc<AtomicU8>,
424 pub(crate) suffix: Vec<u8>,
425}
426
427impl SocketClient {
428 pub async fn connect(
429 config: SocketConfig,
430 post_connection: Option<PyObject>,
431 post_reconnection: Option<PyObject>,
432 post_disconnection: Option<PyObject>,
433 ) -> anyhow::Result<Self> {
434 let suffix = config.suffix.clone();
435 let inner = SocketClientInner::connect_url(config).await?;
436 let writer = inner.writer.clone();
437 let connection_mode = inner.connection_mode.clone();
438
439 let controller_task = Self::spawn_controller_task(
440 inner,
441 connection_mode.clone(),
442 post_reconnection,
443 post_disconnection,
444 );
445
446 if let Some(handler) = post_connection {
447 Python::with_gil(|py| match handler.call0(py) {
448 Ok(_) => tracing::debug!("Called `post_connection` handler"),
449 Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
450 });
451 }
452
453 Ok(Self {
454 writer,
455 controller_task,
456 connection_mode,
457 suffix,
458 })
459 }
460
461 pub fn connection_mode(&self) -> ConnectionMode {
463 ConnectionMode::from_atomic(&self.connection_mode)
464 }
465
466 #[inline]
471 #[must_use]
472 pub fn is_active(&self) -> bool {
473 self.connection_mode().is_active()
474 }
475
476 #[inline]
481 #[must_use]
482 pub fn is_reconnecting(&self) -> bool {
483 self.connection_mode().is_reconnect()
484 }
485
486 #[inline]
490 #[must_use]
491 pub fn is_disconnecting(&self) -> bool {
492 self.connection_mode().is_disconnect()
493 }
494
495 #[inline]
501 #[must_use]
502 pub fn is_closed(&self) -> bool {
503 self.connection_mode().is_closed()
504 }
505
506 pub async fn close(&self) {
511 self.connection_mode
512 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
513
514 match tokio::time::timeout(Duration::from_secs(5), async {
515 while !self.is_closed() {
516 tokio::time::sleep(Duration::from_millis(10)).await;
517 }
518
519 if !self.controller_task.is_finished() {
520 self.controller_task.abort();
521 tracing::debug!("Aborted controller task");
522 }
523 })
524 .await
525 {
526 Ok(()) => {
527 tracing::debug!("Controller task finished");
528 }
529 Err(_) => {
530 tracing::error!("Timeout waiting for controller task to finish");
531 }
532 }
533 }
534
535 pub async fn send_bytes(&self, data: &[u8]) -> Result<(), std::io::Error> {
536 if self.is_closed() {
537 return Err(std::io::Error::new(
538 std::io::ErrorKind::NotConnected,
539 "Not connected",
540 ));
541 }
542
543 let timeout = Duration::from_secs(2);
544 let check_interval = Duration::from_millis(1);
545
546 if !self.is_active() {
547 tracing::debug!("Waiting for client to become ACTIVE before sending (2s)...");
548 match tokio::time::timeout(timeout, async {
549 while !self.is_active() {
550 if matches!(
551 self.connection_mode(),
552 ConnectionMode::Disconnect | ConnectionMode::Closed
553 ) {
554 return Err("Client disconnected waiting to send");
555 }
556
557 tokio::time::sleep(check_interval).await;
558 }
559
560 Ok(())
561 })
562 .await
563 {
564 Ok(Ok(())) => tracing::debug!("Client now active"),
565 Ok(Err(e)) => {
566 tracing::error!("Cannot send data ({}): {e}", String::from_utf8_lossy(data));
567 return Ok(());
568 }
569 Err(_) => {
570 tracing::error!(
571 "Cannot send data ({}): timeout waiting to become ACTIVE",
572 String::from_utf8_lossy(data)
573 );
574 return Ok(());
575 }
576 }
577 }
578
579 let mut writer = self.writer.lock().await;
580 writer.write_all(data).await?;
581 writer.write_all(&self.suffix).await
582 }
583
584 fn spawn_controller_task(
585 mut inner: SocketClientInner,
586 connection_mode: Arc<AtomicU8>,
587 post_reconnection: Option<PyObject>,
588 post_disconnection: Option<PyObject>,
589 ) -> tokio::task::JoinHandle<()> {
590 tokio::task::spawn(async move {
591 tracing::debug!("Starting task 'controller'");
592
593 let check_interval = Duration::from_millis(10);
594
595 loop {
596 tokio::time::sleep(check_interval).await;
597 let mode = ConnectionMode::from_atomic(&connection_mode);
598
599 if mode.is_disconnect() {
600 tracing::debug!("Disconnecting");
601 shutdown(
602 inner.read_task.clone(),
603 inner.heartbeat_task.take(),
604 inner.writer.clone(),
605 )
606 .await;
607
608 if let Some(ref handler) = post_disconnection {
609 Python::with_gil(|py| match handler.call0(py) {
610 Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
611 Err(e) => {
612 tracing::error!("Error calling `post_disconnection` handler: {e}")
613 }
614 });
615 }
616 break; }
618
619 if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
620 match inner.reconnect().await {
621 Ok(()) => {
622 tracing::debug!("Reconnected successfully");
623 inner.backoff.reset();
624
625 if let Some(ref handler) = post_reconnection {
626 Python::with_gil(|py| match handler.call0(py) {
627 Ok(_) => tracing::debug!("Called `post_reconnection` handler"),
628 Err(e) => tracing::error!(
629 "Error calling `post_reconnection` handler: {e}"
630 ),
631 });
632 }
633 }
634 Err(e) => {
635 let duration = inner.backoff.next_duration();
636 tracing::warn!("Reconnect attempt failed: {e}",);
637 if !duration.is_zero() {
638 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
639 }
640 tokio::time::sleep(duration).await;
641 }
642 }
643 }
644 }
645 inner
646 .connection_mode
647 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
648 })
649 }
650}
651
652#[cfg(test)]
656#[cfg(target_os = "linux")] mod tests {
658 use std::{ffi::CString, net::TcpListener};
659
660 use nautilus_common::testing::wait_until_async;
661 use pyo3::prepare_freethreaded_python;
662 use tokio::{
663 io::{AsyncReadExt, AsyncWriteExt},
664 net::TcpStream,
665 task,
666 time::{sleep, Duration},
667 };
668
669 use super::*;
670
671 fn create_handler() -> PyObject {
672 let code_raw = r#"
673class Counter:
674 def __init__(self):
675 self.count = 0
676 self.check = False
677
678 def handler(self, bytes):
679 msg = bytes.decode()
680 if msg == 'ping':
681 self.count += 1
682 elif msg == 'heartbeat message':
683 self.check = True
684
685 def get_check(self):
686 return self.check
687
688 def get_count(self):
689 return self.count
690
691counter = Counter()
692"#;
693 let code = CString::new(code_raw).unwrap();
694 let filename = CString::new("test".to_string()).unwrap();
695 let module = CString::new("test".to_string()).unwrap();
696 Python::with_gil(|py| {
697 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
698 let counter = pymod.getattr("counter").unwrap().into_py(py);
699 let handler = counter.getattr(py, "handler").unwrap().into_py(py);
700 handler
701 })
702 }
703
704 fn bind_test_server() -> (u16, TcpListener) {
705 let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind ephemeral port");
706 let port = listener.local_addr().unwrap().port();
707 (port, listener)
708 }
709
710 async fn run_echo_server(mut socket: TcpStream) {
711 let mut buf = Vec::new();
712 loop {
713 match socket.read_buf(&mut buf).await {
714 Ok(0) => {
715 break;
716 }
717 Ok(_n) => {
718 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
719 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
720 line.truncate(line.len() - 2);
722
723 if line == b"close" {
724 let _ = socket.shutdown().await;
725 return;
726 }
727
728 let mut echo_data = line;
729 echo_data.extend_from_slice(b"\r\n");
730 if socket.write_all(&echo_data).await.is_err() {
731 break;
732 }
733 }
734 }
735 Err(e) => {
736 eprintln!("Server read error: {e}");
737 break;
738 }
739 }
740 }
741 }
742
743 #[tokio::test]
744 async fn test_basic_send_receive() {
745 prepare_freethreaded_python();
746
747 let (port, listener) = bind_test_server();
748 let server_task = task::spawn(async move {
749 let (socket, _) = tokio::net::TcpListener::from_std(listener)
750 .unwrap()
751 .accept()
752 .await
753 .unwrap();
754 run_echo_server(socket).await;
755 });
756
757 let config = SocketConfig {
758 url: format!("127.0.0.1:{port}"),
759 mode: Mode::Plain,
760 suffix: b"\r\n".to_vec(),
761 handler: Arc::new(create_handler()),
762 heartbeat: None,
763 reconnect_timeout_ms: None,
764 reconnect_delay_initial_ms: None,
765 reconnect_backoff_factor: None,
766 reconnect_delay_max_ms: None,
767 reconnect_jitter_ms: None,
768 certs_dir: None,
769 };
770
771 let client = SocketClient::connect(config, None, None, None)
772 .await
773 .expect("Client connect failed unexpectedly");
774
775 client.send_bytes(b"Hello").await.unwrap();
776 client.send_bytes(b"World").await.unwrap();
777
778 sleep(Duration::from_millis(100)).await;
780
781 client.send_bytes(b"close").await.unwrap();
782 server_task.await.unwrap();
783 assert!(!client.is_closed());
784 }
785
786 #[tokio::test]
787 async fn test_reconnect_fail_exhausted() {
788 prepare_freethreaded_python();
789
790 let (port, listener) = bind_test_server();
791 drop(listener); let config = SocketConfig {
794 url: format!("127.0.0.1:{port}"),
795 mode: Mode::Plain,
796 suffix: b"\r\n".to_vec(),
797 handler: Arc::new(create_handler()),
798 heartbeat: None,
799 reconnect_timeout_ms: None,
800 reconnect_delay_initial_ms: None,
801 reconnect_backoff_factor: None,
802 reconnect_delay_max_ms: None,
803 reconnect_jitter_ms: None,
804 certs_dir: None,
805 };
806
807 let client_res = SocketClient::connect(config, None, None, None).await;
808 assert!(
809 client_res.is_err(),
810 "Should fail quickly with no server listening"
811 );
812 }
813
814 #[tokio::test]
815 async fn test_user_disconnect() {
816 prepare_freethreaded_python();
817
818 let (port, listener) = bind_test_server();
819 let server_task = task::spawn(async move {
820 let (socket, _) = tokio::net::TcpListener::from_std(listener)
821 .unwrap()
822 .accept()
823 .await
824 .unwrap();
825 let mut buf = [0u8; 1024];
826 let _ = socket.try_read(&mut buf);
827
828 loop {
829 sleep(Duration::from_secs(1)).await;
830 }
831 });
832
833 let config = SocketConfig {
834 url: format!("127.0.0.1:{port}"),
835 mode: Mode::Plain,
836 suffix: b"\r\n".to_vec(),
837 handler: Arc::new(create_handler()),
838 heartbeat: None,
839 reconnect_timeout_ms: None,
840 reconnect_delay_initial_ms: None,
841 reconnect_backoff_factor: None,
842 reconnect_delay_max_ms: None,
843 reconnect_jitter_ms: None,
844 certs_dir: None,
845 };
846
847 let client = SocketClient::connect(config, None, None, None)
848 .await
849 .unwrap();
850
851 client.close().await;
852 assert!(client.is_closed());
853 server_task.abort();
854 }
855
856 #[tokio::test]
857 async fn test_heartbeat() {
858 prepare_freethreaded_python();
859
860 let (port, listener) = bind_test_server();
861 let received = Arc::new(Mutex::new(Vec::new()));
862 let received2 = received.clone();
863
864 let server_task = task::spawn(async move {
865 let (socket, _) = tokio::net::TcpListener::from_std(listener)
866 .unwrap()
867 .accept()
868 .await
869 .unwrap();
870
871 let mut buf = Vec::new();
872 loop {
873 match socket.try_read_buf(&mut buf) {
874 Ok(0) => break,
875 Ok(_) => {
876 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
877 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
878 line.truncate(line.len() - 2);
879 received2.lock().await.push(line);
880 }
881 }
882 Err(_) => {
883 tokio::time::sleep(Duration::from_millis(10)).await;
884 }
885 }
886 }
887 });
888
889 let heartbeat = Some((1, b"ping".to_vec()));
891
892 let config = SocketConfig {
893 url: format!("127.0.0.1:{port}"),
894 mode: Mode::Plain,
895 suffix: b"\r\n".to_vec(),
896 handler: Arc::new(create_handler().into()),
897 heartbeat,
898 reconnect_timeout_ms: None,
899 reconnect_delay_initial_ms: None,
900 reconnect_backoff_factor: None,
901 reconnect_delay_max_ms: None,
902 reconnect_jitter_ms: None,
903 certs_dir: None,
904 };
905
906 let client = SocketClient::connect(config, None, None, None)
907 .await
908 .unwrap();
909
910 sleep(Duration::from_secs(3)).await;
912
913 {
914 let lock = received.lock().await;
915 let pings = lock
916 .iter()
917 .filter(|line| line == &&b"ping".to_vec())
918 .count();
919 assert!(
920 pings >= 2,
921 "Expected at least 2 heartbeat pings; got {pings}"
922 );
923 }
924
925 client.close().await;
926 server_task.abort();
927 }
928
929 #[tokio::test]
930 async fn test_python_handler_error() {
931 prepare_freethreaded_python();
932
933 let (port, listener) = bind_test_server();
934 let server_task = task::spawn(async move {
935 let (socket, _) = tokio::net::TcpListener::from_std(listener)
936 .unwrap()
937 .accept()
938 .await
939 .unwrap();
940 run_echo_server(socket).await;
941 });
942
943 let code_raw = r#"
944def handler(bytes_data):
945 txt = bytes_data.decode()
946 if "ERR" in txt:
947 raise ValueError("Simulated error in handler")
948 return
949"#;
950 let code = CString::new(code_raw).unwrap();
951 let filename = CString::new("test".to_string()).unwrap();
952 let module = CString::new("test".to_string()).unwrap();
953
954 let handler = Python::with_gil(|py| {
955 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
956 let func = pymod.getattr("handler").unwrap();
957 Arc::new(func.into_py(py))
958 });
959
960 let config = SocketConfig {
961 url: format!("127.0.0.1:{port}"),
962 mode: Mode::Plain,
963 suffix: b"\r\n".to_vec(),
964 handler,
965 heartbeat: None,
966 reconnect_timeout_ms: None,
967 reconnect_delay_initial_ms: None,
968 reconnect_backoff_factor: None,
969 reconnect_delay_max_ms: None,
970 reconnect_jitter_ms: None,
971 certs_dir: None,
972 };
973
974 let client = SocketClient::connect(config, None, None, None)
975 .await
976 .expect("Client connect failed unexpectedly");
977
978 client.send_bytes(b"hello").await.unwrap();
979 sleep(Duration::from_millis(100)).await;
980
981 client.send_bytes(b"ERR").await.unwrap();
982 sleep(Duration::from_secs(1)).await;
983
984 assert!(client.is_active());
985
986 client.close().await;
987
988 assert!(client.is_closed());
989 server_task.abort();
990 }
991
992 #[tokio::test]
993 async fn test_reconnect_success() {
994 prepare_freethreaded_python();
995
996 let (port, listener) = bind_test_server();
997 let listener = tokio::net::TcpListener::from_std(listener).unwrap();
998
999 let server_task = task::spawn(async move {
1003 let (mut socket, _) = listener.accept().await.expect("First accept failed");
1005
1006 sleep(Duration::from_millis(500)).await;
1008 let _ = socket.shutdown().await;
1009
1010 sleep(Duration::from_millis(500)).await;
1012
1013 let (socket, _) = listener.accept().await.expect("Second accept failed");
1015 run_echo_server(socket).await;
1016 });
1017
1018 let config = SocketConfig {
1019 url: format!("127.0.0.1:{port}"),
1020 mode: Mode::Plain,
1021 suffix: b"\r\n".to_vec(),
1022 handler: Arc::new(create_handler()),
1023 heartbeat: None,
1024 reconnect_timeout_ms: Some(5_000),
1025 reconnect_delay_initial_ms: Some(500),
1026 reconnect_delay_max_ms: Some(5_000),
1027 reconnect_backoff_factor: Some(2.0),
1028 reconnect_jitter_ms: Some(50),
1029 certs_dir: None,
1030 };
1031
1032 let client = SocketClient::connect(config, None, None, None)
1033 .await
1034 .expect("Client connect failed unexpectedly");
1035
1036 assert!(client.is_active(), "Client should start as active");
1038
1039 wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1042
1043 client
1044 .send_bytes(b"TestReconnect")
1045 .await
1046 .expect("Send failed");
1047
1048 client.close().await;
1049 server_task.abort();
1050 }
1051}