1use std::{
26 collections::VecDeque,
27 fmt::Debug,
28 sync::{
29 Arc,
30 atomic::{AtomicU8, Ordering},
31 },
32 time::Duration,
33};
34
35use futures_util::{SinkExt, StreamExt};
36use http::HeaderName;
37use nautilus_core::CleanDrop;
38use nautilus_cryptography::providers::install_cryptographic_provider;
39#[cfg(feature = "turmoil")]
40use tokio_tungstenite::MaybeTlsStream;
41#[cfg(feature = "turmoil")]
42use tokio_tungstenite::client_async;
43#[cfg(not(feature = "turmoil"))]
44use tokio_tungstenite::connect_async_with_config;
45use tokio_tungstenite::tungstenite::{
46 Error, Message, client::IntoClientRequest, http::HeaderValue,
47};
48
49use super::{
50 config::WebSocketConfig,
51 consts::{
52 CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
53 GRACEFUL_SHUTDOWN_TIMEOUT_SECS, SEND_OPERATION_CHECK_INTERVAL_MS,
54 },
55 types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
56};
57#[cfg(feature = "turmoil")]
58use crate::net::TcpConnector;
59use crate::{
60 RECONNECTED,
61 backoff::ExponentialBackoff,
62 error::SendError,
63 logging::{log_task_aborted, log_task_started, log_task_stopped},
64 mode::ConnectionMode,
65 ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
66};
67
68pub struct WebSocketClientInner {
84 config: WebSocketConfig,
85 read_task: Option<tokio::task::JoinHandle<()>>,
86 write_task: tokio::task::JoinHandle<()>,
87 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
88 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
89 connection_mode: Arc<AtomicU8>,
90 reconnect_timeout: Duration,
91 backoff: ExponentialBackoff,
92 is_stream_mode: bool,
96 reconnect_max_attempts: Option<u32>,
98 reconnection_attempt_count: u32,
100}
101
102impl WebSocketClientInner {
103 pub async fn new_with_writer(
109 config: WebSocketConfig,
110 writer: MessageWriter,
111 ) -> Result<Self, Error> {
112 install_cryptographic_provider();
113
114 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
115
116 let read_task = None;
118
119 let backoff = ExponentialBackoff::new(
120 Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
121 Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
122 config.reconnect_backoff_factor.unwrap_or(1.5),
123 config.reconnect_jitter_ms.unwrap_or(100),
124 true, )
126 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
127
128 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
129 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
130
131 let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
132 Some(Self::spawn_heartbeat_task(
133 connection_mode.clone(),
134 heartbeat_interval,
135 config.heartbeat_msg.clone(),
136 writer_tx.clone(),
137 ))
138 } else {
139 None
140 };
141
142 Ok(Self {
143 config: config.clone(),
144 writer_tx,
145 connection_mode,
146 reconnect_timeout: Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10000)),
147 heartbeat_task,
148 read_task,
149 write_task,
150 backoff,
151 is_stream_mode: true,
152 reconnect_max_attempts: config.reconnect_max_attempts,
153 reconnection_attempt_count: 0,
154 })
155 }
156
157 pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
165 install_cryptographic_provider();
166
167 let WebSocketConfig {
168 url,
169 message_handler,
170 heartbeat,
171 headers,
172 heartbeat_msg,
173 ping_handler,
174 reconnect_timeout_ms,
175 reconnect_delay_initial_ms,
176 reconnect_delay_max_ms,
177 reconnect_backoff_factor,
178 reconnect_jitter_ms,
179 reconnect_max_attempts,
180 } = &config;
181
182 let is_stream_mode = message_handler.is_none();
184 let reconnect_max_attempts = *reconnect_max_attempts;
185
186 let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
187
188 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
189
190 let read_task = if message_handler.is_some() {
191 Some(Self::spawn_message_handler_task(
192 connection_mode.clone(),
193 reader,
194 message_handler.as_ref(),
195 ping_handler.as_ref(),
196 ))
197 } else {
198 None
199 };
200
201 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
202 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
203
204 let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
206 Self::spawn_heartbeat_task(
207 connection_mode.clone(),
208 *heartbeat_secs,
209 heartbeat_msg.clone(),
210 writer_tx.clone(),
211 )
212 });
213
214 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
215 let backoff = ExponentialBackoff::new(
216 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
217 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
218 reconnect_backoff_factor.unwrap_or(1.5),
219 reconnect_jitter_ms.unwrap_or(100),
220 true, )
222 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
223
224 Ok(Self {
225 config,
226 read_task,
227 write_task,
228 writer_tx,
229 heartbeat_task,
230 connection_mode,
231 reconnect_timeout,
232 backoff,
233 is_stream_mode,
235 reconnect_max_attempts,
236 reconnection_attempt_count: 0,
237 })
238 }
239
240 #[inline]
250 #[cfg(not(feature = "turmoil"))]
251 pub async fn connect_with_server(
252 url: &str,
253 headers: Vec<(String, String)>,
254 ) -> Result<(MessageWriter, MessageReader), Error> {
255 let mut request = url.into_client_request()?;
256 let req_headers = request.headers_mut();
257
258 let mut header_names: Vec<HeaderName> = Vec::new();
259 for (key, val) in headers {
260 let header_value = HeaderValue::from_str(&val)?;
261 let header_name: HeaderName = key.parse()?;
262 header_names.push(header_name.clone());
263 req_headers.insert(header_name, header_value);
264 }
265
266 connect_async_with_config(request, None, true)
267 .await
268 .map(|resp| resp.0.split())
269 }
270
271 #[inline]
284 #[cfg(feature = "turmoil")]
285 pub async fn connect_with_server(
286 url: &str,
287 headers: Vec<(String, String)>,
288 ) -> Result<(MessageWriter, MessageReader), Error> {
289 use rustls::ClientConfig;
290 use tokio_rustls::TlsConnector;
291
292 let mut request = url.into_client_request()?;
293 let req_headers = request.headers_mut();
294
295 let mut header_names: Vec<HeaderName> = Vec::new();
296 for (key, val) in headers {
297 let header_value = HeaderValue::from_str(&val)?;
298 let header_name: HeaderName = key.parse()?;
299 header_names.push(header_name.clone());
300 req_headers.insert(header_name, header_value);
301 }
302
303 let uri = request.uri();
304 let scheme = uri.scheme_str().unwrap_or("ws");
305 let host = uri.host().ok_or_else(|| {
306 Error::Url(tokio_tungstenite::tungstenite::error::UrlError::NoHostName)
307 })?;
308
309 let port = uri
311 .port_u16()
312 .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
313
314 let addr = format!("{host}:{port}");
315
316 let connector = crate::net::RealTcpConnector;
318 let tcp_stream = connector.connect(&addr).await?;
319 if let Err(e) = tcp_stream.set_nodelay(true) {
320 tracing::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
321 }
322
323 let maybe_tls_stream = if scheme == "wss" {
325 let mut root_store = rustls::RootCertStore::empty();
327 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
328
329 let config = ClientConfig::builder()
330 .with_root_certificates(root_store)
331 .with_no_client_auth();
332
333 let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
334 let domain =
335 rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|e| {
336 Error::Io(std::io::Error::new(
337 std::io::ErrorKind::InvalidInput,
338 format!("Invalid DNS name: {e}"),
339 ))
340 })?;
341
342 let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
343 MaybeTlsStream::Rustls(tls_stream)
344 } else {
345 MaybeTlsStream::Plain(tcp_stream)
346 };
347
348 client_async(request, maybe_tls_stream)
350 .await
351 .map(|resp| resp.0.split())
352 }
353
354 pub async fn reconnect(&mut self) -> Result<(), Error> {
369 tracing::debug!("Reconnecting");
370
371 if self.is_stream_mode {
372 tracing::warn!(
373 "Auto-reconnect disabled for stream-based WebSocket client; \
374 stream users must manually reconnect by creating a new connection"
375 );
376 self.connection_mode
378 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
379 return Ok(());
380 }
381
382 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
383 tracing::debug!("Reconnect aborted due to disconnect state");
384 return Ok(());
385 }
386
387 tokio::time::timeout(self.reconnect_timeout, async {
388 let (new_writer, reader) =
390 Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
391
392 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
393 tracing::debug!("Reconnect aborted mid-flight (after connect)");
394 return Ok(());
395 }
396
397 let (tx, rx) = tokio::sync::oneshot::channel();
401 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
402 tracing::error!("{e}");
403 return Err(Error::Io(std::io::Error::new(
404 std::io::ErrorKind::BrokenPipe,
405 format!("Failed to send update command: {e}"),
406 )));
407 }
408
409 match rx.await {
411 Ok(true) => tracing::debug!("Writer confirmed buffer drain success"),
412 Ok(false) => {
413 tracing::warn!("Writer failed to drain buffer, aborting reconnect");
414 return Err(Error::Io(std::io::Error::other(
416 "Failed to drain reconnection buffer",
417 )));
418 }
419 Err(e) => {
420 tracing::error!("Writer dropped update channel: {e}");
421 return Err(Error::Io(std::io::Error::new(
422 std::io::ErrorKind::BrokenPipe,
423 "Writer task dropped response channel",
424 )));
425 }
426 }
427
428 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
430
431 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
432 tracing::debug!("Reconnect aborted mid-flight (after delay)");
433 return Ok(());
434 }
435
436 if let Some(ref read_task) = self.read_task.take()
437 && !read_task.is_finished()
438 {
439 read_task.abort();
440 log_task_aborted("read");
441 }
442
443 if self
446 .connection_mode
447 .compare_exchange(
448 ConnectionMode::Reconnect.as_u8(),
449 ConnectionMode::Active.as_u8(),
450 Ordering::SeqCst,
451 Ordering::SeqCst,
452 )
453 .is_err()
454 {
455 tracing::debug!("Reconnect aborted (state changed during reconnect)");
456 return Ok(());
457 }
458
459 self.read_task = if self.config.message_handler.is_some() {
460 Some(Self::spawn_message_handler_task(
461 self.connection_mode.clone(),
462 reader,
463 self.config.message_handler.as_ref(),
464 self.config.ping_handler.as_ref(),
465 ))
466 } else {
467 None
468 };
469
470 tracing::debug!("Reconnect succeeded");
471 Ok(())
472 })
473 .await
474 .map_err(|_| {
475 Error::Io(std::io::Error::new(
476 std::io::ErrorKind::TimedOut,
477 format!(
478 "reconnection timed out after {}s",
479 self.reconnect_timeout.as_secs_f64()
480 ),
481 ))
482 })?
483 }
484
485 #[inline]
493 #[must_use]
494 pub fn is_alive(&self) -> bool {
495 match &self.read_task {
496 Some(read_task) => !read_task.is_finished(),
497 None => true, }
499 }
500
501 fn spawn_message_handler_task(
502 connection_state: Arc<AtomicU8>,
503 mut reader: MessageReader,
504 message_handler: Option<&MessageHandler>,
505 ping_handler: Option<&PingHandler>,
506 ) -> tokio::task::JoinHandle<()> {
507 tracing::debug!("Started message handler task 'read'");
508
509 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
510
511 let message_handler = message_handler.cloned();
513 let ping_handler = ping_handler.cloned();
514
515 tokio::task::spawn(async move {
516 loop {
517 if !ConnectionMode::from_atomic(&connection_state).is_active() {
518 break;
519 }
520
521 match tokio::time::timeout(check_interval, reader.next()).await {
522 Ok(Some(Ok(Message::Binary(data)))) => {
523 tracing::trace!("Received message <binary> {} bytes", data.len());
524 if let Some(ref handler) = message_handler {
525 handler(Message::Binary(data));
526 }
527 }
528 Ok(Some(Ok(Message::Text(data)))) => {
529 tracing::trace!("Received message: {data}");
530 if let Some(ref handler) = message_handler {
531 handler(Message::Text(data));
532 }
533 }
534 Ok(Some(Ok(Message::Ping(ping_data)))) => {
535 tracing::trace!("Received ping: {ping_data:?}");
536 if let Some(ref handler) = ping_handler {
537 handler(ping_data.to_vec());
538 }
539 }
540 Ok(Some(Ok(Message::Pong(_)))) => {
541 tracing::trace!("Received pong");
542 }
543 Ok(Some(Ok(Message::Close(_)))) => {
544 tracing::debug!("Received close message - terminating");
545 break;
546 }
547 Ok(Some(Ok(_))) => (),
548 Ok(Some(Err(e))) => {
549 tracing::error!("Received error message - terminating: {e}");
550 break;
551 }
552 Ok(None) => {
553 tracing::debug!("No message received - terminating");
554 break;
555 }
556 Err(_) => {
557 continue;
559 }
560 }
561 }
562 })
563 }
564
565 async fn drain_reconnect_buffer(
570 buffer: &mut VecDeque<Message>,
571 writer: &mut MessageWriter,
572 ) -> bool {
573 if buffer.is_empty() {
574 return false;
575 }
576
577 let initial_buffer_len = buffer.len();
578 tracing::info!(
579 "Sending {} buffered messages after reconnection",
580 initial_buffer_len
581 );
582
583 let mut send_error_occurred = false;
584
585 while let Some(buffered_msg) = buffer.front() {
586 let msg_to_send = buffered_msg.clone();
588
589 if let Err(e) = writer.send(msg_to_send).await {
590 tracing::error!(
591 "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
592 buffer.len()
593 );
594 send_error_occurred = true;
595 break; }
597
598 buffer.pop_front();
600 }
601
602 if buffer.is_empty() {
603 tracing::info!(
604 "Successfully sent all {} buffered messages",
605 initial_buffer_len
606 );
607 }
608
609 send_error_occurred
610 }
611
612 fn spawn_write_task(
613 connection_state: Arc<AtomicU8>,
614 writer: MessageWriter,
615 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
616 ) -> tokio::task::JoinHandle<()> {
617 log_task_started("write");
618
619 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
621
622 tokio::task::spawn(async move {
623 let mut active_writer = writer;
624 let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
627
628 loop {
629 match ConnectionMode::from_atomic(&connection_state) {
630 ConnectionMode::Disconnect => {
631 if !reconnect_buffer.is_empty() {
633 tracing::warn!(
634 "Discarding {} buffered messages due to disconnect",
635 reconnect_buffer.len()
636 );
637 reconnect_buffer.clear();
638 }
639
640 _ = tokio::time::timeout(
643 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
644 active_writer.close(),
645 )
646 .await;
647 break;
648 }
649 ConnectionMode::Closed => {
650 if !reconnect_buffer.is_empty() {
652 tracing::warn!(
653 "Discarding {} buffered messages due to closed connection",
654 reconnect_buffer.len()
655 );
656 reconnect_buffer.clear();
657 }
658 break;
659 }
660 _ => {}
661 }
662
663 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
664 Ok(Some(msg)) => {
665 let mode = ConnectionMode::from_atomic(&connection_state);
667 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
668 break;
669 }
670
671 match msg {
672 WriterCommand::Update(new_writer, tx) => {
673 tracing::debug!("Received new writer");
674
675 tokio::time::sleep(Duration::from_millis(100)).await;
677
678 _ = tokio::time::timeout(
681 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
682 active_writer.close(),
683 )
684 .await;
685
686 active_writer = new_writer;
687 tracing::debug!("Updated writer");
688
689 let send_error = Self::drain_reconnect_buffer(
690 &mut reconnect_buffer,
691 &mut active_writer,
692 )
693 .await;
694
695 if let Err(e) = tx.send(!send_error) {
696 tracing::error!(
697 "Failed to report drain status to controller: {e:?}"
698 );
699 }
700 }
701 WriterCommand::Send(msg) if mode.is_reconnect() => {
702 tracing::debug!(
704 "Buffering message during reconnection (buffer size: {})",
705 reconnect_buffer.len() + 1
706 );
707 reconnect_buffer.push_back(msg);
708 }
709 WriterCommand::Send(msg) => {
710 if let Err(e) = active_writer.send(msg.clone()).await {
711 tracing::error!("Failed to send message: {e}");
712 tracing::warn!("Writer triggering reconnect");
713 reconnect_buffer.push_back(msg);
714 connection_state
715 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
716 }
717 }
718 }
719 }
720 Ok(None) => {
721 tracing::debug!("Writer channel closed, terminating writer task");
723 break;
724 }
725 Err(_) => {
726 continue;
728 }
729 }
730 }
731
732 _ = tokio::time::timeout(
735 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
736 active_writer.close(),
737 )
738 .await;
739
740 log_task_stopped("write");
741 })
742 }
743
744 fn spawn_heartbeat_task(
745 connection_state: Arc<AtomicU8>,
746 heartbeat_secs: u64,
747 message: Option<String>,
748 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
749 ) -> tokio::task::JoinHandle<()> {
750 log_task_started("heartbeat");
751
752 tokio::task::spawn(async move {
753 let interval = Duration::from_secs(heartbeat_secs);
754
755 loop {
756 tokio::time::sleep(interval).await;
757
758 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
759 ConnectionMode::Active => {
760 let msg = match &message {
761 Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
762 None => WriterCommand::Send(Message::Ping(vec![].into())),
763 };
764
765 match writer_tx.send(msg) {
766 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
767 Err(e) => {
768 tracing::error!("Failed to send heartbeat to writer task: {e}");
769 }
770 }
771 }
772 ConnectionMode::Reconnect => continue,
773 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
774 }
775 }
776
777 log_task_stopped("heartbeat");
778 })
779 }
780}
781
782impl Drop for WebSocketClientInner {
783 fn drop(&mut self) {
784 self.clean_drop();
786 }
787}
788
789impl CleanDrop for WebSocketClientInner {
791 fn clean_drop(&mut self) {
792 if let Some(ref read_task) = self.read_task.take()
793 && !read_task.is_finished()
794 {
795 read_task.abort();
796 log_task_aborted("read");
797 }
798
799 if !self.write_task.is_finished() {
800 self.write_task.abort();
801 log_task_aborted("write");
802 }
803
804 if let Some(ref handle) = self.heartbeat_task.take()
805 && !handle.is_finished()
806 {
807 handle.abort();
808 log_task_aborted("heartbeat");
809 }
810
811 self.config.message_handler = None;
813 self.config.ping_handler = None;
814 }
815}
816
817impl Debug for WebSocketClientInner {
818 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
819 f.debug_struct("WebSocketClientInner")
820 .field("config", &self.config)
821 .field(
822 "connection_mode",
823 &ConnectionMode::from_atomic(&self.connection_mode),
824 )
825 .field("reconnect_timeout", &self.reconnect_timeout)
826 .field("is_stream_mode", &self.is_stream_mode)
827 .finish()
828 }
829}
830
831#[cfg_attr(
836 feature = "python",
837 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
838)]
839pub struct WebSocketClient {
840 pub(crate) controller_task: tokio::task::JoinHandle<()>,
841 pub(crate) connection_mode: Arc<AtomicU8>,
842 pub(crate) reconnect_timeout: Duration,
843 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
844 pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
845}
846
847impl Debug for WebSocketClient {
848 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
849 f.debug_struct(stringify!(WebSocketClient)).finish()
850 }
851}
852
853impl WebSocketClient {
854 #[allow(clippy::too_many_arguments)]
870 pub async fn connect_stream(
871 config: WebSocketConfig,
872 keyed_quotas: Vec<(String, Quota)>,
873 default_quota: Option<Quota>,
874 post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
875 ) -> Result<(MessageReader, Self), Error> {
876 install_cryptographic_provider();
877
878 let (writer, reader) =
880 WebSocketClientInner::connect_with_server(&config.url, config.headers.clone()).await?;
881
882 let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
884
885 let connection_mode = inner.connection_mode.clone();
886 let reconnect_timeout = inner.reconnect_timeout;
887 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
888 let writer_tx = inner.writer_tx.clone();
889
890 let controller_task =
891 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnect);
892
893 Ok((
894 reader,
895 Self {
896 controller_task,
897 connection_mode,
898 reconnect_timeout,
899 rate_limiter,
900 writer_tx,
901 },
902 ))
903 }
904
905 pub async fn connect(
922 config: WebSocketConfig,
923 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
924 keyed_quotas: Vec<(String, Quota)>,
925 default_quota: Option<Quota>,
926 ) -> Result<Self, Error> {
927 if config.message_handler.is_none() {
929 return Err(Error::Io(std::io::Error::new(
930 std::io::ErrorKind::InvalidInput,
931 "Handler mode requires config.message_handler to be set. Use connect_stream() for stream mode without a handler.",
932 )));
933 }
934
935 tracing::debug!("Connecting");
936 let inner = WebSocketClientInner::connect_url(config).await?;
937 let connection_mode = inner.connection_mode.clone();
938 let writer_tx = inner.writer_tx.clone();
939 let reconnect_timeout = inner.reconnect_timeout;
940
941 let controller_task =
942 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnection);
943
944 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
945
946 Ok(Self {
947 controller_task,
948 connection_mode,
949 reconnect_timeout,
950 rate_limiter,
951 writer_tx,
952 })
953 }
954
955 #[must_use]
957 pub fn connection_mode(&self) -> ConnectionMode {
958 ConnectionMode::from_atomic(&self.connection_mode)
959 }
960
961 #[must_use]
966 pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
967 Arc::clone(&self.connection_mode)
968 }
969
970 #[inline]
975 #[must_use]
976 pub fn is_active(&self) -> bool {
977 self.connection_mode().is_active()
978 }
979
980 #[must_use]
982 pub fn is_disconnected(&self) -> bool {
983 self.controller_task.is_finished()
984 }
985
986 #[inline]
991 #[must_use]
992 pub fn is_reconnecting(&self) -> bool {
993 self.connection_mode().is_reconnect()
994 }
995
996 #[inline]
1000 #[must_use]
1001 pub fn is_disconnecting(&self) -> bool {
1002 self.connection_mode().is_disconnect()
1003 }
1004
1005 #[inline]
1011 #[must_use]
1012 pub fn is_closed(&self) -> bool {
1013 self.connection_mode().is_closed()
1014 }
1015
1016 async fn wait_for_active(&self) -> Result<(), SendError> {
1020 if self.is_closed() {
1021 return Err(SendError::Closed);
1022 }
1023
1024 let timeout = self.reconnect_timeout;
1025 let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
1026
1027 if !self.is_active() {
1028 tracing::debug!("Waiting for client to become ACTIVE before sending...");
1029
1030 let inner = tokio::time::timeout(timeout, async {
1031 loop {
1032 if self.is_active() {
1033 return Ok(());
1034 }
1035 if matches!(
1036 self.connection_mode(),
1037 ConnectionMode::Disconnect | ConnectionMode::Closed
1038 ) {
1039 return Err(());
1040 }
1041 tokio::time::sleep(check_interval).await;
1042 }
1043 })
1044 .await
1045 .map_err(|_| SendError::Timeout)?;
1046 inner.map_err(|()| SendError::Closed)?;
1047 }
1048
1049 Ok(())
1050 }
1051
1052 pub async fn disconnect(&self) {
1057 tracing::debug!("Disconnecting");
1058 self.connection_mode
1059 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1060
1061 if let Ok(()) =
1062 tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1063 while !self.is_disconnected() {
1064 tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS))
1065 .await;
1066 }
1067
1068 if !self.controller_task.is_finished() {
1069 self.controller_task.abort();
1070 log_task_aborted("controller");
1071 }
1072 })
1073 .await
1074 {
1075 tracing::debug!("Controller task finished");
1076 } else {
1077 tracing::error!("Timeout waiting for controller task to finish");
1078 if !self.controller_task.is_finished() {
1079 self.controller_task.abort();
1080 log_task_aborted("controller");
1081 }
1082 }
1083 }
1084
1085 #[allow(unused_variables)]
1091 pub async fn send_text(
1092 &self,
1093 data: String,
1094 keys: Option<Vec<String>>,
1095 ) -> Result<(), SendError> {
1096 if self.is_closed() || self.is_disconnecting() {
1098 return Err(SendError::Closed);
1099 }
1100
1101 self.rate_limiter.await_keys_ready(keys).await;
1102 self.wait_for_active().await?;
1103
1104 tracing::trace!("Sending text: {data:?}");
1105
1106 let msg = Message::Text(data.into());
1107 self.writer_tx
1108 .send(WriterCommand::Send(msg))
1109 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1110 }
1111
1112 pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1118 self.wait_for_active().await?;
1119
1120 tracing::trace!("Sending pong frame ({} bytes)", data.len());
1121
1122 let msg = Message::Pong(data.into());
1123 self.writer_tx
1124 .send(WriterCommand::Send(msg))
1125 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1126 }
1127
1128 #[allow(unused_variables)]
1134 pub async fn send_bytes(
1135 &self,
1136 data: Vec<u8>,
1137 keys: Option<Vec<String>>,
1138 ) -> Result<(), SendError> {
1139 if self.is_closed() || self.is_disconnecting() {
1141 return Err(SendError::Closed);
1142 }
1143
1144 self.rate_limiter.await_keys_ready(keys).await;
1145 self.wait_for_active().await?;
1146
1147 tracing::trace!("Sending bytes: {data:?}");
1148
1149 let msg = Message::Binary(data.into());
1150 self.writer_tx
1151 .send(WriterCommand::Send(msg))
1152 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1153 }
1154
1155 pub async fn send_close_message(&self) -> Result<(), SendError> {
1161 self.wait_for_active().await?;
1162
1163 let msg = Message::Close(None);
1164 self.writer_tx
1165 .send(WriterCommand::Send(msg))
1166 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1167 }
1168
1169 fn spawn_controller_task(
1170 mut inner: WebSocketClientInner,
1171 connection_mode: Arc<AtomicU8>,
1172 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1173 ) -> tokio::task::JoinHandle<()> {
1174 tokio::task::spawn(async move {
1175 log_task_started("controller");
1176
1177 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1178
1179 loop {
1180 tokio::time::sleep(check_interval).await;
1181 let mut mode = ConnectionMode::from_atomic(&connection_mode);
1182
1183 if mode.is_disconnect() {
1184 tracing::debug!("Disconnecting");
1185
1186 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1187 if tokio::time::timeout(timeout, async {
1188 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1190
1191 if let Some(task) = &inner.read_task
1192 && !task.is_finished()
1193 {
1194 task.abort();
1195 log_task_aborted("read");
1196 }
1197
1198 if let Some(task) = &inner.heartbeat_task
1199 && !task.is_finished()
1200 {
1201 task.abort();
1202 log_task_aborted("heartbeat");
1203 }
1204 })
1205 .await
1206 .is_err()
1207 {
1208 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
1209 }
1210
1211 tracing::debug!("Closed");
1212 break; }
1214
1215 if mode.is_active() && !inner.is_alive() {
1216 if connection_mode
1217 .compare_exchange(
1218 ConnectionMode::Active.as_u8(),
1219 ConnectionMode::Reconnect.as_u8(),
1220 Ordering::SeqCst,
1221 Ordering::SeqCst,
1222 )
1223 .is_ok()
1224 {
1225 tracing::debug!("Detected dead read task, transitioning to RECONNECT");
1226 }
1227 mode = ConnectionMode::from_atomic(&connection_mode);
1228 }
1229
1230 if mode.is_reconnect() {
1231 if let Some(max_attempts) = inner.reconnect_max_attempts
1233 && inner.reconnection_attempt_count >= max_attempts
1234 {
1235 tracing::error!(
1236 "Max reconnection attempts ({}) exceeded, transitioning to CLOSED",
1237 max_attempts
1238 );
1239 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1240 break;
1241 }
1242
1243 inner.reconnection_attempt_count += 1;
1244 tracing::debug!(
1245 "Reconnection attempt {} of {}",
1246 inner.reconnection_attempt_count,
1247 inner
1248 .reconnect_max_attempts
1249 .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1250 );
1251
1252 match inner.reconnect().await {
1253 Ok(()) => {
1254 inner.backoff.reset();
1255 inner.reconnection_attempt_count = 0; if ConnectionMode::from_atomic(&connection_mode).is_active() {
1259 if let Some(ref handler) = inner.config.message_handler {
1260 let reconnected_msg =
1261 Message::Text(RECONNECTED.to_string().into());
1262 handler(reconnected_msg);
1263 tracing::debug!("Sent reconnected message to handler");
1264 }
1265
1266 if let Some(ref callback) = post_reconnection {
1268 callback();
1269 tracing::debug!("Called `post_reconnection` handler");
1270 }
1271
1272 tracing::debug!("Reconnected successfully");
1273 } else {
1274 tracing::debug!(
1275 "Skipping post_reconnection handlers due to disconnect state"
1276 );
1277 }
1278 }
1279 Err(e) => {
1280 let duration = inner.backoff.next_duration();
1281 tracing::warn!(
1282 "Reconnect attempt {} failed: {e}",
1283 inner.reconnection_attempt_count
1284 );
1285 if !duration.is_zero() {
1286 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
1287 }
1288 tokio::time::sleep(duration).await;
1289 }
1290 }
1291 }
1292 }
1293 inner
1294 .connection_mode
1295 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1296
1297 log_task_stopped("controller");
1298 })
1299 }
1300}
1301
1302impl Drop for WebSocketClient {
1304 fn drop(&mut self) {
1305 if !self.controller_task.is_finished() {
1306 self.controller_task.abort();
1307 log_task_aborted("controller");
1308 }
1309 }
1310}
1311
1312#[cfg(test)]
1317#[cfg(not(feature = "turmoil"))]
1318#[cfg(target_os = "linux")] mod tests {
1320 use std::{num::NonZeroU32, sync::Arc};
1321
1322 use futures_util::{SinkExt, StreamExt};
1323 use tokio::{
1324 net::TcpListener,
1325 task::{self, JoinHandle},
1326 };
1327 use tokio_tungstenite::{
1328 accept_hdr_async,
1329 tungstenite::{
1330 handshake::server::{self, Callback},
1331 http::HeaderValue,
1332 },
1333 };
1334
1335 use crate::{
1336 ratelimiter::quota::Quota,
1337 websocket::{WebSocketClient, WebSocketConfig},
1338 };
1339
1340 struct TestServer {
1341 task: JoinHandle<()>,
1342 port: u16,
1343 }
1344
1345 #[derive(Debug, Clone)]
1346 struct TestCallback {
1347 key: String,
1348 value: HeaderValue,
1349 }
1350
1351 impl Callback for TestCallback {
1352 #[allow(clippy::panic_in_result_fn)]
1353 fn on_request(
1354 self,
1355 request: &server::Request,
1356 response: server::Response,
1357 ) -> Result<server::Response, server::ErrorResponse> {
1358 let _ = response;
1359 let value = request.headers().get(&self.key);
1360 assert!(value.is_some());
1361
1362 if let Some(value) = request.headers().get(&self.key) {
1363 assert_eq!(value, self.value);
1364 }
1365
1366 Ok(response)
1367 }
1368 }
1369
1370 impl TestServer {
1371 async fn setup() -> Self {
1372 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1373 let port = TcpListener::local_addr(&server).unwrap().port();
1374
1375 let header_key = "test".to_string();
1376 let header_value = "test".to_string();
1377
1378 let test_call_back = TestCallback {
1379 key: header_key,
1380 value: HeaderValue::from_str(&header_value).unwrap(),
1381 };
1382
1383 let task = task::spawn(async move {
1384 loop {
1386 let (conn, _) = server.accept().await.unwrap();
1387 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1388 .await
1389 .unwrap();
1390
1391 task::spawn(async move {
1392 while let Some(Ok(msg)) = websocket.next().await {
1393 match msg {
1394 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1395 if txt == "close-now" =>
1396 {
1397 tracing::debug!("Forcibly closing from server side");
1398 let _ = websocket.close(None).await;
1400 break;
1401 }
1402 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1404 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1405 if websocket.send(msg).await.is_err() {
1406 break;
1407 }
1408 }
1409 tokio_tungstenite::tungstenite::protocol::Message::Close(
1411 _frame,
1412 ) => {
1413 let _ = websocket.close(None).await;
1414 break;
1415 }
1416 _ => {}
1418 }
1419 }
1420 });
1421 }
1422 });
1423
1424 Self { task, port }
1425 }
1426 }
1427
1428 impl Drop for TestServer {
1429 fn drop(&mut self) {
1430 self.task.abort();
1431 }
1432 }
1433
1434 async fn setup_test_client(port: u16) -> WebSocketClient {
1435 let config = WebSocketConfig {
1436 url: format!("ws://127.0.0.1:{port}"),
1437 headers: vec![("test".into(), "test".into())],
1438 message_handler: Some(Arc::new(|_| {})),
1439 heartbeat: None,
1440 heartbeat_msg: None,
1441 ping_handler: None,
1442 reconnect_timeout_ms: None,
1443 reconnect_delay_initial_ms: None,
1444 reconnect_backoff_factor: None,
1445 reconnect_delay_max_ms: None,
1446 reconnect_jitter_ms: None,
1447 reconnect_max_attempts: None,
1448 };
1449 WebSocketClient::connect(config, None, vec![], None)
1450 .await
1451 .expect("Failed to connect")
1452 }
1453
1454 #[tokio::test]
1455 async fn test_websocket_basic() {
1456 let server = TestServer::setup().await;
1457 let client = setup_test_client(server.port).await;
1458
1459 assert!(!client.is_disconnected());
1460
1461 client.disconnect().await;
1462 assert!(client.is_disconnected());
1463 }
1464
1465 #[tokio::test]
1466 async fn test_websocket_heartbeat() {
1467 let server = TestServer::setup().await;
1468 let client = setup_test_client(server.port).await;
1469
1470 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1472
1473 client.disconnect().await;
1475 assert!(client.is_disconnected());
1476 }
1477
1478 #[tokio::test]
1479 async fn test_websocket_reconnect_exhausted() {
1480 let config = WebSocketConfig {
1481 url: "ws://127.0.0.1:9997".into(), headers: vec![],
1483 message_handler: Some(Arc::new(|_| {})),
1484 heartbeat: None,
1485 heartbeat_msg: None,
1486 ping_handler: None,
1487 reconnect_timeout_ms: None,
1488 reconnect_delay_initial_ms: None,
1489 reconnect_backoff_factor: None,
1490 reconnect_delay_max_ms: None,
1491 reconnect_jitter_ms: None,
1492 reconnect_max_attempts: None,
1493 };
1494 let res = WebSocketClient::connect(config, None, vec![], None).await;
1495 assert!(res.is_err(), "Should fail quickly with no server");
1496 }
1497
1498 #[tokio::test]
1499 async fn test_websocket_forced_close_reconnect() {
1500 let server = TestServer::setup().await;
1501 let client = setup_test_client(server.port).await;
1502
1503 client.send_text("Hello".into(), None).await.unwrap();
1505
1506 client.send_text("close-now".into(), None).await.unwrap();
1508
1509 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1511
1512 assert!(!client.is_disconnected());
1514
1515 client.disconnect().await;
1517 assert!(client.is_disconnected());
1518 }
1519
1520 #[tokio::test]
1521 async fn test_rate_limiter() {
1522 let server = TestServer::setup().await;
1523 let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1524
1525 let config = WebSocketConfig {
1526 url: format!("ws://127.0.0.1:{}", server.port),
1527 headers: vec![("test".into(), "test".into())],
1528 message_handler: Some(Arc::new(|_| {})),
1529 heartbeat: None,
1530 heartbeat_msg: None,
1531 ping_handler: None,
1532 reconnect_timeout_ms: None,
1533 reconnect_delay_initial_ms: None,
1534 reconnect_backoff_factor: None,
1535 reconnect_delay_max_ms: None,
1536 reconnect_jitter_ms: None,
1537 reconnect_max_attempts: None,
1538 };
1539
1540 let client = WebSocketClient::connect(config, None, vec![("default".into(), quota)], None)
1541 .await
1542 .unwrap();
1543
1544 client.send_text("test1".into(), None).await.unwrap();
1546 client.send_text("test2".into(), None).await.unwrap();
1547
1548 client.send_text("test3".into(), None).await.unwrap();
1550
1551 client.disconnect().await;
1553 assert!(client.is_disconnected());
1554 }
1555
1556 #[tokio::test]
1557 async fn test_concurrent_writers() {
1558 let server = TestServer::setup().await;
1559 let client = Arc::new(setup_test_client(server.port).await);
1560
1561 let mut handles = vec![];
1562 for i in 0..10 {
1563 let client = client.clone();
1564 handles.push(task::spawn(async move {
1565 client.send_text(format!("test{i}"), None).await.unwrap();
1566 }));
1567 }
1568
1569 for handle in handles {
1570 handle.await.unwrap();
1571 }
1572
1573 client.disconnect().await;
1575 assert!(client.is_disconnected());
1576 }
1577}
1578
1579#[cfg(test)]
1584#[cfg(not(feature = "turmoil"))]
1585mod rust_tests {
1586 use futures_util::StreamExt;
1587 use rstest::rstest;
1588 use tokio::{
1589 net::TcpListener,
1590 task,
1591 time::{Duration, sleep},
1592 };
1593 use tokio_tungstenite::accept_async;
1594
1595 use super::*;
1596 use crate::websocket::types::channel_message_handler;
1597
1598 #[rstest]
1599 #[tokio::test]
1600 async fn test_reconnect_then_disconnect() {
1601 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1603 let port = listener.local_addr().unwrap().port();
1604
1605 let server = task::spawn(async move {
1607 let (stream, _) = listener.accept().await.unwrap();
1608 let ws = accept_async(stream).await.unwrap();
1609 drop(ws);
1610 sleep(Duration::from_secs(1)).await;
1612 });
1613
1614 let (handler, _rx) = channel_message_handler();
1616
1617 let config = WebSocketConfig {
1619 url: format!("ws://127.0.0.1:{port}"),
1620 headers: vec![],
1621 message_handler: Some(handler),
1622 heartbeat: None,
1623 heartbeat_msg: None,
1624 ping_handler: None,
1625 reconnect_timeout_ms: Some(1_000),
1626 reconnect_delay_initial_ms: Some(50),
1627 reconnect_delay_max_ms: Some(100),
1628 reconnect_backoff_factor: Some(1.0),
1629 reconnect_jitter_ms: Some(0),
1630 reconnect_max_attempts: None,
1631 };
1632
1633 let client = WebSocketClient::connect(config, None, vec![], None)
1635 .await
1636 .unwrap();
1637
1638 sleep(Duration::from_millis(100)).await;
1640 client.disconnect().await;
1642 assert!(client.is_disconnected());
1643 server.abort();
1644 }
1645
1646 #[rstest]
1647 #[tokio::test]
1648 async fn test_reconnect_state_flips_when_reader_stops() {
1649 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1651 let port = listener.local_addr().unwrap().port();
1652
1653 let server = task::spawn(async move {
1654 if let Ok((stream, _)) = listener.accept().await
1655 && let Ok(ws) = accept_async(stream).await
1656 {
1657 drop(ws);
1658 }
1659 sleep(Duration::from_millis(50)).await;
1660 });
1661
1662 let (handler, _rx) = channel_message_handler();
1663
1664 let config = WebSocketConfig {
1665 url: format!("ws://127.0.0.1:{port}"),
1666 headers: vec![],
1667 message_handler: Some(handler),
1668 heartbeat: None,
1669 heartbeat_msg: None,
1670 ping_handler: None,
1671 reconnect_timeout_ms: Some(1_000),
1672 reconnect_delay_initial_ms: Some(50),
1673 reconnect_delay_max_ms: Some(100),
1674 reconnect_backoff_factor: Some(1.0),
1675 reconnect_jitter_ms: Some(0),
1676 reconnect_max_attempts: None,
1677 };
1678
1679 let client = WebSocketClient::connect(config, None, vec![], None)
1680 .await
1681 .unwrap();
1682
1683 tokio::time::timeout(Duration::from_secs(2), async {
1684 loop {
1685 if client.is_reconnecting() {
1686 break;
1687 }
1688 tokio::time::sleep(Duration::from_millis(10)).await;
1689 }
1690 })
1691 .await
1692 .expect("client did not enter RECONNECT state");
1693
1694 client.disconnect().await;
1695 server.abort();
1696 }
1697
1698 #[rstest]
1699 #[tokio::test]
1700 async fn test_stream_mode_disables_auto_reconnect() {
1701 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1704 let port = listener.local_addr().unwrap().port();
1705
1706 let server = task::spawn(async move {
1707 if let Ok((stream, _)) = listener.accept().await
1708 && let Ok(_ws) = accept_async(stream).await
1709 {
1710 sleep(Duration::from_millis(100)).await;
1712 }
1713 });
1714
1715 let config = WebSocketConfig {
1716 url: format!("ws://127.0.0.1:{port}"),
1717 headers: vec![],
1718 message_handler: None, heartbeat: None,
1720 heartbeat_msg: None,
1721 ping_handler: None,
1722 reconnect_timeout_ms: Some(1_000),
1723 reconnect_delay_initial_ms: Some(50),
1724 reconnect_delay_max_ms: Some(100),
1725 reconnect_backoff_factor: Some(1.0),
1726 reconnect_jitter_ms: Some(0),
1727 reconnect_max_attempts: None,
1728 };
1729
1730 let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
1732 .await
1733 .unwrap();
1734
1735 server.abort();
1743 }
1744
1745 #[rstest]
1746 #[tokio::test]
1747 async fn test_message_handler_mode_allows_auto_reconnect() {
1748 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1750 let port = listener.local_addr().unwrap().port();
1751
1752 let server = task::spawn(async move {
1753 if let Ok((stream, _)) = listener.accept().await
1755 && let Ok(ws) = accept_async(stream).await
1756 {
1757 drop(ws);
1758 }
1759 sleep(Duration::from_millis(50)).await;
1760 });
1761
1762 let (handler, _rx) = channel_message_handler();
1763
1764 let config = WebSocketConfig {
1765 url: format!("ws://127.0.0.1:{port}"),
1766 headers: vec![],
1767 message_handler: Some(handler), heartbeat: None,
1769 heartbeat_msg: None,
1770 ping_handler: None,
1771 reconnect_timeout_ms: Some(1_000),
1772 reconnect_delay_initial_ms: Some(50),
1773 reconnect_delay_max_ms: Some(100),
1774 reconnect_backoff_factor: Some(1.0),
1775 reconnect_jitter_ms: Some(0),
1776 reconnect_max_attempts: None,
1777 };
1778
1779 let client = WebSocketClient::connect(config, None, vec![], None)
1780 .await
1781 .unwrap();
1782
1783 tokio::time::timeout(Duration::from_secs(2), async {
1785 loop {
1786 if client.is_reconnecting() || client.is_closed() {
1787 break;
1788 }
1789 tokio::time::sleep(Duration::from_millis(10)).await;
1790 }
1791 })
1792 .await
1793 .expect("client should attempt reconnection or close");
1794
1795 assert!(
1798 client.is_reconnecting() || client.is_closed(),
1799 "Client with message handler should attempt reconnection"
1800 );
1801
1802 client.disconnect().await;
1803 server.abort();
1804 }
1805
1806 #[rstest]
1807 #[tokio::test]
1808 async fn test_handler_mode_reconnect_with_new_connection() {
1809 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1811 let port = listener.local_addr().unwrap().port();
1812
1813 let server = task::spawn(async move {
1814 if let Ok((stream, _)) = listener.accept().await
1816 && let Ok(ws) = accept_async(stream).await
1817 {
1818 drop(ws);
1819 }
1820
1821 sleep(Duration::from_millis(100)).await;
1823
1824 if let Ok((stream, _)) = listener.accept().await
1826 && let Ok(mut ws) = accept_async(stream).await
1827 {
1828 use futures_util::SinkExt;
1829 let _ = ws
1830 .send(Message::Text("reconnected".to_string().into()))
1831 .await;
1832 sleep(Duration::from_secs(1)).await;
1833 }
1834 });
1835
1836 let (handler, mut rx) = channel_message_handler();
1837
1838 let config = WebSocketConfig {
1839 url: format!("ws://127.0.0.1:{port}"),
1840 headers: vec![],
1841 message_handler: Some(handler),
1842 heartbeat: None,
1843 heartbeat_msg: None,
1844 ping_handler: None,
1845 reconnect_timeout_ms: Some(2_000),
1846 reconnect_delay_initial_ms: Some(50),
1847 reconnect_delay_max_ms: Some(200),
1848 reconnect_backoff_factor: Some(1.5),
1849 reconnect_jitter_ms: Some(10),
1850 reconnect_max_attempts: None,
1851 };
1852
1853 let client = WebSocketClient::connect(config, None, vec![], None)
1854 .await
1855 .unwrap();
1856
1857 let result = tokio::time::timeout(Duration::from_secs(5), async {
1859 loop {
1860 if let Ok(msg) = rx.try_recv()
1861 && matches!(msg, Message::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
1862 {
1863 return true;
1864 }
1865 tokio::time::sleep(Duration::from_millis(10)).await;
1866 }
1867 })
1868 .await;
1869
1870 assert!(
1871 result.is_ok(),
1872 "Should receive message after reconnection within timeout"
1873 );
1874
1875 client.disconnect().await;
1876 server.abort();
1877 }
1878
1879 #[rstest]
1880 #[tokio::test]
1881 async fn test_stream_mode_no_auto_reconnect() {
1882 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1885 let port = listener.local_addr().unwrap().port();
1886
1887 let server = task::spawn(async move {
1888 if let Ok((stream, _)) = listener.accept().await
1890 && let Ok(mut ws) = accept_async(stream).await
1891 {
1892 use futures_util::SinkExt;
1893 let _ = ws.send(Message::Text("hello".to_string().into())).await;
1894 sleep(Duration::from_millis(50)).await;
1895 }
1897 });
1898
1899 let config = WebSocketConfig {
1900 url: format!("ws://127.0.0.1:{port}"),
1901 headers: vec![],
1902 message_handler: None, heartbeat: None,
1904 heartbeat_msg: None,
1905 ping_handler: None,
1906 reconnect_timeout_ms: Some(1_000),
1907 reconnect_delay_initial_ms: Some(50),
1908 reconnect_delay_max_ms: Some(100),
1909 reconnect_backoff_factor: Some(1.0),
1910 reconnect_jitter_ms: Some(0),
1911 reconnect_max_attempts: None,
1912 };
1913
1914 let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
1915 .await
1916 .unwrap();
1917
1918 assert!(client.is_active(), "Client should start as active");
1920
1921 let msg = reader.next().await;
1923 assert!(
1924 matches!(msg, Some(Ok(Message::Text(ref text))) if AsRef::<str>::as_ref(text) == "hello"),
1925 "Should receive initial message"
1926 );
1927
1928 while let Some(msg) = reader.next().await {
1930 if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
1931 break;
1932 }
1933 }
1934
1935 sleep(Duration::from_millis(200)).await;
1938
1939 assert!(
1942 client.is_active() || client.is_closed(),
1943 "Stream mode client stays ACTIVE (caller owns reader) or caller disconnected"
1944 );
1945 assert!(
1946 !client.is_reconnecting(),
1947 "Stream mode client should never attempt reconnection"
1948 );
1949
1950 client.disconnect().await;
1951 server.abort();
1952 }
1953
1954 #[rstest]
1955 #[tokio::test]
1956 async fn test_send_timeout_uses_configured_reconnect_timeout() {
1957 use nautilus_common::testing::wait_until_async;
1960
1961 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1962 let port = listener.local_addr().unwrap().port();
1963
1964 let server = task::spawn(async move {
1965 if let Ok((stream, _)) = listener.accept().await
1967 && let Ok(ws) = accept_async(stream).await
1968 {
1969 drop(ws);
1970 }
1971 sleep(Duration::from_secs(60)).await;
1973 });
1974
1975 let (handler, _rx) = channel_message_handler();
1976
1977 let config = WebSocketConfig {
1979 url: format!("ws://127.0.0.1:{port}"),
1980 headers: vec![],
1981 message_handler: Some(handler),
1982 heartbeat: None,
1983 heartbeat_msg: None,
1984 ping_handler: None,
1985 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(50),
1987 reconnect_delay_max_ms: Some(100),
1988 reconnect_backoff_factor: Some(1.0),
1989 reconnect_jitter_ms: Some(0),
1990 reconnect_max_attempts: None,
1991 };
1992
1993 let client = WebSocketClient::connect(config, None, vec![], None)
1994 .await
1995 .unwrap();
1996
1997 wait_until_async(
1999 || async { client.is_reconnecting() },
2000 Duration::from_secs(3),
2001 )
2002 .await;
2003
2004 let start = std::time::Instant::now();
2006 let send_result = client.send_text("test".to_string(), None).await;
2007 let elapsed = start.elapsed();
2008
2009 assert!(
2010 send_result.is_err(),
2011 "Send should fail when client stuck in RECONNECT"
2012 );
2013 assert!(
2014 matches!(send_result, Err(crate::error::SendError::Timeout)),
2015 "Send should return Timeout error, was: {:?}",
2016 send_result
2017 );
2018 assert!(
2021 elapsed >= Duration::from_millis(1800),
2022 "Send should timeout after at least 2s (configured timeout), took {:?}",
2023 elapsed
2024 );
2025
2026 client.disconnect().await;
2027 server.abort();
2028 }
2029
2030 #[rstest]
2031 #[tokio::test]
2032 async fn test_send_waits_during_reconnection() {
2033 use nautilus_common::testing::wait_until_async;
2035
2036 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2037 let port = listener.local_addr().unwrap().port();
2038
2039 let server = task::spawn(async move {
2040 if let Ok((stream, _)) = listener.accept().await
2042 && let Ok(ws) = accept_async(stream).await
2043 {
2044 drop(ws);
2045 }
2046
2047 sleep(Duration::from_millis(500)).await;
2049
2050 if let Ok((stream, _)) = listener.accept().await
2052 && let Ok(mut ws) = accept_async(stream).await
2053 {
2054 while let Some(Ok(msg)) = ws.next().await {
2056 if ws.send(msg).await.is_err() {
2057 break;
2058 }
2059 }
2060 }
2061 });
2062
2063 let (handler, _rx) = channel_message_handler();
2064
2065 let config = WebSocketConfig {
2066 url: format!("ws://127.0.0.1:{port}"),
2067 headers: vec![],
2068 message_handler: Some(handler),
2069 heartbeat: None,
2070 heartbeat_msg: None,
2071 ping_handler: None,
2072 reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
2074 reconnect_delay_max_ms: Some(200),
2075 reconnect_backoff_factor: Some(1.0),
2076 reconnect_jitter_ms: Some(0),
2077 reconnect_max_attempts: None,
2078 };
2079
2080 let client = WebSocketClient::connect(config, None, vec![], None)
2081 .await
2082 .unwrap();
2083
2084 wait_until_async(
2086 || async { client.is_reconnecting() },
2087 Duration::from_secs(2),
2088 )
2089 .await;
2090
2091 let send_result = tokio::time::timeout(
2093 Duration::from_secs(3),
2094 client.send_text("test_message".to_string(), None),
2095 )
2096 .await;
2097
2098 assert!(
2099 send_result.is_ok() && send_result.unwrap().is_ok(),
2100 "Send should succeed after waiting for reconnection"
2101 );
2102
2103 client.disconnect().await;
2104 server.abort();
2105 }
2106
2107 #[rstest]
2108 #[tokio::test]
2109 async fn test_rate_limiter_before_active_wait() {
2110 use std::{num::NonZeroU32, sync::Arc};
2115
2116 use nautilus_common::testing::wait_until_async;
2117
2118 use crate::ratelimiter::quota::Quota;
2119
2120 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2121 let port = listener.local_addr().unwrap().port();
2122
2123 let server = task::spawn(async move {
2124 if let Ok((stream, _)) = listener.accept().await
2126 && let Ok(mut ws) = accept_async(stream).await
2127 {
2128 if let Some(Ok(_)) = ws.next().await {
2130 drop(ws);
2131 }
2132 }
2133
2134 sleep(Duration::from_millis(500)).await;
2136
2137 if let Ok((stream, _)) = listener.accept().await
2139 && let Ok(mut ws) = accept_async(stream).await
2140 {
2141 while let Some(Ok(msg)) = ws.next().await {
2142 if ws.send(msg).await.is_err() {
2143 break;
2144 }
2145 }
2146 }
2147 });
2148
2149 let (handler, _rx) = channel_message_handler();
2150
2151 let config = WebSocketConfig {
2152 url: format!("ws://127.0.0.1:{port}"),
2153 headers: vec![],
2154 message_handler: Some(handler),
2155 heartbeat: None,
2156 heartbeat_msg: None,
2157 ping_handler: None,
2158 reconnect_timeout_ms: Some(5_000),
2159 reconnect_delay_initial_ms: Some(50),
2160 reconnect_delay_max_ms: Some(100),
2161 reconnect_backoff_factor: Some(1.0),
2162 reconnect_jitter_ms: Some(0),
2163 reconnect_max_attempts: None,
2164 };
2165
2166 let quota =
2168 Quota::per_second(NonZeroU32::new(1).unwrap()).allow_burst(NonZeroU32::new(1).unwrap());
2169
2170 let client = Arc::new(
2171 WebSocketClient::connect(config, None, vec![("test_key".to_string(), quota)], None)
2172 .await
2173 .unwrap(),
2174 );
2175
2176 client
2178 .send_text("msg1".to_string(), Some(vec!["test_key".to_string()]))
2179 .await
2180 .unwrap();
2181
2182 wait_until_async(
2184 || async { client.is_reconnecting() },
2185 Duration::from_secs(2),
2186 )
2187 .await;
2188
2189 let start = std::time::Instant::now();
2191 let send_result = client
2192 .send_text("msg2".to_string(), Some(vec!["test_key".to_string()]))
2193 .await;
2194 let elapsed = start.elapsed();
2195
2196 assert!(
2198 send_result.is_ok(),
2199 "Send should succeed after rate limit + reconnection, was: {:?}",
2200 send_result
2201 );
2202 assert!(
2206 elapsed >= Duration::from_millis(850),
2207 "Should wait for rate limit (~1s), waited {:?}",
2208 elapsed
2209 );
2210
2211 client.disconnect().await;
2212 server.abort();
2213 }
2214
2215 #[rstest]
2216 #[tokio::test]
2217 async fn test_disconnect_during_reconnect_exits_cleanly() {
2218 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2221 let port = listener.local_addr().unwrap().port();
2222
2223 let server = task::spawn(async move {
2224 if let Ok((stream, _)) = listener.accept().await
2226 && let Ok(ws) = accept_async(stream).await
2227 {
2228 drop(ws);
2229 }
2230 sleep(Duration::from_secs(60)).await;
2232 });
2233
2234 let (handler, _rx) = channel_message_handler();
2235
2236 let config = WebSocketConfig {
2237 url: format!("ws://127.0.0.1:{port}"),
2238 headers: vec![],
2239 message_handler: Some(handler),
2240 heartbeat: None,
2241 heartbeat_msg: None,
2242 ping_handler: None,
2243 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(100),
2245 reconnect_delay_max_ms: Some(200),
2246 reconnect_backoff_factor: Some(1.0),
2247 reconnect_jitter_ms: Some(0),
2248 reconnect_max_attempts: None,
2249 };
2250
2251 let client = WebSocketClient::connect(config, None, vec![], None)
2252 .await
2253 .unwrap();
2254
2255 tokio::time::timeout(Duration::from_secs(2), async {
2257 while !client.is_reconnecting() {
2258 sleep(Duration::from_millis(10)).await;
2259 }
2260 })
2261 .await
2262 .expect("Client should enter RECONNECT state");
2263
2264 client.disconnect().await;
2266
2267 assert!(
2269 client.is_disconnected(),
2270 "Client should be cleanly disconnected"
2271 );
2272
2273 server.abort();
2274 }
2275
2276 #[rstest]
2277 #[tokio::test]
2278 async fn test_send_fails_fast_when_closed_before_rate_limit() {
2279 use std::{num::NonZeroU32, sync::Arc};
2282
2283 use nautilus_common::testing::wait_until_async;
2284
2285 use crate::ratelimiter::quota::Quota;
2286
2287 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2288 let port = listener.local_addr().unwrap().port();
2289
2290 let server = task::spawn(async move {
2291 if let Ok((stream, _)) = listener.accept().await
2293 && let Ok(ws) = accept_async(stream).await
2294 {
2295 drop(ws);
2296 }
2297 sleep(Duration::from_secs(60)).await;
2298 });
2299
2300 let (handler, _rx) = channel_message_handler();
2301
2302 let config = WebSocketConfig {
2303 url: format!("ws://127.0.0.1:{port}"),
2304 headers: vec![],
2305 message_handler: Some(handler),
2306 heartbeat: None,
2307 heartbeat_msg: None,
2308 ping_handler: None,
2309 reconnect_timeout_ms: Some(5_000),
2310 reconnect_delay_initial_ms: Some(50),
2311 reconnect_delay_max_ms: Some(100),
2312 reconnect_backoff_factor: Some(1.0),
2313 reconnect_jitter_ms: Some(0),
2314 reconnect_max_attempts: None,
2315 };
2316
2317 let quota = Quota::with_period(Duration::from_secs(10))
2320 .unwrap()
2321 .allow_burst(NonZeroU32::new(1).unwrap());
2322
2323 let client = Arc::new(
2324 WebSocketClient::connect(config, None, vec![("test_key".to_string(), quota)], None)
2325 .await
2326 .unwrap(),
2327 );
2328
2329 wait_until_async(
2331 || async { client.is_reconnecting() || client.is_closed() },
2332 Duration::from_secs(2),
2333 )
2334 .await;
2335
2336 client.disconnect().await;
2338 assert!(
2339 !client.is_active(),
2340 "Client should not be active after disconnect"
2341 );
2342
2343 let start = std::time::Instant::now();
2345 let result = client
2346 .send_text("test".to_string(), Some(vec!["test_key".to_string()]))
2347 .await;
2348 let elapsed = start.elapsed();
2349
2350 assert!(result.is_err(), "Send should fail when client is closed");
2352 assert!(
2353 matches!(result, Err(crate::error::SendError::Closed)),
2354 "Send should return Closed error, was: {:?}",
2355 result
2356 );
2357
2358 assert!(
2360 elapsed < Duration::from_millis(100),
2361 "Send should fail fast without rate limiting, took {:?}",
2362 elapsed
2363 );
2364
2365 server.abort();
2366 }
2367
2368 #[rstest]
2369 #[tokio::test]
2370 async fn test_connect_rejects_config_without_message_handler() {
2371 let config = WebSocketConfig {
2375 url: "ws://127.0.0.1:9999".to_string(),
2376 headers: vec![],
2377 message_handler: None, heartbeat: None,
2379 heartbeat_msg: None,
2380 ping_handler: None,
2381 reconnect_timeout_ms: Some(1_000),
2382 reconnect_delay_initial_ms: Some(100),
2383 reconnect_delay_max_ms: Some(500),
2384 reconnect_backoff_factor: Some(1.5),
2385 reconnect_jitter_ms: Some(0),
2386 reconnect_max_attempts: None,
2387 };
2388
2389 let result = WebSocketClient::connect(config, None, vec![], None).await;
2390
2391 assert!(
2392 result.is_err(),
2393 "connect() should reject configs without message_handler"
2394 );
2395
2396 let err = result.unwrap_err();
2397 let err_msg = err.to_string();
2398 assert!(
2399 err_msg.contains("Handler mode requires config.message_handler"),
2400 "Error should mention missing message_handler, was: {err_msg}"
2401 );
2402 }
2403
2404 #[rstest]
2405 #[tokio::test]
2406 async fn test_client_without_handler_sets_stream_mode() {
2407 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2411 let port = listener.local_addr().unwrap().port();
2412
2413 let server = task::spawn(async move {
2414 if let Ok((stream, _)) = listener.accept().await
2416 && let Ok(ws) = accept_async(stream).await
2417 {
2418 drop(ws); }
2420 });
2421
2422 let config = WebSocketConfig {
2423 url: format!("ws://127.0.0.1:{port}"),
2424 headers: vec![],
2425 message_handler: None, heartbeat: None,
2427 heartbeat_msg: None,
2428 ping_handler: None,
2429 reconnect_timeout_ms: Some(1_000),
2430 reconnect_delay_initial_ms: Some(100),
2431 reconnect_delay_max_ms: Some(500),
2432 reconnect_backoff_factor: Some(1.5),
2433 reconnect_jitter_ms: Some(0),
2434 reconnect_max_attempts: None,
2435 };
2436
2437 let inner = WebSocketClientInner::connect_url(config).await.unwrap();
2439
2440 assert!(
2442 inner.is_stream_mode,
2443 "Client without handler should have is_stream_mode=true"
2444 );
2445
2446 server.abort();
2450 }
2451}