1use std::{
32 fmt::Debug,
33 sync::{
34 Arc,
35 atomic::{AtomicU8, Ordering},
36 },
37 time::Duration,
38};
39
40use futures_util::{
41 SinkExt, StreamExt,
42 stream::{SplitSink, SplitStream},
43};
44use http::HeaderName;
45use nautilus_cryptography::providers::install_cryptographic_provider;
46#[cfg(feature = "python")]
47use pyo3::{prelude::*, types::PyBytes};
48use tokio::{
49 net::TcpStream,
50 sync::mpsc::{self, Receiver, Sender},
51};
52use tokio_tungstenite::{
53 MaybeTlsStream, WebSocketStream, connect_async,
54 tungstenite::{Error, Message, client::IntoClientRequest, http::HeaderValue},
55};
56
57use crate::{
58 backoff::ExponentialBackoff,
59 error::SendError,
60 logging::{log_task_aborted, log_task_started, log_task_stopped},
61 mode::ConnectionMode,
62 ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
63};
64
65type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
66pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
67
68#[derive(Debug, Clone)]
70pub enum Consumer {
71 #[cfg(feature = "python")]
73 Python(Option<Arc<PyObject>>),
74 Rust(Sender<Message>),
76}
77
78impl Consumer {
79 #[must_use]
83 pub fn rust_consumer() -> (Self, Receiver<Message>) {
84 let (tx, rx) = mpsc::channel(100);
85 (Self::Rust(tx), rx)
86 }
87}
88
89#[derive(Debug, Clone)]
90#[cfg_attr(
91 feature = "python",
92 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
93)]
94pub struct WebSocketConfig {
95 pub url: String,
97 pub headers: Vec<(String, String)>,
99 pub handler: Consumer,
101 pub heartbeat: Option<u64>,
103 pub heartbeat_msg: Option<String>,
105 #[cfg(feature = "python")]
107 pub ping_handler: Option<Arc<PyObject>>,
108 pub reconnect_timeout_ms: Option<u64>,
110 pub reconnect_delay_initial_ms: Option<u64>,
112 pub reconnect_delay_max_ms: Option<u64>,
114 pub reconnect_backoff_factor: Option<f64>,
116 pub reconnect_jitter_ms: Option<u64>,
118}
119
120#[derive(Debug)]
122pub(crate) enum WriterCommand {
123 Update(MessageWriter),
125 Send(Message),
127}
128
129struct WebSocketClientInner {
145 config: WebSocketConfig,
146 read_task: Option<tokio::task::JoinHandle<()>>,
147 write_task: tokio::task::JoinHandle<()>,
148 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
149 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
150 connection_mode: Arc<AtomicU8>,
151 reconnect_timeout: Duration,
152 backoff: ExponentialBackoff,
153}
154
155impl WebSocketClientInner {
156 pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
158 install_cryptographic_provider();
159
160 #[allow(unused_variables)]
161 let WebSocketConfig {
162 url,
163 handler,
164 heartbeat,
165 headers,
166 heartbeat_msg,
167 #[cfg(feature = "python")]
168 ping_handler,
169 reconnect_timeout_ms,
170 reconnect_delay_initial_ms,
171 reconnect_delay_max_ms,
172 reconnect_backoff_factor,
173 reconnect_jitter_ms,
174 } = &config;
175 let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
176
177 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
178
179 let read_task = match &handler {
180 #[cfg(feature = "python")]
181 Consumer::Python(handler) => handler.as_ref().map(|handler| {
182 Self::spawn_python_callback_task(
183 connection_mode.clone(),
184 reader,
185 handler.clone(),
186 ping_handler.clone(),
187 )
188 }),
189 Consumer::Rust(sender) => Some(Self::spawn_rust_streaming_task(
190 connection_mode.clone(),
191 reader,
192 sender.clone(),
193 )),
194 };
195
196 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
197 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
198
199 let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
201 Self::spawn_heartbeat_task(
202 connection_mode.clone(),
203 *heartbeat_secs,
204 heartbeat_msg.clone(),
205 writer_tx.clone(),
206 )
207 });
208
209 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
210 let backoff = ExponentialBackoff::new(
211 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
212 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
213 reconnect_backoff_factor.unwrap_or(1.5),
214 reconnect_jitter_ms.unwrap_or(100),
215 true, )
217 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
218
219 Ok(Self {
220 config,
221 read_task,
222 write_task,
223 writer_tx,
224 heartbeat_task,
225 connection_mode,
226 reconnect_timeout,
227 backoff,
228 })
229 }
230
231 #[inline]
233 pub async fn connect_with_server(
234 url: &str,
235 headers: Vec<(String, String)>,
236 ) -> Result<(MessageWriter, MessageReader), Error> {
237 let mut request = url.into_client_request()?;
238 let req_headers = request.headers_mut();
239
240 let mut header_names: Vec<HeaderName> = Vec::new();
241 for (key, val) in headers {
242 let header_value = HeaderValue::from_str(&val)?;
243 let header_name: HeaderName = key.parse()?;
244 header_names.push(header_name.clone());
245 req_headers.insert(header_name, header_value);
246 }
247
248 connect_async(request).await.map(|resp| resp.0.split())
249 }
250
251 pub async fn reconnect(&mut self) -> Result<(), Error> {
256 tracing::debug!("Reconnecting");
257
258 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
259 tracing::debug!("Reconnect aborted due to disconnect state");
260 return Ok(());
261 }
262
263 tokio::time::timeout(self.reconnect_timeout, async {
264 let (new_writer, reader) =
266 Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
267
268 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
269 tracing::debug!("Reconnect aborted mid-flight (after connect)");
270 return Ok(());
271 }
272
273 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
274 tracing::error!("{e}");
275 }
276
277 tokio::time::sleep(Duration::from_millis(100)).await;
279
280 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
281 tracing::debug!("Reconnect aborted mid-flight (after delay)");
282 return Ok(());
283 }
284
285 if let Some(ref read_task) = self.read_task.take()
286 && !read_task.is_finished()
287 {
288 read_task.abort();
289 log_task_aborted("read");
290 }
291
292 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
294 tracing::debug!("Reconnect aborted mid-flight (before spawn read)");
295 return Ok(());
296 }
297
298 self.connection_mode
300 .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
301
302 self.read_task = match &self.config.handler {
303 #[cfg(feature = "python")]
304 Consumer::Python(handler) => handler.as_ref().map(|handler| {
305 Self::spawn_python_callback_task(
306 self.connection_mode.clone(),
307 reader,
308 handler.clone(),
309 self.config.ping_handler.clone(),
310 )
311 }),
312 Consumer::Rust(sender) => Some(Self::spawn_rust_streaming_task(
313 self.connection_mode.clone(),
314 reader,
315 sender.clone(),
316 )),
317 };
318
319 tracing::debug!("Reconnect succeeded");
320 Ok(())
321 })
322 .await
323 .map_err(|_| {
324 Error::Io(std::io::Error::new(
325 std::io::ErrorKind::TimedOut,
326 format!(
327 "reconnection timed out after {}s",
328 self.reconnect_timeout.as_secs_f64()
329 ),
330 ))
331 })?
332 }
333
334 #[inline]
342 #[must_use]
343 pub fn is_alive(&self) -> bool {
344 match &self.read_task {
345 Some(read_task) => !read_task.is_finished(),
346 None => true, }
348 }
349
350 fn spawn_rust_streaming_task(
351 connection_state: Arc<AtomicU8>,
352 mut reader: MessageReader,
353 sender: Sender<Message>,
354 ) -> tokio::task::JoinHandle<()> {
355 tracing::debug!("Started streaming task 'read'");
356
357 let check_interval = Duration::from_millis(10);
358
359 tokio::task::spawn(async move {
360 loop {
361 if !ConnectionMode::from_atomic(&connection_state).is_active() {
362 break;
363 }
364
365 match tokio::time::timeout(check_interval, reader.next()).await {
366 Ok(Some(Ok(message))) => {
367 if let Err(e) = sender.send(message).await {
368 tracing::error!("Failed to send message: {e}");
369 }
370 }
371 Ok(Some(Err(e))) => {
372 tracing::error!("Received error message - terminating: {e}");
373 break;
374 }
375 Ok(None) => {
376 tracing::debug!("No message received - terminating");
377 break;
378 }
379 Err(_) => {
380 continue;
382 }
383 }
384 }
385 })
386 }
387
388 #[cfg(feature = "python")]
389 fn spawn_python_callback_task(
390 connection_state: Arc<AtomicU8>,
391 mut reader: MessageReader,
392 handler: Arc<PyObject>,
393 ping_handler: Option<Arc<PyObject>>,
394 ) -> tokio::task::JoinHandle<()> {
395 log_task_started("read");
396
397 let check_interval = Duration::from_millis(10);
399
400 tokio::task::spawn(async move {
401 loop {
402 if !ConnectionMode::from_atomic(&connection_state).is_active() {
403 break;
404 }
405
406 match tokio::time::timeout(check_interval, reader.next()).await {
407 Ok(Some(Ok(Message::Binary(data)))) => {
408 tracing::trace!("Received message <binary> {} bytes", data.len());
409 if let Err(e) =
410 Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &data),)))
411 {
412 tracing::error!("Error calling handler: {e}");
413 break;
414 }
415 continue;
416 }
417 Ok(Some(Ok(Message::Text(data)))) => {
418 tracing::trace!("Received message: {data}");
419 if let Err(e) = Python::with_gil(|py| {
420 handler.call1(py, (PyBytes::new(py, data.as_bytes()),))
421 }) {
422 tracing::error!("Error calling handler: {e}");
423 break;
424 }
425 continue;
426 }
427 Ok(Some(Ok(Message::Ping(ping)))) => {
428 tracing::trace!("Received ping: {ping:?}");
429 if let Some(ref handler) = ping_handler
430 && let Err(e) =
431 Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &ping),)))
432 {
433 tracing::error!("Error calling handler: {e}");
434 break;
435 }
436 continue;
437 }
438 Ok(Some(Ok(Message::Pong(_)))) => {
439 tracing::trace!("Received pong");
440 }
441 Ok(Some(Ok(Message::Close(_)))) => {
442 tracing::debug!("Received close message - terminating");
443 break;
444 }
445 Ok(Some(Ok(_))) => (),
446 Ok(Some(Err(e))) => {
447 tracing::error!("Received error message - terminating: {e}");
448 break;
449 }
450 Ok(None) => {
453 tracing::debug!("No message received - terminating");
454 break;
455 }
456 Err(_) => {
457 continue;
459 }
460 }
461 }
462 })
463 }
464
465 fn spawn_write_task(
466 connection_state: Arc<AtomicU8>,
467 writer: MessageWriter,
468 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
469 ) -> tokio::task::JoinHandle<()> {
470 log_task_started("write");
471
472 let check_interval = Duration::from_millis(10);
474
475 tokio::task::spawn(async move {
476 let mut active_writer = writer;
477
478 loop {
479 match ConnectionMode::from_atomic(&connection_state) {
480 ConnectionMode::Disconnect => {
481 _ = active_writer.close().await;
484 break;
485 }
486 ConnectionMode::Closed => break,
487 _ => {}
488 }
489
490 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
491 Ok(Some(msg)) => {
492 let mode = ConnectionMode::from_atomic(&connection_state);
494 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
495 break;
496 }
497
498 match msg {
499 WriterCommand::Update(new_writer) => {
500 tracing::debug!("Received new writer");
501
502 tokio::time::sleep(Duration::from_millis(100)).await;
504
505 _ = active_writer.close().await;
508
509 active_writer = new_writer;
510 tracing::debug!("Updated writer");
511 }
512 _ if mode.is_reconnect() => {
513 tracing::warn!("Skipping message while reconnecting, {msg:?}");
514 continue;
515 }
516 WriterCommand::Send(msg) => {
517 if let Err(e) = active_writer.send(msg).await {
518 tracing::error!("Failed to send message: {e}");
519 tracing::warn!("Writer triggering reconnect");
521 connection_state
522 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
523 }
524 }
525 }
526 }
527 Ok(None) => {
528 tracing::debug!("Writer channel closed, terminating writer task");
530 break;
531 }
532 Err(_) => {
533 continue;
535 }
536 }
537 }
538
539 _ = active_writer.close().await;
542
543 log_task_stopped("write");
544 })
545 }
546
547 fn spawn_heartbeat_task(
548 connection_state: Arc<AtomicU8>,
549 heartbeat_secs: u64,
550 message: Option<String>,
551 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
552 ) -> tokio::task::JoinHandle<()> {
553 log_task_started("heartbeat");
554
555 tokio::task::spawn(async move {
556 let interval = Duration::from_secs(heartbeat_secs);
557
558 loop {
559 tokio::time::sleep(interval).await;
560
561 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
562 ConnectionMode::Active => {
563 let msg = match &message {
564 Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
565 None => WriterCommand::Send(Message::Ping(vec![].into())),
566 };
567
568 match writer_tx.send(msg) {
569 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
570 Err(e) => {
571 tracing::error!("Failed to send heartbeat to writer task: {e}");
572 }
573 }
574 }
575 ConnectionMode::Reconnect => continue,
576 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
577 }
578 }
579
580 log_task_stopped("heartbeat");
581 })
582 }
583}
584
585impl Drop for WebSocketClientInner {
586 fn drop(&mut self) {
587 if let Some(ref read_task) = self.read_task.take()
588 && !read_task.is_finished()
589 {
590 read_task.abort();
591 log_task_aborted("read");
592 }
593
594 if !self.write_task.is_finished() {
595 self.write_task.abort();
596 log_task_aborted("write");
597 }
598
599 if let Some(ref handle) = self.heartbeat_task.take()
600 && !handle.is_finished()
601 {
602 handle.abort();
603 log_task_aborted("heartbeat");
604 }
605 }
606}
607
608#[cfg_attr(
613 feature = "python",
614 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
615)]
616pub struct WebSocketClient {
617 pub(crate) controller_task: tokio::task::JoinHandle<()>,
618 pub(crate) connection_mode: Arc<AtomicU8>,
619 pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
620 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
621}
622
623impl Debug for WebSocketClient {
624 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
625 f.debug_struct(stringify!(WebSocketClient)).finish()
626 }
627}
628
629impl WebSocketClient {
630 #[allow(clippy::too_many_arguments)]
636 pub async fn connect_stream(
637 config: WebSocketConfig,
638 keyed_quotas: Vec<(String, Quota)>,
639 default_quota: Option<Quota>,
640 post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
641 ) -> Result<(MessageReader, Self), Error> {
642 install_cryptographic_provider();
643 let (ws_stream, _) = connect_async(config.url.clone().into_client_request()?).await?;
644 let (writer, reader) = ws_stream.split();
645 let inner = WebSocketClientInner::connect_url(config).await?;
646
647 let connection_mode = inner.connection_mode.clone();
648
649 let writer_tx = inner.writer_tx.clone();
650 if let Err(e) = writer_tx.send(WriterCommand::Update(writer)) {
651 tracing::error!("{e}");
652 }
653
654 let controller_task = Self::spawn_controller_task(
655 inner,
656 connection_mode.clone(),
657 post_reconnect,
658 #[cfg(feature = "python")]
659 None, #[cfg(feature = "python")]
661 None, );
663
664 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
665
666 Ok((
667 reader,
668 Self {
669 controller_task,
670 connection_mode,
671 writer_tx,
672 rate_limiter,
673 },
674 ))
675 }
676
677 pub async fn connect(
686 config: WebSocketConfig,
687 #[cfg(feature = "python")] post_connection: Option<PyObject>,
688 #[cfg(feature = "python")] post_reconnection: Option<PyObject>,
689 #[cfg(feature = "python")] post_disconnection: Option<PyObject>,
690 keyed_quotas: Vec<(String, Quota)>,
691 default_quota: Option<Quota>,
692 ) -> Result<Self, Error> {
693 tracing::debug!("Connecting");
694 let inner = WebSocketClientInner::connect_url(config.clone()).await?;
695 let connection_mode = inner.connection_mode.clone();
696 let writer_tx = inner.writer_tx.clone();
697
698 let controller_task = Self::spawn_controller_task(
699 inner,
700 connection_mode.clone(),
701 None, #[cfg(feature = "python")]
703 post_reconnection, #[cfg(feature = "python")]
705 post_disconnection, );
707
708 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
709
710 #[cfg(feature = "python")]
711 if let Some(handler) = post_connection {
712 Python::with_gil(|py| match handler.call0(py) {
713 Ok(_) => tracing::debug!("Called `post_connection` handler"),
714 Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
715 });
716 }
717
718 Ok(Self {
719 controller_task,
720 connection_mode,
721 writer_tx,
722 rate_limiter,
723 })
724 }
725
726 #[must_use]
728 pub fn connection_mode(&self) -> ConnectionMode {
729 ConnectionMode::from_atomic(&self.connection_mode)
730 }
731
732 #[inline]
737 #[must_use]
738 pub fn is_active(&self) -> bool {
739 self.connection_mode().is_active()
740 }
741
742 #[must_use]
744 pub fn is_disconnected(&self) -> bool {
745 self.controller_task.is_finished()
746 }
747
748 #[inline]
753 #[must_use]
754 pub fn is_reconnecting(&self) -> bool {
755 self.connection_mode().is_reconnect()
756 }
757
758 #[inline]
762 #[must_use]
763 pub fn is_disconnecting(&self) -> bool {
764 self.connection_mode().is_disconnect()
765 }
766
767 #[inline]
773 #[must_use]
774 pub fn is_closed(&self) -> bool {
775 self.connection_mode().is_closed()
776 }
777
778 pub async fn disconnect(&self) {
783 tracing::debug!("Disconnecting");
784 self.connection_mode
785 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
786
787 match tokio::time::timeout(Duration::from_secs(5), async {
788 while !self.is_disconnected() {
789 tokio::time::sleep(Duration::from_millis(10)).await;
790 }
791
792 if !self.controller_task.is_finished() {
793 self.controller_task.abort();
794 log_task_aborted("controller");
795 }
796 })
797 .await
798 {
799 Ok(()) => {
800 tracing::debug!("Controller task finished");
801 }
802 Err(_) => {
803 tracing::error!("Timeout waiting for controller task to finish");
804 }
805 }
806 }
807
808 #[allow(unused_variables)]
814 pub async fn send_text(
815 &self,
816 data: String,
817 keys: Option<Vec<String>>,
818 ) -> std::result::Result<(), SendError> {
819 self.rate_limiter.await_keys_ready(keys).await;
820
821 if !self.is_active() {
822 return Err(SendError::Closed);
823 }
824
825 tracing::trace!("Sending text: {data:?}");
826
827 let msg = Message::Text(data.into());
828 self.writer_tx
829 .send(WriterCommand::Send(msg))
830 .map_err(|e| SendError::BrokenPipe(e.to_string()))
831 }
832
833 #[allow(unused_variables)]
839 pub async fn send_bytes(
840 &self,
841 data: Vec<u8>,
842 keys: Option<Vec<String>>,
843 ) -> std::result::Result<(), SendError> {
844 self.rate_limiter.await_keys_ready(keys).await;
845
846 if !self.is_active() {
847 return Err(SendError::Closed);
848 }
849
850 tracing::trace!("Sending bytes: {data:?}");
851
852 let msg = Message::Binary(data.into());
853 self.writer_tx
854 .send(WriterCommand::Send(msg))
855 .map_err(|e| SendError::BrokenPipe(e.to_string()))
856 }
857
858 pub async fn send_close_message(&self) -> std::result::Result<(), SendError> {
864 if !self.is_active() {
865 return Err(SendError::Closed);
866 }
867
868 let msg = Message::Close(None);
869 self.writer_tx
870 .send(WriterCommand::Send(msg))
871 .map_err(|e| SendError::BrokenPipe(e.to_string()))
872 }
873
874 fn spawn_controller_task(
875 mut inner: WebSocketClientInner,
876 connection_mode: Arc<AtomicU8>,
877 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
878 #[cfg(feature = "python")] py_post_reconnection: Option<PyObject>, #[cfg(feature = "python")] py_post_disconnection: Option<PyObject>, ) -> tokio::task::JoinHandle<()> {
881 tokio::task::spawn(async move {
882 log_task_started("controller");
883
884 let check_interval = Duration::from_millis(10);
885
886 loop {
887 tokio::time::sleep(check_interval).await;
888 let mode = ConnectionMode::from_atomic(&connection_mode);
889
890 if mode.is_disconnect() {
891 tracing::debug!("Disconnecting");
892
893 let timeout = Duration::from_secs(5);
894 if tokio::time::timeout(timeout, async {
895 tokio::time::sleep(Duration::from_millis(100)).await;
897
898 if let Some(task) = &inner.read_task
899 && !task.is_finished()
900 {
901 task.abort();
902 log_task_aborted("read");
903 }
904
905 if let Some(task) = &inner.heartbeat_task
906 && !task.is_finished()
907 {
908 task.abort();
909 log_task_aborted("heartbeat");
910 }
911 })
912 .await
913 .is_err()
914 {
915 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
916 }
917
918 tracing::debug!("Closed");
919
920 #[cfg(feature = "python")]
921 if let Some(ref handler) = py_post_disconnection {
922 Python::with_gil(|py| match handler.call0(py) {
923 Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
924 Err(e) => {
925 tracing::error!("Error calling `post_disconnection` handler: {e}");
926 }
927 });
928 }
929 break; }
931
932 if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
933 match inner.reconnect().await {
934 Ok(()) => {
935 inner.backoff.reset();
936
937 if ConnectionMode::from_atomic(&connection_mode).is_active() {
939 if let Some(ref callback) = post_reconnection {
940 callback();
941 }
942
943 #[cfg(feature = "python")]
945 if let Some(ref callback) = py_post_reconnection {
946 Python::with_gil(|py| match callback.call0(py) {
947 Ok(_) => {
948 tracing::debug!("Called `post_reconnection` handler");
949 }
950 Err(e) => tracing::error!(
951 "Error calling `post_reconnection` handler: {e}"
952 ),
953 });
954 }
955
956 tracing::debug!("Reconnected successfully");
957 } else {
958 tracing::debug!(
959 "Skipping post_reconnection handlers due to disconnect state"
960 );
961 }
962 }
963 Err(e) => {
964 let duration = inner.backoff.next_duration();
965 tracing::warn!("Reconnect attempt failed: {e}");
966 if !duration.is_zero() {
967 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
968 }
969 tokio::time::sleep(duration).await;
970 }
971 }
972 }
973 }
974 inner
975 .connection_mode
976 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
977
978 log_task_stopped("controller");
979 })
980 }
981}
982
983impl Drop for WebSocketClient {
985 fn drop(&mut self) {
986 if !self.controller_task.is_finished() {
987 self.controller_task.abort();
988 log_task_aborted("controller");
989 }
990 }
991}
992
993#[cfg(feature = "python")]
997#[cfg(test)]
998#[cfg(target_os = "linux")] mod tests {
1000 use std::{num::NonZeroU32, sync::Arc};
1001
1002 use futures_util::{SinkExt, StreamExt};
1003 use tokio::{
1004 net::TcpListener,
1005 task::{self, JoinHandle},
1006 };
1007 use tokio_tungstenite::{
1008 accept_hdr_async,
1009 tungstenite::{
1010 handshake::server::{self, Callback},
1011 http::HeaderValue,
1012 },
1013 };
1014
1015 use crate::{
1016 ratelimiter::quota::Quota,
1017 websocket::{Consumer, WebSocketClient, WebSocketConfig},
1018 };
1019
1020 struct TestServer {
1021 task: JoinHandle<()>,
1022 port: u16,
1023 }
1024
1025 #[derive(Debug, Clone)]
1026 struct TestCallback {
1027 key: String,
1028 value: HeaderValue,
1029 }
1030
1031 impl Callback for TestCallback {
1032 fn on_request(
1033 self,
1034 request: &server::Request,
1035 response: server::Response,
1036 ) -> Result<server::Response, server::ErrorResponse> {
1037 let _ = response;
1038 let value = request.headers().get(&self.key);
1039 assert!(value.is_some());
1040
1041 if let Some(value) = request.headers().get(&self.key) {
1042 assert_eq!(value, self.value);
1043 }
1044
1045 Ok(response)
1046 }
1047 }
1048
1049 impl TestServer {
1050 async fn setup() -> Self {
1051 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1052 let port = TcpListener::local_addr(&server).unwrap().port();
1053
1054 let header_key = "test".to_string();
1055 let header_value = "test".to_string();
1056
1057 let test_call_back = TestCallback {
1058 key: header_key,
1059 value: HeaderValue::from_str(&header_value).unwrap(),
1060 };
1061
1062 let task = task::spawn(async move {
1063 loop {
1065 let (conn, _) = server.accept().await.unwrap();
1066 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1067 .await
1068 .unwrap();
1069
1070 task::spawn(async move {
1071 while let Some(Ok(msg)) = websocket.next().await {
1072 match msg {
1073 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1074 if txt == "close-now" =>
1075 {
1076 tracing::debug!("Forcibly closing from server side");
1077 let _ = websocket.close(None).await;
1079 break;
1080 }
1081 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1083 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1084 if websocket.send(msg).await.is_err() {
1085 break;
1086 }
1087 }
1088 tokio_tungstenite::tungstenite::protocol::Message::Close(
1090 _frame,
1091 ) => {
1092 let _ = websocket.close(None).await;
1093 break;
1094 }
1095 _ => {}
1097 }
1098 }
1099 });
1100 }
1101 });
1102
1103 Self { task, port }
1104 }
1105 }
1106
1107 impl Drop for TestServer {
1108 fn drop(&mut self) {
1109 self.task.abort();
1110 }
1111 }
1112
1113 async fn setup_test_client(port: u16) -> WebSocketClient {
1114 let config = WebSocketConfig {
1115 url: format!("ws://127.0.0.1:{port}"),
1116 headers: vec![("test".into(), "test".into())],
1117 handler: Consumer::Python(None),
1118 heartbeat: None,
1119 heartbeat_msg: None,
1120 ping_handler: None,
1121 reconnect_timeout_ms: None,
1122 reconnect_delay_initial_ms: None,
1123 reconnect_backoff_factor: None,
1124 reconnect_delay_max_ms: None,
1125 reconnect_jitter_ms: None,
1126 };
1127 WebSocketClient::connect(config, None, None, None, vec![], None)
1128 .await
1129 .expect("Failed to connect")
1130 }
1131
1132 #[tokio::test]
1133 async fn test_websocket_basic() {
1134 let server = TestServer::setup().await;
1135 let client = setup_test_client(server.port).await;
1136
1137 assert!(!client.is_disconnected());
1138
1139 client.disconnect().await;
1140 assert!(client.is_disconnected());
1141 }
1142
1143 #[tokio::test]
1144 async fn test_websocket_heartbeat() {
1145 let server = TestServer::setup().await;
1146 let client = setup_test_client(server.port).await;
1147
1148 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1150
1151 client.disconnect().await;
1153 assert!(client.is_disconnected());
1154 }
1155
1156 #[tokio::test]
1157 async fn test_websocket_reconnect_exhausted() {
1158 let config = WebSocketConfig {
1159 url: "ws://127.0.0.1:9997".into(), headers: vec![],
1161 handler: Consumer::Python(None),
1162 heartbeat: None,
1163 heartbeat_msg: None,
1164 ping_handler: None,
1165 reconnect_timeout_ms: None,
1166 reconnect_delay_initial_ms: None,
1167 reconnect_backoff_factor: None,
1168 reconnect_delay_max_ms: None,
1169 reconnect_jitter_ms: None,
1170 };
1171 let res = WebSocketClient::connect(config, None, None, None, vec![], None).await;
1172 assert!(res.is_err(), "Should fail quickly with no server");
1173 }
1174
1175 #[tokio::test]
1176 async fn test_websocket_forced_close_reconnect() {
1177 let server = TestServer::setup().await;
1178 let client = setup_test_client(server.port).await;
1179
1180 client.send_text("Hello".into(), None).await.unwrap();
1182
1183 client.send_text("close-now".into(), None).await.unwrap();
1185
1186 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1188
1189 assert!(!client.is_disconnected());
1191
1192 client.disconnect().await;
1194 assert!(client.is_disconnected());
1195 }
1196
1197 #[tokio::test]
1198 async fn test_rate_limiter() {
1199 let server = TestServer::setup().await;
1200 let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1201
1202 let config = WebSocketConfig {
1203 url: format!("ws://127.0.0.1:{}", server.port),
1204 headers: vec![("test".into(), "test".into())],
1205 handler: Consumer::Python(None),
1206 heartbeat: None,
1207 heartbeat_msg: None,
1208 ping_handler: None,
1209 reconnect_timeout_ms: None,
1210 reconnect_delay_initial_ms: None,
1211 reconnect_backoff_factor: None,
1212 reconnect_delay_max_ms: None,
1213 reconnect_jitter_ms: None,
1214 };
1215
1216 let client = WebSocketClient::connect(
1217 config,
1218 None,
1219 None,
1220 None,
1221 vec![("default".into(), quota)],
1222 None,
1223 )
1224 .await
1225 .unwrap();
1226
1227 client.send_text("test1".into(), None).await.unwrap();
1229 client.send_text("test2".into(), None).await.unwrap();
1230
1231 client.send_text("test3".into(), None).await.unwrap();
1233
1234 client.disconnect().await;
1236 assert!(client.is_disconnected());
1237 }
1238
1239 #[tokio::test]
1240 async fn test_concurrent_writers() {
1241 let server = TestServer::setup().await;
1242 let client = Arc::new(setup_test_client(server.port).await);
1243
1244 let mut handles = vec![];
1245 for i in 0..10 {
1246 let client = client.clone();
1247 handles.push(task::spawn(async move {
1248 client.send_text(format!("test{i}"), None).await.unwrap();
1249 }));
1250 }
1251
1252 for handle in handles {
1253 handle.await.unwrap();
1254 }
1255
1256 client.disconnect().await;
1258 assert!(client.is_disconnected());
1259 }
1260}
1261
1262#[cfg(test)]
1263mod rust_tests {
1264 use tokio::{
1265 net::TcpListener,
1266 task,
1267 time::{Duration, sleep},
1268 };
1269 use tokio_tungstenite::accept_async;
1270
1271 use super::*;
1272
1273 #[tokio::test]
1274 async fn test_reconnect_then_disconnect() {
1275 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1277 let port = listener.local_addr().unwrap().port();
1278
1279 let server = task::spawn(async move {
1281 let (stream, _) = listener.accept().await.unwrap();
1282 let ws = accept_async(stream).await.unwrap();
1283 drop(ws);
1284 sleep(Duration::from_secs(1)).await;
1286 });
1287
1288 let (consumer, _rx) = Consumer::rust_consumer();
1290
1291 let config = WebSocketConfig {
1293 url: format!("ws://127.0.0.1:{port}"),
1294 headers: vec![],
1295 handler: consumer,
1296 heartbeat: None,
1297 heartbeat_msg: None,
1298 #[cfg(feature = "python")]
1299 ping_handler: None,
1300 reconnect_timeout_ms: Some(1_000),
1301 reconnect_delay_initial_ms: Some(50),
1302 reconnect_delay_max_ms: Some(100),
1303 reconnect_backoff_factor: Some(1.0),
1304 reconnect_jitter_ms: Some(0),
1305 };
1306
1307 let client = {
1309 #[cfg(feature = "python")]
1310 {
1311 WebSocketClient::connect(config.clone(), None, None, None, vec![], None)
1312 .await
1313 .unwrap()
1314 }
1315 #[cfg(not(feature = "python"))]
1316 {
1317 WebSocketClient::connect(config.clone(), vec![], None)
1318 .await
1319 .unwrap()
1320 }
1321 };
1322
1323 sleep(Duration::from_millis(100)).await;
1325 client.disconnect().await;
1327 assert!(client.is_disconnected());
1328 server.abort();
1329 }
1330}