1use std::{
26 collections::VecDeque,
27 fmt::Debug,
28 sync::{
29 Arc,
30 atomic::{AtomicU8, Ordering},
31 },
32 time::Duration,
33};
34
35use futures_util::{SinkExt, StreamExt};
36use http::HeaderName;
37use nautilus_core::CleanDrop;
38use nautilus_cryptography::providers::install_cryptographic_provider;
39#[cfg(feature = "turmoil")]
40use tokio_tungstenite::MaybeTlsStream;
41#[cfg(feature = "turmoil")]
42use tokio_tungstenite::client_async;
43#[cfg(not(feature = "turmoil"))]
44use tokio_tungstenite::connect_async_with_config;
45use tokio_tungstenite::tungstenite::{
46 Error, Message, client::IntoClientRequest, http::HeaderValue,
47};
48
49use super::{
50 config::WebSocketConfig,
51 consts::{
52 CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
53 GRACEFUL_SHUTDOWN_TIMEOUT_SECS, SEND_OPERATION_CHECK_INTERVAL_MS,
54 },
55 types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
56};
57#[cfg(feature = "turmoil")]
58use crate::net::TcpConnector;
59use crate::{
60 RECONNECTED,
61 backoff::ExponentialBackoff,
62 error::SendError,
63 logging::{log_task_aborted, log_task_started, log_task_stopped},
64 mode::ConnectionMode,
65 ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
66};
67
68pub struct WebSocketClientInner {
84 config: WebSocketConfig,
85 read_task: Option<tokio::task::JoinHandle<()>>,
86 write_task: tokio::task::JoinHandle<()>,
87 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
88 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
89 connection_mode: Arc<AtomicU8>,
90 reconnect_timeout: Duration,
91 backoff: ExponentialBackoff,
92 is_stream_mode: bool,
96 reconnect_max_attempts: Option<u32>,
98 reconnection_attempt_count: u32,
100}
101
102impl WebSocketClientInner {
103 pub async fn new_with_writer(
109 config: WebSocketConfig,
110 writer: MessageWriter,
111 ) -> Result<Self, Error> {
112 install_cryptographic_provider();
113
114 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
115
116 let read_task = None;
118
119 let backoff = ExponentialBackoff::new(
120 Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
121 Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
122 config.reconnect_backoff_factor.unwrap_or(1.5),
123 config.reconnect_jitter_ms.unwrap_or(100),
124 true, )
126 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
127
128 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
129 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
130
131 let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
132 Some(Self::spawn_heartbeat_task(
133 connection_mode.clone(),
134 heartbeat_interval,
135 config.heartbeat_msg.clone(),
136 writer_tx.clone(),
137 ))
138 } else {
139 None
140 };
141
142 Ok(Self {
143 config: config.clone(),
144 writer_tx,
145 connection_mode,
146 reconnect_timeout: Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10000)),
147 heartbeat_task,
148 read_task,
149 write_task,
150 backoff,
151 is_stream_mode: true,
152 reconnect_max_attempts: config.reconnect_max_attempts,
153 reconnection_attempt_count: 0,
154 })
155 }
156
157 pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
165 install_cryptographic_provider();
166
167 let WebSocketConfig {
168 url,
169 message_handler,
170 heartbeat,
171 headers,
172 heartbeat_msg,
173 ping_handler,
174 reconnect_timeout_ms,
175 reconnect_delay_initial_ms,
176 reconnect_delay_max_ms,
177 reconnect_backoff_factor,
178 reconnect_jitter_ms,
179 reconnect_max_attempts,
180 } = &config;
181
182 let is_stream_mode = message_handler.is_none();
184 let reconnect_max_attempts = *reconnect_max_attempts;
185
186 let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
187
188 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
189
190 let read_task = if message_handler.is_some() {
191 Some(Self::spawn_message_handler_task(
192 connection_mode.clone(),
193 reader,
194 message_handler.as_ref(),
195 ping_handler.as_ref(),
196 ))
197 } else {
198 None
199 };
200
201 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
202 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
203
204 let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
206 Self::spawn_heartbeat_task(
207 connection_mode.clone(),
208 *heartbeat_secs,
209 heartbeat_msg.clone(),
210 writer_tx.clone(),
211 )
212 });
213
214 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
215 let backoff = ExponentialBackoff::new(
216 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
217 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
218 reconnect_backoff_factor.unwrap_or(1.5),
219 reconnect_jitter_ms.unwrap_or(100),
220 true, )
222 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
223
224 Ok(Self {
225 config,
226 read_task,
227 write_task,
228 writer_tx,
229 heartbeat_task,
230 connection_mode,
231 reconnect_timeout,
232 backoff,
233 is_stream_mode,
235 reconnect_max_attempts,
236 reconnection_attempt_count: 0,
237 })
238 }
239
240 #[inline]
250 #[cfg(not(feature = "turmoil"))]
251 pub async fn connect_with_server(
252 url: &str,
253 headers: Vec<(String, String)>,
254 ) -> Result<(MessageWriter, MessageReader), Error> {
255 let mut request = url.into_client_request()?;
256 let req_headers = request.headers_mut();
257
258 let mut header_names: Vec<HeaderName> = Vec::new();
259 for (key, val) in headers {
260 let header_value = HeaderValue::from_str(&val)?;
261 let header_name: HeaderName = key.parse()?;
262 header_names.push(header_name.clone());
263 req_headers.insert(header_name, header_value);
264 }
265
266 connect_async_with_config(request, None, true)
267 .await
268 .map(|resp| resp.0.split())
269 }
270
271 #[inline]
284 #[cfg(feature = "turmoil")]
285 pub async fn connect_with_server(
286 url: &str,
287 headers: Vec<(String, String)>,
288 ) -> Result<(MessageWriter, MessageReader), Error> {
289 use rustls::ClientConfig;
290 use tokio_rustls::TlsConnector;
291
292 let mut request = url.into_client_request()?;
293 let req_headers = request.headers_mut();
294
295 let mut header_names: Vec<HeaderName> = Vec::new();
296 for (key, val) in headers {
297 let header_value = HeaderValue::from_str(&val)?;
298 let header_name: HeaderName = key.parse()?;
299 header_names.push(header_name.clone());
300 req_headers.insert(header_name, header_value);
301 }
302
303 let uri = request.uri();
304 let scheme = uri.scheme_str().unwrap_or("ws");
305 let host = uri.host().ok_or_else(|| {
306 Error::Url(tokio_tungstenite::tungstenite::error::UrlError::NoHostName)
307 })?;
308
309 let port = uri
311 .port_u16()
312 .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
313
314 let addr = format!("{host}:{port}");
315
316 let connector = crate::net::RealTcpConnector;
318 let tcp_stream = connector.connect(&addr).await?;
319 if let Err(e) = tcp_stream.set_nodelay(true) {
320 tracing::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
321 }
322
323 let maybe_tls_stream = if scheme == "wss" {
325 let mut root_store = rustls::RootCertStore::empty();
327 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
328
329 let config = ClientConfig::builder()
330 .with_root_certificates(root_store)
331 .with_no_client_auth();
332
333 let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
334 let domain =
335 rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|e| {
336 Error::Io(std::io::Error::new(
337 std::io::ErrorKind::InvalidInput,
338 format!("Invalid DNS name: {e}"),
339 ))
340 })?;
341
342 let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
343 MaybeTlsStream::Rustls(tls_stream)
344 } else {
345 MaybeTlsStream::Plain(tcp_stream)
346 };
347
348 client_async(request, maybe_tls_stream)
350 .await
351 .map(|resp| resp.0.split())
352 }
353
354 pub async fn reconnect(&mut self) -> Result<(), Error> {
369 tracing::debug!("Reconnecting");
370
371 if self.is_stream_mode {
372 tracing::warn!(
373 "Auto-reconnect disabled for stream-based WebSocket client; \
374 stream users must manually reconnect by creating a new connection"
375 );
376 self.connection_mode
378 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
379 return Ok(());
380 }
381
382 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
383 tracing::debug!("Reconnect aborted due to disconnect state");
384 return Ok(());
385 }
386
387 tokio::time::timeout(self.reconnect_timeout, async {
388 let (new_writer, reader) =
390 Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
391
392 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
393 tracing::debug!("Reconnect aborted mid-flight (after connect)");
394 return Ok(());
395 }
396
397 let (tx, rx) = tokio::sync::oneshot::channel();
401 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
402 tracing::error!("{e}");
403 return Err(Error::Io(std::io::Error::new(
404 std::io::ErrorKind::BrokenPipe,
405 format!("Failed to send update command: {e}"),
406 )));
407 }
408
409 match rx.await {
411 Ok(true) => tracing::debug!("Writer confirmed buffer drain success"),
412 Ok(false) => {
413 tracing::warn!("Writer failed to drain buffer, aborting reconnect");
414 return Err(Error::Io(std::io::Error::other(
416 "Failed to drain reconnection buffer",
417 )));
418 }
419 Err(e) => {
420 tracing::error!("Writer dropped update channel: {e}");
421 return Err(Error::Io(std::io::Error::new(
422 std::io::ErrorKind::BrokenPipe,
423 "Writer task dropped response channel",
424 )));
425 }
426 }
427
428 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
430
431 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
432 tracing::debug!("Reconnect aborted mid-flight (after delay)");
433 return Ok(());
434 }
435
436 if let Some(ref read_task) = self.read_task.take()
437 && !read_task.is_finished()
438 {
439 read_task.abort();
440 log_task_aborted("read");
441 }
442
443 if self
446 .connection_mode
447 .compare_exchange(
448 ConnectionMode::Reconnect.as_u8(),
449 ConnectionMode::Active.as_u8(),
450 Ordering::SeqCst,
451 Ordering::SeqCst,
452 )
453 .is_err()
454 {
455 tracing::debug!("Reconnect aborted (state changed during reconnect)");
456 return Ok(());
457 }
458
459 self.read_task = if self.config.message_handler.is_some() {
460 Some(Self::spawn_message_handler_task(
461 self.connection_mode.clone(),
462 reader,
463 self.config.message_handler.as_ref(),
464 self.config.ping_handler.as_ref(),
465 ))
466 } else {
467 None
468 };
469
470 tracing::debug!("Reconnect succeeded");
471 Ok(())
472 })
473 .await
474 .map_err(|_| {
475 Error::Io(std::io::Error::new(
476 std::io::ErrorKind::TimedOut,
477 format!(
478 "reconnection timed out after {}s",
479 self.reconnect_timeout.as_secs_f64()
480 ),
481 ))
482 })?
483 }
484
485 #[inline]
493 #[must_use]
494 pub fn is_alive(&self) -> bool {
495 match &self.read_task {
496 Some(read_task) => !read_task.is_finished(),
497 None => true, }
499 }
500
501 fn spawn_message_handler_task(
502 connection_state: Arc<AtomicU8>,
503 mut reader: MessageReader,
504 message_handler: Option<&MessageHandler>,
505 ping_handler: Option<&PingHandler>,
506 ) -> tokio::task::JoinHandle<()> {
507 tracing::debug!("Started message handler task 'read'");
508
509 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
510
511 let message_handler = message_handler.cloned();
513 let ping_handler = ping_handler.cloned();
514
515 tokio::task::spawn(async move {
516 loop {
517 if !ConnectionMode::from_atomic(&connection_state).is_active() {
518 break;
519 }
520
521 match tokio::time::timeout(check_interval, reader.next()).await {
522 Ok(Some(Ok(Message::Binary(data)))) => {
523 tracing::trace!("Received message <binary> {} bytes", data.len());
524 if let Some(ref handler) = message_handler {
525 handler(Message::Binary(data));
526 }
527 }
528 Ok(Some(Ok(Message::Text(data)))) => {
529 tracing::trace!("Received message: {data}");
530 if let Some(ref handler) = message_handler {
531 handler(Message::Text(data));
532 }
533 }
534 Ok(Some(Ok(Message::Ping(ping_data)))) => {
535 tracing::trace!("Received ping: {ping_data:?}");
536 if let Some(ref handler) = ping_handler {
537 handler(ping_data.to_vec());
538 }
539 }
540 Ok(Some(Ok(Message::Pong(_)))) => {
541 tracing::trace!("Received pong");
542 }
543 Ok(Some(Ok(Message::Close(_)))) => {
544 tracing::debug!("Received close message - terminating");
545 break;
546 }
547 Ok(Some(Ok(_))) => (),
548 Ok(Some(Err(e))) => {
549 tracing::error!("Received error message - terminating: {e}");
550 break;
551 }
552 Ok(None) => {
553 tracing::debug!("No message received - terminating");
554 break;
555 }
556 Err(_) => {
557 continue;
559 }
560 }
561 }
562 })
563 }
564
565 async fn drain_reconnect_buffer(
570 buffer: &mut VecDeque<Message>,
571 writer: &mut MessageWriter,
572 ) -> bool {
573 if buffer.is_empty() {
574 return false;
575 }
576
577 let initial_buffer_len = buffer.len();
578 tracing::info!(
579 "Sending {} buffered messages after reconnection",
580 initial_buffer_len
581 );
582
583 let mut send_error_occurred = false;
584
585 while let Some(buffered_msg) = buffer.front() {
586 let msg_to_send = buffered_msg.clone();
588
589 if let Err(e) = writer.send(msg_to_send).await {
590 tracing::error!(
591 "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
592 buffer.len()
593 );
594 send_error_occurred = true;
595 break; }
597
598 buffer.pop_front();
600 }
601
602 if buffer.is_empty() {
603 tracing::info!(
604 "Successfully sent all {} buffered messages",
605 initial_buffer_len
606 );
607 }
608
609 send_error_occurred
610 }
611
612 fn spawn_write_task(
613 connection_state: Arc<AtomicU8>,
614 writer: MessageWriter,
615 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
616 ) -> tokio::task::JoinHandle<()> {
617 log_task_started("write");
618
619 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
621
622 tokio::task::spawn(async move {
623 let mut active_writer = writer;
624 let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
627
628 loop {
629 match ConnectionMode::from_atomic(&connection_state) {
630 ConnectionMode::Disconnect => {
631 if !reconnect_buffer.is_empty() {
633 tracing::warn!(
634 "Discarding {} buffered messages due to disconnect",
635 reconnect_buffer.len()
636 );
637 reconnect_buffer.clear();
638 }
639
640 _ = tokio::time::timeout(
643 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
644 active_writer.close(),
645 )
646 .await;
647 break;
648 }
649 ConnectionMode::Closed => {
650 if !reconnect_buffer.is_empty() {
652 tracing::warn!(
653 "Discarding {} buffered messages due to closed connection",
654 reconnect_buffer.len()
655 );
656 reconnect_buffer.clear();
657 }
658 break;
659 }
660 _ => {}
661 }
662
663 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
664 Ok(Some(msg)) => {
665 let mode = ConnectionMode::from_atomic(&connection_state);
667 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
668 break;
669 }
670
671 match msg {
672 WriterCommand::Update(new_writer, tx) => {
673 tracing::debug!("Received new writer");
674
675 tokio::time::sleep(Duration::from_millis(100)).await;
677
678 _ = tokio::time::timeout(
681 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
682 active_writer.close(),
683 )
684 .await;
685
686 active_writer = new_writer;
687 tracing::debug!("Updated writer");
688
689 let send_error = Self::drain_reconnect_buffer(
690 &mut reconnect_buffer,
691 &mut active_writer,
692 )
693 .await;
694
695 if let Err(e) = tx.send(!send_error) {
696 tracing::error!(
697 "Failed to report drain status to controller: {e:?}"
698 );
699 }
700 }
701 WriterCommand::Send(msg) if mode.is_reconnect() => {
702 tracing::debug!(
704 "Buffering message during reconnection (buffer size: {})",
705 reconnect_buffer.len() + 1
706 );
707 reconnect_buffer.push_back(msg);
708 }
709 WriterCommand::Send(msg) => {
710 if let Err(e) = active_writer.send(msg.clone()).await {
711 tracing::error!("Failed to send message: {e}");
712 tracing::warn!("Writer triggering reconnect");
713 reconnect_buffer.push_back(msg);
714 connection_state
715 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
716 }
717 }
718 }
719 }
720 Ok(None) => {
721 tracing::debug!("Writer channel closed, terminating writer task");
723 break;
724 }
725 Err(_) => {
726 continue;
728 }
729 }
730 }
731
732 _ = tokio::time::timeout(
735 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
736 active_writer.close(),
737 )
738 .await;
739
740 log_task_stopped("write");
741 })
742 }
743
744 fn spawn_heartbeat_task(
745 connection_state: Arc<AtomicU8>,
746 heartbeat_secs: u64,
747 message: Option<String>,
748 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
749 ) -> tokio::task::JoinHandle<()> {
750 log_task_started("heartbeat");
751
752 tokio::task::spawn(async move {
753 let interval = Duration::from_secs(heartbeat_secs);
754
755 loop {
756 tokio::time::sleep(interval).await;
757
758 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
759 ConnectionMode::Active => {
760 let msg = match &message {
761 Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
762 None => WriterCommand::Send(Message::Ping(vec![].into())),
763 };
764
765 match writer_tx.send(msg) {
766 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
767 Err(e) => {
768 tracing::error!("Failed to send heartbeat to writer task: {e}");
769 }
770 }
771 }
772 ConnectionMode::Reconnect => continue,
773 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
774 }
775 }
776
777 log_task_stopped("heartbeat");
778 })
779 }
780}
781
782impl Drop for WebSocketClientInner {
783 fn drop(&mut self) {
784 self.clean_drop();
786 }
787}
788
789impl CleanDrop for WebSocketClientInner {
791 fn clean_drop(&mut self) {
792 if let Some(ref read_task) = self.read_task.take()
793 && !read_task.is_finished()
794 {
795 read_task.abort();
796 log_task_aborted("read");
797 }
798
799 if !self.write_task.is_finished() {
800 self.write_task.abort();
801 log_task_aborted("write");
802 }
803
804 if let Some(ref handle) = self.heartbeat_task.take()
805 && !handle.is_finished()
806 {
807 handle.abort();
808 log_task_aborted("heartbeat");
809 }
810
811 self.config.message_handler = None;
813 self.config.ping_handler = None;
814 }
815}
816
817impl Debug for WebSocketClientInner {
818 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
819 f.debug_struct("WebSocketClientInner")
820 .field("config", &self.config)
821 .field(
822 "connection_mode",
823 &ConnectionMode::from_atomic(&self.connection_mode),
824 )
825 .field("reconnect_timeout", &self.reconnect_timeout)
826 .field("is_stream_mode", &self.is_stream_mode)
827 .finish()
828 }
829}
830
831#[cfg_attr(
836 feature = "python",
837 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
838)]
839pub struct WebSocketClient {
840 pub(crate) controller_task: tokio::task::JoinHandle<()>,
841 pub(crate) connection_mode: Arc<AtomicU8>,
842 pub(crate) reconnect_timeout: Duration,
843 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
844 pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
845}
846
847impl Debug for WebSocketClient {
848 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
849 f.debug_struct(stringify!(WebSocketClient)).finish()
850 }
851}
852
853impl WebSocketClient {
854 #[allow(clippy::too_many_arguments)]
870 pub async fn connect_stream(
871 config: WebSocketConfig,
872 keyed_quotas: Vec<(String, Quota)>,
873 default_quota: Option<Quota>,
874 post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
875 ) -> Result<(MessageReader, Self), Error> {
876 install_cryptographic_provider();
877
878 let (writer, reader) =
880 WebSocketClientInner::connect_with_server(&config.url, config.headers.clone()).await?;
881
882 let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
884
885 let connection_mode = inner.connection_mode.clone();
886 let reconnect_timeout = inner.reconnect_timeout;
887 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
888 let writer_tx = inner.writer_tx.clone();
889
890 let controller_task =
891 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnect);
892
893 Ok((
894 reader,
895 Self {
896 controller_task,
897 connection_mode,
898 reconnect_timeout,
899 rate_limiter,
900 writer_tx,
901 },
902 ))
903 }
904
905 pub async fn connect(
922 config: WebSocketConfig,
923 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
924 keyed_quotas: Vec<(String, Quota)>,
925 default_quota: Option<Quota>,
926 ) -> Result<Self, Error> {
927 if config.message_handler.is_none() {
929 return Err(Error::Io(std::io::Error::new(
930 std::io::ErrorKind::InvalidInput,
931 "Handler mode requires config.message_handler to be set. Use connect_stream() for stream mode without a handler.",
932 )));
933 }
934
935 tracing::debug!("Connecting");
936 let inner = WebSocketClientInner::connect_url(config).await?;
937 let connection_mode = inner.connection_mode.clone();
938 let writer_tx = inner.writer_tx.clone();
939 let reconnect_timeout = inner.reconnect_timeout;
940
941 let controller_task =
942 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnection);
943
944 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
945
946 Ok(Self {
947 controller_task,
948 connection_mode,
949 reconnect_timeout,
950 rate_limiter,
951 writer_tx,
952 })
953 }
954
955 #[must_use]
957 pub fn connection_mode(&self) -> ConnectionMode {
958 ConnectionMode::from_atomic(&self.connection_mode)
959 }
960
961 #[must_use]
966 pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
967 Arc::clone(&self.connection_mode)
968 }
969
970 #[inline]
975 #[must_use]
976 pub fn is_active(&self) -> bool {
977 self.connection_mode().is_active()
978 }
979
980 #[must_use]
982 pub fn is_disconnected(&self) -> bool {
983 self.controller_task.is_finished()
984 }
985
986 #[inline]
991 #[must_use]
992 pub fn is_reconnecting(&self) -> bool {
993 self.connection_mode().is_reconnect()
994 }
995
996 #[inline]
1000 #[must_use]
1001 pub fn is_disconnecting(&self) -> bool {
1002 self.connection_mode().is_disconnect()
1003 }
1004
1005 #[inline]
1011 #[must_use]
1012 pub fn is_closed(&self) -> bool {
1013 self.connection_mode().is_closed()
1014 }
1015
1016 async fn wait_for_active(&self) -> Result<(), SendError> {
1020 if self.is_closed() {
1021 return Err(SendError::Closed);
1022 }
1023
1024 let timeout = self.reconnect_timeout;
1025 let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
1026
1027 if !self.is_active() {
1028 tracing::debug!("Waiting for client to become ACTIVE before sending...");
1029
1030 let inner = tokio::time::timeout(timeout, async {
1031 loop {
1032 if self.is_active() {
1033 return Ok(());
1034 }
1035 if matches!(
1036 self.connection_mode(),
1037 ConnectionMode::Disconnect | ConnectionMode::Closed
1038 ) {
1039 return Err(());
1040 }
1041 tokio::time::sleep(check_interval).await;
1042 }
1043 })
1044 .await
1045 .map_err(|_| SendError::Timeout)?;
1046 inner.map_err(|()| SendError::Closed)?;
1047 }
1048
1049 Ok(())
1050 }
1051
1052 pub async fn disconnect(&self) {
1057 tracing::debug!("Disconnecting");
1058 self.connection_mode
1059 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1060
1061 if let Ok(()) =
1062 tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1063 while !self.is_disconnected() {
1064 tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS))
1065 .await;
1066 }
1067
1068 if !self.controller_task.is_finished() {
1069 self.controller_task.abort();
1070 log_task_aborted("controller");
1071 }
1072 })
1073 .await
1074 {
1075 tracing::debug!("Controller task finished");
1076 } else {
1077 tracing::error!("Timeout waiting for controller task to finish");
1078 if !self.controller_task.is_finished() {
1079 self.controller_task.abort();
1080 log_task_aborted("controller");
1081 }
1082 }
1083 }
1084
1085 #[allow(unused_variables)]
1091 pub async fn send_text(
1092 &self,
1093 data: String,
1094 keys: Option<Vec<String>>,
1095 ) -> Result<(), SendError> {
1096 if self.is_closed() || self.is_disconnecting() {
1098 return Err(SendError::Closed);
1099 }
1100
1101 self.rate_limiter.await_keys_ready(keys).await;
1102 self.wait_for_active().await?;
1103
1104 tracing::trace!("Sending text: {data:?}");
1105
1106 let msg = Message::Text(data.into());
1107 self.writer_tx
1108 .send(WriterCommand::Send(msg))
1109 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1110 }
1111
1112 pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1118 self.wait_for_active().await?;
1119
1120 tracing::trace!("Sending pong frame ({} bytes)", data.len());
1121
1122 let msg = Message::Pong(data.into());
1123 self.writer_tx
1124 .send(WriterCommand::Send(msg))
1125 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1126 }
1127
1128 #[allow(unused_variables)]
1134 pub async fn send_bytes(
1135 &self,
1136 data: Vec<u8>,
1137 keys: Option<Vec<String>>,
1138 ) -> Result<(), SendError> {
1139 if self.is_closed() || self.is_disconnecting() {
1141 return Err(SendError::Closed);
1142 }
1143
1144 self.rate_limiter.await_keys_ready(keys).await;
1145 self.wait_for_active().await?;
1146
1147 tracing::trace!("Sending bytes: {data:?}");
1148
1149 let msg = Message::Binary(data.into());
1150 self.writer_tx
1151 .send(WriterCommand::Send(msg))
1152 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1153 }
1154
1155 pub async fn send_close_message(&self) -> Result<(), SendError> {
1161 self.wait_for_active().await?;
1162
1163 let msg = Message::Close(None);
1164 self.writer_tx
1165 .send(WriterCommand::Send(msg))
1166 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1167 }
1168
1169 fn spawn_controller_task(
1170 mut inner: WebSocketClientInner,
1171 connection_mode: Arc<AtomicU8>,
1172 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1173 ) -> tokio::task::JoinHandle<()> {
1174 tokio::task::spawn(async move {
1175 log_task_started("controller");
1176
1177 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1178
1179 loop {
1180 tokio::time::sleep(check_interval).await;
1181 let mut mode = ConnectionMode::from_atomic(&connection_mode);
1182
1183 if mode.is_disconnect() {
1184 tracing::debug!("Disconnecting");
1185
1186 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1187 if tokio::time::timeout(timeout, async {
1188 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1190
1191 if let Some(task) = &inner.read_task
1192 && !task.is_finished()
1193 {
1194 task.abort();
1195 log_task_aborted("read");
1196 }
1197
1198 if let Some(task) = &inner.heartbeat_task
1199 && !task.is_finished()
1200 {
1201 task.abort();
1202 log_task_aborted("heartbeat");
1203 }
1204 })
1205 .await
1206 .is_err()
1207 {
1208 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
1209 }
1210
1211 tracing::debug!("Closed");
1212 break; }
1214
1215 if mode.is_active() && !inner.is_alive() {
1216 if connection_mode
1217 .compare_exchange(
1218 ConnectionMode::Active.as_u8(),
1219 ConnectionMode::Reconnect.as_u8(),
1220 Ordering::SeqCst,
1221 Ordering::SeqCst,
1222 )
1223 .is_ok()
1224 {
1225 tracing::debug!("Detected dead read task, transitioning to RECONNECT");
1226 }
1227 mode = ConnectionMode::from_atomic(&connection_mode);
1228 }
1229
1230 if mode.is_reconnect() {
1231 if let Some(max_attempts) = inner.reconnect_max_attempts
1233 && inner.reconnection_attempt_count >= max_attempts
1234 {
1235 tracing::error!(
1236 "Max reconnection attempts ({}) exceeded, transitioning to CLOSED",
1237 max_attempts
1238 );
1239 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1240 break;
1241 }
1242
1243 inner.reconnection_attempt_count += 1;
1244 tracing::debug!(
1245 "Reconnection attempt {} of {}",
1246 inner.reconnection_attempt_count,
1247 inner
1248 .reconnect_max_attempts
1249 .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1250 );
1251
1252 match inner.reconnect().await {
1253 Ok(()) => {
1254 inner.backoff.reset();
1255 inner.reconnection_attempt_count = 0; if ConnectionMode::from_atomic(&connection_mode).is_active() {
1259 if let Some(ref handler) = inner.config.message_handler {
1260 let reconnected_msg =
1261 Message::Text(RECONNECTED.to_string().into());
1262 handler(reconnected_msg);
1263 tracing::debug!("Sent reconnected message to handler");
1264 }
1265
1266 if let Some(ref callback) = post_reconnection {
1268 callback();
1269 tracing::debug!("Called `post_reconnection` handler");
1270 }
1271
1272 tracing::debug!("Reconnected successfully");
1273 } else {
1274 tracing::debug!(
1275 "Skipping post_reconnection handlers due to disconnect state"
1276 );
1277 }
1278 }
1279 Err(e) => {
1280 let duration = inner.backoff.next_duration();
1281 tracing::warn!(
1282 "Reconnect attempt {} failed: {e}",
1283 inner.reconnection_attempt_count
1284 );
1285 if !duration.is_zero() {
1286 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
1287 }
1288 tokio::time::sleep(duration).await;
1289 }
1290 }
1291 }
1292 }
1293 inner
1294 .connection_mode
1295 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1296
1297 log_task_stopped("controller");
1298 })
1299 }
1300}
1301
1302impl Drop for WebSocketClient {
1304 fn drop(&mut self) {
1305 if !self.controller_task.is_finished() {
1306 self.controller_task.abort();
1307 log_task_aborted("controller");
1308 }
1309 }
1310}
1311
1312#[cfg(test)]
1313#[cfg(not(feature = "turmoil"))]
1314#[cfg(target_os = "linux")] mod tests {
1316 use std::{num::NonZeroU32, sync::Arc};
1317
1318 use futures_util::{SinkExt, StreamExt};
1319 use tokio::{
1320 net::TcpListener,
1321 task::{self, JoinHandle},
1322 };
1323 use tokio_tungstenite::{
1324 accept_hdr_async,
1325 tungstenite::{
1326 handshake::server::{self, Callback},
1327 http::HeaderValue,
1328 },
1329 };
1330
1331 use crate::{
1332 ratelimiter::quota::Quota,
1333 websocket::{WebSocketClient, WebSocketConfig},
1334 };
1335
1336 struct TestServer {
1337 task: JoinHandle<()>,
1338 port: u16,
1339 }
1340
1341 #[derive(Debug, Clone)]
1342 struct TestCallback {
1343 key: String,
1344 value: HeaderValue,
1345 }
1346
1347 impl Callback for TestCallback {
1348 #[allow(clippy::panic_in_result_fn)]
1349 fn on_request(
1350 self,
1351 request: &server::Request,
1352 response: server::Response,
1353 ) -> Result<server::Response, server::ErrorResponse> {
1354 let _ = response;
1355 let value = request.headers().get(&self.key);
1356 assert!(value.is_some());
1357
1358 if let Some(value) = request.headers().get(&self.key) {
1359 assert_eq!(value, self.value);
1360 }
1361
1362 Ok(response)
1363 }
1364 }
1365
1366 impl TestServer {
1367 async fn setup() -> Self {
1368 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1369 let port = TcpListener::local_addr(&server).unwrap().port();
1370
1371 let header_key = "test".to_string();
1372 let header_value = "test".to_string();
1373
1374 let test_call_back = TestCallback {
1375 key: header_key,
1376 value: HeaderValue::from_str(&header_value).unwrap(),
1377 };
1378
1379 let task = task::spawn(async move {
1380 loop {
1382 let (conn, _) = server.accept().await.unwrap();
1383 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1384 .await
1385 .unwrap();
1386
1387 task::spawn(async move {
1388 while let Some(Ok(msg)) = websocket.next().await {
1389 match msg {
1390 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1391 if txt == "close-now" =>
1392 {
1393 tracing::debug!("Forcibly closing from server side");
1394 let _ = websocket.close(None).await;
1396 break;
1397 }
1398 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1400 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1401 if websocket.send(msg).await.is_err() {
1402 break;
1403 }
1404 }
1405 tokio_tungstenite::tungstenite::protocol::Message::Close(
1407 _frame,
1408 ) => {
1409 let _ = websocket.close(None).await;
1410 break;
1411 }
1412 _ => {}
1414 }
1415 }
1416 });
1417 }
1418 });
1419
1420 Self { task, port }
1421 }
1422 }
1423
1424 impl Drop for TestServer {
1425 fn drop(&mut self) {
1426 self.task.abort();
1427 }
1428 }
1429
1430 async fn setup_test_client(port: u16) -> WebSocketClient {
1431 let config = WebSocketConfig {
1432 url: format!("ws://127.0.0.1:{port}"),
1433 headers: vec![("test".into(), "test".into())],
1434 message_handler: Some(Arc::new(|_| {})),
1435 heartbeat: None,
1436 heartbeat_msg: None,
1437 ping_handler: None,
1438 reconnect_timeout_ms: None,
1439 reconnect_delay_initial_ms: None,
1440 reconnect_backoff_factor: None,
1441 reconnect_delay_max_ms: None,
1442 reconnect_jitter_ms: None,
1443 reconnect_max_attempts: None,
1444 };
1445 WebSocketClient::connect(config, None, vec![], None)
1446 .await
1447 .expect("Failed to connect")
1448 }
1449
1450 #[tokio::test]
1451 async fn test_websocket_basic() {
1452 let server = TestServer::setup().await;
1453 let client = setup_test_client(server.port).await;
1454
1455 assert!(!client.is_disconnected());
1456
1457 client.disconnect().await;
1458 assert!(client.is_disconnected());
1459 }
1460
1461 #[tokio::test]
1462 async fn test_websocket_heartbeat() {
1463 let server = TestServer::setup().await;
1464 let client = setup_test_client(server.port).await;
1465
1466 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1468
1469 client.disconnect().await;
1471 assert!(client.is_disconnected());
1472 }
1473
1474 #[tokio::test]
1475 async fn test_websocket_reconnect_exhausted() {
1476 let config = WebSocketConfig {
1477 url: "ws://127.0.0.1:9997".into(), headers: vec![],
1479 message_handler: Some(Arc::new(|_| {})),
1480 heartbeat: None,
1481 heartbeat_msg: None,
1482 ping_handler: None,
1483 reconnect_timeout_ms: None,
1484 reconnect_delay_initial_ms: None,
1485 reconnect_backoff_factor: None,
1486 reconnect_delay_max_ms: None,
1487 reconnect_jitter_ms: None,
1488 reconnect_max_attempts: None,
1489 };
1490 let res = WebSocketClient::connect(config, None, vec![], None).await;
1491 assert!(res.is_err(), "Should fail quickly with no server");
1492 }
1493
1494 #[tokio::test]
1495 async fn test_websocket_forced_close_reconnect() {
1496 let server = TestServer::setup().await;
1497 let client = setup_test_client(server.port).await;
1498
1499 client.send_text("Hello".into(), None).await.unwrap();
1501
1502 client.send_text("close-now".into(), None).await.unwrap();
1504
1505 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1507
1508 assert!(!client.is_disconnected());
1510
1511 client.disconnect().await;
1513 assert!(client.is_disconnected());
1514 }
1515
1516 #[tokio::test]
1517 async fn test_rate_limiter() {
1518 let server = TestServer::setup().await;
1519 let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1520
1521 let config = WebSocketConfig {
1522 url: format!("ws://127.0.0.1:{}", server.port),
1523 headers: vec![("test".into(), "test".into())],
1524 message_handler: Some(Arc::new(|_| {})),
1525 heartbeat: None,
1526 heartbeat_msg: None,
1527 ping_handler: None,
1528 reconnect_timeout_ms: None,
1529 reconnect_delay_initial_ms: None,
1530 reconnect_backoff_factor: None,
1531 reconnect_delay_max_ms: None,
1532 reconnect_jitter_ms: None,
1533 reconnect_max_attempts: None,
1534 };
1535
1536 let client = WebSocketClient::connect(config, None, vec![("default".into(), quota)], None)
1537 .await
1538 .unwrap();
1539
1540 client.send_text("test1".into(), None).await.unwrap();
1542 client.send_text("test2".into(), None).await.unwrap();
1543
1544 client.send_text("test3".into(), None).await.unwrap();
1546
1547 client.disconnect().await;
1549 assert!(client.is_disconnected());
1550 }
1551
1552 #[tokio::test]
1553 async fn test_concurrent_writers() {
1554 let server = TestServer::setup().await;
1555 let client = Arc::new(setup_test_client(server.port).await);
1556
1557 let mut handles = vec![];
1558 for i in 0..10 {
1559 let client = client.clone();
1560 handles.push(task::spawn(async move {
1561 client.send_text(format!("test{i}"), None).await.unwrap();
1562 }));
1563 }
1564
1565 for handle in handles {
1566 handle.await.unwrap();
1567 }
1568
1569 client.disconnect().await;
1571 assert!(client.is_disconnected());
1572 }
1573}
1574
1575#[cfg(test)]
1576#[cfg(not(feature = "turmoil"))]
1577mod rust_tests {
1578 use futures_util::StreamExt;
1579 use rstest::rstest;
1580 use tokio::{
1581 net::TcpListener,
1582 task,
1583 time::{Duration, sleep},
1584 };
1585 use tokio_tungstenite::accept_async;
1586
1587 use super::*;
1588 use crate::websocket::types::channel_message_handler;
1589
1590 #[rstest]
1591 #[tokio::test]
1592 async fn test_reconnect_then_disconnect() {
1593 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1595 let port = listener.local_addr().unwrap().port();
1596
1597 let server = task::spawn(async move {
1599 let (stream, _) = listener.accept().await.unwrap();
1600 let ws = accept_async(stream).await.unwrap();
1601 drop(ws);
1602 sleep(Duration::from_secs(1)).await;
1604 });
1605
1606 let (handler, _rx) = channel_message_handler();
1608
1609 let config = WebSocketConfig {
1611 url: format!("ws://127.0.0.1:{port}"),
1612 headers: vec![],
1613 message_handler: Some(handler),
1614 heartbeat: None,
1615 heartbeat_msg: None,
1616 ping_handler: None,
1617 reconnect_timeout_ms: Some(1_000),
1618 reconnect_delay_initial_ms: Some(50),
1619 reconnect_delay_max_ms: Some(100),
1620 reconnect_backoff_factor: Some(1.0),
1621 reconnect_jitter_ms: Some(0),
1622 reconnect_max_attempts: None,
1623 };
1624
1625 let client = WebSocketClient::connect(config, None, vec![], None)
1627 .await
1628 .unwrap();
1629
1630 sleep(Duration::from_millis(100)).await;
1632 client.disconnect().await;
1634 assert!(client.is_disconnected());
1635 server.abort();
1636 }
1637
1638 #[rstest]
1639 #[tokio::test]
1640 async fn test_reconnect_state_flips_when_reader_stops() {
1641 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1643 let port = listener.local_addr().unwrap().port();
1644
1645 let server = task::spawn(async move {
1646 if let Ok((stream, _)) = listener.accept().await
1647 && let Ok(ws) = accept_async(stream).await
1648 {
1649 drop(ws);
1650 }
1651 sleep(Duration::from_millis(50)).await;
1652 });
1653
1654 let (handler, _rx) = channel_message_handler();
1655
1656 let config = WebSocketConfig {
1657 url: format!("ws://127.0.0.1:{port}"),
1658 headers: vec![],
1659 message_handler: Some(handler),
1660 heartbeat: None,
1661 heartbeat_msg: None,
1662 ping_handler: None,
1663 reconnect_timeout_ms: Some(1_000),
1664 reconnect_delay_initial_ms: Some(50),
1665 reconnect_delay_max_ms: Some(100),
1666 reconnect_backoff_factor: Some(1.0),
1667 reconnect_jitter_ms: Some(0),
1668 reconnect_max_attempts: None,
1669 };
1670
1671 let client = WebSocketClient::connect(config, None, vec![], None)
1672 .await
1673 .unwrap();
1674
1675 tokio::time::timeout(Duration::from_secs(2), async {
1676 loop {
1677 if client.is_reconnecting() {
1678 break;
1679 }
1680 tokio::time::sleep(Duration::from_millis(10)).await;
1681 }
1682 })
1683 .await
1684 .expect("client did not enter RECONNECT state");
1685
1686 client.disconnect().await;
1687 server.abort();
1688 }
1689
1690 #[rstest]
1691 #[tokio::test]
1692 async fn test_stream_mode_disables_auto_reconnect() {
1693 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1696 let port = listener.local_addr().unwrap().port();
1697
1698 let server = task::spawn(async move {
1699 if let Ok((stream, _)) = listener.accept().await
1700 && let Ok(_ws) = accept_async(stream).await
1701 {
1702 sleep(Duration::from_millis(100)).await;
1704 }
1705 });
1706
1707 let config = WebSocketConfig {
1708 url: format!("ws://127.0.0.1:{port}"),
1709 headers: vec![],
1710 message_handler: None, heartbeat: None,
1712 heartbeat_msg: None,
1713 ping_handler: None,
1714 reconnect_timeout_ms: Some(1_000),
1715 reconnect_delay_initial_ms: Some(50),
1716 reconnect_delay_max_ms: Some(100),
1717 reconnect_backoff_factor: Some(1.0),
1718 reconnect_jitter_ms: Some(0),
1719 reconnect_max_attempts: None,
1720 };
1721
1722 let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
1724 .await
1725 .unwrap();
1726
1727 server.abort();
1735 }
1736
1737 #[rstest]
1738 #[tokio::test]
1739 async fn test_message_handler_mode_allows_auto_reconnect() {
1740 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1742 let port = listener.local_addr().unwrap().port();
1743
1744 let server = task::spawn(async move {
1745 if let Ok((stream, _)) = listener.accept().await
1747 && let Ok(ws) = accept_async(stream).await
1748 {
1749 drop(ws);
1750 }
1751 sleep(Duration::from_millis(50)).await;
1752 });
1753
1754 let (handler, _rx) = channel_message_handler();
1755
1756 let config = WebSocketConfig {
1757 url: format!("ws://127.0.0.1:{port}"),
1758 headers: vec![],
1759 message_handler: Some(handler), heartbeat: None,
1761 heartbeat_msg: None,
1762 ping_handler: None,
1763 reconnect_timeout_ms: Some(1_000),
1764 reconnect_delay_initial_ms: Some(50),
1765 reconnect_delay_max_ms: Some(100),
1766 reconnect_backoff_factor: Some(1.0),
1767 reconnect_jitter_ms: Some(0),
1768 reconnect_max_attempts: None,
1769 };
1770
1771 let client = WebSocketClient::connect(config, None, vec![], None)
1772 .await
1773 .unwrap();
1774
1775 tokio::time::timeout(Duration::from_secs(2), async {
1777 loop {
1778 if client.is_reconnecting() || client.is_closed() {
1779 break;
1780 }
1781 tokio::time::sleep(Duration::from_millis(10)).await;
1782 }
1783 })
1784 .await
1785 .expect("client should attempt reconnection or close");
1786
1787 assert!(
1790 client.is_reconnecting() || client.is_closed(),
1791 "Client with message handler should attempt reconnection"
1792 );
1793
1794 client.disconnect().await;
1795 server.abort();
1796 }
1797
1798 #[rstest]
1799 #[tokio::test]
1800 async fn test_handler_mode_reconnect_with_new_connection() {
1801 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1803 let port = listener.local_addr().unwrap().port();
1804
1805 let server = task::spawn(async move {
1806 if let Ok((stream, _)) = listener.accept().await
1808 && let Ok(ws) = accept_async(stream).await
1809 {
1810 drop(ws);
1811 }
1812
1813 sleep(Duration::from_millis(100)).await;
1815
1816 if let Ok((stream, _)) = listener.accept().await
1818 && let Ok(mut ws) = accept_async(stream).await
1819 {
1820 use futures_util::SinkExt;
1821 let _ = ws
1822 .send(Message::Text("reconnected".to_string().into()))
1823 .await;
1824 sleep(Duration::from_secs(1)).await;
1825 }
1826 });
1827
1828 let (handler, mut rx) = channel_message_handler();
1829
1830 let config = WebSocketConfig {
1831 url: format!("ws://127.0.0.1:{port}"),
1832 headers: vec![],
1833 message_handler: Some(handler),
1834 heartbeat: None,
1835 heartbeat_msg: None,
1836 ping_handler: None,
1837 reconnect_timeout_ms: Some(2_000),
1838 reconnect_delay_initial_ms: Some(50),
1839 reconnect_delay_max_ms: Some(200),
1840 reconnect_backoff_factor: Some(1.5),
1841 reconnect_jitter_ms: Some(10),
1842 reconnect_max_attempts: None,
1843 };
1844
1845 let client = WebSocketClient::connect(config, None, vec![], None)
1846 .await
1847 .unwrap();
1848
1849 let result = tokio::time::timeout(Duration::from_secs(5), async {
1851 loop {
1852 if let Ok(msg) = rx.try_recv()
1853 && matches!(msg, Message::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
1854 {
1855 return true;
1856 }
1857 tokio::time::sleep(Duration::from_millis(10)).await;
1858 }
1859 })
1860 .await;
1861
1862 assert!(
1863 result.is_ok(),
1864 "Should receive message after reconnection within timeout"
1865 );
1866
1867 client.disconnect().await;
1868 server.abort();
1869 }
1870
1871 #[rstest]
1872 #[tokio::test]
1873 async fn test_stream_mode_no_auto_reconnect() {
1874 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1877 let port = listener.local_addr().unwrap().port();
1878
1879 let server = task::spawn(async move {
1880 if let Ok((stream, _)) = listener.accept().await
1882 && let Ok(mut ws) = accept_async(stream).await
1883 {
1884 use futures_util::SinkExt;
1885 let _ = ws.send(Message::Text("hello".to_string().into())).await;
1886 sleep(Duration::from_millis(50)).await;
1887 }
1889 });
1890
1891 let config = WebSocketConfig {
1892 url: format!("ws://127.0.0.1:{port}"),
1893 headers: vec![],
1894 message_handler: None, heartbeat: None,
1896 heartbeat_msg: None,
1897 ping_handler: None,
1898 reconnect_timeout_ms: Some(1_000),
1899 reconnect_delay_initial_ms: Some(50),
1900 reconnect_delay_max_ms: Some(100),
1901 reconnect_backoff_factor: Some(1.0),
1902 reconnect_jitter_ms: Some(0),
1903 reconnect_max_attempts: None,
1904 };
1905
1906 let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
1907 .await
1908 .unwrap();
1909
1910 assert!(client.is_active(), "Client should start as active");
1912
1913 let msg = reader.next().await;
1915 assert!(
1916 matches!(msg, Some(Ok(Message::Text(ref text))) if AsRef::<str>::as_ref(text) == "hello"),
1917 "Should receive initial message"
1918 );
1919
1920 while let Some(msg) = reader.next().await {
1922 if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
1923 break;
1924 }
1925 }
1926
1927 sleep(Duration::from_millis(200)).await;
1930
1931 assert!(
1934 client.is_active() || client.is_closed(),
1935 "Stream mode client stays ACTIVE (caller owns reader) or caller disconnected"
1936 );
1937 assert!(
1938 !client.is_reconnecting(),
1939 "Stream mode client should never attempt reconnection"
1940 );
1941
1942 client.disconnect().await;
1943 server.abort();
1944 }
1945
1946 #[rstest]
1947 #[tokio::test]
1948 async fn test_send_timeout_uses_configured_reconnect_timeout() {
1949 use nautilus_common::testing::wait_until_async;
1952
1953 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1954 let port = listener.local_addr().unwrap().port();
1955
1956 let server = task::spawn(async move {
1957 if let Ok((stream, _)) = listener.accept().await
1959 && let Ok(ws) = accept_async(stream).await
1960 {
1961 drop(ws);
1962 }
1963 sleep(Duration::from_secs(60)).await;
1965 });
1966
1967 let (handler, _rx) = channel_message_handler();
1968
1969 let config = WebSocketConfig {
1971 url: format!("ws://127.0.0.1:{port}"),
1972 headers: vec![],
1973 message_handler: Some(handler),
1974 heartbeat: None,
1975 heartbeat_msg: None,
1976 ping_handler: None,
1977 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(50),
1979 reconnect_delay_max_ms: Some(100),
1980 reconnect_backoff_factor: Some(1.0),
1981 reconnect_jitter_ms: Some(0),
1982 reconnect_max_attempts: None,
1983 };
1984
1985 let client = WebSocketClient::connect(config, None, vec![], None)
1986 .await
1987 .unwrap();
1988
1989 wait_until_async(
1991 || async { client.is_reconnecting() },
1992 Duration::from_secs(3),
1993 )
1994 .await;
1995
1996 let start = std::time::Instant::now();
1998 let send_result = client.send_text("test".to_string(), None).await;
1999 let elapsed = start.elapsed();
2000
2001 assert!(
2002 send_result.is_err(),
2003 "Send should fail when client stuck in RECONNECT"
2004 );
2005 assert!(
2006 matches!(send_result, Err(crate::error::SendError::Timeout)),
2007 "Send should return Timeout error, was: {send_result:?}"
2008 );
2009 assert!(
2012 elapsed >= Duration::from_millis(1800),
2013 "Send should timeout after at least 2s (configured timeout), took {elapsed:?}"
2014 );
2015
2016 client.disconnect().await;
2017 server.abort();
2018 }
2019
2020 #[rstest]
2021 #[tokio::test]
2022 async fn test_send_waits_during_reconnection() {
2023 use nautilus_common::testing::wait_until_async;
2025
2026 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2027 let port = listener.local_addr().unwrap().port();
2028
2029 let server = task::spawn(async move {
2030 if let Ok((stream, _)) = listener.accept().await
2032 && let Ok(ws) = accept_async(stream).await
2033 {
2034 drop(ws);
2035 }
2036
2037 sleep(Duration::from_millis(500)).await;
2039
2040 if let Ok((stream, _)) = listener.accept().await
2042 && let Ok(mut ws) = accept_async(stream).await
2043 {
2044 while let Some(Ok(msg)) = ws.next().await {
2046 if ws.send(msg).await.is_err() {
2047 break;
2048 }
2049 }
2050 }
2051 });
2052
2053 let (handler, _rx) = channel_message_handler();
2054
2055 let config = WebSocketConfig {
2056 url: format!("ws://127.0.0.1:{port}"),
2057 headers: vec![],
2058 message_handler: Some(handler),
2059 heartbeat: None,
2060 heartbeat_msg: None,
2061 ping_handler: None,
2062 reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
2064 reconnect_delay_max_ms: Some(200),
2065 reconnect_backoff_factor: Some(1.0),
2066 reconnect_jitter_ms: Some(0),
2067 reconnect_max_attempts: None,
2068 };
2069
2070 let client = WebSocketClient::connect(config, None, vec![], None)
2071 .await
2072 .unwrap();
2073
2074 wait_until_async(
2076 || async { client.is_reconnecting() },
2077 Duration::from_secs(2),
2078 )
2079 .await;
2080
2081 let send_result = tokio::time::timeout(
2083 Duration::from_secs(3),
2084 client.send_text("test_message".to_string(), None),
2085 )
2086 .await;
2087
2088 assert!(
2089 send_result.is_ok() && send_result.unwrap().is_ok(),
2090 "Send should succeed after waiting for reconnection"
2091 );
2092
2093 client.disconnect().await;
2094 server.abort();
2095 }
2096
2097 #[rstest]
2098 #[tokio::test]
2099 async fn test_rate_limiter_before_active_wait() {
2100 use std::{num::NonZeroU32, sync::Arc};
2105
2106 use nautilus_common::testing::wait_until_async;
2107
2108 use crate::ratelimiter::quota::Quota;
2109
2110 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2111 let port = listener.local_addr().unwrap().port();
2112
2113 let server = task::spawn(async move {
2114 if let Ok((stream, _)) = listener.accept().await
2116 && let Ok(mut ws) = accept_async(stream).await
2117 {
2118 if let Some(Ok(_)) = ws.next().await {
2120 drop(ws);
2121 }
2122 }
2123
2124 sleep(Duration::from_millis(500)).await;
2126
2127 if let Ok((stream, _)) = listener.accept().await
2129 && let Ok(mut ws) = accept_async(stream).await
2130 {
2131 while let Some(Ok(msg)) = ws.next().await {
2132 if ws.send(msg).await.is_err() {
2133 break;
2134 }
2135 }
2136 }
2137 });
2138
2139 let (handler, _rx) = channel_message_handler();
2140
2141 let config = WebSocketConfig {
2142 url: format!("ws://127.0.0.1:{port}"),
2143 headers: vec![],
2144 message_handler: Some(handler),
2145 heartbeat: None,
2146 heartbeat_msg: None,
2147 ping_handler: None,
2148 reconnect_timeout_ms: Some(5_000),
2149 reconnect_delay_initial_ms: Some(50),
2150 reconnect_delay_max_ms: Some(100),
2151 reconnect_backoff_factor: Some(1.0),
2152 reconnect_jitter_ms: Some(0),
2153 reconnect_max_attempts: None,
2154 };
2155
2156 let quota =
2158 Quota::per_second(NonZeroU32::new(1).unwrap()).allow_burst(NonZeroU32::new(1).unwrap());
2159
2160 let client = Arc::new(
2161 WebSocketClient::connect(config, None, vec![("test_key".to_string(), quota)], None)
2162 .await
2163 .unwrap(),
2164 );
2165
2166 client
2168 .send_text("msg1".to_string(), Some(vec!["test_key".to_string()]))
2169 .await
2170 .unwrap();
2171
2172 wait_until_async(
2174 || async { client.is_reconnecting() },
2175 Duration::from_secs(2),
2176 )
2177 .await;
2178
2179 let start = std::time::Instant::now();
2181 let send_result = client
2182 .send_text("msg2".to_string(), Some(vec!["test_key".to_string()]))
2183 .await;
2184 let elapsed = start.elapsed();
2185
2186 assert!(
2188 send_result.is_ok(),
2189 "Send should succeed after rate limit + reconnection, was: {send_result:?}"
2190 );
2191 assert!(
2195 elapsed >= Duration::from_millis(850),
2196 "Should wait for rate limit (~1s), waited {elapsed:?}"
2197 );
2198
2199 client.disconnect().await;
2200 server.abort();
2201 }
2202
2203 #[rstest]
2204 #[tokio::test]
2205 async fn test_disconnect_during_reconnect_exits_cleanly() {
2206 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2209 let port = listener.local_addr().unwrap().port();
2210
2211 let server = task::spawn(async move {
2212 if let Ok((stream, _)) = listener.accept().await
2214 && let Ok(ws) = accept_async(stream).await
2215 {
2216 drop(ws);
2217 }
2218 sleep(Duration::from_secs(60)).await;
2220 });
2221
2222 let (handler, _rx) = channel_message_handler();
2223
2224 let config = WebSocketConfig {
2225 url: format!("ws://127.0.0.1:{port}"),
2226 headers: vec![],
2227 message_handler: Some(handler),
2228 heartbeat: None,
2229 heartbeat_msg: None,
2230 ping_handler: None,
2231 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(100),
2233 reconnect_delay_max_ms: Some(200),
2234 reconnect_backoff_factor: Some(1.0),
2235 reconnect_jitter_ms: Some(0),
2236 reconnect_max_attempts: None,
2237 };
2238
2239 let client = WebSocketClient::connect(config, None, vec![], None)
2240 .await
2241 .unwrap();
2242
2243 tokio::time::timeout(Duration::from_secs(2), async {
2245 while !client.is_reconnecting() {
2246 sleep(Duration::from_millis(10)).await;
2247 }
2248 })
2249 .await
2250 .expect("Client should enter RECONNECT state");
2251
2252 client.disconnect().await;
2254
2255 assert!(
2257 client.is_disconnected(),
2258 "Client should be cleanly disconnected"
2259 );
2260
2261 server.abort();
2262 }
2263
2264 #[rstest]
2265 #[tokio::test]
2266 async fn test_send_fails_fast_when_closed_before_rate_limit() {
2267 use std::{num::NonZeroU32, sync::Arc};
2270
2271 use nautilus_common::testing::wait_until_async;
2272
2273 use crate::ratelimiter::quota::Quota;
2274
2275 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2276 let port = listener.local_addr().unwrap().port();
2277
2278 let server = task::spawn(async move {
2279 if let Ok((stream, _)) = listener.accept().await
2281 && let Ok(ws) = accept_async(stream).await
2282 {
2283 drop(ws);
2284 }
2285 sleep(Duration::from_secs(60)).await;
2286 });
2287
2288 let (handler, _rx) = channel_message_handler();
2289
2290 let config = WebSocketConfig {
2291 url: format!("ws://127.0.0.1:{port}"),
2292 headers: vec![],
2293 message_handler: Some(handler),
2294 heartbeat: None,
2295 heartbeat_msg: None,
2296 ping_handler: None,
2297 reconnect_timeout_ms: Some(5_000),
2298 reconnect_delay_initial_ms: Some(50),
2299 reconnect_delay_max_ms: Some(100),
2300 reconnect_backoff_factor: Some(1.0),
2301 reconnect_jitter_ms: Some(0),
2302 reconnect_max_attempts: None,
2303 };
2304
2305 let quota = Quota::with_period(Duration::from_secs(10))
2308 .unwrap()
2309 .allow_burst(NonZeroU32::new(1).unwrap());
2310
2311 let client = Arc::new(
2312 WebSocketClient::connect(config, None, vec![("test_key".to_string(), quota)], None)
2313 .await
2314 .unwrap(),
2315 );
2316
2317 wait_until_async(
2319 || async { client.is_reconnecting() || client.is_closed() },
2320 Duration::from_secs(2),
2321 )
2322 .await;
2323
2324 client.disconnect().await;
2326 assert!(
2327 !client.is_active(),
2328 "Client should not be active after disconnect"
2329 );
2330
2331 let start = std::time::Instant::now();
2333 let result = client
2334 .send_text("test".to_string(), Some(vec!["test_key".to_string()]))
2335 .await;
2336 let elapsed = start.elapsed();
2337
2338 assert!(result.is_err(), "Send should fail when client is closed");
2340 assert!(
2341 matches!(result, Err(crate::error::SendError::Closed)),
2342 "Send should return Closed error, was: {result:?}"
2343 );
2344
2345 assert!(
2347 elapsed < Duration::from_millis(100),
2348 "Send should fail fast without rate limiting, took {elapsed:?}"
2349 );
2350
2351 server.abort();
2352 }
2353
2354 #[rstest]
2355 #[tokio::test]
2356 async fn test_connect_rejects_config_without_message_handler() {
2357 let config = WebSocketConfig {
2361 url: "ws://127.0.0.1:9999".to_string(),
2362 headers: vec![],
2363 message_handler: None, heartbeat: None,
2365 heartbeat_msg: None,
2366 ping_handler: None,
2367 reconnect_timeout_ms: Some(1_000),
2368 reconnect_delay_initial_ms: Some(100),
2369 reconnect_delay_max_ms: Some(500),
2370 reconnect_backoff_factor: Some(1.5),
2371 reconnect_jitter_ms: Some(0),
2372 reconnect_max_attempts: None,
2373 };
2374
2375 let result = WebSocketClient::connect(config, None, vec![], None).await;
2376
2377 assert!(
2378 result.is_err(),
2379 "connect() should reject configs without message_handler"
2380 );
2381
2382 let err = result.unwrap_err();
2383 let err_msg = err.to_string();
2384 assert!(
2385 err_msg.contains("Handler mode requires config.message_handler"),
2386 "Error should mention missing message_handler, was: {err_msg}"
2387 );
2388 }
2389
2390 #[rstest]
2391 #[tokio::test]
2392 async fn test_client_without_handler_sets_stream_mode() {
2393 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2397 let port = listener.local_addr().unwrap().port();
2398
2399 let server = task::spawn(async move {
2400 if let Ok((stream, _)) = listener.accept().await
2402 && let Ok(ws) = accept_async(stream).await
2403 {
2404 drop(ws); }
2406 });
2407
2408 let config = WebSocketConfig {
2409 url: format!("ws://127.0.0.1:{port}"),
2410 headers: vec![],
2411 message_handler: None, heartbeat: None,
2413 heartbeat_msg: None,
2414 ping_handler: None,
2415 reconnect_timeout_ms: Some(1_000),
2416 reconnect_delay_initial_ms: Some(100),
2417 reconnect_delay_max_ms: Some(500),
2418 reconnect_backoff_factor: Some(1.5),
2419 reconnect_jitter_ms: Some(0),
2420 reconnect_max_attempts: None,
2421 };
2422
2423 let inner = WebSocketClientInner::connect_url(config).await.unwrap();
2425
2426 assert!(
2428 inner.is_stream_mode,
2429 "Client without handler should have is_stream_mode=true"
2430 );
2431
2432 server.abort();
2436 }
2437}