1use std::{
26 collections::VecDeque,
27 fmt::Debug,
28 sync::{
29 Arc,
30 atomic::{AtomicU8, Ordering},
31 },
32 time::Duration,
33};
34
35use futures_util::{SinkExt, StreamExt};
36use http::HeaderName;
37use nautilus_core::CleanDrop;
38use nautilus_cryptography::providers::install_cryptographic_provider;
39#[cfg(feature = "turmoil")]
40use tokio_tungstenite::MaybeTlsStream;
41#[cfg(feature = "turmoil")]
42use tokio_tungstenite::client_async;
43#[cfg(not(feature = "turmoil"))]
44use tokio_tungstenite::connect_async_with_config;
45use tokio_tungstenite::tungstenite::{
46 Error, Message, client::IntoClientRequest, http::HeaderValue,
47};
48
49use super::{
50 config::WebSocketConfig,
51 consts::{
52 CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
53 GRACEFUL_SHUTDOWN_TIMEOUT_SECS, SEND_OPERATION_CHECK_INTERVAL_MS,
54 },
55 types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
56};
57#[cfg(feature = "turmoil")]
58use crate::net::TcpConnector;
59use crate::{
60 RECONNECTED,
61 backoff::ExponentialBackoff,
62 error::SendError,
63 logging::{log_task_aborted, log_task_started, log_task_stopped},
64 mode::ConnectionMode,
65 ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
66};
67
68pub struct WebSocketClientInner {
84 config: WebSocketConfig,
85 message_handler: Option<MessageHandler>,
87 ping_handler: Option<PingHandler>,
89 read_task: Option<tokio::task::JoinHandle<()>>,
90 write_task: tokio::task::JoinHandle<()>,
91 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
92 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
93 connection_mode: Arc<AtomicU8>,
94 reconnect_timeout: Duration,
95 backoff: ExponentialBackoff,
96 is_stream_mode: bool,
100 reconnect_max_attempts: Option<u32>,
102 reconnection_attempt_count: u32,
104}
105
106impl WebSocketClientInner {
107 pub async fn new_with_writer(
115 config: WebSocketConfig,
116 writer: MessageWriter,
117 ) -> Result<Self, Error> {
118 install_cryptographic_provider();
119
120 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
121
122 let read_task = None;
124
125 let backoff = ExponentialBackoff::new(
126 Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
127 Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
128 config.reconnect_backoff_factor.unwrap_or(1.5),
129 config.reconnect_jitter_ms.unwrap_or(100),
130 true, )
132 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
133
134 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
135 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
136
137 let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
138 Some(Self::spawn_heartbeat_task(
139 connection_mode.clone(),
140 heartbeat_interval,
141 config.heartbeat_msg.clone(),
142 writer_tx.clone(),
143 ))
144 } else {
145 None
146 };
147
148 let reconnect_max_attempts = config.reconnect_max_attempts;
149 let reconnect_timeout = Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10000));
150
151 Ok(Self {
152 config,
153 message_handler: None, ping_handler: None,
155 writer_tx,
156 connection_mode,
157 reconnect_timeout,
158 heartbeat_task,
159 read_task,
160 write_task,
161 backoff,
162 is_stream_mode: true,
163 reconnect_max_attempts,
164 reconnection_attempt_count: 0,
165 })
166 }
167
168 pub async fn connect_url(
176 config: WebSocketConfig,
177 message_handler: Option<MessageHandler>,
178 ping_handler: Option<PingHandler>,
179 ) -> Result<Self, Error> {
180 install_cryptographic_provider();
181
182 let is_stream_mode = message_handler.is_none();
184 let reconnect_max_attempts = config.reconnect_max_attempts;
185
186 let (writer, reader) =
187 Self::connect_with_server(&config.url, config.headers.clone()).await?;
188
189 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
190
191 let read_task = if message_handler.is_some() {
192 Some(Self::spawn_message_handler_task(
193 connection_mode.clone(),
194 reader,
195 message_handler.as_ref(),
196 ping_handler.as_ref(),
197 ))
198 } else {
199 None
200 };
201
202 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
203 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
204
205 let heartbeat_task = config.heartbeat.map(|heartbeat_secs| {
207 Self::spawn_heartbeat_task(
208 connection_mode.clone(),
209 heartbeat_secs,
210 config.heartbeat_msg.clone(),
211 writer_tx.clone(),
212 )
213 });
214
215 let reconnect_timeout =
216 Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10_000));
217 let backoff = ExponentialBackoff::new(
218 Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
219 Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
220 config.reconnect_backoff_factor.unwrap_or(1.5),
221 config.reconnect_jitter_ms.unwrap_or(100),
222 true, )
224 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
225
226 Ok(Self {
227 config,
228 message_handler,
229 ping_handler,
230 read_task,
231 write_task,
232 writer_tx,
233 heartbeat_task,
234 connection_mode,
235 reconnect_timeout,
236 backoff,
237 is_stream_mode,
239 reconnect_max_attempts,
240 reconnection_attempt_count: 0,
241 })
242 }
243
244 #[inline]
254 #[cfg(not(feature = "turmoil"))]
255 pub async fn connect_with_server(
256 url: &str,
257 headers: Vec<(String, String)>,
258 ) -> Result<(MessageWriter, MessageReader), Error> {
259 let mut request = url.into_client_request()?;
260 let req_headers = request.headers_mut();
261
262 let mut header_names: Vec<HeaderName> = Vec::new();
263 for (key, val) in headers {
264 let header_value = HeaderValue::from_str(&val)?;
265 let header_name: HeaderName = key.parse()?;
266 header_names.push(header_name.clone());
267 req_headers.insert(header_name, header_value);
268 }
269
270 connect_async_with_config(request, None, true)
271 .await
272 .map(|resp| resp.0.split())
273 }
274
275 #[inline]
288 #[cfg(feature = "turmoil")]
289 pub async fn connect_with_server(
290 url: &str,
291 headers: Vec<(String, String)>,
292 ) -> Result<(MessageWriter, MessageReader), Error> {
293 use rustls::ClientConfig;
294 use tokio_rustls::TlsConnector;
295
296 let mut request = url.into_client_request()?;
297 let req_headers = request.headers_mut();
298
299 let mut header_names: Vec<HeaderName> = Vec::new();
300 for (key, val) in headers {
301 let header_value = HeaderValue::from_str(&val)?;
302 let header_name: HeaderName = key.parse()?;
303 header_names.push(header_name.clone());
304 req_headers.insert(header_name, header_value);
305 }
306
307 let uri = request.uri();
308 let scheme = uri.scheme_str().unwrap_or("ws");
309 let host = uri.host().ok_or_else(|| {
310 Error::Url(tokio_tungstenite::tungstenite::error::UrlError::NoHostName)
311 })?;
312
313 let port = uri
315 .port_u16()
316 .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
317
318 let addr = format!("{host}:{port}");
319
320 let connector = crate::net::RealTcpConnector;
322 let tcp_stream = connector.connect(&addr).await?;
323 if let Err(e) = tcp_stream.set_nodelay(true) {
324 log::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
325 }
326
327 let maybe_tls_stream = if scheme == "wss" {
329 let mut root_store = rustls::RootCertStore::empty();
331 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
332
333 let config = ClientConfig::builder()
334 .with_root_certificates(root_store)
335 .with_no_client_auth();
336
337 let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
338 let domain =
339 rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|e| {
340 Error::Io(std::io::Error::new(
341 std::io::ErrorKind::InvalidInput,
342 format!("Invalid DNS name: {e}"),
343 ))
344 })?;
345
346 let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
347 MaybeTlsStream::Rustls(tls_stream)
348 } else {
349 MaybeTlsStream::Plain(tcp_stream)
350 };
351
352 client_async(request, maybe_tls_stream)
354 .await
355 .map(|resp| resp.0.split())
356 }
357
358 pub async fn reconnect(&mut self) -> Result<(), Error> {
373 log::debug!("Reconnecting");
374
375 if self.is_stream_mode {
376 log::warn!(
377 "Auto-reconnect disabled for stream-based WebSocket client; \
378 stream users must manually reconnect by creating a new connection"
379 );
380 self.connection_mode
382 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
383 return Ok(());
384 }
385
386 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
387 log::debug!("Reconnect aborted due to disconnect state");
388 return Ok(());
389 }
390
391 tokio::time::timeout(self.reconnect_timeout, async {
392 let (new_writer, reader) =
394 Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
395
396 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
397 log::debug!("Reconnect aborted mid-flight (after connect)");
398 return Ok(());
399 }
400
401 let (tx, rx) = tokio::sync::oneshot::channel();
405 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
406 log::error!("{e}");
407 return Err(Error::Io(std::io::Error::new(
408 std::io::ErrorKind::BrokenPipe,
409 format!("Failed to send update command: {e}"),
410 )));
411 }
412
413 match rx.await {
415 Ok(true) => log::debug!("Writer confirmed buffer drain success"),
416 Ok(false) => {
417 log::warn!("Writer failed to drain buffer, aborting reconnect");
418 return Err(Error::Io(std::io::Error::other(
420 "Failed to drain reconnection buffer",
421 )));
422 }
423 Err(e) => {
424 log::error!("Writer dropped update channel: {e}");
425 return Err(Error::Io(std::io::Error::new(
426 std::io::ErrorKind::BrokenPipe,
427 "Writer task dropped response channel",
428 )));
429 }
430 }
431
432 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
434
435 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
436 log::debug!("Reconnect aborted mid-flight (after delay)");
437 return Ok(());
438 }
439
440 if let Some(ref read_task) = self.read_task.take()
441 && !read_task.is_finished()
442 {
443 read_task.abort();
444 log_task_aborted("read");
445 }
446
447 if self
450 .connection_mode
451 .compare_exchange(
452 ConnectionMode::Reconnect.as_u8(),
453 ConnectionMode::Active.as_u8(),
454 Ordering::SeqCst,
455 Ordering::SeqCst,
456 )
457 .is_err()
458 {
459 log::debug!("Reconnect aborted (state changed during reconnect)");
460 return Ok(());
461 }
462
463 self.read_task = if self.message_handler.is_some() {
464 Some(Self::spawn_message_handler_task(
465 self.connection_mode.clone(),
466 reader,
467 self.message_handler.as_ref(),
468 self.ping_handler.as_ref(),
469 ))
470 } else {
471 None
472 };
473
474 log::debug!("Reconnect succeeded");
475 Ok(())
476 })
477 .await
478 .map_err(|_| {
479 Error::Io(std::io::Error::new(
480 std::io::ErrorKind::TimedOut,
481 format!(
482 "reconnection timed out after {}s",
483 self.reconnect_timeout.as_secs_f64()
484 ),
485 ))
486 })?
487 }
488
489 #[inline]
497 #[must_use]
498 pub fn is_alive(&self) -> bool {
499 match &self.read_task {
500 Some(read_task) => !read_task.is_finished(),
501 None => true, }
503 }
504
505 fn spawn_message_handler_task(
506 connection_state: Arc<AtomicU8>,
507 mut reader: MessageReader,
508 message_handler: Option<&MessageHandler>,
509 ping_handler: Option<&PingHandler>,
510 ) -> tokio::task::JoinHandle<()> {
511 log::debug!("Started message handler task 'read'");
512
513 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
514
515 let message_handler = message_handler.cloned();
517 let ping_handler = ping_handler.cloned();
518
519 tokio::task::spawn(async move {
520 loop {
521 if !ConnectionMode::from_atomic(&connection_state).is_active() {
522 break;
523 }
524
525 match tokio::time::timeout(check_interval, reader.next()).await {
526 Ok(Some(Ok(Message::Binary(data)))) => {
527 log::trace!("Received message <binary> {} bytes", data.len());
528 if let Some(ref handler) = message_handler {
529 handler(Message::Binary(data));
530 }
531 }
532 Ok(Some(Ok(Message::Text(data)))) => {
533 log::trace!("Received message: {data}");
534 if let Some(ref handler) = message_handler {
535 handler(Message::Text(data));
536 }
537 }
538 Ok(Some(Ok(Message::Ping(ping_data)))) => {
539 log::trace!("Received ping: {ping_data:?}");
540 if let Some(ref handler) = ping_handler {
541 handler(ping_data.to_vec());
542 }
543 }
544 Ok(Some(Ok(Message::Pong(_)))) => {
545 log::trace!("Received pong");
546 }
547 Ok(Some(Ok(Message::Close(_)))) => {
548 log::debug!("Received close message - terminating");
549 break;
550 }
551 Ok(Some(Ok(_))) => (),
552 Ok(Some(Err(e))) => {
553 log::error!("Received error message - terminating: {e}");
554 break;
555 }
556 Ok(None) => {
557 log::debug!("No message received - terminating");
558 break;
559 }
560 Err(_) => {
561 continue;
563 }
564 }
565 }
566 })
567 }
568
569 async fn drain_reconnect_buffer(
574 buffer: &mut VecDeque<Message>,
575 writer: &mut MessageWriter,
576 ) -> bool {
577 if buffer.is_empty() {
578 return false;
579 }
580
581 let initial_buffer_len = buffer.len();
582 log::info!("Sending {initial_buffer_len} buffered messages after reconnection");
583
584 let mut send_error_occurred = false;
585
586 while let Some(buffered_msg) = buffer.front() {
587 let msg_to_send = buffered_msg.clone();
589
590 if let Err(e) = writer.send(msg_to_send).await {
591 log::error!(
592 "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
593 buffer.len()
594 );
595 send_error_occurred = true;
596 break; }
598
599 buffer.pop_front();
601 }
602
603 if buffer.is_empty() {
604 log::info!("Successfully sent all {initial_buffer_len} buffered messages");
605 }
606
607 send_error_occurred
608 }
609
610 fn spawn_write_task(
611 connection_state: Arc<AtomicU8>,
612 writer: MessageWriter,
613 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
614 ) -> tokio::task::JoinHandle<()> {
615 log_task_started("write");
616
617 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
619
620 tokio::task::spawn(async move {
621 let mut active_writer = writer;
622 let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
625
626 loop {
627 match ConnectionMode::from_atomic(&connection_state) {
628 ConnectionMode::Disconnect => {
629 if !reconnect_buffer.is_empty() {
631 log::warn!(
632 "Discarding {} buffered messages due to disconnect",
633 reconnect_buffer.len()
634 );
635 reconnect_buffer.clear();
636 }
637
638 _ = tokio::time::timeout(
641 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
642 active_writer.close(),
643 )
644 .await;
645 break;
646 }
647 ConnectionMode::Closed => {
648 if !reconnect_buffer.is_empty() {
650 log::warn!(
651 "Discarding {} buffered messages due to closed connection",
652 reconnect_buffer.len()
653 );
654 reconnect_buffer.clear();
655 }
656 break;
657 }
658 _ => {}
659 }
660
661 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
662 Ok(Some(msg)) => {
663 let mode = ConnectionMode::from_atomic(&connection_state);
665 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
666 break;
667 }
668
669 match msg {
670 WriterCommand::Update(new_writer, tx) => {
671 log::debug!("Received new writer");
672
673 tokio::time::sleep(Duration::from_millis(100)).await;
675
676 _ = tokio::time::timeout(
679 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
680 active_writer.close(),
681 )
682 .await;
683
684 active_writer = new_writer;
685 log::debug!("Updated writer");
686
687 let send_error = Self::drain_reconnect_buffer(
688 &mut reconnect_buffer,
689 &mut active_writer,
690 )
691 .await;
692
693 if let Err(e) = tx.send(!send_error) {
694 log::error!(
695 "Failed to report drain status to controller: {e:?}"
696 );
697 }
698 }
699 WriterCommand::Send(msg) if mode.is_reconnect() => {
700 log::debug!(
702 "Buffering message during reconnection (buffer size: {})",
703 reconnect_buffer.len() + 1
704 );
705 reconnect_buffer.push_back(msg);
706 }
707 WriterCommand::Send(msg) => {
708 if let Err(e) = active_writer.send(msg.clone()).await {
709 log::error!("Failed to send message: {e}");
710 log::warn!("Writer triggering reconnect");
711 reconnect_buffer.push_back(msg);
712 connection_state
713 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
714 }
715 }
716 }
717 }
718 Ok(None) => {
719 log::debug!("Writer channel closed, terminating writer task");
721 break;
722 }
723 Err(_) => {
724 continue;
726 }
727 }
728 }
729
730 _ = tokio::time::timeout(
733 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
734 active_writer.close(),
735 )
736 .await;
737
738 log_task_stopped("write");
739 })
740 }
741
742 fn spawn_heartbeat_task(
743 connection_state: Arc<AtomicU8>,
744 heartbeat_secs: u64,
745 message: Option<String>,
746 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
747 ) -> tokio::task::JoinHandle<()> {
748 log_task_started("heartbeat");
749
750 tokio::task::spawn(async move {
751 let interval = Duration::from_secs(heartbeat_secs);
752
753 loop {
754 tokio::time::sleep(interval).await;
755
756 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
757 ConnectionMode::Active => {
758 let msg = match &message {
759 Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
760 None => WriterCommand::Send(Message::Ping(vec![].into())),
761 };
762
763 match writer_tx.send(msg) {
764 Ok(()) => log::trace!("Sent heartbeat to writer task"),
765 Err(e) => {
766 log::error!("Failed to send heartbeat to writer task: {e}");
767 }
768 }
769 }
770 ConnectionMode::Reconnect => continue,
771 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
772 }
773 }
774
775 log_task_stopped("heartbeat");
776 })
777 }
778}
779
780impl Drop for WebSocketClientInner {
781 fn drop(&mut self) {
782 self.clean_drop();
784 }
785}
786
787impl CleanDrop for WebSocketClientInner {
789 fn clean_drop(&mut self) {
790 if let Some(ref read_task) = self.read_task.take()
791 && !read_task.is_finished()
792 {
793 read_task.abort();
794 log_task_aborted("read");
795 }
796
797 if !self.write_task.is_finished() {
798 self.write_task.abort();
799 log_task_aborted("write");
800 }
801
802 if let Some(ref handle) = self.heartbeat_task.take()
803 && !handle.is_finished()
804 {
805 handle.abort();
806 log_task_aborted("heartbeat");
807 }
808
809 self.message_handler = None;
811 self.ping_handler = None;
812 }
813}
814
815impl Debug for WebSocketClientInner {
816 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
817 f.debug_struct(stringify!(WebSocketClientInner))
818 .field("config", &self.config)
819 .field(
820 "connection_mode",
821 &ConnectionMode::from_atomic(&self.connection_mode),
822 )
823 .field("reconnect_timeout", &self.reconnect_timeout)
824 .field("is_stream_mode", &self.is_stream_mode)
825 .finish()
826 }
827}
828
829#[cfg_attr(
834 feature = "python",
835 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
836)]
837pub struct WebSocketClient {
838 pub(crate) controller_task: tokio::task::JoinHandle<()>,
839 pub(crate) connection_mode: Arc<AtomicU8>,
840 pub(crate) reconnect_timeout: Duration,
841 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
842 pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
843}
844
845impl Debug for WebSocketClient {
846 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
847 f.debug_struct(stringify!(WebSocketClient)).finish()
848 }
849}
850
851impl WebSocketClient {
852 #[allow(clippy::too_many_arguments)]
868 pub async fn connect_stream(
869 config: WebSocketConfig,
870 keyed_quotas: Vec<(String, Quota)>,
871 default_quota: Option<Quota>,
872 post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
873 ) -> Result<(MessageReader, Self), Error> {
874 install_cryptographic_provider();
875
876 let (writer, reader) =
878 WebSocketClientInner::connect_with_server(&config.url, config.headers.clone()).await?;
879
880 let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
882
883 let connection_mode = inner.connection_mode.clone();
884 let reconnect_timeout = inner.reconnect_timeout;
885 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
886 let writer_tx = inner.writer_tx.clone();
887
888 let controller_task =
889 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnect);
890
891 Ok((
892 reader,
893 Self {
894 controller_task,
895 connection_mode,
896 reconnect_timeout,
897 rate_limiter,
898 writer_tx,
899 },
900 ))
901 }
902
903 pub async fn connect(
921 config: WebSocketConfig,
922 message_handler: Option<MessageHandler>,
923 ping_handler: Option<PingHandler>,
924 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
925 keyed_quotas: Vec<(String, Quota)>,
926 default_quota: Option<Quota>,
927 ) -> Result<Self, Error> {
928 if message_handler.is_none() {
930 return Err(Error::Io(std::io::Error::new(
931 std::io::ErrorKind::InvalidInput,
932 "Handler mode requires message_handler to be set. Use connect_stream() for stream mode without a handler.",
933 )));
934 }
935
936 log::debug!("Connecting");
937 let inner =
938 WebSocketClientInner::connect_url(config, message_handler, ping_handler).await?;
939 let connection_mode = inner.connection_mode.clone();
940 let writer_tx = inner.writer_tx.clone();
941 let reconnect_timeout = inner.reconnect_timeout;
942
943 let controller_task =
944 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnection);
945
946 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
947
948 Ok(Self {
949 controller_task,
950 connection_mode,
951 reconnect_timeout,
952 rate_limiter,
953 writer_tx,
954 })
955 }
956
957 #[must_use]
959 pub fn connection_mode(&self) -> ConnectionMode {
960 ConnectionMode::from_atomic(&self.connection_mode)
961 }
962
963 #[must_use]
968 pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
969 Arc::clone(&self.connection_mode)
970 }
971
972 #[inline]
977 #[must_use]
978 pub fn is_active(&self) -> bool {
979 self.connection_mode().is_active()
980 }
981
982 #[must_use]
984 pub fn is_disconnected(&self) -> bool {
985 self.controller_task.is_finished()
986 }
987
988 #[inline]
993 #[must_use]
994 pub fn is_reconnecting(&self) -> bool {
995 self.connection_mode().is_reconnect()
996 }
997
998 #[inline]
1002 #[must_use]
1003 pub fn is_disconnecting(&self) -> bool {
1004 self.connection_mode().is_disconnect()
1005 }
1006
1007 #[inline]
1013 #[must_use]
1014 pub fn is_closed(&self) -> bool {
1015 self.connection_mode().is_closed()
1016 }
1017
1018 async fn wait_for_active(&self) -> Result<(), SendError> {
1022 if self.is_closed() {
1023 return Err(SendError::Closed);
1024 }
1025
1026 let timeout = self.reconnect_timeout;
1027 let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
1028
1029 if !self.is_active() {
1030 log::debug!("Waiting for client to become ACTIVE before sending...");
1031
1032 let inner = tokio::time::timeout(timeout, async {
1033 loop {
1034 if self.is_active() {
1035 return Ok(());
1036 }
1037 if matches!(
1038 self.connection_mode(),
1039 ConnectionMode::Disconnect | ConnectionMode::Closed
1040 ) {
1041 return Err(());
1042 }
1043 tokio::time::sleep(check_interval).await;
1044 }
1045 })
1046 .await
1047 .map_err(|_| SendError::Timeout)?;
1048 inner.map_err(|()| SendError::Closed)?;
1049 }
1050
1051 Ok(())
1052 }
1053
1054 pub async fn disconnect(&self) {
1059 log::debug!("Disconnecting");
1060 self.connection_mode
1061 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1062
1063 if tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1064 while !self.is_disconnected() {
1065 tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
1066 }
1067
1068 if !self.controller_task.is_finished() {
1069 self.controller_task.abort();
1070 log_task_aborted("controller");
1071 }
1072 })
1073 .await
1074 == Ok(())
1075 {
1076 log::debug!("Controller task finished");
1077 } else {
1078 log::error!("Timeout waiting for controller task to finish");
1079 if !self.controller_task.is_finished() {
1080 self.controller_task.abort();
1081 log_task_aborted("controller");
1082 }
1083 }
1084 }
1085
1086 #[allow(unused_variables)]
1092 pub async fn send_text(
1093 &self,
1094 data: String,
1095 keys: Option<Vec<String>>,
1096 ) -> Result<(), SendError> {
1097 if self.is_closed() || self.is_disconnecting() {
1099 return Err(SendError::Closed);
1100 }
1101
1102 self.rate_limiter.await_keys_ready(keys).await;
1103 self.wait_for_active().await?;
1104
1105 log::trace!("Sending text: {data:?}");
1106
1107 let msg = Message::Text(data.into());
1108 self.writer_tx
1109 .send(WriterCommand::Send(msg))
1110 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1111 }
1112
1113 pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1119 self.wait_for_active().await?;
1120
1121 log::trace!("Sending pong frame ({} bytes)", data.len());
1122
1123 let msg = Message::Pong(data.into());
1124 self.writer_tx
1125 .send(WriterCommand::Send(msg))
1126 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1127 }
1128
1129 #[allow(unused_variables)]
1135 pub async fn send_bytes(
1136 &self,
1137 data: Vec<u8>,
1138 keys: Option<Vec<String>>,
1139 ) -> Result<(), SendError> {
1140 if self.is_closed() || self.is_disconnecting() {
1142 return Err(SendError::Closed);
1143 }
1144
1145 self.rate_limiter.await_keys_ready(keys).await;
1146 self.wait_for_active().await?;
1147
1148 log::trace!("Sending bytes: {data:?}");
1149
1150 let msg = Message::Binary(data.into());
1151 self.writer_tx
1152 .send(WriterCommand::Send(msg))
1153 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1154 }
1155
1156 pub async fn send_close_message(&self) -> Result<(), SendError> {
1162 self.wait_for_active().await?;
1163
1164 let msg = Message::Close(None);
1165 self.writer_tx
1166 .send(WriterCommand::Send(msg))
1167 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1168 }
1169
1170 fn spawn_controller_task(
1171 mut inner: WebSocketClientInner,
1172 connection_mode: Arc<AtomicU8>,
1173 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1174 ) -> tokio::task::JoinHandle<()> {
1175 tokio::task::spawn(async move {
1176 log_task_started("controller");
1177
1178 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1179
1180 loop {
1181 tokio::time::sleep(check_interval).await;
1182 let mut mode = ConnectionMode::from_atomic(&connection_mode);
1183
1184 if mode.is_disconnect() {
1185 log::debug!("Disconnecting");
1186
1187 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1188 if tokio::time::timeout(timeout, async {
1189 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1191
1192 if let Some(task) = &inner.read_task
1193 && !task.is_finished()
1194 {
1195 task.abort();
1196 log_task_aborted("read");
1197 }
1198
1199 if let Some(task) = &inner.heartbeat_task
1200 && !task.is_finished()
1201 {
1202 task.abort();
1203 log_task_aborted("heartbeat");
1204 }
1205 })
1206 .await
1207 .is_err()
1208 {
1209 log::error!("Shutdown timed out after {}s", timeout.as_secs());
1210 }
1211
1212 log::debug!("Closed");
1213 break; }
1215
1216 if mode.is_closed() {
1217 log::debug!("Connection closed");
1218 break;
1219 }
1220
1221 if mode.is_active() && !inner.is_alive() {
1222 if connection_mode
1223 .compare_exchange(
1224 ConnectionMode::Active.as_u8(),
1225 ConnectionMode::Reconnect.as_u8(),
1226 Ordering::SeqCst,
1227 Ordering::SeqCst,
1228 )
1229 .is_ok()
1230 {
1231 log::debug!("Detected dead read task, transitioning to RECONNECT");
1232 }
1233 mode = ConnectionMode::from_atomic(&connection_mode);
1234 }
1235
1236 if mode.is_reconnect() {
1237 if let Some(max_attempts) = inner.reconnect_max_attempts
1239 && inner.reconnection_attempt_count >= max_attempts
1240 {
1241 log::error!(
1242 "Max reconnection attempts ({max_attempts}) exceeded, transitioning to CLOSED"
1243 );
1244 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1245 break;
1246 }
1247
1248 inner.reconnection_attempt_count += 1;
1249 log::debug!(
1250 "Reconnection attempt {} of {}",
1251 inner.reconnection_attempt_count,
1252 inner
1253 .reconnect_max_attempts
1254 .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1255 );
1256
1257 match inner.reconnect().await {
1258 Ok(()) => {
1259 inner.backoff.reset();
1260 inner.reconnection_attempt_count = 0; if ConnectionMode::from_atomic(&connection_mode).is_active() {
1264 if let Some(ref handler) = inner.message_handler {
1265 let reconnected_msg =
1266 Message::Text(RECONNECTED.to_string().into());
1267 handler(reconnected_msg);
1268 log::debug!("Sent reconnected message to handler");
1269 }
1270
1271 if let Some(ref callback) = post_reconnection {
1273 callback();
1274 log::debug!("Called `post_reconnection` handler");
1275 }
1276
1277 log::debug!("Reconnected successfully");
1278 } else {
1279 log::debug!(
1280 "Skipping post_reconnection handlers due to disconnect state"
1281 );
1282 }
1283 }
1284 Err(e) => {
1285 let duration = inner.backoff.next_duration();
1286 log::warn!(
1287 "Reconnect attempt {} failed: {e}",
1288 inner.reconnection_attempt_count
1289 );
1290 if !duration.is_zero() {
1291 log::warn!("Backing off for {}s...", duration.as_secs_f64());
1292 }
1293 tokio::time::sleep(duration).await;
1294 }
1295 }
1296 }
1297 }
1298 inner
1299 .connection_mode
1300 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1301
1302 log_task_stopped("controller");
1303 })
1304 }
1305}
1306
1307impl Drop for WebSocketClient {
1309 fn drop(&mut self) {
1310 if !self.controller_task.is_finished() {
1311 self.controller_task.abort();
1312 log_task_aborted("controller");
1313 }
1314 }
1315}
1316
1317#[cfg(test)]
1318#[cfg(not(feature = "turmoil"))]
1319#[cfg(target_os = "linux")] mod tests {
1321 use std::{num::NonZeroU32, sync::Arc};
1322
1323 use futures_util::{SinkExt, StreamExt};
1324 use tokio::{
1325 net::TcpListener,
1326 task::{self, JoinHandle},
1327 };
1328 use tokio_tungstenite::{
1329 accept_hdr_async,
1330 tungstenite::{
1331 handshake::server::{self, Callback},
1332 http::HeaderValue,
1333 },
1334 };
1335
1336 use crate::{
1337 ratelimiter::quota::Quota,
1338 websocket::{WebSocketClient, WebSocketConfig},
1339 };
1340
1341 struct TestServer {
1342 task: JoinHandle<()>,
1343 port: u16,
1344 }
1345
1346 #[derive(Debug, Clone)]
1347 struct TestCallback {
1348 key: String,
1349 value: HeaderValue,
1350 }
1351
1352 impl Callback for TestCallback {
1353 #[allow(clippy::panic_in_result_fn)]
1354 fn on_request(
1355 self,
1356 request: &server::Request,
1357 response: server::Response,
1358 ) -> Result<server::Response, server::ErrorResponse> {
1359 let _ = response;
1360 let value = request.headers().get(&self.key);
1361 assert!(value.is_some());
1362
1363 if let Some(value) = request.headers().get(&self.key) {
1364 assert_eq!(value, self.value);
1365 }
1366
1367 Ok(response)
1368 }
1369 }
1370
1371 impl TestServer {
1372 async fn setup() -> Self {
1373 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1374 let port = TcpListener::local_addr(&server).unwrap().port();
1375
1376 let header_key = "test".to_string();
1377 let header_value = "test".to_string();
1378
1379 let test_call_back = TestCallback {
1380 key: header_key,
1381 value: HeaderValue::from_str(&header_value).unwrap(),
1382 };
1383
1384 let task = task::spawn(async move {
1385 loop {
1387 let (conn, _) = server.accept().await.unwrap();
1388 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1389 .await
1390 .unwrap();
1391
1392 task::spawn(async move {
1393 while let Some(Ok(msg)) = websocket.next().await {
1394 match msg {
1395 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1396 if txt == "close-now" =>
1397 {
1398 log::debug!("Forcibly closing from server side");
1399 let _ = websocket.close(None).await;
1401 break;
1402 }
1403 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1405 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1406 if websocket.send(msg).await.is_err() {
1407 break;
1408 }
1409 }
1410 tokio_tungstenite::tungstenite::protocol::Message::Close(
1412 _frame,
1413 ) => {
1414 let _ = websocket.close(None).await;
1415 break;
1416 }
1417 _ => {}
1419 }
1420 }
1421 });
1422 }
1423 });
1424
1425 Self { task, port }
1426 }
1427 }
1428
1429 impl Drop for TestServer {
1430 fn drop(&mut self) {
1431 self.task.abort();
1432 }
1433 }
1434
1435 async fn setup_test_client(port: u16) -> WebSocketClient {
1436 let config = WebSocketConfig {
1437 url: format!("ws://127.0.0.1:{port}"),
1438 headers: vec![("test".into(), "test".into())],
1439 heartbeat: None,
1440 heartbeat_msg: None,
1441 reconnect_timeout_ms: None,
1442 reconnect_delay_initial_ms: None,
1443 reconnect_backoff_factor: None,
1444 reconnect_delay_max_ms: None,
1445 reconnect_jitter_ms: None,
1446 reconnect_max_attempts: None,
1447 };
1448 WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1449 .await
1450 .expect("Failed to connect")
1451 }
1452
1453 #[tokio::test]
1454 async fn test_websocket_basic() {
1455 let server = TestServer::setup().await;
1456 let client = setup_test_client(server.port).await;
1457
1458 assert!(!client.is_disconnected());
1459
1460 client.disconnect().await;
1461 assert!(client.is_disconnected());
1462 }
1463
1464 #[tokio::test]
1465 async fn test_websocket_heartbeat() {
1466 let server = TestServer::setup().await;
1467 let client = setup_test_client(server.port).await;
1468
1469 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1471
1472 client.disconnect().await;
1474 assert!(client.is_disconnected());
1475 }
1476
1477 #[tokio::test]
1478 async fn test_websocket_reconnect_exhausted() {
1479 let config = WebSocketConfig {
1480 url: "ws://127.0.0.1:9997".into(), headers: vec![],
1482 heartbeat: None,
1483 heartbeat_msg: None,
1484 reconnect_timeout_ms: None,
1485 reconnect_delay_initial_ms: None,
1486 reconnect_backoff_factor: None,
1487 reconnect_delay_max_ms: None,
1488 reconnect_jitter_ms: None,
1489 reconnect_max_attempts: None,
1490 };
1491 let res =
1492 WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1493 .await;
1494 assert!(res.is_err(), "Should fail quickly with no server");
1495 }
1496
1497 #[tokio::test]
1498 async fn test_websocket_forced_close_reconnect() {
1499 let server = TestServer::setup().await;
1500 let client = setup_test_client(server.port).await;
1501
1502 client.send_text("Hello".into(), None).await.unwrap();
1504
1505 client.send_text("close-now".into(), None).await.unwrap();
1507
1508 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1510
1511 assert!(!client.is_disconnected());
1513
1514 client.disconnect().await;
1516 assert!(client.is_disconnected());
1517 }
1518
1519 #[tokio::test]
1520 async fn test_rate_limiter() {
1521 let server = TestServer::setup().await;
1522 let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1523
1524 let config = WebSocketConfig {
1525 url: format!("ws://127.0.0.1:{}", server.port),
1526 headers: vec![("test".into(), "test".into())],
1527 heartbeat: None,
1528 heartbeat_msg: None,
1529 reconnect_timeout_ms: None,
1530 reconnect_delay_initial_ms: None,
1531 reconnect_backoff_factor: None,
1532 reconnect_delay_max_ms: None,
1533 reconnect_jitter_ms: None,
1534 reconnect_max_attempts: None,
1535 };
1536
1537 let client = WebSocketClient::connect(
1538 config,
1539 Some(Arc::new(|_| {})),
1540 None,
1541 None,
1542 vec![("default".into(), quota)],
1543 None,
1544 )
1545 .await
1546 .unwrap();
1547
1548 client.send_text("test1".into(), None).await.unwrap();
1550 client.send_text("test2".into(), None).await.unwrap();
1551
1552 client.send_text("test3".into(), None).await.unwrap();
1554
1555 client.disconnect().await;
1557 assert!(client.is_disconnected());
1558 }
1559
1560 #[tokio::test]
1561 async fn test_concurrent_writers() {
1562 let server = TestServer::setup().await;
1563 let client = Arc::new(setup_test_client(server.port).await);
1564
1565 let mut handles = vec![];
1566 for i in 0..10 {
1567 let client = client.clone();
1568 handles.push(task::spawn(async move {
1569 client.send_text(format!("test{i}"), None).await.unwrap();
1570 }));
1571 }
1572
1573 for handle in handles {
1574 handle.await.unwrap();
1575 }
1576
1577 client.disconnect().await;
1579 assert!(client.is_disconnected());
1580 }
1581}
1582
1583#[cfg(test)]
1584#[cfg(not(feature = "turmoil"))]
1585mod rust_tests {
1586 use futures_util::StreamExt;
1587 use rstest::rstest;
1588 use tokio::{
1589 net::TcpListener,
1590 task,
1591 time::{Duration, sleep},
1592 };
1593 use tokio_tungstenite::accept_async;
1594
1595 use super::*;
1596 use crate::websocket::types::channel_message_handler;
1597
1598 #[rstest]
1599 #[tokio::test]
1600 async fn test_reconnect_then_disconnect() {
1601 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1603 let port = listener.local_addr().unwrap().port();
1604
1605 let server = task::spawn(async move {
1607 let (stream, _) = listener.accept().await.unwrap();
1608 let ws = accept_async(stream).await.unwrap();
1609 drop(ws);
1610 sleep(Duration::from_secs(1)).await;
1612 });
1613
1614 let (handler, _rx) = channel_message_handler();
1616
1617 let config = WebSocketConfig {
1619 url: format!("ws://127.0.0.1:{port}"),
1620 headers: vec![],
1621 heartbeat: None,
1622 heartbeat_msg: None,
1623 reconnect_timeout_ms: Some(1_000),
1624 reconnect_delay_initial_ms: Some(50),
1625 reconnect_delay_max_ms: Some(100),
1626 reconnect_backoff_factor: Some(1.0),
1627 reconnect_jitter_ms: Some(0),
1628 reconnect_max_attempts: None,
1629 };
1630
1631 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1633 .await
1634 .unwrap();
1635
1636 sleep(Duration::from_millis(100)).await;
1638 client.disconnect().await;
1640 assert!(client.is_disconnected());
1641 server.abort();
1642 }
1643
1644 #[rstest]
1645 #[tokio::test]
1646 async fn test_reconnect_state_flips_when_reader_stops() {
1647 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1649 let port = listener.local_addr().unwrap().port();
1650
1651 let server = task::spawn(async move {
1652 if let Ok((stream, _)) = listener.accept().await
1653 && let Ok(ws) = accept_async(stream).await
1654 {
1655 drop(ws);
1656 }
1657 sleep(Duration::from_millis(50)).await;
1658 });
1659
1660 let (handler, _rx) = channel_message_handler();
1661
1662 let config = WebSocketConfig {
1663 url: format!("ws://127.0.0.1:{port}"),
1664 headers: vec![],
1665 heartbeat: None,
1666 heartbeat_msg: None,
1667 reconnect_timeout_ms: Some(1_000),
1668 reconnect_delay_initial_ms: Some(50),
1669 reconnect_delay_max_ms: Some(100),
1670 reconnect_backoff_factor: Some(1.0),
1671 reconnect_jitter_ms: Some(0),
1672 reconnect_max_attempts: None,
1673 };
1674
1675 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1676 .await
1677 .unwrap();
1678
1679 tokio::time::timeout(Duration::from_secs(2), async {
1680 loop {
1681 if client.is_reconnecting() {
1682 break;
1683 }
1684 tokio::time::sleep(Duration::from_millis(10)).await;
1685 }
1686 })
1687 .await
1688 .expect("client did not enter RECONNECT state");
1689
1690 client.disconnect().await;
1691 server.abort();
1692 }
1693
1694 #[rstest]
1695 #[tokio::test]
1696 async fn test_stream_mode_disables_auto_reconnect() {
1697 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1700 let port = listener.local_addr().unwrap().port();
1701
1702 let server = task::spawn(async move {
1703 if let Ok((stream, _)) = listener.accept().await
1704 && let Ok(_ws) = accept_async(stream).await
1705 {
1706 sleep(Duration::from_millis(100)).await;
1708 }
1709 });
1710
1711 let config = WebSocketConfig {
1712 url: format!("ws://127.0.0.1:{port}"),
1713 headers: vec![],
1714 heartbeat: None,
1715 heartbeat_msg: None,
1716 reconnect_timeout_ms: Some(1_000),
1717 reconnect_delay_initial_ms: Some(50),
1718 reconnect_delay_max_ms: Some(100),
1719 reconnect_backoff_factor: Some(1.0),
1720 reconnect_jitter_ms: Some(0),
1721 reconnect_max_attempts: None,
1722 };
1723
1724 let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
1725 .await
1726 .unwrap();
1727
1728 server.abort();
1736 }
1737
1738 #[rstest]
1739 #[tokio::test]
1740 async fn test_message_handler_mode_allows_auto_reconnect() {
1741 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1743 let port = listener.local_addr().unwrap().port();
1744
1745 let server = task::spawn(async move {
1746 if let Ok((stream, _)) = listener.accept().await
1748 && let Ok(ws) = accept_async(stream).await
1749 {
1750 drop(ws);
1751 }
1752 sleep(Duration::from_millis(50)).await;
1753 });
1754
1755 let (handler, _rx) = channel_message_handler();
1756
1757 let config = WebSocketConfig {
1758 url: format!("ws://127.0.0.1:{port}"),
1759 headers: vec![],
1760 heartbeat: None,
1761 heartbeat_msg: None,
1762 reconnect_timeout_ms: Some(1_000),
1763 reconnect_delay_initial_ms: Some(50),
1764 reconnect_delay_max_ms: Some(100),
1765 reconnect_backoff_factor: Some(1.0),
1766 reconnect_jitter_ms: Some(0),
1767 reconnect_max_attempts: None,
1768 };
1769
1770 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1771 .await
1772 .unwrap();
1773
1774 tokio::time::timeout(Duration::from_secs(2), async {
1776 loop {
1777 if client.is_reconnecting() || client.is_closed() {
1778 break;
1779 }
1780 tokio::time::sleep(Duration::from_millis(10)).await;
1781 }
1782 })
1783 .await
1784 .expect("client should attempt reconnection or close");
1785
1786 assert!(
1789 client.is_reconnecting() || client.is_closed(),
1790 "Client with message handler should attempt reconnection"
1791 );
1792
1793 client.disconnect().await;
1794 server.abort();
1795 }
1796
1797 #[rstest]
1798 #[tokio::test]
1799 async fn test_handler_mode_reconnect_with_new_connection() {
1800 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1802 let port = listener.local_addr().unwrap().port();
1803
1804 let server = task::spawn(async move {
1805 if let Ok((stream, _)) = listener.accept().await
1807 && let Ok(ws) = accept_async(stream).await
1808 {
1809 drop(ws);
1810 }
1811
1812 sleep(Duration::from_millis(100)).await;
1814
1815 if let Ok((stream, _)) = listener.accept().await
1817 && let Ok(mut ws) = accept_async(stream).await
1818 {
1819 use futures_util::SinkExt;
1820 let _ = ws
1821 .send(Message::Text("reconnected".to_string().into()))
1822 .await;
1823 sleep(Duration::from_secs(1)).await;
1824 }
1825 });
1826
1827 let (handler, mut rx) = channel_message_handler();
1828
1829 let config = WebSocketConfig {
1830 url: format!("ws://127.0.0.1:{port}"),
1831 headers: vec![],
1832 heartbeat: None,
1833 heartbeat_msg: None,
1834 reconnect_timeout_ms: Some(2_000),
1835 reconnect_delay_initial_ms: Some(50),
1836 reconnect_delay_max_ms: Some(200),
1837 reconnect_backoff_factor: Some(1.5),
1838 reconnect_jitter_ms: Some(10),
1839 reconnect_max_attempts: None,
1840 };
1841
1842 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1843 .await
1844 .unwrap();
1845
1846 let result = tokio::time::timeout(Duration::from_secs(5), async {
1848 loop {
1849 if let Ok(msg) = rx.try_recv()
1850 && matches!(msg, Message::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
1851 {
1852 return true;
1853 }
1854 tokio::time::sleep(Duration::from_millis(10)).await;
1855 }
1856 })
1857 .await;
1858
1859 assert!(
1860 result.is_ok(),
1861 "Should receive message after reconnection within timeout"
1862 );
1863
1864 client.disconnect().await;
1865 server.abort();
1866 }
1867
1868 #[rstest]
1869 #[tokio::test]
1870 async fn test_stream_mode_no_auto_reconnect() {
1871 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1874 let port = listener.local_addr().unwrap().port();
1875
1876 let server = task::spawn(async move {
1877 if let Ok((stream, _)) = listener.accept().await
1879 && let Ok(mut ws) = accept_async(stream).await
1880 {
1881 use futures_util::SinkExt;
1882 let _ = ws.send(Message::Text("hello".to_string().into())).await;
1883 sleep(Duration::from_millis(50)).await;
1884 }
1886 });
1887
1888 let config = WebSocketConfig {
1889 url: format!("ws://127.0.0.1:{port}"),
1890 headers: vec![],
1891 heartbeat: None,
1892 heartbeat_msg: None,
1893 reconnect_timeout_ms: Some(1_000),
1894 reconnect_delay_initial_ms: Some(50),
1895 reconnect_delay_max_ms: Some(100),
1896 reconnect_backoff_factor: Some(1.0),
1897 reconnect_jitter_ms: Some(0),
1898 reconnect_max_attempts: None,
1899 };
1900
1901 let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
1902 .await
1903 .unwrap();
1904
1905 assert!(client.is_active(), "Client should start as active");
1907
1908 let msg = reader.next().await;
1910 assert!(
1911 matches!(msg, Some(Ok(Message::Text(ref text))) if AsRef::<str>::as_ref(text) == "hello"),
1912 "Should receive initial message"
1913 );
1914
1915 while let Some(msg) = reader.next().await {
1917 if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
1918 break;
1919 }
1920 }
1921
1922 sleep(Duration::from_millis(200)).await;
1925
1926 assert!(
1929 client.is_active() || client.is_closed(),
1930 "Stream mode client stays ACTIVE (caller owns reader) or caller disconnected"
1931 );
1932 assert!(
1933 !client.is_reconnecting(),
1934 "Stream mode client should never attempt reconnection"
1935 );
1936
1937 client.disconnect().await;
1938 server.abort();
1939 }
1940
1941 #[rstest]
1942 #[tokio::test]
1943 async fn test_send_timeout_uses_configured_reconnect_timeout() {
1944 use nautilus_common::testing::wait_until_async;
1947
1948 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1949 let port = listener.local_addr().unwrap().port();
1950
1951 let server = task::spawn(async move {
1952 if let Ok((stream, _)) = listener.accept().await
1954 && let Ok(ws) = accept_async(stream).await
1955 {
1956 drop(ws);
1957 }
1958 sleep(Duration::from_secs(60)).await;
1960 });
1961
1962 let (handler, _rx) = channel_message_handler();
1963
1964 let config = WebSocketConfig {
1966 url: format!("ws://127.0.0.1:{port}"),
1967 headers: vec![],
1968 heartbeat: None,
1969 heartbeat_msg: None,
1970 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(50),
1972 reconnect_delay_max_ms: Some(100),
1973 reconnect_backoff_factor: Some(1.0),
1974 reconnect_jitter_ms: Some(0),
1975 reconnect_max_attempts: None,
1976 };
1977
1978 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1979 .await
1980 .unwrap();
1981
1982 wait_until_async(
1984 || async { client.is_reconnecting() },
1985 Duration::from_secs(3),
1986 )
1987 .await;
1988
1989 let start = std::time::Instant::now();
1991 let send_result = client.send_text("test".to_string(), None).await;
1992 let elapsed = start.elapsed();
1993
1994 assert!(
1995 send_result.is_err(),
1996 "Send should fail when client stuck in RECONNECT"
1997 );
1998 assert!(
1999 matches!(send_result, Err(crate::error::SendError::Timeout)),
2000 "Send should return Timeout error, was: {send_result:?}"
2001 );
2002 assert!(
2005 elapsed >= Duration::from_millis(1800),
2006 "Send should timeout after at least 2s (configured timeout), took {elapsed:?}"
2007 );
2008
2009 client.disconnect().await;
2010 server.abort();
2011 }
2012
2013 #[rstest]
2014 #[tokio::test]
2015 async fn test_send_waits_during_reconnection() {
2016 use nautilus_common::testing::wait_until_async;
2018
2019 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2020 let port = listener.local_addr().unwrap().port();
2021
2022 let server = task::spawn(async move {
2023 if let Ok((stream, _)) = listener.accept().await
2025 && let Ok(ws) = accept_async(stream).await
2026 {
2027 drop(ws);
2028 }
2029
2030 sleep(Duration::from_millis(500)).await;
2032
2033 if let Ok((stream, _)) = listener.accept().await
2035 && let Ok(mut ws) = accept_async(stream).await
2036 {
2037 while let Some(Ok(msg)) = ws.next().await {
2039 if ws.send(msg).await.is_err() {
2040 break;
2041 }
2042 }
2043 }
2044 });
2045
2046 let (handler, _rx) = channel_message_handler();
2047
2048 let config = WebSocketConfig {
2049 url: format!("ws://127.0.0.1:{port}"),
2050 headers: vec![],
2051 heartbeat: None,
2052 heartbeat_msg: None,
2053 reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
2055 reconnect_delay_max_ms: Some(200),
2056 reconnect_backoff_factor: Some(1.0),
2057 reconnect_jitter_ms: Some(0),
2058 reconnect_max_attempts: None,
2059 };
2060
2061 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2062 .await
2063 .unwrap();
2064
2065 wait_until_async(
2067 || async { client.is_reconnecting() },
2068 Duration::from_secs(2),
2069 )
2070 .await;
2071
2072 let send_result = tokio::time::timeout(
2074 Duration::from_secs(3),
2075 client.send_text("test_message".to_string(), None),
2076 )
2077 .await;
2078
2079 assert!(
2080 send_result.is_ok() && send_result.unwrap().is_ok(),
2081 "Send should succeed after waiting for reconnection"
2082 );
2083
2084 client.disconnect().await;
2085 server.abort();
2086 }
2087
2088 #[rstest]
2089 #[tokio::test]
2090 async fn test_rate_limiter_before_active_wait() {
2091 use std::{num::NonZeroU32, sync::Arc};
2096
2097 use nautilus_common::testing::wait_until_async;
2098
2099 use crate::ratelimiter::quota::Quota;
2100
2101 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2102 let port = listener.local_addr().unwrap().port();
2103
2104 let server = task::spawn(async move {
2105 if let Ok((stream, _)) = listener.accept().await
2107 && let Ok(mut ws) = accept_async(stream).await
2108 {
2109 if let Some(Ok(_)) = ws.next().await {
2111 drop(ws);
2112 }
2113 }
2114
2115 sleep(Duration::from_millis(500)).await;
2117
2118 if let Ok((stream, _)) = listener.accept().await
2120 && let Ok(mut ws) = accept_async(stream).await
2121 {
2122 while let Some(Ok(msg)) = ws.next().await {
2123 if ws.send(msg).await.is_err() {
2124 break;
2125 }
2126 }
2127 }
2128 });
2129
2130 let (handler, _rx) = channel_message_handler();
2131
2132 let config = WebSocketConfig {
2133 url: format!("ws://127.0.0.1:{port}"),
2134 headers: vec![],
2135 heartbeat: None,
2136 heartbeat_msg: None,
2137 reconnect_timeout_ms: Some(5_000),
2138 reconnect_delay_initial_ms: Some(50),
2139 reconnect_delay_max_ms: Some(100),
2140 reconnect_backoff_factor: Some(1.0),
2141 reconnect_jitter_ms: Some(0),
2142 reconnect_max_attempts: None,
2143 };
2144
2145 let quota =
2147 Quota::per_second(NonZeroU32::new(1).unwrap()).allow_burst(NonZeroU32::new(1).unwrap());
2148
2149 let client = Arc::new(
2150 WebSocketClient::connect(
2151 config,
2152 Some(handler),
2153 None,
2154 None,
2155 vec![("test_key".to_string(), quota)],
2156 None,
2157 )
2158 .await
2159 .unwrap(),
2160 );
2161
2162 client
2164 .send_text("msg1".to_string(), Some(vec!["test_key".to_string()]))
2165 .await
2166 .unwrap();
2167
2168 wait_until_async(
2170 || async { client.is_reconnecting() },
2171 Duration::from_secs(2),
2172 )
2173 .await;
2174
2175 let start = std::time::Instant::now();
2177 let send_result = client
2178 .send_text("msg2".to_string(), Some(vec!["test_key".to_string()]))
2179 .await;
2180 let elapsed = start.elapsed();
2181
2182 assert!(
2184 send_result.is_ok(),
2185 "Send should succeed after rate limit + reconnection, was: {send_result:?}"
2186 );
2187 assert!(
2191 elapsed >= Duration::from_millis(850),
2192 "Should wait for rate limit (~1s), waited {elapsed:?}"
2193 );
2194
2195 client.disconnect().await;
2196 server.abort();
2197 }
2198
2199 #[rstest]
2200 #[tokio::test]
2201 async fn test_disconnect_during_reconnect_exits_cleanly() {
2202 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2205 let port = listener.local_addr().unwrap().port();
2206
2207 let server = task::spawn(async move {
2208 if let Ok((stream, _)) = listener.accept().await
2210 && let Ok(ws) = accept_async(stream).await
2211 {
2212 drop(ws);
2213 }
2214 sleep(Duration::from_secs(60)).await;
2216 });
2217
2218 let (handler, _rx) = channel_message_handler();
2219
2220 let config = WebSocketConfig {
2221 url: format!("ws://127.0.0.1:{port}"),
2222 headers: vec![],
2223 heartbeat: None,
2224 heartbeat_msg: None,
2225 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(100),
2227 reconnect_delay_max_ms: Some(200),
2228 reconnect_backoff_factor: Some(1.0),
2229 reconnect_jitter_ms: Some(0),
2230 reconnect_max_attempts: None,
2231 };
2232
2233 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2234 .await
2235 .unwrap();
2236
2237 tokio::time::timeout(Duration::from_secs(2), async {
2239 while !client.is_reconnecting() {
2240 sleep(Duration::from_millis(10)).await;
2241 }
2242 })
2243 .await
2244 .expect("Client should enter RECONNECT state");
2245
2246 client.disconnect().await;
2248
2249 assert!(
2251 client.is_disconnected(),
2252 "Client should be cleanly disconnected"
2253 );
2254
2255 server.abort();
2256 }
2257
2258 #[rstest]
2259 #[tokio::test]
2260 async fn test_send_fails_fast_when_closed_before_rate_limit() {
2261 use std::{num::NonZeroU32, sync::Arc};
2264
2265 use nautilus_common::testing::wait_until_async;
2266
2267 use crate::ratelimiter::quota::Quota;
2268
2269 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2270 let port = listener.local_addr().unwrap().port();
2271
2272 let server = task::spawn(async move {
2273 if let Ok((stream, _)) = listener.accept().await
2275 && let Ok(ws) = accept_async(stream).await
2276 {
2277 drop(ws);
2278 }
2279 sleep(Duration::from_secs(60)).await;
2280 });
2281
2282 let (handler, _rx) = channel_message_handler();
2283
2284 let config = WebSocketConfig {
2285 url: format!("ws://127.0.0.1:{port}"),
2286 headers: vec![],
2287 heartbeat: None,
2288 heartbeat_msg: None,
2289 reconnect_timeout_ms: Some(5_000),
2290 reconnect_delay_initial_ms: Some(50),
2291 reconnect_delay_max_ms: Some(100),
2292 reconnect_backoff_factor: Some(1.0),
2293 reconnect_jitter_ms: Some(0),
2294 reconnect_max_attempts: None,
2295 };
2296
2297 let quota = Quota::with_period(Duration::from_secs(10))
2300 .unwrap()
2301 .allow_burst(NonZeroU32::new(1).unwrap());
2302
2303 let client = Arc::new(
2304 WebSocketClient::connect(
2305 config,
2306 Some(handler),
2307 None,
2308 None,
2309 vec![("test_key".to_string(), quota)],
2310 None,
2311 )
2312 .await
2313 .unwrap(),
2314 );
2315
2316 wait_until_async(
2318 || async { client.is_reconnecting() || client.is_closed() },
2319 Duration::from_secs(2),
2320 )
2321 .await;
2322
2323 client.disconnect().await;
2325 assert!(
2326 !client.is_active(),
2327 "Client should not be active after disconnect"
2328 );
2329
2330 let start = std::time::Instant::now();
2332 let result = client
2333 .send_text("test".to_string(), Some(vec!["test_key".to_string()]))
2334 .await;
2335 let elapsed = start.elapsed();
2336
2337 assert!(result.is_err(), "Send should fail when client is closed");
2339 assert!(
2340 matches!(result, Err(crate::error::SendError::Closed)),
2341 "Send should return Closed error, was: {result:?}"
2342 );
2343
2344 assert!(
2346 elapsed < Duration::from_millis(100),
2347 "Send should fail fast without rate limiting, took {elapsed:?}"
2348 );
2349
2350 server.abort();
2351 }
2352
2353 #[rstest]
2354 #[tokio::test]
2355 async fn test_connect_rejects_none_message_handler() {
2356 let config = WebSocketConfig {
2360 url: "ws://127.0.0.1:9999".to_string(),
2361 headers: vec![],
2362 heartbeat: None,
2363 heartbeat_msg: None,
2364 reconnect_timeout_ms: Some(1_000),
2365 reconnect_delay_initial_ms: Some(100),
2366 reconnect_delay_max_ms: Some(500),
2367 reconnect_backoff_factor: Some(1.5),
2368 reconnect_jitter_ms: Some(0),
2369 reconnect_max_attempts: None,
2370 };
2371
2372 let result = WebSocketClient::connect(config, None, None, None, vec![], None).await;
2374
2375 assert!(
2376 result.is_err(),
2377 "connect() should reject None message_handler"
2378 );
2379
2380 let err = result.unwrap_err();
2381 let err_msg = err.to_string();
2382 assert!(
2383 err_msg.contains("Handler mode requires message_handler"),
2384 "Error should mention missing message_handler, was: {err_msg}"
2385 );
2386 }
2387
2388 #[rstest]
2389 #[tokio::test]
2390 async fn test_client_without_handler_sets_stream_mode() {
2391 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2395 let port = listener.local_addr().unwrap().port();
2396
2397 let server = task::spawn(async move {
2398 if let Ok((stream, _)) = listener.accept().await
2400 && let Ok(ws) = accept_async(stream).await
2401 {
2402 drop(ws); }
2404 });
2405
2406 let config = WebSocketConfig {
2407 url: format!("ws://127.0.0.1:{port}"),
2408 headers: vec![],
2409 heartbeat: None,
2410 heartbeat_msg: None,
2411 reconnect_timeout_ms: Some(1_000),
2412 reconnect_delay_initial_ms: Some(100),
2413 reconnect_delay_max_ms: Some(500),
2414 reconnect_backoff_factor: Some(1.5),
2415 reconnect_jitter_ms: Some(0),
2416 reconnect_max_attempts: None,
2417 };
2418
2419 let inner = WebSocketClientInner::connect_url(config, None, None)
2421 .await
2422 .unwrap();
2423
2424 assert!(
2426 inner.is_stream_mode,
2427 "Client without handler should have is_stream_mode=true"
2428 );
2429
2430 server.abort();
2434 }
2435}