1use std::{
32 fmt::Debug,
33 sync::{
34 Arc,
35 atomic::{AtomicU8, Ordering},
36 },
37 time::Duration,
38};
39
40use futures_util::{
41 SinkExt, StreamExt,
42 stream::{SplitSink, SplitStream},
43};
44use http::HeaderName;
45use nautilus_core::CleanDrop;
46use nautilus_cryptography::providers::install_cryptographic_provider;
47#[cfg(feature = "turmoil")]
48use tokio_tungstenite::client_async;
49#[cfg(not(feature = "turmoil"))]
50use tokio_tungstenite::connect_async;
51use tokio_tungstenite::{
52 MaybeTlsStream, WebSocketStream,
53 tungstenite::{Error, Message, client::IntoClientRequest, http::HeaderValue},
54};
55
56#[cfg(feature = "turmoil")]
57use crate::net::TcpConnector;
58use crate::{
59 RECONNECTED,
60 backoff::ExponentialBackoff,
61 error::SendError,
62 logging::{log_task_aborted, log_task_started, log_task_stopped},
63 mode::ConnectionMode,
64 ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
65};
66
67pub const TEXT_PING: &str = "ping";
68pub const TEXT_PONG: &str = "pong";
69
70const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
72const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
73const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
74const SEND_OPERATION_CHECK_INTERVAL_MS: u64 = 1;
75
76#[cfg(not(feature = "turmoil"))]
77type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>, Message>;
78#[cfg(not(feature = "turmoil"))]
79pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>>;
80
81#[cfg(feature = "turmoil")]
82type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<crate::net::TcpStream>>, Message>;
83#[cfg(feature = "turmoil")]
84pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<crate::net::TcpStream>>>;
85
86pub type MessageHandler = Arc<dyn Fn(Message) + Send + Sync>;
96
97pub type PingHandler = Arc<dyn Fn(Vec<u8>) + Send + Sync>;
99
100#[must_use]
104pub fn channel_message_handler() -> (
105 MessageHandler,
106 tokio::sync::mpsc::UnboundedReceiver<Message>,
107) {
108 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
109 let handler = Arc::new(move |msg: Message| {
110 if let Err(e) = tx.send(msg) {
111 tracing::debug!("Failed to send message to channel: {e}");
112 }
113 });
114 (handler, rx)
115}
116
117#[cfg_attr(
137 feature = "python",
138 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
139)]
140pub struct WebSocketConfig {
141 pub url: String,
143 pub headers: Vec<(String, String)>,
145 pub message_handler: Option<MessageHandler>,
152 pub heartbeat: Option<u64>,
154 pub heartbeat_msg: Option<String>,
156 pub ping_handler: Option<PingHandler>,
158 pub reconnect_timeout_ms: Option<u64>,
162 pub reconnect_delay_initial_ms: Option<u64>,
166 pub reconnect_delay_max_ms: Option<u64>,
170 pub reconnect_backoff_factor: Option<f64>,
174 pub reconnect_jitter_ms: Option<u64>,
178}
179
180impl Debug for WebSocketConfig {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 f.debug_struct(stringify!(WebSocketConfig))
183 .field("url", &self.url)
184 .field("headers", &self.headers)
185 .field(
186 "message_handler",
187 &self.message_handler.as_ref().map(|_| "<function>"),
188 )
189 .field("heartbeat", &self.heartbeat)
190 .field("heartbeat_msg", &self.heartbeat_msg)
191 .field(
192 "ping_handler",
193 &self.ping_handler.as_ref().map(|_| "<function>"),
194 )
195 .field("reconnect_timeout_ms", &self.reconnect_timeout_ms)
196 .field(
197 "reconnect_delay_initial_ms",
198 &self.reconnect_delay_initial_ms,
199 )
200 .field("reconnect_delay_max_ms", &self.reconnect_delay_max_ms)
201 .field("reconnect_backoff_factor", &self.reconnect_backoff_factor)
202 .field("reconnect_jitter_ms", &self.reconnect_jitter_ms)
203 .finish()
204 }
205}
206
207impl Clone for WebSocketConfig {
208 fn clone(&self) -> Self {
209 Self {
210 url: self.url.clone(),
211 headers: self.headers.clone(),
212 message_handler: self.message_handler.clone(),
213 heartbeat: self.heartbeat,
214 heartbeat_msg: self.heartbeat_msg.clone(),
215 ping_handler: self.ping_handler.clone(),
216 reconnect_timeout_ms: self.reconnect_timeout_ms,
217 reconnect_delay_initial_ms: self.reconnect_delay_initial_ms,
218 reconnect_delay_max_ms: self.reconnect_delay_max_ms,
219 reconnect_backoff_factor: self.reconnect_backoff_factor,
220 reconnect_jitter_ms: self.reconnect_jitter_ms,
221 }
222 }
223}
224
225#[derive(Debug)]
227pub(crate) enum WriterCommand {
228 Update(MessageWriter),
230 Send(Message),
232}
233
234struct WebSocketClientInner {
250 config: WebSocketConfig,
251 read_task: Option<tokio::task::JoinHandle<()>>,
252 write_task: tokio::task::JoinHandle<()>,
253 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
254 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
255 connection_mode: Arc<AtomicU8>,
256 reconnect_timeout: Duration,
257 backoff: ExponentialBackoff,
258 is_stream_mode: bool,
262}
263
264impl WebSocketClientInner {
265 pub async fn new_with_writer(
267 config: WebSocketConfig,
268 writer: MessageWriter,
269 ) -> Result<Self, Error> {
270 install_cryptographic_provider();
271
272 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
273
274 let read_task = None;
276
277 let backoff = ExponentialBackoff::new(
278 Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
279 Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
280 config.reconnect_backoff_factor.unwrap_or(1.5),
281 config.reconnect_jitter_ms.unwrap_or(100),
282 true, )
284 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
285
286 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
287 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
288
289 let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
290 Some(Self::spawn_heartbeat_task(
291 connection_mode.clone(),
292 heartbeat_interval,
293 config.heartbeat_msg.clone(),
294 writer_tx.clone(),
295 ))
296 } else {
297 None
298 };
299
300 Ok(Self {
301 config: config.clone(),
302 writer_tx,
303 connection_mode,
304 reconnect_timeout: Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10000)),
305 heartbeat_task,
306 read_task,
307 write_task,
308 backoff,
309 is_stream_mode: true,
310 })
311 }
312
313 pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
315 install_cryptographic_provider();
316
317 let WebSocketConfig {
318 url,
319 message_handler,
320 heartbeat,
321 headers,
322 heartbeat_msg,
323 ping_handler,
324 reconnect_timeout_ms,
325 reconnect_delay_initial_ms,
326 reconnect_delay_max_ms,
327 reconnect_backoff_factor,
328 reconnect_jitter_ms,
329 } = &config;
330 let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
331
332 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
333
334 let read_task = if message_handler.is_some() {
335 Some(Self::spawn_message_handler_task(
336 connection_mode.clone(),
337 reader,
338 message_handler.as_ref(),
339 ping_handler.as_ref(),
340 ))
341 } else {
342 None
343 };
344
345 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
346 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
347
348 let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
350 Self::spawn_heartbeat_task(
351 connection_mode.clone(),
352 *heartbeat_secs,
353 heartbeat_msg.clone(),
354 writer_tx.clone(),
355 )
356 });
357
358 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
359 let backoff = ExponentialBackoff::new(
360 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
361 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
362 reconnect_backoff_factor.unwrap_or(1.5),
363 reconnect_jitter_ms.unwrap_or(100),
364 true, )
366 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
367
368 Ok(Self {
369 config,
370 read_task,
371 write_task,
372 writer_tx,
373 heartbeat_task,
374 connection_mode,
375 reconnect_timeout,
376 backoff,
377 is_stream_mode: false,
378 })
379 }
380
381 #[inline]
384 #[cfg(not(feature = "turmoil"))]
385 pub async fn connect_with_server(
386 url: &str,
387 headers: Vec<(String, String)>,
388 ) -> Result<(MessageWriter, MessageReader), Error> {
389 let mut request = url.into_client_request()?;
390 let req_headers = request.headers_mut();
391
392 let mut header_names: Vec<HeaderName> = Vec::new();
393 for (key, val) in headers {
394 let header_value = HeaderValue::from_str(&val)?;
395 let header_name: HeaderName = key.parse()?;
396 header_names.push(header_name.clone());
397 req_headers.insert(header_name, header_value);
398 }
399
400 connect_async(request).await.map(|resp| resp.0.split())
401 }
402
403 #[inline]
406 #[cfg(feature = "turmoil")]
407 pub async fn connect_with_server(
408 url: &str,
409 headers: Vec<(String, String)>,
410 ) -> Result<(MessageWriter, MessageReader), Error> {
411 use rustls::ClientConfig;
412 use tokio_rustls::TlsConnector;
413
414 let mut request = url.into_client_request()?;
415 let req_headers = request.headers_mut();
416
417 let mut header_names: Vec<HeaderName> = Vec::new();
418 for (key, val) in headers {
419 let header_value = HeaderValue::from_str(&val)?;
420 let header_name: HeaderName = key.parse()?;
421 header_names.push(header_name.clone());
422 req_headers.insert(header_name, header_value);
423 }
424
425 let uri = request.uri();
426 let scheme = uri.scheme_str().unwrap_or("ws");
427 let host = uri.host().ok_or_else(|| {
428 Error::Url(tokio_tungstenite::tungstenite::error::UrlError::NoHostName)
429 })?;
430
431 let port = uri
433 .port_u16()
434 .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
435
436 let addr = format!("{host}:{port}");
437
438 let connector = crate::net::RealTcpConnector;
440 let tcp_stream = connector.connect(&addr).await?;
441
442 let maybe_tls_stream = if scheme == "wss" {
444 let mut root_store = rustls::RootCertStore::empty();
446 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
447
448 let config = ClientConfig::builder()
449 .with_root_certificates(root_store)
450 .with_no_client_auth();
451
452 let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
453 let domain =
454 rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|e| {
455 Error::Io(std::io::Error::new(
456 std::io::ErrorKind::InvalidInput,
457 format!("Invalid DNS name: {e}"),
458 ))
459 })?;
460
461 let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
462 MaybeTlsStream::Rustls(tls_stream)
463 } else {
464 MaybeTlsStream::Plain(tcp_stream)
465 };
466
467 client_async(request, maybe_tls_stream)
469 .await
470 .map(|resp| resp.0.split())
471 }
472
473 pub async fn reconnect(&mut self) -> Result<(), Error> {
482 tracing::debug!("Reconnecting");
483
484 if self.is_stream_mode {
485 tracing::warn!(
486 "Auto-reconnect disabled for stream-based WebSocket client; \
487 stream users must manually reconnect by creating a new connection"
488 );
489 self.connection_mode
491 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
492 return Ok(());
493 }
494
495 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
496 tracing::debug!("Reconnect aborted due to disconnect state");
497 return Ok(());
498 }
499
500 tokio::time::timeout(self.reconnect_timeout, async {
501 let (new_writer, reader) =
503 Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
504
505 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
506 tracing::debug!("Reconnect aborted mid-flight (after connect)");
507 return Ok(());
508 }
509
510 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
511 tracing::error!("{e}");
512 }
513
514 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
516
517 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
518 tracing::debug!("Reconnect aborted mid-flight (after delay)");
519 return Ok(());
520 }
521
522 if let Some(ref read_task) = self.read_task.take()
523 && !read_task.is_finished()
524 {
525 read_task.abort();
526 log_task_aborted("read");
527 }
528
529 if self
532 .connection_mode
533 .compare_exchange(
534 ConnectionMode::Reconnect.as_u8(),
535 ConnectionMode::Active.as_u8(),
536 Ordering::SeqCst,
537 Ordering::SeqCst,
538 )
539 .is_err()
540 {
541 tracing::debug!("Reconnect aborted (state changed during reconnect)");
542 return Ok(());
543 }
544
545 self.read_task = if self.config.message_handler.is_some() {
546 Some(Self::spawn_message_handler_task(
547 self.connection_mode.clone(),
548 reader,
549 self.config.message_handler.as_ref(),
550 self.config.ping_handler.as_ref(),
551 ))
552 } else {
553 None
554 };
555
556 tracing::debug!("Reconnect succeeded");
557 Ok(())
558 })
559 .await
560 .map_err(|_| {
561 Error::Io(std::io::Error::new(
562 std::io::ErrorKind::TimedOut,
563 format!(
564 "reconnection timed out after {}s",
565 self.reconnect_timeout.as_secs_f64()
566 ),
567 ))
568 })?
569 }
570
571 #[inline]
579 #[must_use]
580 pub fn is_alive(&self) -> bool {
581 match &self.read_task {
582 Some(read_task) => !read_task.is_finished(),
583 None => true, }
585 }
586
587 fn spawn_message_handler_task(
588 connection_state: Arc<AtomicU8>,
589 mut reader: MessageReader,
590 message_handler: Option<&MessageHandler>,
591 ping_handler: Option<&PingHandler>,
592 ) -> tokio::task::JoinHandle<()> {
593 tracing::debug!("Started message handler task 'read'");
594
595 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
596
597 let message_handler = message_handler.cloned();
599 let ping_handler = ping_handler.cloned();
600
601 tokio::task::spawn(async move {
602 loop {
603 if !ConnectionMode::from_atomic(&connection_state).is_active() {
604 break;
605 }
606
607 match tokio::time::timeout(check_interval, reader.next()).await {
608 Ok(Some(Ok(Message::Binary(data)))) => {
609 tracing::trace!("Received message <binary> {} bytes", data.len());
610 if let Some(ref handler) = message_handler {
611 handler(Message::Binary(data));
612 }
613 }
614 Ok(Some(Ok(Message::Text(data)))) => {
615 tracing::trace!("Received message: {data}");
616 if let Some(ref handler) = message_handler {
617 handler(Message::Text(data));
618 }
619 }
620 Ok(Some(Ok(Message::Ping(ping_data)))) => {
621 tracing::trace!("Received ping: {ping_data:?}");
622 if let Some(ref handler) = ping_handler {
623 handler(ping_data.to_vec());
624 }
625 }
626 Ok(Some(Ok(Message::Pong(_)))) => {
627 tracing::trace!("Received pong");
628 }
629 Ok(Some(Ok(Message::Close(_)))) => {
630 tracing::debug!("Received close message - terminating");
631 break;
632 }
633 Ok(Some(Ok(_))) => (),
634 Ok(Some(Err(e))) => {
635 tracing::error!("Received error message - terminating: {e}");
636 break;
637 }
638 Ok(None) => {
639 tracing::debug!("No message received - terminating");
640 break;
641 }
642 Err(_) => {
643 continue;
645 }
646 }
647 }
648 })
649 }
650
651 fn spawn_write_task(
652 connection_state: Arc<AtomicU8>,
653 writer: MessageWriter,
654 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
655 ) -> tokio::task::JoinHandle<()> {
656 log_task_started("write");
657
658 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
660
661 tokio::task::spawn(async move {
662 let mut active_writer = writer;
663
664 loop {
665 match ConnectionMode::from_atomic(&connection_state) {
666 ConnectionMode::Disconnect => {
667 _ = tokio::time::timeout(
670 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
671 active_writer.close(),
672 )
673 .await;
674 break;
675 }
676 ConnectionMode::Closed => break,
677 _ => {}
678 }
679
680 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
681 Ok(Some(msg)) => {
682 let mode = ConnectionMode::from_atomic(&connection_state);
684 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
685 break;
686 }
687
688 match msg {
689 WriterCommand::Update(new_writer) => {
690 tracing::debug!("Received new writer");
691
692 tokio::time::sleep(Duration::from_millis(100)).await;
694
695 _ = tokio::time::timeout(
698 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
699 active_writer.close(),
700 )
701 .await;
702
703 active_writer = new_writer;
704 tracing::debug!("Updated writer");
705 }
706 _ if mode.is_reconnect() => {
707 tracing::warn!("Skipping message while reconnecting, {msg:?}");
708 continue;
709 }
710 WriterCommand::Send(msg) => {
711 if let Err(e) = active_writer.send(msg).await {
712 tracing::error!("Failed to send message: {e}");
713 tracing::warn!("Writer triggering reconnect");
715 connection_state
716 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
717 }
718 }
719 }
720 }
721 Ok(None) => {
722 tracing::debug!("Writer channel closed, terminating writer task");
724 break;
725 }
726 Err(_) => {
727 continue;
729 }
730 }
731 }
732
733 _ = tokio::time::timeout(
736 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
737 active_writer.close(),
738 )
739 .await;
740
741 log_task_stopped("write");
742 })
743 }
744
745 fn spawn_heartbeat_task(
746 connection_state: Arc<AtomicU8>,
747 heartbeat_secs: u64,
748 message: Option<String>,
749 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
750 ) -> tokio::task::JoinHandle<()> {
751 log_task_started("heartbeat");
752
753 tokio::task::spawn(async move {
754 let interval = Duration::from_secs(heartbeat_secs);
755
756 loop {
757 tokio::time::sleep(interval).await;
758
759 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
760 ConnectionMode::Active => {
761 let msg = match &message {
762 Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
763 None => WriterCommand::Send(Message::Ping(vec![].into())),
764 };
765
766 match writer_tx.send(msg) {
767 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
768 Err(e) => {
769 tracing::error!("Failed to send heartbeat to writer task: {e}");
770 }
771 }
772 }
773 ConnectionMode::Reconnect => continue,
774 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
775 }
776 }
777
778 log_task_stopped("heartbeat");
779 })
780 }
781}
782
783impl Drop for WebSocketClientInner {
784 fn drop(&mut self) {
785 self.clean_drop();
787 }
788}
789
790impl CleanDrop for WebSocketClientInner {
792 fn clean_drop(&mut self) {
793 if let Some(ref read_task) = self.read_task.take()
794 && !read_task.is_finished()
795 {
796 read_task.abort();
797 log_task_aborted("read");
798 }
799
800 if !self.write_task.is_finished() {
801 self.write_task.abort();
802 log_task_aborted("write");
803 }
804
805 if let Some(ref handle) = self.heartbeat_task.take()
806 && !handle.is_finished()
807 {
808 handle.abort();
809 log_task_aborted("heartbeat");
810 }
811
812 self.config.message_handler = None;
814 self.config.ping_handler = None;
815 }
816}
817
818#[cfg_attr(
823 feature = "python",
824 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
825)]
826pub struct WebSocketClient {
827 pub(crate) controller_task: tokio::task::JoinHandle<()>,
828 pub(crate) connection_mode: Arc<AtomicU8>,
829 pub(crate) reconnect_timeout: Duration,
830 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
831 pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
832}
833
834impl Debug for WebSocketClient {
835 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
836 f.debug_struct(stringify!(WebSocketClient)).finish()
837 }
838}
839
840impl WebSocketClient {
841 #[allow(clippy::too_many_arguments)]
857 pub async fn connect_stream(
858 config: WebSocketConfig,
859 keyed_quotas: Vec<(String, Quota)>,
860 default_quota: Option<Quota>,
861 post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
862 ) -> Result<(MessageReader, Self), Error> {
863 install_cryptographic_provider();
864
865 let (writer, reader) =
867 WebSocketClientInner::connect_with_server(&config.url, config.headers.clone()).await?;
868
869 let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
871
872 let connection_mode = inner.connection_mode.clone();
873 let reconnect_timeout = inner.reconnect_timeout;
874 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
875 let writer_tx = inner.writer_tx.clone();
876
877 let controller_task =
878 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnect);
879
880 Ok((
881 reader,
882 Self {
883 controller_task,
884 connection_mode,
885 reconnect_timeout,
886 rate_limiter,
887 writer_tx,
888 },
889 ))
890 }
891
892 pub async fn connect(
909 config: WebSocketConfig,
910 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
911 keyed_quotas: Vec<(String, Quota)>,
912 default_quota: Option<Quota>,
913 ) -> Result<Self, Error> {
914 tracing::debug!("Connecting");
915 let inner = WebSocketClientInner::connect_url(config).await?;
916 let connection_mode = inner.connection_mode.clone();
917 let writer_tx = inner.writer_tx.clone();
918 let reconnect_timeout = inner.reconnect_timeout;
919
920 let controller_task =
921 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnection);
922
923 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
924
925 Ok(Self {
926 controller_task,
927 connection_mode,
928 reconnect_timeout,
929 rate_limiter,
930 writer_tx,
931 })
932 }
933
934 #[must_use]
936 pub fn connection_mode(&self) -> ConnectionMode {
937 ConnectionMode::from_atomic(&self.connection_mode)
938 }
939
940 #[inline]
945 #[must_use]
946 pub fn is_active(&self) -> bool {
947 self.connection_mode().is_active()
948 }
949
950 #[must_use]
952 pub fn is_disconnected(&self) -> bool {
953 self.controller_task.is_finished()
954 }
955
956 #[inline]
961 #[must_use]
962 pub fn is_reconnecting(&self) -> bool {
963 self.connection_mode().is_reconnect()
964 }
965
966 #[inline]
970 #[must_use]
971 pub fn is_disconnecting(&self) -> bool {
972 self.connection_mode().is_disconnect()
973 }
974
975 #[inline]
981 #[must_use]
982 pub fn is_closed(&self) -> bool {
983 self.connection_mode().is_closed()
984 }
985
986 async fn wait_for_active(&self) -> Result<(), SendError> {
990 if self.is_closed() {
991 return Err(SendError::Closed);
992 }
993
994 let timeout = self.reconnect_timeout;
995 let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
996
997 if !self.is_active() {
998 tracing::debug!("Waiting for client to become ACTIVE before sending...");
999
1000 let inner = tokio::time::timeout(timeout, async {
1001 loop {
1002 if self.is_active() {
1003 return Ok(());
1004 }
1005 if matches!(
1006 self.connection_mode(),
1007 ConnectionMode::Disconnect | ConnectionMode::Closed
1008 ) {
1009 return Err(());
1010 }
1011 tokio::time::sleep(check_interval).await;
1012 }
1013 })
1014 .await
1015 .map_err(|_| SendError::Timeout)?;
1016 inner.map_err(|()| SendError::Closed)?;
1017 }
1018
1019 Ok(())
1020 }
1021
1022 pub async fn disconnect(&self) {
1027 tracing::debug!("Disconnecting");
1028 self.connection_mode
1029 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1030
1031 if let Ok(()) =
1032 tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1033 while !self.is_disconnected() {
1034 tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS))
1035 .await;
1036 }
1037
1038 if !self.controller_task.is_finished() {
1039 self.controller_task.abort();
1040 log_task_aborted("controller");
1041 }
1042 })
1043 .await
1044 {
1045 tracing::debug!("Controller task finished");
1046 } else {
1047 tracing::error!("Timeout waiting for controller task to finish");
1048 if !self.controller_task.is_finished() {
1049 self.controller_task.abort();
1050 log_task_aborted("controller");
1051 }
1052 }
1053 }
1054
1055 #[allow(unused_variables)]
1061 pub async fn send_text(
1062 &self,
1063 data: String,
1064 keys: Option<Vec<String>>,
1065 ) -> Result<(), SendError> {
1066 self.rate_limiter.await_keys_ready(keys).await;
1067 self.wait_for_active().await?;
1068
1069 tracing::trace!("Sending text: {data:?}");
1070
1071 let msg = Message::Text(data.into());
1072 self.writer_tx
1073 .send(WriterCommand::Send(msg))
1074 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1075 }
1076
1077 pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1083 self.wait_for_active().await?;
1084
1085 tracing::trace!("Sending pong frame ({} bytes)", data.len());
1086
1087 let msg = Message::Pong(data.into());
1088 self.writer_tx
1089 .send(WriterCommand::Send(msg))
1090 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1091 }
1092
1093 #[allow(unused_variables)]
1099 pub async fn send_bytes(
1100 &self,
1101 data: Vec<u8>,
1102 keys: Option<Vec<String>>,
1103 ) -> Result<(), SendError> {
1104 self.rate_limiter.await_keys_ready(keys).await;
1105 self.wait_for_active().await?;
1106
1107 tracing::trace!("Sending bytes: {data:?}");
1108
1109 let msg = Message::Binary(data.into());
1110 self.writer_tx
1111 .send(WriterCommand::Send(msg))
1112 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1113 }
1114
1115 pub async fn send_close_message(&self) -> Result<(), SendError> {
1121 self.wait_for_active().await?;
1122
1123 let msg = Message::Close(None);
1124 self.writer_tx
1125 .send(WriterCommand::Send(msg))
1126 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1127 }
1128
1129 fn spawn_controller_task(
1130 mut inner: WebSocketClientInner,
1131 connection_mode: Arc<AtomicU8>,
1132 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1133 ) -> tokio::task::JoinHandle<()> {
1134 tokio::task::spawn(async move {
1135 log_task_started("controller");
1136
1137 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1138
1139 loop {
1140 tokio::time::sleep(check_interval).await;
1141 let mut mode = ConnectionMode::from_atomic(&connection_mode);
1142
1143 if mode.is_disconnect() {
1144 tracing::debug!("Disconnecting");
1145
1146 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1147 if tokio::time::timeout(timeout, async {
1148 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1150
1151 if let Some(task) = &inner.read_task
1152 && !task.is_finished()
1153 {
1154 task.abort();
1155 log_task_aborted("read");
1156 }
1157
1158 if let Some(task) = &inner.heartbeat_task
1159 && !task.is_finished()
1160 {
1161 task.abort();
1162 log_task_aborted("heartbeat");
1163 }
1164 })
1165 .await
1166 .is_err()
1167 {
1168 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
1169 }
1170
1171 tracing::debug!("Closed");
1172 break; }
1174
1175 if mode.is_active() && !inner.is_alive() {
1176 if connection_mode
1177 .compare_exchange(
1178 ConnectionMode::Active.as_u8(),
1179 ConnectionMode::Reconnect.as_u8(),
1180 Ordering::SeqCst,
1181 Ordering::SeqCst,
1182 )
1183 .is_ok()
1184 {
1185 tracing::debug!("Detected dead read task, transitioning to RECONNECT");
1186 }
1187 mode = ConnectionMode::from_atomic(&connection_mode);
1188 }
1189
1190 if mode.is_reconnect() {
1191 match inner.reconnect().await {
1192 Ok(()) => {
1193 inner.backoff.reset();
1194
1195 if ConnectionMode::from_atomic(&connection_mode).is_active() {
1197 if let Some(ref handler) = inner.config.message_handler {
1198 let reconnected_msg =
1199 Message::Text(RECONNECTED.to_string().into());
1200 handler(reconnected_msg);
1201 tracing::debug!("Sent reconnected message to handler");
1202 }
1203
1204 if let Some(ref callback) = post_reconnection {
1206 callback();
1207 tracing::debug!("Called `post_reconnection` handler");
1208 }
1209
1210 tracing::debug!("Reconnected successfully");
1211 } else {
1212 tracing::debug!(
1213 "Skipping post_reconnection handlers due to disconnect state"
1214 );
1215 }
1216 }
1217 Err(e) => {
1218 let duration = inner.backoff.next_duration();
1219 tracing::warn!("Reconnect attempt failed: {e}");
1220 if !duration.is_zero() {
1221 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
1222 }
1223 tokio::time::sleep(duration).await;
1224 }
1225 }
1226 }
1227 }
1228 inner
1229 .connection_mode
1230 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1231
1232 log_task_stopped("controller");
1233 })
1234 }
1235}
1236
1237impl Drop for WebSocketClient {
1239 fn drop(&mut self) {
1240 if !self.controller_task.is_finished() {
1241 self.controller_task.abort();
1242 log_task_aborted("controller");
1243 }
1244 }
1245}
1246
1247#[cfg(test)]
1252#[cfg(not(feature = "turmoil"))]
1253#[cfg(target_os = "linux")] mod tests {
1255 use std::{num::NonZeroU32, sync::Arc};
1256
1257 use futures_util::{SinkExt, StreamExt};
1258 use tokio::{
1259 net::TcpListener,
1260 task::{self, JoinHandle},
1261 };
1262 use tokio_tungstenite::{
1263 accept_hdr_async,
1264 tungstenite::{
1265 handshake::server::{self, Callback},
1266 http::HeaderValue,
1267 },
1268 };
1269
1270 use crate::{
1271 ratelimiter::quota::Quota,
1272 websocket::{WebSocketClient, WebSocketConfig},
1273 };
1274
1275 struct TestServer {
1276 task: JoinHandle<()>,
1277 port: u16,
1278 }
1279
1280 #[derive(Debug, Clone)]
1281 struct TestCallback {
1282 key: String,
1283 value: HeaderValue,
1284 }
1285
1286 impl Callback for TestCallback {
1287 fn on_request(
1288 self,
1289 request: &server::Request,
1290 response: server::Response,
1291 ) -> Result<server::Response, server::ErrorResponse> {
1292 let _ = response;
1293 let value = request.headers().get(&self.key);
1294 assert!(value.is_some());
1295
1296 if let Some(value) = request.headers().get(&self.key) {
1297 assert_eq!(value, self.value);
1298 }
1299
1300 Ok(response)
1301 }
1302 }
1303
1304 impl TestServer {
1305 async fn setup() -> Self {
1306 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1307 let port = TcpListener::local_addr(&server).unwrap().port();
1308
1309 let header_key = "test".to_string();
1310 let header_value = "test".to_string();
1311
1312 let test_call_back = TestCallback {
1313 key: header_key,
1314 value: HeaderValue::from_str(&header_value).unwrap(),
1315 };
1316
1317 let task = task::spawn(async move {
1318 loop {
1320 let (conn, _) = server.accept().await.unwrap();
1321 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1322 .await
1323 .unwrap();
1324
1325 task::spawn(async move {
1326 while let Some(Ok(msg)) = websocket.next().await {
1327 match msg {
1328 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1329 if txt == "close-now" =>
1330 {
1331 tracing::debug!("Forcibly closing from server side");
1332 let _ = websocket.close(None).await;
1334 break;
1335 }
1336 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1338 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1339 if websocket.send(msg).await.is_err() {
1340 break;
1341 }
1342 }
1343 tokio_tungstenite::tungstenite::protocol::Message::Close(
1345 _frame,
1346 ) => {
1347 let _ = websocket.close(None).await;
1348 break;
1349 }
1350 _ => {}
1352 }
1353 }
1354 });
1355 }
1356 });
1357
1358 Self { task, port }
1359 }
1360 }
1361
1362 impl Drop for TestServer {
1363 fn drop(&mut self) {
1364 self.task.abort();
1365 }
1366 }
1367
1368 async fn setup_test_client(port: u16) -> WebSocketClient {
1369 let config = WebSocketConfig {
1370 url: format!("ws://127.0.0.1:{port}"),
1371 headers: vec![("test".into(), "test".into())],
1372 message_handler: None,
1373 heartbeat: None,
1374 heartbeat_msg: None,
1375 ping_handler: None,
1376 reconnect_timeout_ms: None,
1377 reconnect_delay_initial_ms: None,
1378 reconnect_backoff_factor: None,
1379 reconnect_delay_max_ms: None,
1380 reconnect_jitter_ms: None,
1381 };
1382 WebSocketClient::connect(config, None, vec![], None)
1383 .await
1384 .expect("Failed to connect")
1385 }
1386
1387 #[tokio::test]
1388 async fn test_websocket_basic() {
1389 let server = TestServer::setup().await;
1390 let client = setup_test_client(server.port).await;
1391
1392 assert!(!client.is_disconnected());
1393
1394 client.disconnect().await;
1395 assert!(client.is_disconnected());
1396 }
1397
1398 #[tokio::test]
1399 async fn test_websocket_heartbeat() {
1400 let server = TestServer::setup().await;
1401 let client = setup_test_client(server.port).await;
1402
1403 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1405
1406 client.disconnect().await;
1408 assert!(client.is_disconnected());
1409 }
1410
1411 #[tokio::test]
1412 async fn test_websocket_reconnect_exhausted() {
1413 let config = WebSocketConfig {
1414 url: "ws://127.0.0.1:9997".into(), headers: vec![],
1416 message_handler: None,
1417 heartbeat: None,
1418 heartbeat_msg: None,
1419 ping_handler: None,
1420 reconnect_timeout_ms: None,
1421 reconnect_delay_initial_ms: None,
1422 reconnect_backoff_factor: None,
1423 reconnect_delay_max_ms: None,
1424 reconnect_jitter_ms: None,
1425 };
1426 let res = WebSocketClient::connect(config, None, vec![], None).await;
1427 assert!(res.is_err(), "Should fail quickly with no server");
1428 }
1429
1430 #[tokio::test]
1431 async fn test_websocket_forced_close_reconnect() {
1432 let server = TestServer::setup().await;
1433 let client = setup_test_client(server.port).await;
1434
1435 client.send_text("Hello".into(), None).await.unwrap();
1437
1438 client.send_text("close-now".into(), None).await.unwrap();
1440
1441 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1443
1444 assert!(!client.is_disconnected());
1446
1447 client.disconnect().await;
1449 assert!(client.is_disconnected());
1450 }
1451
1452 #[tokio::test]
1453 async fn test_rate_limiter() {
1454 let server = TestServer::setup().await;
1455 let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1456
1457 let config = WebSocketConfig {
1458 url: format!("ws://127.0.0.1:{}", server.port),
1459 headers: vec![("test".into(), "test".into())],
1460 message_handler: None,
1461 heartbeat: None,
1462 heartbeat_msg: None,
1463 ping_handler: None,
1464 reconnect_timeout_ms: None,
1465 reconnect_delay_initial_ms: None,
1466 reconnect_backoff_factor: None,
1467 reconnect_delay_max_ms: None,
1468 reconnect_jitter_ms: None,
1469 };
1470
1471 let client = WebSocketClient::connect(config, None, vec![("default".into(), quota)], None)
1472 .await
1473 .unwrap();
1474
1475 client.send_text("test1".into(), None).await.unwrap();
1477 client.send_text("test2".into(), None).await.unwrap();
1478
1479 client.send_text("test3".into(), None).await.unwrap();
1481
1482 client.disconnect().await;
1484 assert!(client.is_disconnected());
1485 }
1486
1487 #[tokio::test]
1488 async fn test_concurrent_writers() {
1489 let server = TestServer::setup().await;
1490 let client = Arc::new(setup_test_client(server.port).await);
1491
1492 let mut handles = vec![];
1493 for i in 0..10 {
1494 let client = client.clone();
1495 handles.push(task::spawn(async move {
1496 client.send_text(format!("test{i}"), None).await.unwrap();
1497 }));
1498 }
1499
1500 for handle in handles {
1501 handle.await.unwrap();
1502 }
1503
1504 client.disconnect().await;
1506 assert!(client.is_disconnected());
1507 }
1508}
1509
1510#[cfg(test)]
1515#[cfg(not(feature = "turmoil"))]
1516mod rust_tests {
1517 use futures_util::StreamExt;
1518 use rstest::rstest;
1519 use tokio::{
1520 net::TcpListener,
1521 task,
1522 time::{Duration, sleep},
1523 };
1524 use tokio_tungstenite::accept_async;
1525
1526 use super::*;
1527
1528 #[rstest]
1529 #[tokio::test]
1530 async fn test_reconnect_then_disconnect() {
1531 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1533 let port = listener.local_addr().unwrap().port();
1534
1535 let server = task::spawn(async move {
1537 let (stream, _) = listener.accept().await.unwrap();
1538 let ws = accept_async(stream).await.unwrap();
1539 drop(ws);
1540 sleep(Duration::from_secs(1)).await;
1542 });
1543
1544 let (handler, _rx) = channel_message_handler();
1546
1547 let config = WebSocketConfig {
1549 url: format!("ws://127.0.0.1:{port}"),
1550 headers: vec![],
1551 message_handler: Some(handler),
1552 heartbeat: None,
1553 heartbeat_msg: None,
1554 ping_handler: None,
1555 reconnect_timeout_ms: Some(1_000),
1556 reconnect_delay_initial_ms: Some(50),
1557 reconnect_delay_max_ms: Some(100),
1558 reconnect_backoff_factor: Some(1.0),
1559 reconnect_jitter_ms: Some(0),
1560 };
1561
1562 let client = WebSocketClient::connect(config, None, vec![], None)
1564 .await
1565 .unwrap();
1566
1567 sleep(Duration::from_millis(100)).await;
1569 client.disconnect().await;
1571 assert!(client.is_disconnected());
1572 server.abort();
1573 }
1574
1575 #[rstest]
1576 #[tokio::test]
1577 async fn test_reconnect_state_flips_when_reader_stops() {
1578 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1580 let port = listener.local_addr().unwrap().port();
1581
1582 let server = task::spawn(async move {
1583 if let Ok((stream, _)) = listener.accept().await
1584 && let Ok(ws) = accept_async(stream).await
1585 {
1586 drop(ws);
1587 }
1588 sleep(Duration::from_millis(50)).await;
1589 });
1590
1591 let (handler, _rx) = channel_message_handler();
1592
1593 let config = WebSocketConfig {
1594 url: format!("ws://127.0.0.1:{port}"),
1595 headers: vec![],
1596 message_handler: Some(handler),
1597 heartbeat: None,
1598 heartbeat_msg: None,
1599 ping_handler: None,
1600 reconnect_timeout_ms: Some(1_000),
1601 reconnect_delay_initial_ms: Some(50),
1602 reconnect_delay_max_ms: Some(100),
1603 reconnect_backoff_factor: Some(1.0),
1604 reconnect_jitter_ms: Some(0),
1605 };
1606
1607 let client = WebSocketClient::connect(config, None, vec![], None)
1608 .await
1609 .unwrap();
1610
1611 tokio::time::timeout(Duration::from_secs(2), async {
1612 loop {
1613 if client.is_reconnecting() {
1614 break;
1615 }
1616 tokio::time::sleep(Duration::from_millis(10)).await;
1617 }
1618 })
1619 .await
1620 .expect("client did not enter RECONNECT state");
1621
1622 client.disconnect().await;
1623 server.abort();
1624 }
1625
1626 #[rstest]
1627 #[tokio::test]
1628 async fn test_stream_mode_disables_auto_reconnect() {
1629 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1632 let port = listener.local_addr().unwrap().port();
1633
1634 let server = task::spawn(async move {
1635 if let Ok((stream, _)) = listener.accept().await
1636 && let Ok(_ws) = accept_async(stream).await
1637 {
1638 sleep(Duration::from_millis(100)).await;
1640 }
1641 });
1642
1643 let config = WebSocketConfig {
1644 url: format!("ws://127.0.0.1:{port}"),
1645 headers: vec![],
1646 message_handler: None, heartbeat: None,
1648 heartbeat_msg: None,
1649 ping_handler: None,
1650 reconnect_timeout_ms: Some(1_000),
1651 reconnect_delay_initial_ms: Some(50),
1652 reconnect_delay_max_ms: Some(100),
1653 reconnect_backoff_factor: Some(1.0),
1654 reconnect_jitter_ms: Some(0),
1655 };
1656
1657 let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
1659 .await
1660 .unwrap();
1661
1662 server.abort();
1670 }
1671
1672 #[rstest]
1673 #[tokio::test]
1674 async fn test_message_handler_mode_allows_auto_reconnect() {
1675 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1677 let port = listener.local_addr().unwrap().port();
1678
1679 let server = task::spawn(async move {
1680 if let Ok((stream, _)) = listener.accept().await
1682 && let Ok(ws) = accept_async(stream).await
1683 {
1684 drop(ws);
1685 }
1686 sleep(Duration::from_millis(50)).await;
1687 });
1688
1689 let (handler, _rx) = channel_message_handler();
1690
1691 let config = WebSocketConfig {
1692 url: format!("ws://127.0.0.1:{port}"),
1693 headers: vec![],
1694 message_handler: Some(handler), heartbeat: None,
1696 heartbeat_msg: None,
1697 ping_handler: None,
1698 reconnect_timeout_ms: Some(1_000),
1699 reconnect_delay_initial_ms: Some(50),
1700 reconnect_delay_max_ms: Some(100),
1701 reconnect_backoff_factor: Some(1.0),
1702 reconnect_jitter_ms: Some(0),
1703 };
1704
1705 let client = WebSocketClient::connect(config, None, vec![], None)
1706 .await
1707 .unwrap();
1708
1709 tokio::time::timeout(Duration::from_secs(2), async {
1711 loop {
1712 if client.is_reconnecting() || client.is_closed() {
1713 break;
1714 }
1715 tokio::time::sleep(Duration::from_millis(10)).await;
1716 }
1717 })
1718 .await
1719 .expect("client should attempt reconnection or close");
1720
1721 assert!(
1724 client.is_reconnecting() || client.is_closed(),
1725 "Client with message handler should attempt reconnection"
1726 );
1727
1728 client.disconnect().await;
1729 server.abort();
1730 }
1731
1732 #[rstest]
1733 #[tokio::test]
1734 async fn test_handler_mode_reconnect_with_new_connection() {
1735 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1737 let port = listener.local_addr().unwrap().port();
1738
1739 let server = task::spawn(async move {
1740 if let Ok((stream, _)) = listener.accept().await
1742 && let Ok(ws) = accept_async(stream).await
1743 {
1744 drop(ws);
1745 }
1746
1747 sleep(Duration::from_millis(100)).await;
1749
1750 if let Ok((stream, _)) = listener.accept().await
1752 && let Ok(mut ws) = accept_async(stream).await
1753 {
1754 use futures_util::SinkExt;
1755 let _ = ws
1756 .send(Message::Text("reconnected".to_string().into()))
1757 .await;
1758 sleep(Duration::from_secs(1)).await;
1759 }
1760 });
1761
1762 let (handler, mut rx) = channel_message_handler();
1763
1764 let config = WebSocketConfig {
1765 url: format!("ws://127.0.0.1:{port}"),
1766 headers: vec![],
1767 message_handler: Some(handler),
1768 heartbeat: None,
1769 heartbeat_msg: None,
1770 ping_handler: None,
1771 reconnect_timeout_ms: Some(2_000),
1772 reconnect_delay_initial_ms: Some(50),
1773 reconnect_delay_max_ms: Some(200),
1774 reconnect_backoff_factor: Some(1.5),
1775 reconnect_jitter_ms: Some(10),
1776 };
1777
1778 let client = WebSocketClient::connect(config, None, vec![], None)
1779 .await
1780 .unwrap();
1781
1782 let result = tokio::time::timeout(Duration::from_secs(5), async {
1784 loop {
1785 if let Ok(msg) = rx.try_recv()
1786 && matches!(msg, Message::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
1787 {
1788 return true;
1789 }
1790 tokio::time::sleep(Duration::from_millis(10)).await;
1791 }
1792 })
1793 .await;
1794
1795 assert!(
1796 result.is_ok(),
1797 "Should receive message after reconnection within timeout"
1798 );
1799
1800 client.disconnect().await;
1801 server.abort();
1802 }
1803
1804 #[rstest]
1805 #[tokio::test]
1806 async fn test_stream_mode_no_auto_reconnect() {
1807 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1810 let port = listener.local_addr().unwrap().port();
1811
1812 let server = task::spawn(async move {
1813 if let Ok((stream, _)) = listener.accept().await
1815 && let Ok(mut ws) = accept_async(stream).await
1816 {
1817 use futures_util::SinkExt;
1818 let _ = ws.send(Message::Text("hello".to_string().into())).await;
1819 sleep(Duration::from_millis(50)).await;
1820 }
1822 });
1823
1824 let config = WebSocketConfig {
1825 url: format!("ws://127.0.0.1:{port}"),
1826 headers: vec![],
1827 message_handler: None, heartbeat: None,
1829 heartbeat_msg: None,
1830 ping_handler: None,
1831 reconnect_timeout_ms: Some(1_000),
1832 reconnect_delay_initial_ms: Some(50),
1833 reconnect_delay_max_ms: Some(100),
1834 reconnect_backoff_factor: Some(1.0),
1835 reconnect_jitter_ms: Some(0),
1836 };
1837
1838 let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
1839 .await
1840 .unwrap();
1841
1842 assert!(client.is_active(), "Client should start as active");
1844
1845 let msg = reader.next().await;
1847 assert!(
1848 matches!(msg, Some(Ok(Message::Text(ref text))) if AsRef::<str>::as_ref(text) == "hello"),
1849 "Should receive initial message"
1850 );
1851
1852 while let Some(msg) = reader.next().await {
1854 if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
1855 break;
1856 }
1857 }
1858
1859 sleep(Duration::from_millis(200)).await;
1862
1863 assert!(
1866 client.is_active() || client.is_closed(),
1867 "Stream mode client stays ACTIVE (caller owns reader) or caller disconnected"
1868 );
1869 assert!(
1870 !client.is_reconnecting(),
1871 "Stream mode client should never attempt reconnection"
1872 );
1873
1874 client.disconnect().await;
1875 server.abort();
1876 }
1877
1878 #[rstest]
1879 #[tokio::test]
1880 async fn test_send_timeout_uses_configured_reconnect_timeout() {
1881 use nautilus_common::testing::wait_until_async;
1884
1885 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1886 let port = listener.local_addr().unwrap().port();
1887
1888 let server = task::spawn(async move {
1889 if let Ok((stream, _)) = listener.accept().await
1891 && let Ok(ws) = accept_async(stream).await
1892 {
1893 drop(ws);
1894 }
1895 sleep(Duration::from_secs(60)).await;
1897 });
1898
1899 let (handler, _rx) = channel_message_handler();
1900
1901 let config = WebSocketConfig {
1903 url: format!("ws://127.0.0.1:{port}"),
1904 headers: vec![],
1905 message_handler: Some(handler),
1906 heartbeat: None,
1907 heartbeat_msg: None,
1908 ping_handler: None,
1909 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(50),
1911 reconnect_delay_max_ms: Some(100),
1912 reconnect_backoff_factor: Some(1.0),
1913 reconnect_jitter_ms: Some(0),
1914 };
1915
1916 let client = WebSocketClient::connect(config, None, vec![], None)
1917 .await
1918 .unwrap();
1919
1920 wait_until_async(
1922 || async { client.is_reconnecting() },
1923 Duration::from_secs(3),
1924 )
1925 .await;
1926
1927 let start = std::time::Instant::now();
1929 let send_result = client.send_text("test".to_string(), None).await;
1930 let elapsed = start.elapsed();
1931
1932 assert!(
1933 send_result.is_err(),
1934 "Send should fail when client stuck in RECONNECT"
1935 );
1936 assert!(
1937 matches!(send_result, Err(crate::error::SendError::Timeout)),
1938 "Send should return Timeout error, got: {:?}",
1939 send_result
1940 );
1941 assert!(
1944 elapsed >= Duration::from_millis(1800),
1945 "Send should timeout after at least 2s (configured timeout), took {:?}",
1946 elapsed
1947 );
1948
1949 client.disconnect().await;
1950 server.abort();
1951 }
1952
1953 #[rstest]
1954 #[tokio::test]
1955 async fn test_send_waits_during_reconnection() {
1956 use nautilus_common::testing::wait_until_async;
1958
1959 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1960 let port = listener.local_addr().unwrap().port();
1961
1962 let server = task::spawn(async move {
1963 if let Ok((stream, _)) = listener.accept().await
1965 && let Ok(ws) = accept_async(stream).await
1966 {
1967 drop(ws);
1968 }
1969
1970 sleep(Duration::from_millis(500)).await;
1972
1973 if let Ok((stream, _)) = listener.accept().await
1975 && let Ok(mut ws) = accept_async(stream).await
1976 {
1977 while let Some(Ok(msg)) = ws.next().await {
1979 if ws.send(msg).await.is_err() {
1980 break;
1981 }
1982 }
1983 }
1984 });
1985
1986 let (handler, _rx) = channel_message_handler();
1987
1988 let config = WebSocketConfig {
1989 url: format!("ws://127.0.0.1:{port}"),
1990 headers: vec![],
1991 message_handler: Some(handler),
1992 heartbeat: None,
1993 heartbeat_msg: None,
1994 ping_handler: None,
1995 reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
1997 reconnect_delay_max_ms: Some(200),
1998 reconnect_backoff_factor: Some(1.0),
1999 reconnect_jitter_ms: Some(0),
2000 };
2001
2002 let client = WebSocketClient::connect(config, None, vec![], None)
2003 .await
2004 .unwrap();
2005
2006 wait_until_async(
2008 || async { client.is_reconnecting() },
2009 Duration::from_secs(2),
2010 )
2011 .await;
2012
2013 let send_result = tokio::time::timeout(
2015 Duration::from_secs(3),
2016 client.send_text("test_message".to_string(), None),
2017 )
2018 .await;
2019
2020 assert!(
2021 send_result.is_ok() && send_result.unwrap().is_ok(),
2022 "Send should succeed after waiting for reconnection"
2023 );
2024
2025 client.disconnect().await;
2026 server.abort();
2027 }
2028
2029 #[rstest]
2030 #[tokio::test]
2031 async fn test_rate_limiter_before_active_wait() {
2032 use std::{num::NonZeroU32, sync::Arc};
2037
2038 use nautilus_common::testing::wait_until_async;
2039
2040 use crate::ratelimiter::quota::Quota;
2041
2042 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2043 let port = listener.local_addr().unwrap().port();
2044
2045 let server = task::spawn(async move {
2046 if let Ok((stream, _)) = listener.accept().await
2048 && let Ok(mut ws) = accept_async(stream).await
2049 {
2050 if let Some(Ok(_)) = ws.next().await {
2052 drop(ws);
2053 }
2054 }
2055
2056 sleep(Duration::from_millis(500)).await;
2058
2059 if let Ok((stream, _)) = listener.accept().await
2061 && let Ok(mut ws) = accept_async(stream).await
2062 {
2063 while let Some(Ok(msg)) = ws.next().await {
2064 if ws.send(msg).await.is_err() {
2065 break;
2066 }
2067 }
2068 }
2069 });
2070
2071 let (handler, _rx) = channel_message_handler();
2072
2073 let config = WebSocketConfig {
2074 url: format!("ws://127.0.0.1:{port}"),
2075 headers: vec![],
2076 message_handler: Some(handler),
2077 heartbeat: None,
2078 heartbeat_msg: None,
2079 ping_handler: None,
2080 reconnect_timeout_ms: Some(5_000),
2081 reconnect_delay_initial_ms: Some(50),
2082 reconnect_delay_max_ms: Some(100),
2083 reconnect_backoff_factor: Some(1.0),
2084 reconnect_jitter_ms: Some(0),
2085 };
2086
2087 let quota =
2089 Quota::per_second(NonZeroU32::new(1).unwrap()).allow_burst(NonZeroU32::new(1).unwrap());
2090
2091 let client = Arc::new(
2092 WebSocketClient::connect(config, None, vec![("test_key".to_string(), quota)], None)
2093 .await
2094 .unwrap(),
2095 );
2096
2097 client
2099 .send_text("msg1".to_string(), Some(vec!["test_key".to_string()]))
2100 .await
2101 .unwrap();
2102
2103 wait_until_async(
2105 || async { client.is_reconnecting() },
2106 Duration::from_secs(2),
2107 )
2108 .await;
2109
2110 let start = std::time::Instant::now();
2112 let send_result = client
2113 .send_text("msg2".to_string(), Some(vec!["test_key".to_string()]))
2114 .await;
2115 let elapsed = start.elapsed();
2116
2117 assert!(
2119 send_result.is_ok(),
2120 "Send should succeed after rate limit + reconnection, got: {:?}",
2121 send_result
2122 );
2123 assert!(
2127 elapsed >= Duration::from_millis(850),
2128 "Should wait for rate limit (~1s), waited {:?}",
2129 elapsed
2130 );
2131
2132 client.disconnect().await;
2133 server.abort();
2134 }
2135
2136 #[rstest]
2137 #[tokio::test]
2138 async fn test_disconnect_during_reconnect_exits_cleanly() {
2139 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2142 let port = listener.local_addr().unwrap().port();
2143
2144 let server = task::spawn(async move {
2145 if let Ok((stream, _)) = listener.accept().await
2147 && let Ok(ws) = accept_async(stream).await
2148 {
2149 drop(ws);
2150 }
2151 sleep(Duration::from_secs(60)).await;
2153 });
2154
2155 let (handler, _rx) = channel_message_handler();
2156
2157 let config = WebSocketConfig {
2158 url: format!("ws://127.0.0.1:{port}"),
2159 headers: vec![],
2160 message_handler: Some(handler),
2161 heartbeat: None,
2162 heartbeat_msg: None,
2163 ping_handler: None,
2164 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(100),
2166 reconnect_delay_max_ms: Some(200),
2167 reconnect_backoff_factor: Some(1.0),
2168 reconnect_jitter_ms: Some(0),
2169 };
2170
2171 let client = WebSocketClient::connect(config, None, vec![], None)
2172 .await
2173 .unwrap();
2174
2175 tokio::time::timeout(Duration::from_secs(2), async {
2177 while !client.is_reconnecting() {
2178 sleep(Duration::from_millis(10)).await;
2179 }
2180 })
2181 .await
2182 .expect("Client should enter RECONNECT state");
2183
2184 client.disconnect().await;
2186
2187 assert!(
2189 client.is_disconnected(),
2190 "Client should be cleanly disconnected"
2191 );
2192
2193 server.abort();
2194 }
2195}