1use std::{
26 collections::VecDeque,
27 fmt::Debug,
28 sync::{
29 Arc,
30 atomic::{AtomicU8, Ordering},
31 },
32 time::Duration,
33};
34
35use futures_util::{SinkExt, StreamExt};
36use http::HeaderName;
37use nautilus_core::CleanDrop;
38use nautilus_cryptography::providers::install_cryptographic_provider;
39#[cfg(feature = "turmoil")]
40use tokio_tungstenite::MaybeTlsStream;
41#[cfg(feature = "turmoil")]
42use tokio_tungstenite::client_async;
43#[cfg(not(feature = "turmoil"))]
44use tokio_tungstenite::connect_async_with_config;
45use tokio_tungstenite::tungstenite::{
46 Error, Message, client::IntoClientRequest, http::HeaderValue,
47};
48
49use super::{
50 config::WebSocketConfig,
51 consts::{
52 CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
53 GRACEFUL_SHUTDOWN_TIMEOUT_SECS, SEND_OPERATION_CHECK_INTERVAL_MS,
54 },
55 types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
56};
57#[cfg(feature = "turmoil")]
58use crate::net::TcpConnector;
59use crate::{
60 RECONNECTED,
61 backoff::ExponentialBackoff,
62 error::SendError,
63 logging::{log_task_aborted, log_task_started, log_task_stopped},
64 mode::ConnectionMode,
65 ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
66};
67
68pub struct WebSocketClientInner {
84 config: WebSocketConfig,
85 message_handler: Option<MessageHandler>,
87 ping_handler: Option<PingHandler>,
89 read_task: Option<tokio::task::JoinHandle<()>>,
90 write_task: tokio::task::JoinHandle<()>,
91 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
92 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
93 connection_mode: Arc<AtomicU8>,
94 reconnect_timeout: Duration,
95 backoff: ExponentialBackoff,
96 is_stream_mode: bool,
100 reconnect_max_attempts: Option<u32>,
102 reconnection_attempt_count: u32,
104}
105
106impl WebSocketClientInner {
107 pub async fn new_with_writer(
115 config: WebSocketConfig,
116 writer: MessageWriter,
117 ) -> Result<Self, Error> {
118 install_cryptographic_provider();
119
120 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
121
122 let read_task = None;
124
125 let backoff = ExponentialBackoff::new(
126 Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
127 Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
128 config.reconnect_backoff_factor.unwrap_or(1.5),
129 config.reconnect_jitter_ms.unwrap_or(100),
130 true, )
132 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
133
134 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
135 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
136
137 let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
138 Some(Self::spawn_heartbeat_task(
139 connection_mode.clone(),
140 heartbeat_interval,
141 config.heartbeat_msg.clone(),
142 writer_tx.clone(),
143 ))
144 } else {
145 None
146 };
147
148 let reconnect_max_attempts = config.reconnect_max_attempts;
149 let reconnect_timeout = Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10000));
150
151 Ok(Self {
152 config,
153 message_handler: None, ping_handler: None,
155 writer_tx,
156 connection_mode,
157 reconnect_timeout,
158 heartbeat_task,
159 read_task,
160 write_task,
161 backoff,
162 is_stream_mode: true,
163 reconnect_max_attempts,
164 reconnection_attempt_count: 0,
165 })
166 }
167
168 pub async fn connect_url(
176 config: WebSocketConfig,
177 message_handler: Option<MessageHandler>,
178 ping_handler: Option<PingHandler>,
179 ) -> Result<Self, Error> {
180 install_cryptographic_provider();
181
182 let is_stream_mode = message_handler.is_none();
184 let reconnect_max_attempts = config.reconnect_max_attempts;
185
186 let (writer, reader) =
187 Self::connect_with_server(&config.url, config.headers.clone()).await?;
188
189 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
190
191 let read_task = if message_handler.is_some() {
192 Some(Self::spawn_message_handler_task(
193 connection_mode.clone(),
194 reader,
195 message_handler.as_ref(),
196 ping_handler.as_ref(),
197 ))
198 } else {
199 None
200 };
201
202 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
203 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
204
205 let heartbeat_task = config.heartbeat.map(|heartbeat_secs| {
207 Self::spawn_heartbeat_task(
208 connection_mode.clone(),
209 heartbeat_secs,
210 config.heartbeat_msg.clone(),
211 writer_tx.clone(),
212 )
213 });
214
215 let reconnect_timeout =
216 Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10_000));
217 let backoff = ExponentialBackoff::new(
218 Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
219 Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
220 config.reconnect_backoff_factor.unwrap_or(1.5),
221 config.reconnect_jitter_ms.unwrap_or(100),
222 true, )
224 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
225
226 Ok(Self {
227 config,
228 message_handler,
229 ping_handler,
230 read_task,
231 write_task,
232 writer_tx,
233 heartbeat_task,
234 connection_mode,
235 reconnect_timeout,
236 backoff,
237 is_stream_mode,
239 reconnect_max_attempts,
240 reconnection_attempt_count: 0,
241 })
242 }
243
244 #[inline]
254 #[cfg(not(feature = "turmoil"))]
255 pub async fn connect_with_server(
256 url: &str,
257 headers: Vec<(String, String)>,
258 ) -> Result<(MessageWriter, MessageReader), Error> {
259 let mut request = url.into_client_request()?;
260 let req_headers = request.headers_mut();
261
262 let mut header_names: Vec<HeaderName> = Vec::new();
263 for (key, val) in headers {
264 let header_value = HeaderValue::from_str(&val)?;
265 let header_name: HeaderName = key.parse()?;
266 header_names.push(header_name.clone());
267 req_headers.insert(header_name, header_value);
268 }
269
270 connect_async_with_config(request, None, true)
271 .await
272 .map(|resp| resp.0.split())
273 }
274
275 #[inline]
288 #[cfg(feature = "turmoil")]
289 pub async fn connect_with_server(
290 url: &str,
291 headers: Vec<(String, String)>,
292 ) -> Result<(MessageWriter, MessageReader), Error> {
293 use rustls::ClientConfig;
294 use tokio_rustls::TlsConnector;
295
296 let mut request = url.into_client_request()?;
297 let req_headers = request.headers_mut();
298
299 let mut header_names: Vec<HeaderName> = Vec::new();
300 for (key, val) in headers {
301 let header_value = HeaderValue::from_str(&val)?;
302 let header_name: HeaderName = key.parse()?;
303 header_names.push(header_name.clone());
304 req_headers.insert(header_name, header_value);
305 }
306
307 let uri = request.uri();
308 let scheme = uri.scheme_str().unwrap_or("ws");
309 let host = uri.host().ok_or_else(|| {
310 Error::Url(tokio_tungstenite::tungstenite::error::UrlError::NoHostName)
311 })?;
312
313 let port = uri
315 .port_u16()
316 .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
317
318 let addr = format!("{host}:{port}");
319
320 let connector = crate::net::RealTcpConnector;
322 let tcp_stream = connector.connect(&addr).await?;
323 if let Err(e) = tcp_stream.set_nodelay(true) {
324 tracing::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
325 }
326
327 let maybe_tls_stream = if scheme == "wss" {
329 let mut root_store = rustls::RootCertStore::empty();
331 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
332
333 let config = ClientConfig::builder()
334 .with_root_certificates(root_store)
335 .with_no_client_auth();
336
337 let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
338 let domain =
339 rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|e| {
340 Error::Io(std::io::Error::new(
341 std::io::ErrorKind::InvalidInput,
342 format!("Invalid DNS name: {e}"),
343 ))
344 })?;
345
346 let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
347 MaybeTlsStream::Rustls(tls_stream)
348 } else {
349 MaybeTlsStream::Plain(tcp_stream)
350 };
351
352 client_async(request, maybe_tls_stream)
354 .await
355 .map(|resp| resp.0.split())
356 }
357
358 pub async fn reconnect(&mut self) -> Result<(), Error> {
373 tracing::debug!("Reconnecting");
374
375 if self.is_stream_mode {
376 tracing::warn!(
377 "Auto-reconnect disabled for stream-based WebSocket client; \
378 stream users must manually reconnect by creating a new connection"
379 );
380 self.connection_mode
382 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
383 return Ok(());
384 }
385
386 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
387 tracing::debug!("Reconnect aborted due to disconnect state");
388 return Ok(());
389 }
390
391 tokio::time::timeout(self.reconnect_timeout, async {
392 let (new_writer, reader) =
394 Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
395
396 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
397 tracing::debug!("Reconnect aborted mid-flight (after connect)");
398 return Ok(());
399 }
400
401 let (tx, rx) = tokio::sync::oneshot::channel();
405 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
406 tracing::error!("{e}");
407 return Err(Error::Io(std::io::Error::new(
408 std::io::ErrorKind::BrokenPipe,
409 format!("Failed to send update command: {e}"),
410 )));
411 }
412
413 match rx.await {
415 Ok(true) => tracing::debug!("Writer confirmed buffer drain success"),
416 Ok(false) => {
417 tracing::warn!("Writer failed to drain buffer, aborting reconnect");
418 return Err(Error::Io(std::io::Error::other(
420 "Failed to drain reconnection buffer",
421 )));
422 }
423 Err(e) => {
424 tracing::error!("Writer dropped update channel: {e}");
425 return Err(Error::Io(std::io::Error::new(
426 std::io::ErrorKind::BrokenPipe,
427 "Writer task dropped response channel",
428 )));
429 }
430 }
431
432 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
434
435 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
436 tracing::debug!("Reconnect aborted mid-flight (after delay)");
437 return Ok(());
438 }
439
440 if let Some(ref read_task) = self.read_task.take()
441 && !read_task.is_finished()
442 {
443 read_task.abort();
444 log_task_aborted("read");
445 }
446
447 if self
450 .connection_mode
451 .compare_exchange(
452 ConnectionMode::Reconnect.as_u8(),
453 ConnectionMode::Active.as_u8(),
454 Ordering::SeqCst,
455 Ordering::SeqCst,
456 )
457 .is_err()
458 {
459 tracing::debug!("Reconnect aborted (state changed during reconnect)");
460 return Ok(());
461 }
462
463 self.read_task = if self.message_handler.is_some() {
464 Some(Self::spawn_message_handler_task(
465 self.connection_mode.clone(),
466 reader,
467 self.message_handler.as_ref(),
468 self.ping_handler.as_ref(),
469 ))
470 } else {
471 None
472 };
473
474 tracing::debug!("Reconnect succeeded");
475 Ok(())
476 })
477 .await
478 .map_err(|_| {
479 Error::Io(std::io::Error::new(
480 std::io::ErrorKind::TimedOut,
481 format!(
482 "reconnection timed out after {}s",
483 self.reconnect_timeout.as_secs_f64()
484 ),
485 ))
486 })?
487 }
488
489 #[inline]
497 #[must_use]
498 pub fn is_alive(&self) -> bool {
499 match &self.read_task {
500 Some(read_task) => !read_task.is_finished(),
501 None => true, }
503 }
504
505 fn spawn_message_handler_task(
506 connection_state: Arc<AtomicU8>,
507 mut reader: MessageReader,
508 message_handler: Option<&MessageHandler>,
509 ping_handler: Option<&PingHandler>,
510 ) -> tokio::task::JoinHandle<()> {
511 tracing::debug!("Started message handler task 'read'");
512
513 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
514
515 let message_handler = message_handler.cloned();
517 let ping_handler = ping_handler.cloned();
518
519 tokio::task::spawn(async move {
520 loop {
521 if !ConnectionMode::from_atomic(&connection_state).is_active() {
522 break;
523 }
524
525 match tokio::time::timeout(check_interval, reader.next()).await {
526 Ok(Some(Ok(Message::Binary(data)))) => {
527 tracing::trace!("Received message <binary> {} bytes", data.len());
528 if let Some(ref handler) = message_handler {
529 handler(Message::Binary(data));
530 }
531 }
532 Ok(Some(Ok(Message::Text(data)))) => {
533 tracing::trace!("Received message: {data}");
534 if let Some(ref handler) = message_handler {
535 handler(Message::Text(data));
536 }
537 }
538 Ok(Some(Ok(Message::Ping(ping_data)))) => {
539 tracing::trace!("Received ping: {ping_data:?}");
540 if let Some(ref handler) = ping_handler {
541 handler(ping_data.to_vec());
542 }
543 }
544 Ok(Some(Ok(Message::Pong(_)))) => {
545 tracing::trace!("Received pong");
546 }
547 Ok(Some(Ok(Message::Close(_)))) => {
548 tracing::debug!("Received close message - terminating");
549 break;
550 }
551 Ok(Some(Ok(_))) => (),
552 Ok(Some(Err(e))) => {
553 tracing::error!("Received error message - terminating: {e}");
554 break;
555 }
556 Ok(None) => {
557 tracing::debug!("No message received - terminating");
558 break;
559 }
560 Err(_) => {
561 continue;
563 }
564 }
565 }
566 })
567 }
568
569 async fn drain_reconnect_buffer(
574 buffer: &mut VecDeque<Message>,
575 writer: &mut MessageWriter,
576 ) -> bool {
577 if buffer.is_empty() {
578 return false;
579 }
580
581 let initial_buffer_len = buffer.len();
582 tracing::info!(
583 "Sending {} buffered messages after reconnection",
584 initial_buffer_len
585 );
586
587 let mut send_error_occurred = false;
588
589 while let Some(buffered_msg) = buffer.front() {
590 let msg_to_send = buffered_msg.clone();
592
593 if let Err(e) = writer.send(msg_to_send).await {
594 tracing::error!(
595 "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
596 buffer.len()
597 );
598 send_error_occurred = true;
599 break; }
601
602 buffer.pop_front();
604 }
605
606 if buffer.is_empty() {
607 tracing::info!(
608 "Successfully sent all {} buffered messages",
609 initial_buffer_len
610 );
611 }
612
613 send_error_occurred
614 }
615
616 fn spawn_write_task(
617 connection_state: Arc<AtomicU8>,
618 writer: MessageWriter,
619 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
620 ) -> tokio::task::JoinHandle<()> {
621 log_task_started("write");
622
623 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
625
626 tokio::task::spawn(async move {
627 let mut active_writer = writer;
628 let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
631
632 loop {
633 match ConnectionMode::from_atomic(&connection_state) {
634 ConnectionMode::Disconnect => {
635 if !reconnect_buffer.is_empty() {
637 tracing::warn!(
638 "Discarding {} buffered messages due to disconnect",
639 reconnect_buffer.len()
640 );
641 reconnect_buffer.clear();
642 }
643
644 _ = tokio::time::timeout(
647 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
648 active_writer.close(),
649 )
650 .await;
651 break;
652 }
653 ConnectionMode::Closed => {
654 if !reconnect_buffer.is_empty() {
656 tracing::warn!(
657 "Discarding {} buffered messages due to closed connection",
658 reconnect_buffer.len()
659 );
660 reconnect_buffer.clear();
661 }
662 break;
663 }
664 _ => {}
665 }
666
667 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
668 Ok(Some(msg)) => {
669 let mode = ConnectionMode::from_atomic(&connection_state);
671 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
672 break;
673 }
674
675 match msg {
676 WriterCommand::Update(new_writer, tx) => {
677 tracing::debug!("Received new writer");
678
679 tokio::time::sleep(Duration::from_millis(100)).await;
681
682 _ = tokio::time::timeout(
685 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
686 active_writer.close(),
687 )
688 .await;
689
690 active_writer = new_writer;
691 tracing::debug!("Updated writer");
692
693 let send_error = Self::drain_reconnect_buffer(
694 &mut reconnect_buffer,
695 &mut active_writer,
696 )
697 .await;
698
699 if let Err(e) = tx.send(!send_error) {
700 tracing::error!(
701 "Failed to report drain status to controller: {e:?}"
702 );
703 }
704 }
705 WriterCommand::Send(msg) if mode.is_reconnect() => {
706 tracing::debug!(
708 "Buffering message during reconnection (buffer size: {})",
709 reconnect_buffer.len() + 1
710 );
711 reconnect_buffer.push_back(msg);
712 }
713 WriterCommand::Send(msg) => {
714 if let Err(e) = active_writer.send(msg.clone()).await {
715 tracing::error!("Failed to send message: {e}");
716 tracing::warn!("Writer triggering reconnect");
717 reconnect_buffer.push_back(msg);
718 connection_state
719 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
720 }
721 }
722 }
723 }
724 Ok(None) => {
725 tracing::debug!("Writer channel closed, terminating writer task");
727 break;
728 }
729 Err(_) => {
730 continue;
732 }
733 }
734 }
735
736 _ = tokio::time::timeout(
739 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
740 active_writer.close(),
741 )
742 .await;
743
744 log_task_stopped("write");
745 })
746 }
747
748 fn spawn_heartbeat_task(
749 connection_state: Arc<AtomicU8>,
750 heartbeat_secs: u64,
751 message: Option<String>,
752 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
753 ) -> tokio::task::JoinHandle<()> {
754 log_task_started("heartbeat");
755
756 tokio::task::spawn(async move {
757 let interval = Duration::from_secs(heartbeat_secs);
758
759 loop {
760 tokio::time::sleep(interval).await;
761
762 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
763 ConnectionMode::Active => {
764 let msg = match &message {
765 Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
766 None => WriterCommand::Send(Message::Ping(vec![].into())),
767 };
768
769 match writer_tx.send(msg) {
770 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
771 Err(e) => {
772 tracing::error!("Failed to send heartbeat to writer task: {e}");
773 }
774 }
775 }
776 ConnectionMode::Reconnect => continue,
777 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
778 }
779 }
780
781 log_task_stopped("heartbeat");
782 })
783 }
784}
785
786impl Drop for WebSocketClientInner {
787 fn drop(&mut self) {
788 self.clean_drop();
790 }
791}
792
793impl CleanDrop for WebSocketClientInner {
795 fn clean_drop(&mut self) {
796 if let Some(ref read_task) = self.read_task.take()
797 && !read_task.is_finished()
798 {
799 read_task.abort();
800 log_task_aborted("read");
801 }
802
803 if !self.write_task.is_finished() {
804 self.write_task.abort();
805 log_task_aborted("write");
806 }
807
808 if let Some(ref handle) = self.heartbeat_task.take()
809 && !handle.is_finished()
810 {
811 handle.abort();
812 log_task_aborted("heartbeat");
813 }
814
815 self.message_handler = None;
817 self.ping_handler = None;
818 }
819}
820
821impl Debug for WebSocketClientInner {
822 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
823 f.debug_struct("WebSocketClientInner")
824 .field("config", &self.config)
825 .field(
826 "connection_mode",
827 &ConnectionMode::from_atomic(&self.connection_mode),
828 )
829 .field("reconnect_timeout", &self.reconnect_timeout)
830 .field("is_stream_mode", &self.is_stream_mode)
831 .finish()
832 }
833}
834
835#[cfg_attr(
840 feature = "python",
841 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
842)]
843pub struct WebSocketClient {
844 pub(crate) controller_task: tokio::task::JoinHandle<()>,
845 pub(crate) connection_mode: Arc<AtomicU8>,
846 pub(crate) reconnect_timeout: Duration,
847 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
848 pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
849}
850
851impl Debug for WebSocketClient {
852 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
853 f.debug_struct(stringify!(WebSocketClient)).finish()
854 }
855}
856
857impl WebSocketClient {
858 #[allow(clippy::too_many_arguments)]
874 pub async fn connect_stream(
875 config: WebSocketConfig,
876 keyed_quotas: Vec<(String, Quota)>,
877 default_quota: Option<Quota>,
878 post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
879 ) -> Result<(MessageReader, Self), Error> {
880 install_cryptographic_provider();
881
882 let (writer, reader) =
884 WebSocketClientInner::connect_with_server(&config.url, config.headers.clone()).await?;
885
886 let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
888
889 let connection_mode = inner.connection_mode.clone();
890 let reconnect_timeout = inner.reconnect_timeout;
891 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
892 let writer_tx = inner.writer_tx.clone();
893
894 let controller_task =
895 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnect);
896
897 Ok((
898 reader,
899 Self {
900 controller_task,
901 connection_mode,
902 reconnect_timeout,
903 rate_limiter,
904 writer_tx,
905 },
906 ))
907 }
908
909 pub async fn connect(
927 config: WebSocketConfig,
928 message_handler: Option<MessageHandler>,
929 ping_handler: Option<PingHandler>,
930 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
931 keyed_quotas: Vec<(String, Quota)>,
932 default_quota: Option<Quota>,
933 ) -> Result<Self, Error> {
934 if message_handler.is_none() {
936 return Err(Error::Io(std::io::Error::new(
937 std::io::ErrorKind::InvalidInput,
938 "Handler mode requires message_handler to be set. Use connect_stream() for stream mode without a handler.",
939 )));
940 }
941
942 tracing::debug!("Connecting");
943 let inner =
944 WebSocketClientInner::connect_url(config, message_handler, ping_handler).await?;
945 let connection_mode = inner.connection_mode.clone();
946 let writer_tx = inner.writer_tx.clone();
947 let reconnect_timeout = inner.reconnect_timeout;
948
949 let controller_task =
950 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnection);
951
952 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
953
954 Ok(Self {
955 controller_task,
956 connection_mode,
957 reconnect_timeout,
958 rate_limiter,
959 writer_tx,
960 })
961 }
962
963 #[must_use]
965 pub fn connection_mode(&self) -> ConnectionMode {
966 ConnectionMode::from_atomic(&self.connection_mode)
967 }
968
969 #[must_use]
974 pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
975 Arc::clone(&self.connection_mode)
976 }
977
978 #[inline]
983 #[must_use]
984 pub fn is_active(&self) -> bool {
985 self.connection_mode().is_active()
986 }
987
988 #[must_use]
990 pub fn is_disconnected(&self) -> bool {
991 self.controller_task.is_finished()
992 }
993
994 #[inline]
999 #[must_use]
1000 pub fn is_reconnecting(&self) -> bool {
1001 self.connection_mode().is_reconnect()
1002 }
1003
1004 #[inline]
1008 #[must_use]
1009 pub fn is_disconnecting(&self) -> bool {
1010 self.connection_mode().is_disconnect()
1011 }
1012
1013 #[inline]
1019 #[must_use]
1020 pub fn is_closed(&self) -> bool {
1021 self.connection_mode().is_closed()
1022 }
1023
1024 async fn wait_for_active(&self) -> Result<(), SendError> {
1028 if self.is_closed() {
1029 return Err(SendError::Closed);
1030 }
1031
1032 let timeout = self.reconnect_timeout;
1033 let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
1034
1035 if !self.is_active() {
1036 tracing::debug!("Waiting for client to become ACTIVE before sending...");
1037
1038 let inner = tokio::time::timeout(timeout, async {
1039 loop {
1040 if self.is_active() {
1041 return Ok(());
1042 }
1043 if matches!(
1044 self.connection_mode(),
1045 ConnectionMode::Disconnect | ConnectionMode::Closed
1046 ) {
1047 return Err(());
1048 }
1049 tokio::time::sleep(check_interval).await;
1050 }
1051 })
1052 .await
1053 .map_err(|_| SendError::Timeout)?;
1054 inner.map_err(|()| SendError::Closed)?;
1055 }
1056
1057 Ok(())
1058 }
1059
1060 pub async fn disconnect(&self) {
1065 tracing::debug!("Disconnecting");
1066 self.connection_mode
1067 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1068
1069 if let Ok(()) =
1070 tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1071 while !self.is_disconnected() {
1072 tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS))
1073 .await;
1074 }
1075
1076 if !self.controller_task.is_finished() {
1077 self.controller_task.abort();
1078 log_task_aborted("controller");
1079 }
1080 })
1081 .await
1082 {
1083 tracing::debug!("Controller task finished");
1084 } else {
1085 tracing::error!("Timeout waiting for controller task to finish");
1086 if !self.controller_task.is_finished() {
1087 self.controller_task.abort();
1088 log_task_aborted("controller");
1089 }
1090 }
1091 }
1092
1093 #[allow(unused_variables)]
1099 pub async fn send_text(
1100 &self,
1101 data: String,
1102 keys: Option<Vec<String>>,
1103 ) -> Result<(), SendError> {
1104 if self.is_closed() || self.is_disconnecting() {
1106 return Err(SendError::Closed);
1107 }
1108
1109 self.rate_limiter.await_keys_ready(keys).await;
1110 self.wait_for_active().await?;
1111
1112 tracing::trace!("Sending text: {data:?}");
1113
1114 let msg = Message::Text(data.into());
1115 self.writer_tx
1116 .send(WriterCommand::Send(msg))
1117 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1118 }
1119
1120 pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1126 self.wait_for_active().await?;
1127
1128 tracing::trace!("Sending pong frame ({} bytes)", data.len());
1129
1130 let msg = Message::Pong(data.into());
1131 self.writer_tx
1132 .send(WriterCommand::Send(msg))
1133 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1134 }
1135
1136 #[allow(unused_variables)]
1142 pub async fn send_bytes(
1143 &self,
1144 data: Vec<u8>,
1145 keys: Option<Vec<String>>,
1146 ) -> Result<(), SendError> {
1147 if self.is_closed() || self.is_disconnecting() {
1149 return Err(SendError::Closed);
1150 }
1151
1152 self.rate_limiter.await_keys_ready(keys).await;
1153 self.wait_for_active().await?;
1154
1155 tracing::trace!("Sending bytes: {data:?}");
1156
1157 let msg = Message::Binary(data.into());
1158 self.writer_tx
1159 .send(WriterCommand::Send(msg))
1160 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1161 }
1162
1163 pub async fn send_close_message(&self) -> Result<(), SendError> {
1169 self.wait_for_active().await?;
1170
1171 let msg = Message::Close(None);
1172 self.writer_tx
1173 .send(WriterCommand::Send(msg))
1174 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1175 }
1176
1177 fn spawn_controller_task(
1178 mut inner: WebSocketClientInner,
1179 connection_mode: Arc<AtomicU8>,
1180 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1181 ) -> tokio::task::JoinHandle<()> {
1182 tokio::task::spawn(async move {
1183 log_task_started("controller");
1184
1185 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1186
1187 loop {
1188 tokio::time::sleep(check_interval).await;
1189 let mut mode = ConnectionMode::from_atomic(&connection_mode);
1190
1191 if mode.is_disconnect() {
1192 tracing::debug!("Disconnecting");
1193
1194 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1195 if tokio::time::timeout(timeout, async {
1196 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1198
1199 if let Some(task) = &inner.read_task
1200 && !task.is_finished()
1201 {
1202 task.abort();
1203 log_task_aborted("read");
1204 }
1205
1206 if let Some(task) = &inner.heartbeat_task
1207 && !task.is_finished()
1208 {
1209 task.abort();
1210 log_task_aborted("heartbeat");
1211 }
1212 })
1213 .await
1214 .is_err()
1215 {
1216 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
1217 }
1218
1219 tracing::debug!("Closed");
1220 break; }
1222
1223 if mode.is_closed() {
1224 tracing::debug!("Connection closed");
1225 break;
1226 }
1227
1228 if mode.is_active() && !inner.is_alive() {
1229 if connection_mode
1230 .compare_exchange(
1231 ConnectionMode::Active.as_u8(),
1232 ConnectionMode::Reconnect.as_u8(),
1233 Ordering::SeqCst,
1234 Ordering::SeqCst,
1235 )
1236 .is_ok()
1237 {
1238 tracing::debug!("Detected dead read task, transitioning to RECONNECT");
1239 }
1240 mode = ConnectionMode::from_atomic(&connection_mode);
1241 }
1242
1243 if mode.is_reconnect() {
1244 if let Some(max_attempts) = inner.reconnect_max_attempts
1246 && inner.reconnection_attempt_count >= max_attempts
1247 {
1248 tracing::error!(
1249 "Max reconnection attempts ({}) exceeded, transitioning to CLOSED",
1250 max_attempts
1251 );
1252 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1253 break;
1254 }
1255
1256 inner.reconnection_attempt_count += 1;
1257 tracing::debug!(
1258 "Reconnection attempt {} of {}",
1259 inner.reconnection_attempt_count,
1260 inner
1261 .reconnect_max_attempts
1262 .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1263 );
1264
1265 match inner.reconnect().await {
1266 Ok(()) => {
1267 inner.backoff.reset();
1268 inner.reconnection_attempt_count = 0; if ConnectionMode::from_atomic(&connection_mode).is_active() {
1272 if let Some(ref handler) = inner.message_handler {
1273 let reconnected_msg =
1274 Message::Text(RECONNECTED.to_string().into());
1275 handler(reconnected_msg);
1276 tracing::debug!("Sent reconnected message to handler");
1277 }
1278
1279 if let Some(ref callback) = post_reconnection {
1281 callback();
1282 tracing::debug!("Called `post_reconnection` handler");
1283 }
1284
1285 tracing::debug!("Reconnected successfully");
1286 } else {
1287 tracing::debug!(
1288 "Skipping post_reconnection handlers due to disconnect state"
1289 );
1290 }
1291 }
1292 Err(e) => {
1293 let duration = inner.backoff.next_duration();
1294 tracing::warn!(
1295 "Reconnect attempt {} failed: {e}",
1296 inner.reconnection_attempt_count
1297 );
1298 if !duration.is_zero() {
1299 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
1300 }
1301 tokio::time::sleep(duration).await;
1302 }
1303 }
1304 }
1305 }
1306 inner
1307 .connection_mode
1308 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1309
1310 log_task_stopped("controller");
1311 })
1312 }
1313}
1314
1315impl Drop for WebSocketClient {
1317 fn drop(&mut self) {
1318 if !self.controller_task.is_finished() {
1319 self.controller_task.abort();
1320 log_task_aborted("controller");
1321 }
1322 }
1323}
1324
1325#[cfg(test)]
1326#[cfg(not(feature = "turmoil"))]
1327#[cfg(target_os = "linux")] mod tests {
1329 use std::{num::NonZeroU32, sync::Arc};
1330
1331 use futures_util::{SinkExt, StreamExt};
1332 use tokio::{
1333 net::TcpListener,
1334 task::{self, JoinHandle},
1335 };
1336 use tokio_tungstenite::{
1337 accept_hdr_async,
1338 tungstenite::{
1339 handshake::server::{self, Callback},
1340 http::HeaderValue,
1341 },
1342 };
1343
1344 use crate::{
1345 ratelimiter::quota::Quota,
1346 websocket::{WebSocketClient, WebSocketConfig},
1347 };
1348
1349 struct TestServer {
1350 task: JoinHandle<()>,
1351 port: u16,
1352 }
1353
1354 #[derive(Debug, Clone)]
1355 struct TestCallback {
1356 key: String,
1357 value: HeaderValue,
1358 }
1359
1360 impl Callback for TestCallback {
1361 #[allow(clippy::panic_in_result_fn)]
1362 fn on_request(
1363 self,
1364 request: &server::Request,
1365 response: server::Response,
1366 ) -> Result<server::Response, server::ErrorResponse> {
1367 let _ = response;
1368 let value = request.headers().get(&self.key);
1369 assert!(value.is_some());
1370
1371 if let Some(value) = request.headers().get(&self.key) {
1372 assert_eq!(value, self.value);
1373 }
1374
1375 Ok(response)
1376 }
1377 }
1378
1379 impl TestServer {
1380 async fn setup() -> Self {
1381 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1382 let port = TcpListener::local_addr(&server).unwrap().port();
1383
1384 let header_key = "test".to_string();
1385 let header_value = "test".to_string();
1386
1387 let test_call_back = TestCallback {
1388 key: header_key,
1389 value: HeaderValue::from_str(&header_value).unwrap(),
1390 };
1391
1392 let task = task::spawn(async move {
1393 loop {
1395 let (conn, _) = server.accept().await.unwrap();
1396 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1397 .await
1398 .unwrap();
1399
1400 task::spawn(async move {
1401 while let Some(Ok(msg)) = websocket.next().await {
1402 match msg {
1403 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1404 if txt == "close-now" =>
1405 {
1406 tracing::debug!("Forcibly closing from server side");
1407 let _ = websocket.close(None).await;
1409 break;
1410 }
1411 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1413 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1414 if websocket.send(msg).await.is_err() {
1415 break;
1416 }
1417 }
1418 tokio_tungstenite::tungstenite::protocol::Message::Close(
1420 _frame,
1421 ) => {
1422 let _ = websocket.close(None).await;
1423 break;
1424 }
1425 _ => {}
1427 }
1428 }
1429 });
1430 }
1431 });
1432
1433 Self { task, port }
1434 }
1435 }
1436
1437 impl Drop for TestServer {
1438 fn drop(&mut self) {
1439 self.task.abort();
1440 }
1441 }
1442
1443 async fn setup_test_client(port: u16) -> WebSocketClient {
1444 let config = WebSocketConfig {
1445 url: format!("ws://127.0.0.1:{port}"),
1446 headers: vec![("test".into(), "test".into())],
1447 heartbeat: None,
1448 heartbeat_msg: None,
1449 reconnect_timeout_ms: None,
1450 reconnect_delay_initial_ms: None,
1451 reconnect_backoff_factor: None,
1452 reconnect_delay_max_ms: None,
1453 reconnect_jitter_ms: None,
1454 reconnect_max_attempts: None,
1455 };
1456 WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1457 .await
1458 .expect("Failed to connect")
1459 }
1460
1461 #[tokio::test]
1462 async fn test_websocket_basic() {
1463 let server = TestServer::setup().await;
1464 let client = setup_test_client(server.port).await;
1465
1466 assert!(!client.is_disconnected());
1467
1468 client.disconnect().await;
1469 assert!(client.is_disconnected());
1470 }
1471
1472 #[tokio::test]
1473 async fn test_websocket_heartbeat() {
1474 let server = TestServer::setup().await;
1475 let client = setup_test_client(server.port).await;
1476
1477 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1479
1480 client.disconnect().await;
1482 assert!(client.is_disconnected());
1483 }
1484
1485 #[tokio::test]
1486 async fn test_websocket_reconnect_exhausted() {
1487 let config = WebSocketConfig {
1488 url: "ws://127.0.0.1:9997".into(), headers: vec![],
1490 heartbeat: None,
1491 heartbeat_msg: None,
1492 reconnect_timeout_ms: None,
1493 reconnect_delay_initial_ms: None,
1494 reconnect_backoff_factor: None,
1495 reconnect_delay_max_ms: None,
1496 reconnect_jitter_ms: None,
1497 reconnect_max_attempts: None,
1498 };
1499 let res =
1500 WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1501 .await;
1502 assert!(res.is_err(), "Should fail quickly with no server");
1503 }
1504
1505 #[tokio::test]
1506 async fn test_websocket_forced_close_reconnect() {
1507 let server = TestServer::setup().await;
1508 let client = setup_test_client(server.port).await;
1509
1510 client.send_text("Hello".into(), None).await.unwrap();
1512
1513 client.send_text("close-now".into(), None).await.unwrap();
1515
1516 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1518
1519 assert!(!client.is_disconnected());
1521
1522 client.disconnect().await;
1524 assert!(client.is_disconnected());
1525 }
1526
1527 #[tokio::test]
1528 async fn test_rate_limiter() {
1529 let server = TestServer::setup().await;
1530 let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1531
1532 let config = WebSocketConfig {
1533 url: format!("ws://127.0.0.1:{}", server.port),
1534 headers: vec![("test".into(), "test".into())],
1535 heartbeat: None,
1536 heartbeat_msg: None,
1537 reconnect_timeout_ms: None,
1538 reconnect_delay_initial_ms: None,
1539 reconnect_backoff_factor: None,
1540 reconnect_delay_max_ms: None,
1541 reconnect_jitter_ms: None,
1542 reconnect_max_attempts: None,
1543 };
1544
1545 let client = WebSocketClient::connect(
1546 config,
1547 Some(Arc::new(|_| {})),
1548 None,
1549 None,
1550 vec![("default".into(), quota)],
1551 None,
1552 )
1553 .await
1554 .unwrap();
1555
1556 client.send_text("test1".into(), None).await.unwrap();
1558 client.send_text("test2".into(), None).await.unwrap();
1559
1560 client.send_text("test3".into(), None).await.unwrap();
1562
1563 client.disconnect().await;
1565 assert!(client.is_disconnected());
1566 }
1567
1568 #[tokio::test]
1569 async fn test_concurrent_writers() {
1570 let server = TestServer::setup().await;
1571 let client = Arc::new(setup_test_client(server.port).await);
1572
1573 let mut handles = vec![];
1574 for i in 0..10 {
1575 let client = client.clone();
1576 handles.push(task::spawn(async move {
1577 client.send_text(format!("test{i}"), None).await.unwrap();
1578 }));
1579 }
1580
1581 for handle in handles {
1582 handle.await.unwrap();
1583 }
1584
1585 client.disconnect().await;
1587 assert!(client.is_disconnected());
1588 }
1589}
1590
1591#[cfg(test)]
1592#[cfg(not(feature = "turmoil"))]
1593mod rust_tests {
1594 use futures_util::StreamExt;
1595 use rstest::rstest;
1596 use tokio::{
1597 net::TcpListener,
1598 task,
1599 time::{Duration, sleep},
1600 };
1601 use tokio_tungstenite::accept_async;
1602
1603 use super::*;
1604 use crate::websocket::types::channel_message_handler;
1605
1606 #[rstest]
1607 #[tokio::test]
1608 async fn test_reconnect_then_disconnect() {
1609 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1611 let port = listener.local_addr().unwrap().port();
1612
1613 let server = task::spawn(async move {
1615 let (stream, _) = listener.accept().await.unwrap();
1616 let ws = accept_async(stream).await.unwrap();
1617 drop(ws);
1618 sleep(Duration::from_secs(1)).await;
1620 });
1621
1622 let (handler, _rx) = channel_message_handler();
1624
1625 let config = WebSocketConfig {
1627 url: format!("ws://127.0.0.1:{port}"),
1628 headers: vec![],
1629 heartbeat: None,
1630 heartbeat_msg: None,
1631 reconnect_timeout_ms: Some(1_000),
1632 reconnect_delay_initial_ms: Some(50),
1633 reconnect_delay_max_ms: Some(100),
1634 reconnect_backoff_factor: Some(1.0),
1635 reconnect_jitter_ms: Some(0),
1636 reconnect_max_attempts: None,
1637 };
1638
1639 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1641 .await
1642 .unwrap();
1643
1644 sleep(Duration::from_millis(100)).await;
1646 client.disconnect().await;
1648 assert!(client.is_disconnected());
1649 server.abort();
1650 }
1651
1652 #[rstest]
1653 #[tokio::test]
1654 async fn test_reconnect_state_flips_when_reader_stops() {
1655 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1657 let port = listener.local_addr().unwrap().port();
1658
1659 let server = task::spawn(async move {
1660 if let Ok((stream, _)) = listener.accept().await
1661 && let Ok(ws) = accept_async(stream).await
1662 {
1663 drop(ws);
1664 }
1665 sleep(Duration::from_millis(50)).await;
1666 });
1667
1668 let (handler, _rx) = channel_message_handler();
1669
1670 let config = WebSocketConfig {
1671 url: format!("ws://127.0.0.1:{port}"),
1672 headers: vec![],
1673 heartbeat: None,
1674 heartbeat_msg: None,
1675 reconnect_timeout_ms: Some(1_000),
1676 reconnect_delay_initial_ms: Some(50),
1677 reconnect_delay_max_ms: Some(100),
1678 reconnect_backoff_factor: Some(1.0),
1679 reconnect_jitter_ms: Some(0),
1680 reconnect_max_attempts: None,
1681 };
1682
1683 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1684 .await
1685 .unwrap();
1686
1687 tokio::time::timeout(Duration::from_secs(2), async {
1688 loop {
1689 if client.is_reconnecting() {
1690 break;
1691 }
1692 tokio::time::sleep(Duration::from_millis(10)).await;
1693 }
1694 })
1695 .await
1696 .expect("client did not enter RECONNECT state");
1697
1698 client.disconnect().await;
1699 server.abort();
1700 }
1701
1702 #[rstest]
1703 #[tokio::test]
1704 async fn test_stream_mode_disables_auto_reconnect() {
1705 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1708 let port = listener.local_addr().unwrap().port();
1709
1710 let server = task::spawn(async move {
1711 if let Ok((stream, _)) = listener.accept().await
1712 && let Ok(_ws) = accept_async(stream).await
1713 {
1714 sleep(Duration::from_millis(100)).await;
1716 }
1717 });
1718
1719 let config = WebSocketConfig {
1720 url: format!("ws://127.0.0.1:{port}"),
1721 headers: vec![],
1722 heartbeat: None,
1723 heartbeat_msg: None,
1724 reconnect_timeout_ms: Some(1_000),
1725 reconnect_delay_initial_ms: Some(50),
1726 reconnect_delay_max_ms: Some(100),
1727 reconnect_backoff_factor: Some(1.0),
1728 reconnect_jitter_ms: Some(0),
1729 reconnect_max_attempts: None,
1730 };
1731
1732 let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
1733 .await
1734 .unwrap();
1735
1736 server.abort();
1744 }
1745
1746 #[rstest]
1747 #[tokio::test]
1748 async fn test_message_handler_mode_allows_auto_reconnect() {
1749 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1751 let port = listener.local_addr().unwrap().port();
1752
1753 let server = task::spawn(async move {
1754 if let Ok((stream, _)) = listener.accept().await
1756 && let Ok(ws) = accept_async(stream).await
1757 {
1758 drop(ws);
1759 }
1760 sleep(Duration::from_millis(50)).await;
1761 });
1762
1763 let (handler, _rx) = channel_message_handler();
1764
1765 let config = WebSocketConfig {
1766 url: format!("ws://127.0.0.1:{port}"),
1767 headers: vec![],
1768 heartbeat: None,
1769 heartbeat_msg: None,
1770 reconnect_timeout_ms: Some(1_000),
1771 reconnect_delay_initial_ms: Some(50),
1772 reconnect_delay_max_ms: Some(100),
1773 reconnect_backoff_factor: Some(1.0),
1774 reconnect_jitter_ms: Some(0),
1775 reconnect_max_attempts: None,
1776 };
1777
1778 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1779 .await
1780 .unwrap();
1781
1782 tokio::time::timeout(Duration::from_secs(2), async {
1784 loop {
1785 if client.is_reconnecting() || client.is_closed() {
1786 break;
1787 }
1788 tokio::time::sleep(Duration::from_millis(10)).await;
1789 }
1790 })
1791 .await
1792 .expect("client should attempt reconnection or close");
1793
1794 assert!(
1797 client.is_reconnecting() || client.is_closed(),
1798 "Client with message handler should attempt reconnection"
1799 );
1800
1801 client.disconnect().await;
1802 server.abort();
1803 }
1804
1805 #[rstest]
1806 #[tokio::test]
1807 async fn test_handler_mode_reconnect_with_new_connection() {
1808 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1810 let port = listener.local_addr().unwrap().port();
1811
1812 let server = task::spawn(async move {
1813 if let Ok((stream, _)) = listener.accept().await
1815 && let Ok(ws) = accept_async(stream).await
1816 {
1817 drop(ws);
1818 }
1819
1820 sleep(Duration::from_millis(100)).await;
1822
1823 if let Ok((stream, _)) = listener.accept().await
1825 && let Ok(mut ws) = accept_async(stream).await
1826 {
1827 use futures_util::SinkExt;
1828 let _ = ws
1829 .send(Message::Text("reconnected".to_string().into()))
1830 .await;
1831 sleep(Duration::from_secs(1)).await;
1832 }
1833 });
1834
1835 let (handler, mut rx) = channel_message_handler();
1836
1837 let config = WebSocketConfig {
1838 url: format!("ws://127.0.0.1:{port}"),
1839 headers: vec![],
1840 heartbeat: None,
1841 heartbeat_msg: None,
1842 reconnect_timeout_ms: Some(2_000),
1843 reconnect_delay_initial_ms: Some(50),
1844 reconnect_delay_max_ms: Some(200),
1845 reconnect_backoff_factor: Some(1.5),
1846 reconnect_jitter_ms: Some(10),
1847 reconnect_max_attempts: None,
1848 };
1849
1850 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1851 .await
1852 .unwrap();
1853
1854 let result = tokio::time::timeout(Duration::from_secs(5), async {
1856 loop {
1857 if let Ok(msg) = rx.try_recv()
1858 && matches!(msg, Message::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
1859 {
1860 return true;
1861 }
1862 tokio::time::sleep(Duration::from_millis(10)).await;
1863 }
1864 })
1865 .await;
1866
1867 assert!(
1868 result.is_ok(),
1869 "Should receive message after reconnection within timeout"
1870 );
1871
1872 client.disconnect().await;
1873 server.abort();
1874 }
1875
1876 #[rstest]
1877 #[tokio::test]
1878 async fn test_stream_mode_no_auto_reconnect() {
1879 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1882 let port = listener.local_addr().unwrap().port();
1883
1884 let server = task::spawn(async move {
1885 if let Ok((stream, _)) = listener.accept().await
1887 && let Ok(mut ws) = accept_async(stream).await
1888 {
1889 use futures_util::SinkExt;
1890 let _ = ws.send(Message::Text("hello".to_string().into())).await;
1891 sleep(Duration::from_millis(50)).await;
1892 }
1894 });
1895
1896 let config = WebSocketConfig {
1897 url: format!("ws://127.0.0.1:{port}"),
1898 headers: vec![],
1899 heartbeat: None,
1900 heartbeat_msg: None,
1901 reconnect_timeout_ms: Some(1_000),
1902 reconnect_delay_initial_ms: Some(50),
1903 reconnect_delay_max_ms: Some(100),
1904 reconnect_backoff_factor: Some(1.0),
1905 reconnect_jitter_ms: Some(0),
1906 reconnect_max_attempts: None,
1907 };
1908
1909 let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
1910 .await
1911 .unwrap();
1912
1913 assert!(client.is_active(), "Client should start as active");
1915
1916 let msg = reader.next().await;
1918 assert!(
1919 matches!(msg, Some(Ok(Message::Text(ref text))) if AsRef::<str>::as_ref(text) == "hello"),
1920 "Should receive initial message"
1921 );
1922
1923 while let Some(msg) = reader.next().await {
1925 if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
1926 break;
1927 }
1928 }
1929
1930 sleep(Duration::from_millis(200)).await;
1933
1934 assert!(
1937 client.is_active() || client.is_closed(),
1938 "Stream mode client stays ACTIVE (caller owns reader) or caller disconnected"
1939 );
1940 assert!(
1941 !client.is_reconnecting(),
1942 "Stream mode client should never attempt reconnection"
1943 );
1944
1945 client.disconnect().await;
1946 server.abort();
1947 }
1948
1949 #[rstest]
1950 #[tokio::test]
1951 async fn test_send_timeout_uses_configured_reconnect_timeout() {
1952 use nautilus_common::testing::wait_until_async;
1955
1956 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1957 let port = listener.local_addr().unwrap().port();
1958
1959 let server = task::spawn(async move {
1960 if let Ok((stream, _)) = listener.accept().await
1962 && let Ok(ws) = accept_async(stream).await
1963 {
1964 drop(ws);
1965 }
1966 sleep(Duration::from_secs(60)).await;
1968 });
1969
1970 let (handler, _rx) = channel_message_handler();
1971
1972 let config = WebSocketConfig {
1974 url: format!("ws://127.0.0.1:{port}"),
1975 headers: vec![],
1976 heartbeat: None,
1977 heartbeat_msg: None,
1978 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(50),
1980 reconnect_delay_max_ms: Some(100),
1981 reconnect_backoff_factor: Some(1.0),
1982 reconnect_jitter_ms: Some(0),
1983 reconnect_max_attempts: None,
1984 };
1985
1986 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1987 .await
1988 .unwrap();
1989
1990 wait_until_async(
1992 || async { client.is_reconnecting() },
1993 Duration::from_secs(3),
1994 )
1995 .await;
1996
1997 let start = std::time::Instant::now();
1999 let send_result = client.send_text("test".to_string(), None).await;
2000 let elapsed = start.elapsed();
2001
2002 assert!(
2003 send_result.is_err(),
2004 "Send should fail when client stuck in RECONNECT"
2005 );
2006 assert!(
2007 matches!(send_result, Err(crate::error::SendError::Timeout)),
2008 "Send should return Timeout error, was: {send_result:?}"
2009 );
2010 assert!(
2013 elapsed >= Duration::from_millis(1800),
2014 "Send should timeout after at least 2s (configured timeout), took {elapsed:?}"
2015 );
2016
2017 client.disconnect().await;
2018 server.abort();
2019 }
2020
2021 #[rstest]
2022 #[tokio::test]
2023 async fn test_send_waits_during_reconnection() {
2024 use nautilus_common::testing::wait_until_async;
2026
2027 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2028 let port = listener.local_addr().unwrap().port();
2029
2030 let server = task::spawn(async move {
2031 if let Ok((stream, _)) = listener.accept().await
2033 && let Ok(ws) = accept_async(stream).await
2034 {
2035 drop(ws);
2036 }
2037
2038 sleep(Duration::from_millis(500)).await;
2040
2041 if let Ok((stream, _)) = listener.accept().await
2043 && let Ok(mut ws) = accept_async(stream).await
2044 {
2045 while let Some(Ok(msg)) = ws.next().await {
2047 if ws.send(msg).await.is_err() {
2048 break;
2049 }
2050 }
2051 }
2052 });
2053
2054 let (handler, _rx) = channel_message_handler();
2055
2056 let config = WebSocketConfig {
2057 url: format!("ws://127.0.0.1:{port}"),
2058 headers: vec![],
2059 heartbeat: None,
2060 heartbeat_msg: None,
2061 reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
2063 reconnect_delay_max_ms: Some(200),
2064 reconnect_backoff_factor: Some(1.0),
2065 reconnect_jitter_ms: Some(0),
2066 reconnect_max_attempts: None,
2067 };
2068
2069 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2070 .await
2071 .unwrap();
2072
2073 wait_until_async(
2075 || async { client.is_reconnecting() },
2076 Duration::from_secs(2),
2077 )
2078 .await;
2079
2080 let send_result = tokio::time::timeout(
2082 Duration::from_secs(3),
2083 client.send_text("test_message".to_string(), None),
2084 )
2085 .await;
2086
2087 assert!(
2088 send_result.is_ok() && send_result.unwrap().is_ok(),
2089 "Send should succeed after waiting for reconnection"
2090 );
2091
2092 client.disconnect().await;
2093 server.abort();
2094 }
2095
2096 #[rstest]
2097 #[tokio::test]
2098 async fn test_rate_limiter_before_active_wait() {
2099 use std::{num::NonZeroU32, sync::Arc};
2104
2105 use nautilus_common::testing::wait_until_async;
2106
2107 use crate::ratelimiter::quota::Quota;
2108
2109 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2110 let port = listener.local_addr().unwrap().port();
2111
2112 let server = task::spawn(async move {
2113 if let Ok((stream, _)) = listener.accept().await
2115 && let Ok(mut ws) = accept_async(stream).await
2116 {
2117 if let Some(Ok(_)) = ws.next().await {
2119 drop(ws);
2120 }
2121 }
2122
2123 sleep(Duration::from_millis(500)).await;
2125
2126 if let Ok((stream, _)) = listener.accept().await
2128 && let Ok(mut ws) = accept_async(stream).await
2129 {
2130 while let Some(Ok(msg)) = ws.next().await {
2131 if ws.send(msg).await.is_err() {
2132 break;
2133 }
2134 }
2135 }
2136 });
2137
2138 let (handler, _rx) = channel_message_handler();
2139
2140 let config = WebSocketConfig {
2141 url: format!("ws://127.0.0.1:{port}"),
2142 headers: vec![],
2143 heartbeat: None,
2144 heartbeat_msg: None,
2145 reconnect_timeout_ms: Some(5_000),
2146 reconnect_delay_initial_ms: Some(50),
2147 reconnect_delay_max_ms: Some(100),
2148 reconnect_backoff_factor: Some(1.0),
2149 reconnect_jitter_ms: Some(0),
2150 reconnect_max_attempts: None,
2151 };
2152
2153 let quota =
2155 Quota::per_second(NonZeroU32::new(1).unwrap()).allow_burst(NonZeroU32::new(1).unwrap());
2156
2157 let client = Arc::new(
2158 WebSocketClient::connect(
2159 config,
2160 Some(handler),
2161 None,
2162 None,
2163 vec![("test_key".to_string(), quota)],
2164 None,
2165 )
2166 .await
2167 .unwrap(),
2168 );
2169
2170 client
2172 .send_text("msg1".to_string(), Some(vec!["test_key".to_string()]))
2173 .await
2174 .unwrap();
2175
2176 wait_until_async(
2178 || async { client.is_reconnecting() },
2179 Duration::from_secs(2),
2180 )
2181 .await;
2182
2183 let start = std::time::Instant::now();
2185 let send_result = client
2186 .send_text("msg2".to_string(), Some(vec!["test_key".to_string()]))
2187 .await;
2188 let elapsed = start.elapsed();
2189
2190 assert!(
2192 send_result.is_ok(),
2193 "Send should succeed after rate limit + reconnection, was: {send_result:?}"
2194 );
2195 assert!(
2199 elapsed >= Duration::from_millis(850),
2200 "Should wait for rate limit (~1s), waited {elapsed:?}"
2201 );
2202
2203 client.disconnect().await;
2204 server.abort();
2205 }
2206
2207 #[rstest]
2208 #[tokio::test]
2209 async fn test_disconnect_during_reconnect_exits_cleanly() {
2210 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2213 let port = listener.local_addr().unwrap().port();
2214
2215 let server = task::spawn(async move {
2216 if let Ok((stream, _)) = listener.accept().await
2218 && let Ok(ws) = accept_async(stream).await
2219 {
2220 drop(ws);
2221 }
2222 sleep(Duration::from_secs(60)).await;
2224 });
2225
2226 let (handler, _rx) = channel_message_handler();
2227
2228 let config = WebSocketConfig {
2229 url: format!("ws://127.0.0.1:{port}"),
2230 headers: vec![],
2231 heartbeat: None,
2232 heartbeat_msg: None,
2233 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(100),
2235 reconnect_delay_max_ms: Some(200),
2236 reconnect_backoff_factor: Some(1.0),
2237 reconnect_jitter_ms: Some(0),
2238 reconnect_max_attempts: None,
2239 };
2240
2241 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2242 .await
2243 .unwrap();
2244
2245 tokio::time::timeout(Duration::from_secs(2), async {
2247 while !client.is_reconnecting() {
2248 sleep(Duration::from_millis(10)).await;
2249 }
2250 })
2251 .await
2252 .expect("Client should enter RECONNECT state");
2253
2254 client.disconnect().await;
2256
2257 assert!(
2259 client.is_disconnected(),
2260 "Client should be cleanly disconnected"
2261 );
2262
2263 server.abort();
2264 }
2265
2266 #[rstest]
2267 #[tokio::test]
2268 async fn test_send_fails_fast_when_closed_before_rate_limit() {
2269 use std::{num::NonZeroU32, sync::Arc};
2272
2273 use nautilus_common::testing::wait_until_async;
2274
2275 use crate::ratelimiter::quota::Quota;
2276
2277 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2278 let port = listener.local_addr().unwrap().port();
2279
2280 let server = task::spawn(async move {
2281 if let Ok((stream, _)) = listener.accept().await
2283 && let Ok(ws) = accept_async(stream).await
2284 {
2285 drop(ws);
2286 }
2287 sleep(Duration::from_secs(60)).await;
2288 });
2289
2290 let (handler, _rx) = channel_message_handler();
2291
2292 let config = WebSocketConfig {
2293 url: format!("ws://127.0.0.1:{port}"),
2294 headers: vec![],
2295 heartbeat: None,
2296 heartbeat_msg: None,
2297 reconnect_timeout_ms: Some(5_000),
2298 reconnect_delay_initial_ms: Some(50),
2299 reconnect_delay_max_ms: Some(100),
2300 reconnect_backoff_factor: Some(1.0),
2301 reconnect_jitter_ms: Some(0),
2302 reconnect_max_attempts: None,
2303 };
2304
2305 let quota = Quota::with_period(Duration::from_secs(10))
2308 .unwrap()
2309 .allow_burst(NonZeroU32::new(1).unwrap());
2310
2311 let client = Arc::new(
2312 WebSocketClient::connect(
2313 config,
2314 Some(handler),
2315 None,
2316 None,
2317 vec![("test_key".to_string(), quota)],
2318 None,
2319 )
2320 .await
2321 .unwrap(),
2322 );
2323
2324 wait_until_async(
2326 || async { client.is_reconnecting() || client.is_closed() },
2327 Duration::from_secs(2),
2328 )
2329 .await;
2330
2331 client.disconnect().await;
2333 assert!(
2334 !client.is_active(),
2335 "Client should not be active after disconnect"
2336 );
2337
2338 let start = std::time::Instant::now();
2340 let result = client
2341 .send_text("test".to_string(), Some(vec!["test_key".to_string()]))
2342 .await;
2343 let elapsed = start.elapsed();
2344
2345 assert!(result.is_err(), "Send should fail when client is closed");
2347 assert!(
2348 matches!(result, Err(crate::error::SendError::Closed)),
2349 "Send should return Closed error, was: {result:?}"
2350 );
2351
2352 assert!(
2354 elapsed < Duration::from_millis(100),
2355 "Send should fail fast without rate limiting, took {elapsed:?}"
2356 );
2357
2358 server.abort();
2359 }
2360
2361 #[rstest]
2362 #[tokio::test]
2363 async fn test_connect_rejects_none_message_handler() {
2364 let config = WebSocketConfig {
2368 url: "ws://127.0.0.1:9999".to_string(),
2369 headers: vec![],
2370 heartbeat: None,
2371 heartbeat_msg: None,
2372 reconnect_timeout_ms: Some(1_000),
2373 reconnect_delay_initial_ms: Some(100),
2374 reconnect_delay_max_ms: Some(500),
2375 reconnect_backoff_factor: Some(1.5),
2376 reconnect_jitter_ms: Some(0),
2377 reconnect_max_attempts: None,
2378 };
2379
2380 let result = WebSocketClient::connect(config, None, None, None, vec![], None).await;
2382
2383 assert!(
2384 result.is_err(),
2385 "connect() should reject None message_handler"
2386 );
2387
2388 let err = result.unwrap_err();
2389 let err_msg = err.to_string();
2390 assert!(
2391 err_msg.contains("Handler mode requires message_handler"),
2392 "Error should mention missing message_handler, was: {err_msg}"
2393 );
2394 }
2395
2396 #[rstest]
2397 #[tokio::test]
2398 async fn test_client_without_handler_sets_stream_mode() {
2399 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2403 let port = listener.local_addr().unwrap().port();
2404
2405 let server = task::spawn(async move {
2406 if let Ok((stream, _)) = listener.accept().await
2408 && let Ok(ws) = accept_async(stream).await
2409 {
2410 drop(ws); }
2412 });
2413
2414 let config = WebSocketConfig {
2415 url: format!("ws://127.0.0.1:{port}"),
2416 headers: vec![],
2417 heartbeat: None,
2418 heartbeat_msg: None,
2419 reconnect_timeout_ms: Some(1_000),
2420 reconnect_delay_initial_ms: Some(100),
2421 reconnect_delay_max_ms: Some(500),
2422 reconnect_backoff_factor: Some(1.5),
2423 reconnect_jitter_ms: Some(0),
2424 reconnect_max_attempts: None,
2425 };
2426
2427 let inner = WebSocketClientInner::connect_url(config, None, None)
2429 .await
2430 .unwrap();
2431
2432 assert!(
2434 inner.is_stream_mode,
2435 "Client without handler should have is_stream_mode=true"
2436 );
2437
2438 server.abort();
2442 }
2443}