1use std::{
26 collections::VecDeque,
27 fmt::Debug,
28 sync::{
29 Arc,
30 atomic::{AtomicU8, Ordering},
31 },
32 time::Duration,
33};
34
35use futures_util::{SinkExt, StreamExt};
36use http::HeaderName;
37use nautilus_core::CleanDrop;
38use nautilus_cryptography::providers::install_cryptographic_provider;
39#[cfg(feature = "turmoil")]
40use tokio_tungstenite::MaybeTlsStream;
41#[cfg(feature = "turmoil")]
42use tokio_tungstenite::client_async;
43#[cfg(not(feature = "turmoil"))]
44use tokio_tungstenite::connect_async_with_config;
45use tokio_tungstenite::tungstenite::{
46 Error, Message, client::IntoClientRequest, http::HeaderValue,
47};
48use ustr::Ustr;
49
50use super::{
51 config::WebSocketConfig,
52 consts::{
53 CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
54 GRACEFUL_SHUTDOWN_TIMEOUT_SECS, SEND_OPERATION_CHECK_INTERVAL_MS,
55 },
56 types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
57};
58#[cfg(feature = "turmoil")]
59use crate::net::TcpConnector;
60use crate::{
61 RECONNECTED,
62 backoff::ExponentialBackoff,
63 error::SendError,
64 logging::{log_task_aborted, log_task_started, log_task_stopped},
65 mode::ConnectionMode,
66 ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
67};
68
69pub struct WebSocketClientInner {
85 config: WebSocketConfig,
86 message_handler: Option<MessageHandler>,
88 ping_handler: Option<PingHandler>,
90 read_task: Option<tokio::task::JoinHandle<()>>,
91 write_task: tokio::task::JoinHandle<()>,
92 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
93 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
94 connection_mode: Arc<AtomicU8>,
95 reconnect_timeout: Duration,
96 backoff: ExponentialBackoff,
97 is_stream_mode: bool,
101 reconnect_max_attempts: Option<u32>,
103 reconnection_attempt_count: u32,
105}
106
107impl WebSocketClientInner {
108 pub async fn new_with_writer(
116 config: WebSocketConfig,
117 writer: MessageWriter,
118 ) -> Result<Self, Error> {
119 install_cryptographic_provider();
120
121 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
122
123 let read_task = None;
125
126 let backoff = ExponentialBackoff::new(
127 Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
128 Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
129 config.reconnect_backoff_factor.unwrap_or(1.5),
130 config.reconnect_jitter_ms.unwrap_or(100),
131 true, )
133 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
134
135 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
136 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
137
138 let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
139 Some(Self::spawn_heartbeat_task(
140 connection_mode.clone(),
141 heartbeat_interval,
142 config.heartbeat_msg.clone(),
143 writer_tx.clone(),
144 ))
145 } else {
146 None
147 };
148
149 let reconnect_max_attempts = config.reconnect_max_attempts;
150 let reconnect_timeout = Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10000));
151
152 Ok(Self {
153 config,
154 message_handler: None, ping_handler: None,
156 writer_tx,
157 connection_mode,
158 reconnect_timeout,
159 heartbeat_task,
160 read_task,
161 write_task,
162 backoff,
163 is_stream_mode: true,
164 reconnect_max_attempts,
165 reconnection_attempt_count: 0,
166 })
167 }
168
169 pub async fn connect_url(
177 config: WebSocketConfig,
178 message_handler: Option<MessageHandler>,
179 ping_handler: Option<PingHandler>,
180 ) -> Result<Self, Error> {
181 install_cryptographic_provider();
182
183 let is_stream_mode = message_handler.is_none();
185 let reconnect_max_attempts = config.reconnect_max_attempts;
186
187 let (writer, reader) =
188 Self::connect_with_server(&config.url, config.headers.clone()).await?;
189
190 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
191
192 let read_task = if message_handler.is_some() {
193 Some(Self::spawn_message_handler_task(
194 connection_mode.clone(),
195 reader,
196 message_handler.as_ref(),
197 ping_handler.as_ref(),
198 ))
199 } else {
200 None
201 };
202
203 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
204 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
205
206 let heartbeat_task = config.heartbeat.map(|heartbeat_secs| {
208 Self::spawn_heartbeat_task(
209 connection_mode.clone(),
210 heartbeat_secs,
211 config.heartbeat_msg.clone(),
212 writer_tx.clone(),
213 )
214 });
215
216 let reconnect_timeout =
217 Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10_000));
218 let backoff = ExponentialBackoff::new(
219 Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
220 Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
221 config.reconnect_backoff_factor.unwrap_or(1.5),
222 config.reconnect_jitter_ms.unwrap_or(100),
223 true, )
225 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
226
227 Ok(Self {
228 config,
229 message_handler,
230 ping_handler,
231 read_task,
232 write_task,
233 writer_tx,
234 heartbeat_task,
235 connection_mode,
236 reconnect_timeout,
237 backoff,
238 is_stream_mode,
240 reconnect_max_attempts,
241 reconnection_attempt_count: 0,
242 })
243 }
244
245 #[inline]
255 #[cfg(not(feature = "turmoil"))]
256 pub async fn connect_with_server(
257 url: &str,
258 headers: Vec<(String, String)>,
259 ) -> Result<(MessageWriter, MessageReader), Error> {
260 let mut request = url.into_client_request()?;
261 let req_headers = request.headers_mut();
262
263 let mut header_names: Vec<HeaderName> = Vec::new();
264 for (key, val) in headers {
265 let header_value = HeaderValue::from_str(&val)?;
266 let header_name: HeaderName = key.parse()?;
267 header_names.push(header_name.clone());
268 req_headers.insert(header_name, header_value);
269 }
270
271 connect_async_with_config(request, None, true)
272 .await
273 .map(|resp| resp.0.split())
274 }
275
276 #[inline]
289 #[cfg(feature = "turmoil")]
290 pub async fn connect_with_server(
291 url: &str,
292 headers: Vec<(String, String)>,
293 ) -> Result<(MessageWriter, MessageReader), Error> {
294 use rustls::ClientConfig;
295 use tokio_rustls::TlsConnector;
296
297 let mut request = url.into_client_request()?;
298 let req_headers = request.headers_mut();
299
300 let mut header_names: Vec<HeaderName> = Vec::new();
301 for (key, val) in headers {
302 let header_value = HeaderValue::from_str(&val)?;
303 let header_name: HeaderName = key.parse()?;
304 header_names.push(header_name.clone());
305 req_headers.insert(header_name, header_value);
306 }
307
308 let uri = request.uri();
309 let scheme = uri.scheme_str().unwrap_or("ws");
310 let host = uri.host().ok_or_else(|| {
311 Error::Url(tokio_tungstenite::tungstenite::error::UrlError::NoHostName)
312 })?;
313
314 let port = uri
316 .port_u16()
317 .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
318
319 let addr = format!("{host}:{port}");
320
321 let connector = crate::net::RealTcpConnector;
323 let tcp_stream = connector.connect(&addr).await?;
324 if let Err(e) = tcp_stream.set_nodelay(true) {
325 log::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
326 }
327
328 let maybe_tls_stream = if scheme == "wss" {
330 let mut root_store = rustls::RootCertStore::empty();
332 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
333
334 let config = ClientConfig::builder()
335 .with_root_certificates(root_store)
336 .with_no_client_auth();
337
338 let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
339 let domain =
340 rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|e| {
341 Error::Io(std::io::Error::new(
342 std::io::ErrorKind::InvalidInput,
343 format!("Invalid DNS name: {e}"),
344 ))
345 })?;
346
347 let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
348 MaybeTlsStream::Rustls(tls_stream)
349 } else {
350 MaybeTlsStream::Plain(tcp_stream)
351 };
352
353 client_async(request, maybe_tls_stream)
355 .await
356 .map(|resp| resp.0.split())
357 }
358
359 pub async fn reconnect(&mut self) -> Result<(), Error> {
374 log::debug!("Reconnecting");
375
376 if self.is_stream_mode {
377 log::warn!(
378 "Auto-reconnect disabled for stream-based WebSocket client; \
379 stream users must manually reconnect by creating a new connection"
380 );
381 self.connection_mode
383 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
384 return Ok(());
385 }
386
387 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
388 log::debug!("Reconnect aborted due to disconnect state");
389 return Ok(());
390 }
391
392 tokio::time::timeout(self.reconnect_timeout, async {
393 let (new_writer, reader) =
395 Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
396
397 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
398 log::debug!("Reconnect aborted mid-flight (after connect)");
399 return Ok(());
400 }
401
402 let (tx, rx) = tokio::sync::oneshot::channel();
406 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
407 log::error!("{e}");
408 return Err(Error::Io(std::io::Error::new(
409 std::io::ErrorKind::BrokenPipe,
410 format!("Failed to send update command: {e}"),
411 )));
412 }
413
414 match rx.await {
416 Ok(true) => log::debug!("Writer confirmed buffer drain success"),
417 Ok(false) => {
418 log::warn!("Writer failed to drain buffer, aborting reconnect");
419 return Err(Error::Io(std::io::Error::other(
421 "Failed to drain reconnection buffer",
422 )));
423 }
424 Err(e) => {
425 log::error!("Writer dropped update channel: {e}");
426 return Err(Error::Io(std::io::Error::new(
427 std::io::ErrorKind::BrokenPipe,
428 "Writer task dropped response channel",
429 )));
430 }
431 }
432
433 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
435
436 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
437 log::debug!("Reconnect aborted mid-flight (after delay)");
438 return Ok(());
439 }
440
441 if let Some(ref read_task) = self.read_task.take()
442 && !read_task.is_finished()
443 {
444 read_task.abort();
445 log_task_aborted("read");
446 }
447
448 if self
451 .connection_mode
452 .compare_exchange(
453 ConnectionMode::Reconnect.as_u8(),
454 ConnectionMode::Active.as_u8(),
455 Ordering::SeqCst,
456 Ordering::SeqCst,
457 )
458 .is_err()
459 {
460 log::debug!("Reconnect aborted (state changed during reconnect)");
461 return Ok(());
462 }
463
464 self.read_task = if self.message_handler.is_some() {
465 Some(Self::spawn_message_handler_task(
466 self.connection_mode.clone(),
467 reader,
468 self.message_handler.as_ref(),
469 self.ping_handler.as_ref(),
470 ))
471 } else {
472 None
473 };
474
475 log::debug!("Reconnect succeeded");
476 Ok(())
477 })
478 .await
479 .map_err(|_| {
480 Error::Io(std::io::Error::new(
481 std::io::ErrorKind::TimedOut,
482 format!(
483 "reconnection timed out after {}s",
484 self.reconnect_timeout.as_secs_f64()
485 ),
486 ))
487 })?
488 }
489
490 #[inline]
498 #[must_use]
499 pub fn is_alive(&self) -> bool {
500 match &self.read_task {
501 Some(read_task) => !read_task.is_finished(),
502 None => true, }
504 }
505
506 fn spawn_message_handler_task(
507 connection_state: Arc<AtomicU8>,
508 mut reader: MessageReader,
509 message_handler: Option<&MessageHandler>,
510 ping_handler: Option<&PingHandler>,
511 ) -> tokio::task::JoinHandle<()> {
512 log::debug!("Started message handler task 'read'");
513
514 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
515
516 let message_handler = message_handler.cloned();
518 let ping_handler = ping_handler.cloned();
519
520 tokio::task::spawn(async move {
521 loop {
522 if !ConnectionMode::from_atomic(&connection_state).is_active() {
523 break;
524 }
525
526 match tokio::time::timeout(check_interval, reader.next()).await {
527 Ok(Some(Ok(Message::Binary(data)))) => {
528 log::trace!("Received message <binary> {} bytes", data.len());
529 if let Some(ref handler) = message_handler {
530 handler(Message::Binary(data));
531 }
532 }
533 Ok(Some(Ok(Message::Text(data)))) => {
534 log::trace!("Received message: {data}");
535 if let Some(ref handler) = message_handler {
536 handler(Message::Text(data));
537 }
538 }
539 Ok(Some(Ok(Message::Ping(ping_data)))) => {
540 log::trace!("Received ping: {ping_data:?}");
541 if let Some(ref handler) = ping_handler {
542 handler(ping_data.to_vec());
543 }
544 }
545 Ok(Some(Ok(Message::Pong(_)))) => {
546 log::trace!("Received pong");
547 }
548 Ok(Some(Ok(Message::Close(_)))) => {
549 log::debug!("Received close message - terminating");
550 break;
551 }
552 Ok(Some(Ok(_))) => (),
553 Ok(Some(Err(e))) => {
554 log::error!("Received error message - terminating: {e}");
555 break;
556 }
557 Ok(None) => {
558 log::debug!("No message received - terminating");
559 break;
560 }
561 Err(_) => {
562 continue;
564 }
565 }
566 }
567 })
568 }
569
570 async fn drain_reconnect_buffer(
575 buffer: &mut VecDeque<Message>,
576 writer: &mut MessageWriter,
577 ) -> bool {
578 if buffer.is_empty() {
579 return false;
580 }
581
582 let initial_buffer_len = buffer.len();
583 log::info!("Sending {initial_buffer_len} buffered messages after reconnection");
584
585 let mut send_error_occurred = false;
586
587 while let Some(buffered_msg) = buffer.front() {
588 let msg_to_send = buffered_msg.clone();
590
591 if let Err(e) = writer.send(msg_to_send).await {
592 log::error!(
593 "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
594 buffer.len()
595 );
596 send_error_occurred = true;
597 break; }
599
600 buffer.pop_front();
602 }
603
604 if buffer.is_empty() {
605 log::info!("Successfully sent all {initial_buffer_len} buffered messages");
606 }
607
608 send_error_occurred
609 }
610
611 fn spawn_write_task(
612 connection_state: Arc<AtomicU8>,
613 writer: MessageWriter,
614 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
615 ) -> tokio::task::JoinHandle<()> {
616 log_task_started("write");
617
618 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
620
621 tokio::task::spawn(async move {
622 let mut active_writer = writer;
623 let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
626
627 loop {
628 match ConnectionMode::from_atomic(&connection_state) {
629 ConnectionMode::Disconnect => {
630 if !reconnect_buffer.is_empty() {
632 log::warn!(
633 "Discarding {} buffered messages due to disconnect",
634 reconnect_buffer.len()
635 );
636 reconnect_buffer.clear();
637 }
638
639 _ = tokio::time::timeout(
642 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
643 active_writer.close(),
644 )
645 .await;
646 break;
647 }
648 ConnectionMode::Closed => {
649 if !reconnect_buffer.is_empty() {
651 log::warn!(
652 "Discarding {} buffered messages due to closed connection",
653 reconnect_buffer.len()
654 );
655 reconnect_buffer.clear();
656 }
657 break;
658 }
659 _ => {}
660 }
661
662 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
663 Ok(Some(msg)) => {
664 let mode = ConnectionMode::from_atomic(&connection_state);
666 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
667 break;
668 }
669
670 match msg {
671 WriterCommand::Update(new_writer, tx) => {
672 log::debug!("Received new writer");
673
674 tokio::time::sleep(Duration::from_millis(100)).await;
676
677 _ = tokio::time::timeout(
680 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
681 active_writer.close(),
682 )
683 .await;
684
685 active_writer = new_writer;
686 log::debug!("Updated writer");
687
688 let send_error = Self::drain_reconnect_buffer(
689 &mut reconnect_buffer,
690 &mut active_writer,
691 )
692 .await;
693
694 if let Err(e) = tx.send(!send_error) {
695 log::error!(
696 "Failed to report drain status to controller: {e:?}"
697 );
698 }
699 }
700 WriterCommand::Send(msg) if mode.is_reconnect() => {
701 log::debug!(
703 "Buffering message during reconnection (buffer size: {})",
704 reconnect_buffer.len() + 1
705 );
706 reconnect_buffer.push_back(msg);
707 }
708 WriterCommand::Send(msg) => {
709 if let Err(e) = active_writer.send(msg.clone()).await {
710 log::error!("Failed to send message: {e}");
711 log::warn!("Writer triggering reconnect");
712 reconnect_buffer.push_back(msg);
713 connection_state
714 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
715 }
716 }
717 }
718 }
719 Ok(None) => {
720 log::debug!("Writer channel closed, terminating writer task");
722 break;
723 }
724 Err(_) => {
725 continue;
727 }
728 }
729 }
730
731 _ = tokio::time::timeout(
734 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
735 active_writer.close(),
736 )
737 .await;
738
739 log_task_stopped("write");
740 })
741 }
742
743 fn spawn_heartbeat_task(
744 connection_state: Arc<AtomicU8>,
745 heartbeat_secs: u64,
746 message: Option<String>,
747 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
748 ) -> tokio::task::JoinHandle<()> {
749 log_task_started("heartbeat");
750
751 tokio::task::spawn(async move {
752 let interval = Duration::from_secs(heartbeat_secs);
753
754 loop {
755 tokio::time::sleep(interval).await;
756
757 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
758 ConnectionMode::Active => {
759 let msg = match &message {
760 Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
761 None => WriterCommand::Send(Message::Ping(vec![].into())),
762 };
763
764 match writer_tx.send(msg) {
765 Ok(()) => log::trace!("Sent heartbeat to writer task"),
766 Err(e) => {
767 log::error!("Failed to send heartbeat to writer task: {e}");
768 }
769 }
770 }
771 ConnectionMode::Reconnect => continue,
772 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
773 }
774 }
775
776 log_task_stopped("heartbeat");
777 })
778 }
779}
780
781impl Drop for WebSocketClientInner {
782 fn drop(&mut self) {
783 self.clean_drop();
785 }
786}
787
788impl CleanDrop for WebSocketClientInner {
790 fn clean_drop(&mut self) {
791 if let Some(ref read_task) = self.read_task.take()
792 && !read_task.is_finished()
793 {
794 read_task.abort();
795 log_task_aborted("read");
796 }
797
798 if !self.write_task.is_finished() {
799 self.write_task.abort();
800 log_task_aborted("write");
801 }
802
803 if let Some(ref handle) = self.heartbeat_task.take()
804 && !handle.is_finished()
805 {
806 handle.abort();
807 log_task_aborted("heartbeat");
808 }
809
810 self.message_handler = None;
812 self.ping_handler = None;
813 }
814}
815
816impl Debug for WebSocketClientInner {
817 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
818 f.debug_struct(stringify!(WebSocketClientInner))
819 .field("config", &self.config)
820 .field(
821 "connection_mode",
822 &ConnectionMode::from_atomic(&self.connection_mode),
823 )
824 .field("reconnect_timeout", &self.reconnect_timeout)
825 .field("is_stream_mode", &self.is_stream_mode)
826 .finish()
827 }
828}
829
830#[cfg_attr(
835 feature = "python",
836 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
837)]
838pub struct WebSocketClient {
839 pub(crate) controller_task: tokio::task::JoinHandle<()>,
840 pub(crate) connection_mode: Arc<AtomicU8>,
841 pub(crate) reconnect_timeout: Duration,
842 pub(crate) rate_limiter: Arc<RateLimiter<Ustr, MonotonicClock>>,
843 pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
844}
845
846impl Debug for WebSocketClient {
847 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
848 f.debug_struct(stringify!(WebSocketClient)).finish()
849 }
850}
851
852impl WebSocketClient {
853 #[allow(clippy::too_many_arguments)]
869 pub async fn connect_stream(
870 config: WebSocketConfig,
871 keyed_quotas: Vec<(String, Quota)>,
872 default_quota: Option<Quota>,
873 post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
874 ) -> Result<(MessageReader, Self), Error> {
875 install_cryptographic_provider();
876
877 let (writer, reader) =
879 WebSocketClientInner::connect_with_server(&config.url, config.headers.clone()).await?;
880
881 let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
883
884 let connection_mode = inner.connection_mode.clone();
885 let reconnect_timeout = inner.reconnect_timeout;
886 let keyed_quotas = keyed_quotas
887 .into_iter()
888 .map(|(key, quota)| (Ustr::from(&key), quota))
889 .collect();
890 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
891 let writer_tx = inner.writer_tx.clone();
892
893 let controller_task =
894 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnect);
895
896 Ok((
897 reader,
898 Self {
899 controller_task,
900 connection_mode,
901 reconnect_timeout,
902 rate_limiter,
903 writer_tx,
904 },
905 ))
906 }
907
908 pub async fn connect(
926 config: WebSocketConfig,
927 message_handler: Option<MessageHandler>,
928 ping_handler: Option<PingHandler>,
929 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
930 keyed_quotas: Vec<(String, Quota)>,
931 default_quota: Option<Quota>,
932 ) -> Result<Self, Error> {
933 if message_handler.is_none() {
935 return Err(Error::Io(std::io::Error::new(
936 std::io::ErrorKind::InvalidInput,
937 "Handler mode requires message_handler to be set. Use connect_stream() for stream mode without a handler.",
938 )));
939 }
940
941 log::debug!("Connecting");
942 let inner =
943 WebSocketClientInner::connect_url(config, message_handler, ping_handler).await?;
944 let connection_mode = inner.connection_mode.clone();
945 let writer_tx = inner.writer_tx.clone();
946 let reconnect_timeout = inner.reconnect_timeout;
947
948 let controller_task =
949 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnection);
950
951 let keyed_quotas = keyed_quotas
952 .into_iter()
953 .map(|(key, quota)| (Ustr::from(&key), quota))
954 .collect();
955 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
956
957 Ok(Self {
958 controller_task,
959 connection_mode,
960 reconnect_timeout,
961 rate_limiter,
962 writer_tx,
963 })
964 }
965
966 #[must_use]
968 pub fn connection_mode(&self) -> ConnectionMode {
969 ConnectionMode::from_atomic(&self.connection_mode)
970 }
971
972 #[must_use]
977 pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
978 Arc::clone(&self.connection_mode)
979 }
980
981 #[inline]
986 #[must_use]
987 pub fn is_active(&self) -> bool {
988 self.connection_mode().is_active()
989 }
990
991 #[must_use]
993 pub fn is_disconnected(&self) -> bool {
994 self.controller_task.is_finished()
995 }
996
997 #[inline]
1002 #[must_use]
1003 pub fn is_reconnecting(&self) -> bool {
1004 self.connection_mode().is_reconnect()
1005 }
1006
1007 #[inline]
1011 #[must_use]
1012 pub fn is_disconnecting(&self) -> bool {
1013 self.connection_mode().is_disconnect()
1014 }
1015
1016 #[inline]
1022 #[must_use]
1023 pub fn is_closed(&self) -> bool {
1024 self.connection_mode().is_closed()
1025 }
1026
1027 async fn wait_for_active(&self) -> Result<(), SendError> {
1031 if self.is_closed() {
1032 return Err(SendError::Closed);
1033 }
1034
1035 let timeout = self.reconnect_timeout;
1036 let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
1037
1038 if !self.is_active() {
1039 log::debug!("Waiting for client to become ACTIVE before sending...");
1040
1041 let inner = tokio::time::timeout(timeout, async {
1042 loop {
1043 if self.is_active() {
1044 return Ok(());
1045 }
1046 if matches!(
1047 self.connection_mode(),
1048 ConnectionMode::Disconnect | ConnectionMode::Closed
1049 ) {
1050 return Err(());
1051 }
1052 tokio::time::sleep(check_interval).await;
1053 }
1054 })
1055 .await
1056 .map_err(|_| SendError::Timeout)?;
1057 inner.map_err(|()| SendError::Closed)?;
1058 }
1059
1060 Ok(())
1061 }
1062
1063 pub async fn disconnect(&self) {
1068 log::debug!("Disconnecting");
1069 self.connection_mode
1070 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1071
1072 if tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1073 while !self.is_disconnected() {
1074 tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
1075 }
1076
1077 if !self.controller_task.is_finished() {
1078 self.controller_task.abort();
1079 log_task_aborted("controller");
1080 }
1081 })
1082 .await
1083 == Ok(())
1084 {
1085 log::debug!("Controller task finished");
1086 } else {
1087 log::error!("Timeout waiting for controller task to finish");
1088 if !self.controller_task.is_finished() {
1089 self.controller_task.abort();
1090 log_task_aborted("controller");
1091 }
1092 }
1093 }
1094
1095 #[allow(unused_variables)]
1101 pub async fn send_text(&self, data: String, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1102 if self.is_closed() || self.is_disconnecting() {
1104 return Err(SendError::Closed);
1105 }
1106
1107 self.rate_limiter.await_keys_ready(keys).await;
1108 self.wait_for_active().await?;
1109
1110 log::trace!("Sending text: {data:?}");
1111
1112 let msg = Message::Text(data.into());
1113 self.writer_tx
1114 .send(WriterCommand::Send(msg))
1115 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1116 }
1117
1118 pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1124 self.wait_for_active().await?;
1125
1126 log::trace!("Sending pong frame ({} bytes)", data.len());
1127
1128 let msg = Message::Pong(data.into());
1129 self.writer_tx
1130 .send(WriterCommand::Send(msg))
1131 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1132 }
1133
1134 #[allow(unused_variables)]
1140 pub async fn send_bytes(&self, data: Vec<u8>, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1141 if self.is_closed() || self.is_disconnecting() {
1143 return Err(SendError::Closed);
1144 }
1145
1146 self.rate_limiter.await_keys_ready(keys).await;
1147 self.wait_for_active().await?;
1148
1149 log::trace!("Sending bytes: {data:?}");
1150
1151 let msg = Message::Binary(data.into());
1152 self.writer_tx
1153 .send(WriterCommand::Send(msg))
1154 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1155 }
1156
1157 pub async fn send_close_message(&self) -> Result<(), SendError> {
1163 self.wait_for_active().await?;
1164
1165 let msg = Message::Close(None);
1166 self.writer_tx
1167 .send(WriterCommand::Send(msg))
1168 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1169 }
1170
1171 fn spawn_controller_task(
1172 mut inner: WebSocketClientInner,
1173 connection_mode: Arc<AtomicU8>,
1174 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1175 ) -> tokio::task::JoinHandle<()> {
1176 tokio::task::spawn(async move {
1177 log_task_started("controller");
1178
1179 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1180
1181 loop {
1182 tokio::time::sleep(check_interval).await;
1183 let mut mode = ConnectionMode::from_atomic(&connection_mode);
1184
1185 if mode.is_disconnect() {
1186 log::debug!("Disconnecting");
1187
1188 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1189 if tokio::time::timeout(timeout, async {
1190 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1192
1193 if let Some(task) = &inner.read_task
1194 && !task.is_finished()
1195 {
1196 task.abort();
1197 log_task_aborted("read");
1198 }
1199
1200 if let Some(task) = &inner.heartbeat_task
1201 && !task.is_finished()
1202 {
1203 task.abort();
1204 log_task_aborted("heartbeat");
1205 }
1206 })
1207 .await
1208 .is_err()
1209 {
1210 log::error!("Shutdown timed out after {}s", timeout.as_secs());
1211 }
1212
1213 log::debug!("Closed");
1214 break; }
1216
1217 if mode.is_closed() {
1218 log::debug!("Connection closed");
1219 break;
1220 }
1221
1222 if mode.is_active() && !inner.is_alive() {
1223 if connection_mode
1224 .compare_exchange(
1225 ConnectionMode::Active.as_u8(),
1226 ConnectionMode::Reconnect.as_u8(),
1227 Ordering::SeqCst,
1228 Ordering::SeqCst,
1229 )
1230 .is_ok()
1231 {
1232 log::debug!("Detected dead read task, transitioning to RECONNECT");
1233 }
1234 mode = ConnectionMode::from_atomic(&connection_mode);
1235 }
1236
1237 if mode.is_reconnect() {
1238 if let Some(max_attempts) = inner.reconnect_max_attempts
1240 && inner.reconnection_attempt_count >= max_attempts
1241 {
1242 log::error!(
1243 "Max reconnection attempts ({max_attempts}) exceeded, transitioning to CLOSED"
1244 );
1245 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1246 break;
1247 }
1248
1249 inner.reconnection_attempt_count += 1;
1250 log::debug!(
1251 "Reconnection attempt {} of {}",
1252 inner.reconnection_attempt_count,
1253 inner
1254 .reconnect_max_attempts
1255 .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1256 );
1257
1258 match inner.reconnect().await {
1259 Ok(()) => {
1260 inner.backoff.reset();
1261 inner.reconnection_attempt_count = 0; if ConnectionMode::from_atomic(&connection_mode).is_active() {
1265 if let Some(ref handler) = inner.message_handler {
1266 let reconnected_msg =
1267 Message::Text(RECONNECTED.to_string().into());
1268 handler(reconnected_msg);
1269 log::debug!("Sent reconnected message to handler");
1270 }
1271
1272 if let Some(ref callback) = post_reconnection {
1274 callback();
1275 log::debug!("Called `post_reconnection` handler");
1276 }
1277
1278 log::debug!("Reconnected successfully");
1279 } else {
1280 log::debug!(
1281 "Skipping post_reconnection handlers due to disconnect state"
1282 );
1283 }
1284 }
1285 Err(e) => {
1286 let duration = inner.backoff.next_duration();
1287 log::warn!(
1288 "Reconnect attempt {} failed: {e}",
1289 inner.reconnection_attempt_count
1290 );
1291 if !duration.is_zero() {
1292 log::warn!("Backing off for {}s...", duration.as_secs_f64());
1293 }
1294 tokio::time::sleep(duration).await;
1295 }
1296 }
1297 }
1298 }
1299 inner
1300 .connection_mode
1301 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1302
1303 log_task_stopped("controller");
1304 })
1305 }
1306}
1307
1308impl Drop for WebSocketClient {
1310 fn drop(&mut self) {
1311 if !self.controller_task.is_finished() {
1312 self.controller_task.abort();
1313 log_task_aborted("controller");
1314 }
1315 }
1316}
1317
1318#[cfg(test)]
1319#[cfg(not(feature = "turmoil"))]
1320#[cfg(target_os = "linux")] mod tests {
1322 use std::{num::NonZeroU32, sync::Arc};
1323
1324 use futures_util::{SinkExt, StreamExt};
1325 use tokio::{
1326 net::TcpListener,
1327 task::{self, JoinHandle},
1328 };
1329 use tokio_tungstenite::{
1330 accept_hdr_async,
1331 tungstenite::{
1332 handshake::server::{self, Callback},
1333 http::HeaderValue,
1334 },
1335 };
1336
1337 use crate::{
1338 ratelimiter::quota::Quota,
1339 websocket::{WebSocketClient, WebSocketConfig},
1340 };
1341
1342 struct TestServer {
1343 task: JoinHandle<()>,
1344 port: u16,
1345 }
1346
1347 #[derive(Debug, Clone)]
1348 struct TestCallback {
1349 key: String,
1350 value: HeaderValue,
1351 }
1352
1353 impl Callback for TestCallback {
1354 #[allow(clippy::panic_in_result_fn)]
1355 fn on_request(
1356 self,
1357 request: &server::Request,
1358 response: server::Response,
1359 ) -> Result<server::Response, server::ErrorResponse> {
1360 let _ = response;
1361 let value = request.headers().get(&self.key);
1362 assert!(value.is_some());
1363
1364 if let Some(value) = request.headers().get(&self.key) {
1365 assert_eq!(value, self.value);
1366 }
1367
1368 Ok(response)
1369 }
1370 }
1371
1372 impl TestServer {
1373 async fn setup() -> Self {
1374 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1375 let port = TcpListener::local_addr(&server).unwrap().port();
1376
1377 let header_key = "test".to_string();
1378 let header_value = "test".to_string();
1379
1380 let test_call_back = TestCallback {
1381 key: header_key,
1382 value: HeaderValue::from_str(&header_value).unwrap(),
1383 };
1384
1385 let task = task::spawn(async move {
1386 loop {
1388 let (conn, _) = server.accept().await.unwrap();
1389 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1390 .await
1391 .unwrap();
1392
1393 task::spawn(async move {
1394 while let Some(Ok(msg)) = websocket.next().await {
1395 match msg {
1396 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1397 if txt == "close-now" =>
1398 {
1399 log::debug!("Forcibly closing from server side");
1400 let _ = websocket.close(None).await;
1402 break;
1403 }
1404 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1406 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1407 if websocket.send(msg).await.is_err() {
1408 break;
1409 }
1410 }
1411 tokio_tungstenite::tungstenite::protocol::Message::Close(
1413 _frame,
1414 ) => {
1415 let _ = websocket.close(None).await;
1416 break;
1417 }
1418 _ => {}
1420 }
1421 }
1422 });
1423 }
1424 });
1425
1426 Self { task, port }
1427 }
1428 }
1429
1430 impl Drop for TestServer {
1431 fn drop(&mut self) {
1432 self.task.abort();
1433 }
1434 }
1435
1436 async fn setup_test_client(port: u16) -> WebSocketClient {
1437 let config = WebSocketConfig {
1438 url: format!("ws://127.0.0.1:{port}"),
1439 headers: vec![("test".into(), "test".into())],
1440 heartbeat: None,
1441 heartbeat_msg: None,
1442 reconnect_timeout_ms: None,
1443 reconnect_delay_initial_ms: None,
1444 reconnect_backoff_factor: None,
1445 reconnect_delay_max_ms: None,
1446 reconnect_jitter_ms: None,
1447 reconnect_max_attempts: None,
1448 };
1449 WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1450 .await
1451 .expect("Failed to connect")
1452 }
1453
1454 #[tokio::test]
1455 async fn test_websocket_basic() {
1456 let server = TestServer::setup().await;
1457 let client = setup_test_client(server.port).await;
1458
1459 assert!(!client.is_disconnected());
1460
1461 client.disconnect().await;
1462 assert!(client.is_disconnected());
1463 }
1464
1465 #[tokio::test]
1466 async fn test_websocket_heartbeat() {
1467 let server = TestServer::setup().await;
1468 let client = setup_test_client(server.port).await;
1469
1470 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1472
1473 client.disconnect().await;
1475 assert!(client.is_disconnected());
1476 }
1477
1478 #[tokio::test]
1479 async fn test_websocket_reconnect_exhausted() {
1480 let config = WebSocketConfig {
1481 url: "ws://127.0.0.1:9997".into(), headers: vec![],
1483 heartbeat: None,
1484 heartbeat_msg: None,
1485 reconnect_timeout_ms: None,
1486 reconnect_delay_initial_ms: None,
1487 reconnect_backoff_factor: None,
1488 reconnect_delay_max_ms: None,
1489 reconnect_jitter_ms: None,
1490 reconnect_max_attempts: None,
1491 };
1492 let res =
1493 WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1494 .await;
1495 assert!(res.is_err(), "Should fail quickly with no server");
1496 }
1497
1498 #[tokio::test]
1499 async fn test_websocket_forced_close_reconnect() {
1500 let server = TestServer::setup().await;
1501 let client = setup_test_client(server.port).await;
1502
1503 client.send_text("Hello".into(), None).await.unwrap();
1505
1506 client.send_text("close-now".into(), None).await.unwrap();
1508
1509 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1511
1512 assert!(!client.is_disconnected());
1514
1515 client.disconnect().await;
1517 assert!(client.is_disconnected());
1518 }
1519
1520 #[tokio::test]
1521 async fn test_rate_limiter() {
1522 let server = TestServer::setup().await;
1523 let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1524
1525 let config = WebSocketConfig {
1526 url: format!("ws://127.0.0.1:{}", server.port),
1527 headers: vec![("test".into(), "test".into())],
1528 heartbeat: None,
1529 heartbeat_msg: None,
1530 reconnect_timeout_ms: None,
1531 reconnect_delay_initial_ms: None,
1532 reconnect_backoff_factor: None,
1533 reconnect_delay_max_ms: None,
1534 reconnect_jitter_ms: None,
1535 reconnect_max_attempts: None,
1536 };
1537
1538 let client = WebSocketClient::connect(
1539 config,
1540 Some(Arc::new(|_| {})),
1541 None,
1542 None,
1543 vec![("default".into(), quota)],
1544 None,
1545 )
1546 .await
1547 .unwrap();
1548
1549 client.send_text("test1".into(), None).await.unwrap();
1551 client.send_text("test2".into(), None).await.unwrap();
1552
1553 client.send_text("test3".into(), None).await.unwrap();
1555
1556 client.disconnect().await;
1558 assert!(client.is_disconnected());
1559 }
1560
1561 #[tokio::test]
1562 async fn test_concurrent_writers() {
1563 let server = TestServer::setup().await;
1564 let client = Arc::new(setup_test_client(server.port).await);
1565
1566 let mut handles = vec![];
1567 for i in 0..10 {
1568 let client = client.clone();
1569 handles.push(task::spawn(async move {
1570 client.send_text(format!("test{i}"), None).await.unwrap();
1571 }));
1572 }
1573
1574 for handle in handles {
1575 handle.await.unwrap();
1576 }
1577
1578 client.disconnect().await;
1580 assert!(client.is_disconnected());
1581 }
1582}
1583
1584#[cfg(test)]
1585#[cfg(not(feature = "turmoil"))]
1586mod rust_tests {
1587 use futures_util::StreamExt;
1588 use rstest::rstest;
1589 use tokio::{
1590 net::TcpListener,
1591 task,
1592 time::{Duration, sleep},
1593 };
1594 use tokio_tungstenite::accept_async;
1595
1596 use super::*;
1597 use crate::websocket::types::channel_message_handler;
1598
1599 #[rstest]
1600 #[tokio::test]
1601 async fn test_reconnect_then_disconnect() {
1602 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1604 let port = listener.local_addr().unwrap().port();
1605
1606 let server = task::spawn(async move {
1608 let (stream, _) = listener.accept().await.unwrap();
1609 let ws = accept_async(stream).await.unwrap();
1610 drop(ws);
1611 sleep(Duration::from_secs(1)).await;
1613 });
1614
1615 let (handler, _rx) = channel_message_handler();
1617
1618 let config = WebSocketConfig {
1620 url: format!("ws://127.0.0.1:{port}"),
1621 headers: vec![],
1622 heartbeat: None,
1623 heartbeat_msg: None,
1624 reconnect_timeout_ms: Some(1_000),
1625 reconnect_delay_initial_ms: Some(50),
1626 reconnect_delay_max_ms: Some(100),
1627 reconnect_backoff_factor: Some(1.0),
1628 reconnect_jitter_ms: Some(0),
1629 reconnect_max_attempts: None,
1630 };
1631
1632 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1634 .await
1635 .unwrap();
1636
1637 sleep(Duration::from_millis(100)).await;
1639 client.disconnect().await;
1641 assert!(client.is_disconnected());
1642 server.abort();
1643 }
1644
1645 #[rstest]
1646 #[tokio::test]
1647 async fn test_reconnect_state_flips_when_reader_stops() {
1648 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1650 let port = listener.local_addr().unwrap().port();
1651
1652 let server = task::spawn(async move {
1653 if let Ok((stream, _)) = listener.accept().await
1654 && let Ok(ws) = accept_async(stream).await
1655 {
1656 drop(ws);
1657 }
1658 sleep(Duration::from_millis(50)).await;
1659 });
1660
1661 let (handler, _rx) = channel_message_handler();
1662
1663 let config = WebSocketConfig {
1664 url: format!("ws://127.0.0.1:{port}"),
1665 headers: vec![],
1666 heartbeat: None,
1667 heartbeat_msg: None,
1668 reconnect_timeout_ms: Some(1_000),
1669 reconnect_delay_initial_ms: Some(50),
1670 reconnect_delay_max_ms: Some(100),
1671 reconnect_backoff_factor: Some(1.0),
1672 reconnect_jitter_ms: Some(0),
1673 reconnect_max_attempts: None,
1674 };
1675
1676 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1677 .await
1678 .unwrap();
1679
1680 tokio::time::timeout(Duration::from_secs(2), async {
1681 loop {
1682 if client.is_reconnecting() {
1683 break;
1684 }
1685 tokio::time::sleep(Duration::from_millis(10)).await;
1686 }
1687 })
1688 .await
1689 .expect("client did not enter RECONNECT state");
1690
1691 client.disconnect().await;
1692 server.abort();
1693 }
1694
1695 #[rstest]
1696 #[tokio::test]
1697 async fn test_stream_mode_disables_auto_reconnect() {
1698 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1701 let port = listener.local_addr().unwrap().port();
1702
1703 let server = task::spawn(async move {
1704 if let Ok((stream, _)) = listener.accept().await
1705 && let Ok(_ws) = accept_async(stream).await
1706 {
1707 sleep(Duration::from_millis(100)).await;
1709 }
1710 });
1711
1712 let config = WebSocketConfig {
1713 url: format!("ws://127.0.0.1:{port}"),
1714 headers: vec![],
1715 heartbeat: None,
1716 heartbeat_msg: None,
1717 reconnect_timeout_ms: Some(1_000),
1718 reconnect_delay_initial_ms: Some(50),
1719 reconnect_delay_max_ms: Some(100),
1720 reconnect_backoff_factor: Some(1.0),
1721 reconnect_jitter_ms: Some(0),
1722 reconnect_max_attempts: None,
1723 };
1724
1725 let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
1726 .await
1727 .unwrap();
1728
1729 server.abort();
1737 }
1738
1739 #[rstest]
1740 #[tokio::test]
1741 async fn test_message_handler_mode_allows_auto_reconnect() {
1742 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1744 let port = listener.local_addr().unwrap().port();
1745
1746 let server = task::spawn(async move {
1747 if let Ok((stream, _)) = listener.accept().await
1749 && let Ok(ws) = accept_async(stream).await
1750 {
1751 drop(ws);
1752 }
1753 sleep(Duration::from_millis(50)).await;
1754 });
1755
1756 let (handler, _rx) = channel_message_handler();
1757
1758 let config = WebSocketConfig {
1759 url: format!("ws://127.0.0.1:{port}"),
1760 headers: vec![],
1761 heartbeat: None,
1762 heartbeat_msg: None,
1763 reconnect_timeout_ms: Some(1_000),
1764 reconnect_delay_initial_ms: Some(50),
1765 reconnect_delay_max_ms: Some(100),
1766 reconnect_backoff_factor: Some(1.0),
1767 reconnect_jitter_ms: Some(0),
1768 reconnect_max_attempts: None,
1769 };
1770
1771 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1772 .await
1773 .unwrap();
1774
1775 tokio::time::timeout(Duration::from_secs(2), async {
1777 loop {
1778 if client.is_reconnecting() || client.is_closed() {
1779 break;
1780 }
1781 tokio::time::sleep(Duration::from_millis(10)).await;
1782 }
1783 })
1784 .await
1785 .expect("client should attempt reconnection or close");
1786
1787 assert!(
1790 client.is_reconnecting() || client.is_closed(),
1791 "Client with message handler should attempt reconnection"
1792 );
1793
1794 client.disconnect().await;
1795 server.abort();
1796 }
1797
1798 #[rstest]
1799 #[tokio::test]
1800 async fn test_handler_mode_reconnect_with_new_connection() {
1801 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1803 let port = listener.local_addr().unwrap().port();
1804
1805 let server = task::spawn(async move {
1806 if let Ok((stream, _)) = listener.accept().await
1808 && let Ok(ws) = accept_async(stream).await
1809 {
1810 drop(ws);
1811 }
1812
1813 sleep(Duration::from_millis(100)).await;
1815
1816 if let Ok((stream, _)) = listener.accept().await
1818 && let Ok(mut ws) = accept_async(stream).await
1819 {
1820 use futures_util::SinkExt;
1821 let _ = ws
1822 .send(Message::Text("reconnected".to_string().into()))
1823 .await;
1824 sleep(Duration::from_secs(1)).await;
1825 }
1826 });
1827
1828 let (handler, mut rx) = channel_message_handler();
1829
1830 let config = WebSocketConfig {
1831 url: format!("ws://127.0.0.1:{port}"),
1832 headers: vec![],
1833 heartbeat: None,
1834 heartbeat_msg: None,
1835 reconnect_timeout_ms: Some(2_000),
1836 reconnect_delay_initial_ms: Some(50),
1837 reconnect_delay_max_ms: Some(200),
1838 reconnect_backoff_factor: Some(1.5),
1839 reconnect_jitter_ms: Some(10),
1840 reconnect_max_attempts: None,
1841 };
1842
1843 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1844 .await
1845 .unwrap();
1846
1847 let result = tokio::time::timeout(Duration::from_secs(5), async {
1849 loop {
1850 if let Ok(msg) = rx.try_recv()
1851 && matches!(msg, Message::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
1852 {
1853 return true;
1854 }
1855 tokio::time::sleep(Duration::from_millis(10)).await;
1856 }
1857 })
1858 .await;
1859
1860 assert!(
1861 result.is_ok(),
1862 "Should receive message after reconnection within timeout"
1863 );
1864
1865 client.disconnect().await;
1866 server.abort();
1867 }
1868
1869 #[rstest]
1870 #[tokio::test]
1871 async fn test_stream_mode_no_auto_reconnect() {
1872 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1875 let port = listener.local_addr().unwrap().port();
1876
1877 let server = task::spawn(async move {
1878 if let Ok((stream, _)) = listener.accept().await
1880 && let Ok(mut ws) = accept_async(stream).await
1881 {
1882 use futures_util::SinkExt;
1883 let _ = ws.send(Message::Text("hello".to_string().into())).await;
1884 sleep(Duration::from_millis(50)).await;
1885 }
1887 });
1888
1889 let config = WebSocketConfig {
1890 url: format!("ws://127.0.0.1:{port}"),
1891 headers: vec![],
1892 heartbeat: None,
1893 heartbeat_msg: None,
1894 reconnect_timeout_ms: Some(1_000),
1895 reconnect_delay_initial_ms: Some(50),
1896 reconnect_delay_max_ms: Some(100),
1897 reconnect_backoff_factor: Some(1.0),
1898 reconnect_jitter_ms: Some(0),
1899 reconnect_max_attempts: None,
1900 };
1901
1902 let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
1903 .await
1904 .unwrap();
1905
1906 assert!(client.is_active(), "Client should start as active");
1908
1909 let msg = reader.next().await;
1911 assert!(
1912 matches!(msg, Some(Ok(Message::Text(ref text))) if AsRef::<str>::as_ref(text) == "hello"),
1913 "Should receive initial message"
1914 );
1915
1916 while let Some(msg) = reader.next().await {
1918 if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
1919 break;
1920 }
1921 }
1922
1923 sleep(Duration::from_millis(200)).await;
1926
1927 assert!(
1930 client.is_active() || client.is_closed(),
1931 "Stream mode client stays ACTIVE (caller owns reader) or caller disconnected"
1932 );
1933 assert!(
1934 !client.is_reconnecting(),
1935 "Stream mode client should never attempt reconnection"
1936 );
1937
1938 client.disconnect().await;
1939 server.abort();
1940 }
1941
1942 #[rstest]
1943 #[tokio::test]
1944 async fn test_send_timeout_uses_configured_reconnect_timeout() {
1945 use nautilus_common::testing::wait_until_async;
1948
1949 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1950 let port = listener.local_addr().unwrap().port();
1951
1952 let server = task::spawn(async move {
1953 if let Ok((stream, _)) = listener.accept().await
1955 && let Ok(ws) = accept_async(stream).await
1956 {
1957 drop(ws);
1958 }
1959 sleep(Duration::from_secs(60)).await;
1961 });
1962
1963 let (handler, _rx) = channel_message_handler();
1964
1965 let config = WebSocketConfig {
1967 url: format!("ws://127.0.0.1:{port}"),
1968 headers: vec![],
1969 heartbeat: None,
1970 heartbeat_msg: None,
1971 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(50),
1973 reconnect_delay_max_ms: Some(100),
1974 reconnect_backoff_factor: Some(1.0),
1975 reconnect_jitter_ms: Some(0),
1976 reconnect_max_attempts: None,
1977 };
1978
1979 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1980 .await
1981 .unwrap();
1982
1983 wait_until_async(
1985 || async { client.is_reconnecting() },
1986 Duration::from_secs(3),
1987 )
1988 .await;
1989
1990 let start = std::time::Instant::now();
1992 let send_result = client.send_text("test".to_string(), None).await;
1993 let elapsed = start.elapsed();
1994
1995 assert!(
1996 send_result.is_err(),
1997 "Send should fail when client stuck in RECONNECT"
1998 );
1999 assert!(
2000 matches!(send_result, Err(crate::error::SendError::Timeout)),
2001 "Send should return Timeout error, was: {send_result:?}"
2002 );
2003 assert!(
2006 elapsed >= Duration::from_millis(1800),
2007 "Send should timeout after at least 2s (configured timeout), took {elapsed:?}"
2008 );
2009
2010 client.disconnect().await;
2011 server.abort();
2012 }
2013
2014 #[rstest]
2015 #[tokio::test]
2016 async fn test_send_waits_during_reconnection() {
2017 use nautilus_common::testing::wait_until_async;
2019
2020 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2021 let port = listener.local_addr().unwrap().port();
2022
2023 let server = task::spawn(async move {
2024 if let Ok((stream, _)) = listener.accept().await
2026 && let Ok(ws) = accept_async(stream).await
2027 {
2028 drop(ws);
2029 }
2030
2031 sleep(Duration::from_millis(500)).await;
2033
2034 if let Ok((stream, _)) = listener.accept().await
2036 && let Ok(mut ws) = accept_async(stream).await
2037 {
2038 while let Some(Ok(msg)) = ws.next().await {
2040 if ws.send(msg).await.is_err() {
2041 break;
2042 }
2043 }
2044 }
2045 });
2046
2047 let (handler, _rx) = channel_message_handler();
2048
2049 let config = WebSocketConfig {
2050 url: format!("ws://127.0.0.1:{port}"),
2051 headers: vec![],
2052 heartbeat: None,
2053 heartbeat_msg: None,
2054 reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
2056 reconnect_delay_max_ms: Some(200),
2057 reconnect_backoff_factor: Some(1.0),
2058 reconnect_jitter_ms: Some(0),
2059 reconnect_max_attempts: None,
2060 };
2061
2062 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2063 .await
2064 .unwrap();
2065
2066 wait_until_async(
2068 || async { client.is_reconnecting() },
2069 Duration::from_secs(2),
2070 )
2071 .await;
2072
2073 let send_result = tokio::time::timeout(
2075 Duration::from_secs(3),
2076 client.send_text("test_message".to_string(), None),
2077 )
2078 .await;
2079
2080 assert!(
2081 send_result.is_ok() && send_result.unwrap().is_ok(),
2082 "Send should succeed after waiting for reconnection"
2083 );
2084
2085 client.disconnect().await;
2086 server.abort();
2087 }
2088
2089 #[rstest]
2090 #[tokio::test]
2091 async fn test_rate_limiter_before_active_wait() {
2092 use std::{num::NonZeroU32, sync::Arc};
2097
2098 use nautilus_common::testing::wait_until_async;
2099
2100 use crate::ratelimiter::quota::Quota;
2101
2102 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2103 let port = listener.local_addr().unwrap().port();
2104
2105 let server = task::spawn(async move {
2106 if let Ok((stream, _)) = listener.accept().await
2108 && let Ok(mut ws) = accept_async(stream).await
2109 {
2110 if let Some(Ok(_)) = ws.next().await {
2112 drop(ws);
2113 }
2114 }
2115
2116 sleep(Duration::from_millis(500)).await;
2118
2119 if let Ok((stream, _)) = listener.accept().await
2121 && let Ok(mut ws) = accept_async(stream).await
2122 {
2123 while let Some(Ok(msg)) = ws.next().await {
2124 if ws.send(msg).await.is_err() {
2125 break;
2126 }
2127 }
2128 }
2129 });
2130
2131 let (handler, _rx) = channel_message_handler();
2132
2133 let config = WebSocketConfig {
2134 url: format!("ws://127.0.0.1:{port}"),
2135 headers: vec![],
2136 heartbeat: None,
2137 heartbeat_msg: None,
2138 reconnect_timeout_ms: Some(5_000),
2139 reconnect_delay_initial_ms: Some(50),
2140 reconnect_delay_max_ms: Some(100),
2141 reconnect_backoff_factor: Some(1.0),
2142 reconnect_jitter_ms: Some(0),
2143 reconnect_max_attempts: None,
2144 };
2145
2146 let quota =
2148 Quota::per_second(NonZeroU32::new(1).unwrap()).allow_burst(NonZeroU32::new(1).unwrap());
2149
2150 let client = Arc::new(
2151 WebSocketClient::connect(
2152 config,
2153 Some(handler),
2154 None,
2155 None,
2156 vec![("test_key".to_string(), quota)],
2157 None,
2158 )
2159 .await
2160 .unwrap(),
2161 );
2162
2163 let test_key: [Ustr; 1] = [Ustr::from("test_key")];
2165 client
2166 .send_text("msg1".to_string(), Some(test_key.as_slice()))
2167 .await
2168 .unwrap();
2169
2170 wait_until_async(
2172 || async { client.is_reconnecting() },
2173 Duration::from_secs(2),
2174 )
2175 .await;
2176
2177 let start = std::time::Instant::now();
2179 let send_result = client
2180 .send_text("msg2".to_string(), Some(test_key.as_slice()))
2181 .await;
2182 let elapsed = start.elapsed();
2183
2184 assert!(
2186 send_result.is_ok(),
2187 "Send should succeed after rate limit + reconnection, was: {send_result:?}"
2188 );
2189 assert!(
2193 elapsed >= Duration::from_millis(850),
2194 "Should wait for rate limit (~1s), waited {elapsed:?}"
2195 );
2196
2197 client.disconnect().await;
2198 server.abort();
2199 }
2200
2201 #[rstest]
2202 #[tokio::test]
2203 async fn test_disconnect_during_reconnect_exits_cleanly() {
2204 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2207 let port = listener.local_addr().unwrap().port();
2208
2209 let server = task::spawn(async move {
2210 if let Ok((stream, _)) = listener.accept().await
2212 && let Ok(ws) = accept_async(stream).await
2213 {
2214 drop(ws);
2215 }
2216 sleep(Duration::from_secs(60)).await;
2218 });
2219
2220 let (handler, _rx) = channel_message_handler();
2221
2222 let config = WebSocketConfig {
2223 url: format!("ws://127.0.0.1:{port}"),
2224 headers: vec![],
2225 heartbeat: None,
2226 heartbeat_msg: None,
2227 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(100),
2229 reconnect_delay_max_ms: Some(200),
2230 reconnect_backoff_factor: Some(1.0),
2231 reconnect_jitter_ms: Some(0),
2232 reconnect_max_attempts: None,
2233 };
2234
2235 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2236 .await
2237 .unwrap();
2238
2239 tokio::time::timeout(Duration::from_secs(2), async {
2241 while !client.is_reconnecting() {
2242 sleep(Duration::from_millis(10)).await;
2243 }
2244 })
2245 .await
2246 .expect("Client should enter RECONNECT state");
2247
2248 client.disconnect().await;
2250
2251 assert!(
2253 client.is_disconnected(),
2254 "Client should be cleanly disconnected"
2255 );
2256
2257 server.abort();
2258 }
2259
2260 #[rstest]
2261 #[tokio::test]
2262 async fn test_send_fails_fast_when_closed_before_rate_limit() {
2263 use std::{num::NonZeroU32, sync::Arc};
2266
2267 use nautilus_common::testing::wait_until_async;
2268
2269 use crate::ratelimiter::quota::Quota;
2270
2271 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2272 let port = listener.local_addr().unwrap().port();
2273
2274 let server = task::spawn(async move {
2275 if let Ok((stream, _)) = listener.accept().await
2277 && let Ok(ws) = accept_async(stream).await
2278 {
2279 drop(ws);
2280 }
2281 sleep(Duration::from_secs(60)).await;
2282 });
2283
2284 let (handler, _rx) = channel_message_handler();
2285
2286 let config = WebSocketConfig {
2287 url: format!("ws://127.0.0.1:{port}"),
2288 headers: vec![],
2289 heartbeat: None,
2290 heartbeat_msg: None,
2291 reconnect_timeout_ms: Some(5_000),
2292 reconnect_delay_initial_ms: Some(50),
2293 reconnect_delay_max_ms: Some(100),
2294 reconnect_backoff_factor: Some(1.0),
2295 reconnect_jitter_ms: Some(0),
2296 reconnect_max_attempts: None,
2297 };
2298
2299 let quota = Quota::with_period(Duration::from_secs(10))
2302 .unwrap()
2303 .allow_burst(NonZeroU32::new(1).unwrap());
2304
2305 let client = Arc::new(
2306 WebSocketClient::connect(
2307 config,
2308 Some(handler),
2309 None,
2310 None,
2311 vec![("test_key".to_string(), quota)],
2312 None,
2313 )
2314 .await
2315 .unwrap(),
2316 );
2317
2318 wait_until_async(
2320 || async { client.is_reconnecting() || client.is_closed() },
2321 Duration::from_secs(2),
2322 )
2323 .await;
2324
2325 client.disconnect().await;
2327 assert!(
2328 !client.is_active(),
2329 "Client should not be active after disconnect"
2330 );
2331
2332 let start = std::time::Instant::now();
2334 let test_key: [Ustr; 1] = [Ustr::from("test_key")];
2335 let result = client
2336 .send_text("test".to_string(), Some(test_key.as_slice()))
2337 .await;
2338 let elapsed = start.elapsed();
2339
2340 assert!(result.is_err(), "Send should fail when client is closed");
2342 assert!(
2343 matches!(result, Err(crate::error::SendError::Closed)),
2344 "Send should return Closed error, was: {result:?}"
2345 );
2346
2347 assert!(
2349 elapsed < Duration::from_millis(100),
2350 "Send should fail fast without rate limiting, took {elapsed:?}"
2351 );
2352
2353 server.abort();
2354 }
2355
2356 #[rstest]
2357 #[tokio::test]
2358 async fn test_connect_rejects_none_message_handler() {
2359 let config = WebSocketConfig {
2363 url: "ws://127.0.0.1:9999".to_string(),
2364 headers: vec![],
2365 heartbeat: None,
2366 heartbeat_msg: None,
2367 reconnect_timeout_ms: Some(1_000),
2368 reconnect_delay_initial_ms: Some(100),
2369 reconnect_delay_max_ms: Some(500),
2370 reconnect_backoff_factor: Some(1.5),
2371 reconnect_jitter_ms: Some(0),
2372 reconnect_max_attempts: None,
2373 };
2374
2375 let result = WebSocketClient::connect(config, None, None, None, vec![], None).await;
2377
2378 assert!(
2379 result.is_err(),
2380 "connect() should reject None message_handler"
2381 );
2382
2383 let err = result.unwrap_err();
2384 let err_msg = err.to_string();
2385 assert!(
2386 err_msg.contains("Handler mode requires message_handler"),
2387 "Error should mention missing message_handler, was: {err_msg}"
2388 );
2389 }
2390
2391 #[rstest]
2392 #[tokio::test]
2393 async fn test_client_without_handler_sets_stream_mode() {
2394 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2398 let port = listener.local_addr().unwrap().port();
2399
2400 let server = task::spawn(async move {
2401 if let Ok((stream, _)) = listener.accept().await
2403 && let Ok(ws) = accept_async(stream).await
2404 {
2405 drop(ws); }
2407 });
2408
2409 let config = WebSocketConfig {
2410 url: format!("ws://127.0.0.1:{port}"),
2411 headers: vec![],
2412 heartbeat: None,
2413 heartbeat_msg: None,
2414 reconnect_timeout_ms: Some(1_000),
2415 reconnect_delay_initial_ms: Some(100),
2416 reconnect_delay_max_ms: Some(500),
2417 reconnect_backoff_factor: Some(1.5),
2418 reconnect_jitter_ms: Some(0),
2419 reconnect_max_attempts: None,
2420 };
2421
2422 let inner = WebSocketClientInner::connect_url(config, None, None)
2424 .await
2425 .unwrap();
2426
2427 assert!(
2429 inner.is_stream_mode,
2430 "Client without handler should have is_stream_mode=true"
2431 );
2432
2433 server.abort();
2437 }
2438}