1use std::{
32 path::Path,
33 sync::{
34 Arc,
35 atomic::{AtomicU8, Ordering},
36 },
37 time::Duration,
38};
39
40use bytes::Bytes;
41use nautilus_cryptography::providers::install_cryptographic_provider;
42use pyo3::prelude::*;
43use tokio::{
44 io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
45 net::TcpStream,
46};
47use tokio_tungstenite::{
48 MaybeTlsStream,
49 tungstenite::{Error, client::IntoClientRequest, stream::Mode},
50};
51
52use crate::{
53 backoff::ExponentialBackoff,
54 fix::process_fix_buffer,
55 mode::ConnectionMode,
56 tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
57};
58
59type TcpWriter = WriteHalf<MaybeTlsStream<TcpStream>>;
60type TcpReader = ReadHalf<MaybeTlsStream<TcpStream>>;
61pub type TcpMessageHandler = dyn Fn(&[u8]) + Send + Sync;
62
63#[derive(Debug, Clone)]
65#[cfg_attr(
66 feature = "python",
67 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
68)]
69pub struct SocketConfig {
70 pub url: String,
72 pub mode: Mode,
74 pub suffix: Vec<u8>,
76 pub py_handler: Option<Arc<PyObject>>,
78 pub heartbeat: Option<(u64, Vec<u8>)>,
80 pub reconnect_timeout_ms: Option<u64>,
82 pub reconnect_delay_initial_ms: Option<u64>,
84 pub reconnect_delay_max_ms: Option<u64>,
86 pub reconnect_backoff_factor: Option<f64>,
88 pub reconnect_jitter_ms: Option<u64>,
90 pub certs_dir: Option<String>,
92}
93
94#[derive(Debug)]
96pub enum WriterCommand {
97 Update(TcpWriter),
99 Send(Bytes),
101}
102
103#[cfg_attr(
119 feature = "python",
120 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
121)]
122struct SocketClientInner {
123 config: SocketConfig,
124 connector: Option<Connector>,
125 read_task: Arc<tokio::task::JoinHandle<()>>,
126 write_task: tokio::task::JoinHandle<()>,
127 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
128 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
129 connection_mode: Arc<AtomicU8>,
130 reconnect_timeout: Duration,
131 backoff: ExponentialBackoff,
132 handler: Option<Arc<TcpMessageHandler>>,
133}
134
135impl SocketClientInner {
136 pub async fn connect_url(
137 config: SocketConfig,
138 handler: Option<Arc<TcpMessageHandler>>,
139 ) -> anyhow::Result<Self> {
140 install_cryptographic_provider();
141
142 let SocketConfig {
143 url,
144 mode,
145 heartbeat,
146 suffix,
147 py_handler,
148 reconnect_timeout_ms,
149 reconnect_delay_initial_ms,
150 reconnect_delay_max_ms,
151 reconnect_backoff_factor,
152 reconnect_jitter_ms,
153 certs_dir,
154 } = &config;
155 let connector = if let Some(dir) = certs_dir {
156 let config = create_tls_config_from_certs_dir(Path::new(dir))?;
157 Some(Connector::Rustls(Arc::new(config)))
158 } else {
159 None
160 };
161
162 let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector.clone()).await?;
163 tracing::debug!("Connected");
164
165 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
166
167 let read_task = Arc::new(Self::spawn_read_task(
168 connection_mode.clone(),
169 reader,
170 handler.clone(),
171 py_handler.clone(),
172 suffix.clone(),
173 ));
174
175 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
176
177 let write_task =
178 Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
179
180 let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
182 Self::spawn_heartbeat_task(
183 connection_mode.clone(),
184 heartbeat.clone(),
185 writer_tx.clone(),
186 )
187 });
188
189 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
190 let backoff = ExponentialBackoff::new(
191 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
192 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
193 reconnect_backoff_factor.unwrap_or(1.5),
194 reconnect_jitter_ms.unwrap_or(100),
195 true, );
197
198 Ok(Self {
199 config,
200 connector,
201 read_task,
202 write_task,
203 writer_tx,
204 heartbeat_task,
205 connection_mode,
206 reconnect_timeout,
207 backoff,
208 handler,
209 })
210 }
211
212 pub async fn tls_connect_with_server(
213 url: &str,
214 mode: Mode,
215 connector: Option<Connector>,
216 ) -> Result<(TcpReader, TcpWriter), Error> {
217 tracing::debug!("Connecting to {url}");
218 let tcp_result = TcpStream::connect(url).await;
219
220 match tcp_result {
221 Ok(stream) => {
222 tracing::debug!("TCP connection established, proceeding with TLS");
223 let request = url.into_client_request()?;
224 tcp_tls(&request, mode, stream, connector)
225 .await
226 .map(tokio::io::split)
227 }
228 Err(e) => {
229 tracing::error!("TCP connection failed: {e:?}");
230 Err(Error::Io(e))
231 }
232 }
233 }
234
235 async fn reconnect(&mut self) -> Result<(), Error> {
240 tracing::debug!("Reconnecting");
241
242 tokio::time::timeout(self.reconnect_timeout, async {
243 let SocketConfig {
244 url,
245 mode,
246 heartbeat: _,
247 suffix,
248 py_handler,
249 reconnect_timeout_ms: _,
250 reconnect_delay_initial_ms: _,
251 reconnect_backoff_factor: _,
252 reconnect_delay_max_ms: _,
253 reconnect_jitter_ms: _,
254 certs_dir: _,
255 } = &self.config;
256 let connector = self.connector.clone();
258 let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
259 tracing::debug!("Connected");
260
261 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
262 tracing::error!("{e}");
263 }
264
265 tokio::time::sleep(Duration::from_millis(100)).await;
267
268 if !self.read_task.is_finished() {
269 self.read_task.abort();
270 tracing::debug!("Aborted task 'read'");
271 }
272
273 self.connection_mode
274 .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
275
276 self.read_task = Arc::new(Self::spawn_read_task(
278 self.connection_mode.clone(),
279 reader,
280 self.handler.clone(),
281 py_handler.clone(),
282 suffix.clone(),
283 ));
284
285 tracing::debug!("Reconnect succeeded");
286 Ok(())
287 })
288 .await
289 .map_err(|_| {
290 Error::Io(std::io::Error::new(
291 std::io::ErrorKind::TimedOut,
292 format!(
293 "reconnection timed out after {}s",
294 self.reconnect_timeout.as_secs_f64()
295 ),
296 ))
297 })?
298 }
299
300 #[inline]
307 #[must_use]
308 pub fn is_alive(&self) -> bool {
309 !self.read_task.is_finished()
310 }
311
312 #[must_use]
313 fn spawn_read_task(
314 connection_state: Arc<AtomicU8>,
315 mut reader: TcpReader,
316 handler: Option<Arc<TcpMessageHandler>>,
317 py_handler: Option<Arc<PyObject>>,
318 suffix: Vec<u8>,
319 ) -> tokio::task::JoinHandle<()> {
320 tracing::debug!("Started task 'read'");
321
322 let check_interval = Duration::from_millis(10);
324
325 tokio::task::spawn(async move {
326 let mut buf = Vec::new();
327
328 loop {
329 if !ConnectionMode::from_atomic(&connection_state).is_active() {
330 break;
331 }
332
333 match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
334 Ok(Ok(0)) => {
336 tracing::debug!("Connection closed by server");
337 break;
338 }
339 Ok(Err(e)) => {
340 tracing::debug!("Connection ended: {e}");
341 break;
342 }
343 Ok(Ok(bytes)) => {
345 tracing::trace!("Received <binary> {bytes} bytes");
346
347 if let Some(handler) = &handler {
348 process_fix_buffer(&mut buf, handler);
349 } else {
350 while let Some((i, _)) = &buf
351 .windows(suffix.len())
352 .enumerate()
353 .find(|(_, pair)| pair.eq(&suffix))
354 {
355 let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
356 data.truncate(data.len() - suffix.len());
357
358 if let Some(handler) = &handler {
359 handler(&data);
360 }
361
362 if let Some(py_handler) = &py_handler {
363 if let Err(e) = Python::with_gil(|py| {
364 py_handler.call1(py, (data.as_slice(),))
365 }) {
366 tracing::error!("Call to handler failed: {e}");
367 break;
368 }
369 }
370 }
371 }
372 }
373 Err(_) => {
374 continue;
376 }
377 }
378 }
379
380 tracing::debug!("Completed task 'read'");
381 })
382 }
383
384 fn spawn_write_task(
385 connection_state: Arc<AtomicU8>,
386 writer: TcpWriter,
387 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
388 suffix: Vec<u8>,
389 ) -> tokio::task::JoinHandle<()> {
390 tracing::debug!("Started task 'write'");
391
392 let check_interval = Duration::from_millis(10);
394
395 tokio::task::spawn(async move {
396 let mut active_writer = writer;
397
398 loop {
399 if matches!(
400 ConnectionMode::from_atomic(&connection_state),
401 ConnectionMode::Disconnect | ConnectionMode::Closed
402 ) {
403 break;
404 }
405
406 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
407 Ok(Some(msg)) => {
408 let mode = ConnectionMode::from_atomic(&connection_state);
410 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
411 break;
412 }
413
414 match msg {
415 WriterCommand::Update(new_writer) => {
416 tracing::debug!("Received new writer");
417
418 tokio::time::sleep(Duration::from_millis(100)).await;
420
421 _ = active_writer.shutdown().await;
424
425 active_writer = new_writer;
426 tracing::debug!("Updated writer");
427 }
428 _ if mode.is_reconnect() => {
429 tracing::warn!("Skipping message while reconnecting, {msg:?}");
430 continue;
431 }
432 WriterCommand::Send(msg) => {
433 if let Err(e) = active_writer.write_all(&msg).await {
434 tracing::error!("Failed to send message: {e}");
435 tracing::warn!("Writer triggering reconnect");
437 connection_state
438 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
439 continue;
440 }
441 if let Err(e) = active_writer.write_all(&suffix).await {
442 tracing::error!("Failed to send message: {e}");
443 }
444 }
445 }
446 }
447 Ok(None) => {
448 tracing::debug!("Writer channel closed, terminating writer task");
450 break;
451 }
452 Err(_) => {
453 continue;
455 }
456 }
457 }
458
459 _ = active_writer.shutdown().await;
462
463 tracing::debug!("Completed task 'write'");
464 })
465 }
466
467 fn spawn_heartbeat_task(
468 connection_state: Arc<AtomicU8>,
469 heartbeat: (u64, Vec<u8>),
470 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
471 ) -> tokio::task::JoinHandle<()> {
472 tracing::debug!("Started task 'heartbeat'");
473 let (interval_secs, message) = heartbeat;
474
475 tokio::task::spawn(async move {
476 let interval = Duration::from_secs(interval_secs);
477
478 loop {
479 tokio::time::sleep(interval).await;
480
481 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
482 ConnectionMode::Active => {
483 let msg = WriterCommand::Send(message.clone().into());
484
485 match writer_tx.send(msg) {
486 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
487 Err(e) => {
488 tracing::error!("Failed to send heartbeat to writer task: {e}");
489 }
490 }
491 }
492 ConnectionMode::Reconnect => continue,
493 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
494 }
495 }
496
497 tracing::debug!("Completed task 'heartbeat'");
498 })
499 }
500}
501
502impl Drop for SocketClientInner {
503 fn drop(&mut self) {
504 if !self.read_task.is_finished() {
505 self.read_task.abort();
506 tracing::debug!("Aborted task 'read'");
507 }
508
509 if !self.write_task.is_finished() {
510 self.write_task.abort();
511 tracing::debug!("Aborted task 'write'");
512 }
513
514 if let Some(ref handle) = self.heartbeat_task.take() {
515 if !handle.is_finished() {
516 handle.abort();
517 tracing::debug!("Aborted task 'heartbeat'");
518 }
519 }
520 }
521}
522
523#[cfg_attr(
524 feature = "python",
525 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
526)]
527pub struct SocketClient {
528 pub(crate) controller_task: tokio::task::JoinHandle<()>,
529 pub(crate) connection_mode: Arc<AtomicU8>,
530 pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
531}
532
533impl SocketClient {
534 pub async fn connect(
540 config: SocketConfig,
541 handler: Option<Arc<TcpMessageHandler>>,
542 post_connection: Option<PyObject>,
543 post_reconnection: Option<PyObject>,
544 post_disconnection: Option<PyObject>,
545 ) -> anyhow::Result<Self> {
546 let inner = SocketClientInner::connect_url(config, handler).await?;
547 let writer_tx = inner.writer_tx.clone();
548 let connection_mode = inner.connection_mode.clone();
549
550 let controller_task = Self::spawn_controller_task(
551 inner,
552 connection_mode.clone(),
553 post_reconnection,
554 post_disconnection,
555 );
556
557 if let Some(handler) = post_connection {
558 Python::with_gil(|py| match handler.call0(py) {
559 Ok(_) => tracing::debug!("Called `post_connection` handler"),
560 Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
561 });
562 }
563
564 Ok(Self {
565 controller_task,
566 connection_mode,
567 writer_tx,
568 })
569 }
570
571 #[must_use]
573 pub fn connection_mode(&self) -> ConnectionMode {
574 ConnectionMode::from_atomic(&self.connection_mode)
575 }
576
577 #[inline]
582 #[must_use]
583 pub fn is_active(&self) -> bool {
584 self.connection_mode().is_active()
585 }
586
587 #[inline]
592 #[must_use]
593 pub fn is_reconnecting(&self) -> bool {
594 self.connection_mode().is_reconnect()
595 }
596
597 #[inline]
601 #[must_use]
602 pub fn is_disconnecting(&self) -> bool {
603 self.connection_mode().is_disconnect()
604 }
605
606 #[inline]
612 #[must_use]
613 pub fn is_closed(&self) -> bool {
614 self.connection_mode().is_closed()
615 }
616
617 pub async fn close(&self) {
622 self.connection_mode
623 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
624
625 match tokio::time::timeout(Duration::from_secs(5), async {
626 while !self.is_closed() {
627 tokio::time::sleep(Duration::from_millis(10)).await;
628 }
629
630 if !self.controller_task.is_finished() {
631 self.controller_task.abort();
632 tracing::debug!("Aborted controller task");
633 }
634 })
635 .await
636 {
637 Ok(()) => {
638 tracing::debug!("Controller task finished");
639 }
640 Err(_) => {
641 tracing::error!("Timeout waiting for controller task to finish");
642 }
643 }
644 }
645
646 pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), std::io::Error> {
652 if self.is_closed() {
653 return Err(std::io::Error::new(
654 std::io::ErrorKind::NotConnected,
655 "Not connected",
656 ));
657 }
658
659 let timeout = Duration::from_secs(2);
660 let check_interval = Duration::from_millis(1);
661
662 if !self.is_active() {
663 tracing::debug!("Waiting for client to become ACTIVE before sending (2s)...");
664 match tokio::time::timeout(timeout, async {
665 while !self.is_active() {
666 if matches!(
667 self.connection_mode(),
668 ConnectionMode::Disconnect | ConnectionMode::Closed
669 ) {
670 return Err("Client disconnected waiting to send");
671 }
672
673 tokio::time::sleep(check_interval).await;
674 }
675
676 Ok(())
677 })
678 .await
679 {
680 Ok(Ok(())) => tracing::debug!("Client now active"),
681 Ok(Err(e)) => {
682 tracing::error!(
683 "Failed to send data ({}): {e}",
684 String::from_utf8_lossy(&data)
685 );
686 return Ok(());
687 }
688 Err(_) => {
689 tracing::error!(
690 "Failed to send data ({}): timeout waiting to become ACTIVE",
691 String::from_utf8_lossy(&data)
692 );
693 return Ok(());
694 }
695 }
696 }
697
698 let msg = WriterCommand::Send(data.into());
699 if let Err(e) = self.writer_tx.send(msg) {
700 tracing::error!("{e}");
701 }
702 Ok(())
703 }
704
705 fn spawn_controller_task(
706 mut inner: SocketClientInner,
707 connection_mode: Arc<AtomicU8>,
708 post_reconnection: Option<PyObject>,
709 post_disconnection: Option<PyObject>,
710 ) -> tokio::task::JoinHandle<()> {
711 tokio::task::spawn(async move {
712 tracing::debug!("Started task 'controller'");
713
714 let check_interval = Duration::from_millis(10);
715
716 loop {
717 tokio::time::sleep(check_interval).await;
718 let mode = ConnectionMode::from_atomic(&connection_mode);
719
720 if mode.is_disconnect() {
721 tracing::debug!("Disconnecting");
722
723 let timeout = Duration::from_secs(5);
724 if tokio::time::timeout(timeout, async {
725 tokio::time::sleep(Duration::from_millis(100)).await;
727
728 if !inner.read_task.is_finished() {
729 inner.read_task.abort();
730 tracing::debug!("Aborted task 'read'");
731 }
732
733 if let Some(task) = &inner.heartbeat_task {
734 if !task.is_finished() {
735 task.abort();
736 tracing::debug!("Aborted task 'heartbeat'");
737 }
738 }
739 })
740 .await
741 .is_err()
742 {
743 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
744 }
745
746 tracing::debug!("Closed");
747
748 if let Some(ref handler) = post_disconnection {
749 Python::with_gil(|py| match handler.call0(py) {
750 Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
751 Err(e) => {
752 tracing::error!("Error calling `post_disconnection` handler: {e}");
753 }
754 });
755 }
756 break; }
758
759 if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
760 match inner.reconnect().await {
761 Ok(()) => {
762 tracing::debug!("Reconnected successfully");
763 inner.backoff.reset();
764
765 if let Some(ref handler) = post_reconnection {
766 Python::with_gil(|py| match handler.call0(py) {
767 Ok(_) => {
768 tracing::debug!("Called `post_reconnection` handler");
769 }
770 Err(e) => tracing::error!(
771 "Error calling `post_reconnection` handler: {e}"
772 ),
773 });
774 }
775 }
776 Err(e) => {
777 let duration = inner.backoff.next_duration();
778 tracing::warn!("Reconnect attempt failed: {e}");
779 if !duration.is_zero() {
780 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
781 }
782 tokio::time::sleep(duration).await;
783 }
784 }
785 }
786 }
787 inner
788 .connection_mode
789 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
790
791 tracing::debug!("Completed task 'controller'");
792 })
793 }
794}
795
796#[cfg(test)]
800#[cfg(target_os = "linux")] mod tests {
802 use std::ffi::CString;
803
804 use nautilus_common::testing::wait_until_async;
805 use nautilus_core::python::IntoPyObjectNautilusExt;
806 use pyo3::prepare_freethreaded_python;
807 use tokio::{
808 io::{AsyncReadExt, AsyncWriteExt},
809 net::{TcpListener, TcpStream},
810 sync::Mutex,
811 task,
812 time::{Duration, sleep},
813 };
814
815 use super::*;
816
817 fn create_handler() -> PyObject {
818 let code_raw = r"
819class Counter:
820 def __init__(self):
821 self.count = 0
822 self.check = False
823
824 def handler(self, bytes):
825 msg = bytes.decode()
826 if msg == 'ping':
827 self.count += 1
828 elif msg == 'heartbeat message':
829 self.check = True
830
831 def get_check(self):
832 return self.check
833
834 def get_count(self):
835 return self.count
836
837counter = Counter()
838";
839 let code = CString::new(code_raw).unwrap();
840 let filename = CString::new("test".to_string()).unwrap();
841 let module = CString::new("test".to_string()).unwrap();
842 Python::with_gil(|py| {
843 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
844 let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
845
846 counter
847 .getattr(py, "handler")
848 .unwrap()
849 .into_py_any_unwrap(py)
850 })
851 }
852
853 async fn bind_test_server() -> (u16, TcpListener) {
854 let listener = TcpListener::bind("127.0.0.1:0")
855 .await
856 .expect("Failed to bind ephemeral port");
857 let port = listener.local_addr().unwrap().port();
858 (port, listener)
859 }
860
861 async fn run_echo_server(mut socket: TcpStream) {
862 let mut buf = Vec::new();
863 loop {
864 match socket.read_buf(&mut buf).await {
865 Ok(0) => {
866 break;
867 }
868 Ok(_n) => {
869 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
870 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
871 line.truncate(line.len() - 2);
873
874 if line == b"close" {
875 let _ = socket.shutdown().await;
876 return;
877 }
878
879 let mut echo_data = line;
880 echo_data.extend_from_slice(b"\r\n");
881 if socket.write_all(&echo_data).await.is_err() {
882 break;
883 }
884 }
885 }
886 Err(e) => {
887 eprintln!("Server read error: {e}");
888 break;
889 }
890 }
891 }
892 }
893
894 #[tokio::test]
895 async fn test_basic_send_receive() {
896 prepare_freethreaded_python();
897
898 let (port, listener) = bind_test_server().await;
899 let server_task = task::spawn(async move {
900 let (socket, _) = listener.accept().await.unwrap();
901 run_echo_server(socket).await;
902 });
903
904 let config = SocketConfig {
905 url: format!("127.0.0.1:{port}"),
906 mode: Mode::Plain,
907 suffix: b"\r\n".to_vec(),
908 py_handler: Some(Arc::new(create_handler())),
909 heartbeat: None,
910 reconnect_timeout_ms: None,
911 reconnect_delay_initial_ms: None,
912 reconnect_backoff_factor: None,
913 reconnect_delay_max_ms: None,
914 reconnect_jitter_ms: None,
915 certs_dir: None,
916 };
917
918 let client = SocketClient::connect(config, None, None, None, None)
919 .await
920 .expect("Client connect failed unexpectedly");
921
922 client.send_bytes(b"Hello".into()).await.unwrap();
923 client.send_bytes(b"World".into()).await.unwrap();
924
925 sleep(Duration::from_millis(100)).await;
927
928 client.send_bytes(b"close".into()).await.unwrap();
929 server_task.await.unwrap();
930 assert!(!client.is_closed());
931 }
932
933 #[tokio::test]
934 async fn test_reconnect_fail_exhausted() {
935 prepare_freethreaded_python();
936
937 let (port, listener) = bind_test_server().await;
938 drop(listener); let config = SocketConfig {
941 url: format!("127.0.0.1:{port}"),
942 mode: Mode::Plain,
943 suffix: b"\r\n".to_vec(),
944 py_handler: Some(Arc::new(create_handler())),
945 heartbeat: None,
946 reconnect_timeout_ms: None,
947 reconnect_delay_initial_ms: None,
948 reconnect_backoff_factor: None,
949 reconnect_delay_max_ms: None,
950 reconnect_jitter_ms: None,
951 certs_dir: None,
952 };
953
954 let client_res = SocketClient::connect(config, None, None, None, None).await;
955 assert!(
956 client_res.is_err(),
957 "Should fail quickly with no server listening"
958 );
959 }
960
961 #[tokio::test]
962 async fn test_user_disconnect() {
963 prepare_freethreaded_python();
964
965 let (port, listener) = bind_test_server().await;
966 let server_task = task::spawn(async move {
967 let (socket, _) = listener.accept().await.unwrap();
968 let mut buf = [0u8; 1024];
969 let _ = socket.try_read(&mut buf);
970
971 loop {
972 sleep(Duration::from_secs(1)).await;
973 }
974 });
975
976 let config = SocketConfig {
977 url: format!("127.0.0.1:{port}"),
978 mode: Mode::Plain,
979 suffix: b"\r\n".to_vec(),
980 py_handler: Some(Arc::new(create_handler())),
981 heartbeat: None,
982 reconnect_timeout_ms: None,
983 reconnect_delay_initial_ms: None,
984 reconnect_backoff_factor: None,
985 reconnect_delay_max_ms: None,
986 reconnect_jitter_ms: None,
987 certs_dir: None,
988 };
989
990 let client = SocketClient::connect(config, None, None, None, None)
991 .await
992 .unwrap();
993
994 client.close().await;
995 assert!(client.is_closed());
996 server_task.abort();
997 }
998
999 #[tokio::test]
1000 async fn test_heartbeat() {
1001 prepare_freethreaded_python();
1002
1003 let (port, listener) = bind_test_server().await;
1004 let received = Arc::new(Mutex::new(Vec::new()));
1005 let received2 = received.clone();
1006
1007 let server_task = task::spawn(async move {
1008 let (socket, _) = listener.accept().await.unwrap();
1009
1010 let mut buf = Vec::new();
1011 loop {
1012 match socket.try_read_buf(&mut buf) {
1013 Ok(0) => break,
1014 Ok(_) => {
1015 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1016 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1017 line.truncate(line.len() - 2);
1018 received2.lock().await.push(line);
1019 }
1020 }
1021 Err(_) => {
1022 tokio::time::sleep(Duration::from_millis(10)).await;
1023 }
1024 }
1025 }
1026 });
1027
1028 let heartbeat = Some((1, b"ping".to_vec()));
1030
1031 let config = SocketConfig {
1032 url: format!("127.0.0.1:{port}"),
1033 mode: Mode::Plain,
1034 suffix: b"\r\n".to_vec(),
1035 py_handler: Some(Arc::new(create_handler())),
1036 heartbeat,
1037 reconnect_timeout_ms: None,
1038 reconnect_delay_initial_ms: None,
1039 reconnect_backoff_factor: None,
1040 reconnect_delay_max_ms: None,
1041 reconnect_jitter_ms: None,
1042 certs_dir: None,
1043 };
1044
1045 let client = SocketClient::connect(config, None, None, None, None)
1046 .await
1047 .unwrap();
1048
1049 sleep(Duration::from_secs(3)).await;
1051
1052 {
1053 let lock = received.lock().await;
1054 let pings = lock
1055 .iter()
1056 .filter(|line| line == &&b"ping".to_vec())
1057 .count();
1058 assert!(
1059 pings >= 2,
1060 "Expected at least 2 heartbeat pings; got {pings}"
1061 );
1062 }
1063
1064 client.close().await;
1065 server_task.abort();
1066 }
1067
1068 #[tokio::test]
1069 async fn test_python_handler_error() {
1070 prepare_freethreaded_python();
1071
1072 let (port, listener) = bind_test_server().await;
1073 let server_task = task::spawn(async move {
1074 let (socket, _) = listener.accept().await.unwrap();
1075 run_echo_server(socket).await;
1076 });
1077
1078 let code_raw = r#"
1079def handler(bytes_data):
1080 txt = bytes_data.decode()
1081 if "ERR" in txt:
1082 raise ValueError("Simulated error in handler")
1083 return
1084"#;
1085 let code = CString::new(code_raw).unwrap();
1086 let filename = CString::new("test".to_string()).unwrap();
1087 let module = CString::new("test".to_string()).unwrap();
1088
1089 let py_handler = Some(Python::with_gil(|py| {
1090 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
1091 let func = pymod.getattr("handler").unwrap();
1092 Arc::new(func.into_py_any_unwrap(py))
1093 }));
1094
1095 let config = SocketConfig {
1096 url: format!("127.0.0.1:{port}"),
1097 mode: Mode::Plain,
1098 suffix: b"\r\n".to_vec(),
1099 py_handler,
1100 heartbeat: None,
1101 reconnect_timeout_ms: None,
1102 reconnect_delay_initial_ms: None,
1103 reconnect_backoff_factor: None,
1104 reconnect_delay_max_ms: None,
1105 reconnect_jitter_ms: None,
1106 certs_dir: None,
1107 };
1108
1109 let client = SocketClient::connect(config, None, None, None, None)
1110 .await
1111 .expect("Client connect failed unexpectedly");
1112
1113 client.send_bytes(b"hello".into()).await.unwrap();
1114 sleep(Duration::from_millis(100)).await;
1115
1116 client.send_bytes(b"ERR".into()).await.unwrap();
1117 sleep(Duration::from_secs(1)).await;
1118
1119 assert!(client.is_active());
1120
1121 client.close().await;
1122
1123 assert!(client.is_closed());
1124 server_task.abort();
1125 }
1126
1127 #[tokio::test]
1128 async fn test_reconnect_success() {
1129 prepare_freethreaded_python();
1130
1131 let (port, listener) = bind_test_server().await;
1132
1133 let server_task = task::spawn(async move {
1137 let (mut socket, _) = listener.accept().await.expect("First accept failed");
1139
1140 sleep(Duration::from_millis(500)).await;
1142 let _ = socket.shutdown().await;
1143
1144 sleep(Duration::from_millis(500)).await;
1146
1147 let (socket, _) = listener.accept().await.expect("Second accept failed");
1149 run_echo_server(socket).await;
1150 });
1151
1152 let config = SocketConfig {
1153 url: format!("127.0.0.1:{port}"),
1154 mode: Mode::Plain,
1155 suffix: b"\r\n".to_vec(),
1156 py_handler: Some(Arc::new(create_handler())),
1157 heartbeat: None,
1158 reconnect_timeout_ms: Some(5_000),
1159 reconnect_delay_initial_ms: Some(500),
1160 reconnect_delay_max_ms: Some(5_000),
1161 reconnect_backoff_factor: Some(2.0),
1162 reconnect_jitter_ms: Some(50),
1163 certs_dir: None,
1164 };
1165
1166 let client = SocketClient::connect(config, None, None, None, None)
1167 .await
1168 .expect("Client connect failed unexpectedly");
1169
1170 assert!(client.is_active(), "Client should start as active");
1172
1173 wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1176
1177 client
1178 .send_bytes(b"TestReconnect".into())
1179 .await
1180 .expect("Send failed");
1181
1182 client.close().await;
1183 server_task.abort();
1184 }
1185}