1use std::{
32 fmt::Debug,
33 sync::{
34 Arc,
35 atomic::{AtomicU8, Ordering},
36 },
37 time::Duration,
38};
39
40use futures_util::{
41 SinkExt, StreamExt,
42 stream::{SplitSink, SplitStream},
43};
44use http::HeaderName;
45use nautilus_core::CleanDrop;
46use nautilus_cryptography::providers::install_cryptographic_provider;
47use tokio::net::TcpStream;
48use tokio_tungstenite::{
49 MaybeTlsStream, WebSocketStream, connect_async,
50 tungstenite::{Error, Message, client::IntoClientRequest, http::HeaderValue},
51};
52
53use crate::{
54 RECONNECTED,
55 backoff::ExponentialBackoff,
56 error::SendError,
57 logging::{log_task_aborted, log_task_started, log_task_stopped},
58 mode::ConnectionMode,
59 ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
60};
61
62pub const TEXT_PING: &str = "ping";
63pub const TEXT_PONG: &str = "pong";
64
65const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
67const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
68const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
69const SEND_OPERATION_CHECK_INTERVAL_MS: u64 = 1;
70
71type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
72pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
73
74pub type MessageHandler = Arc<dyn Fn(Message) + Send + Sync>;
84
85pub type PingHandler = Arc<dyn Fn(Vec<u8>) + Send + Sync>;
87
88#[must_use]
92pub fn channel_message_handler() -> (
93 MessageHandler,
94 tokio::sync::mpsc::UnboundedReceiver<Message>,
95) {
96 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
97 let handler = Arc::new(move |msg: Message| {
98 if let Err(e) = tx.send(msg) {
99 tracing::debug!("Failed to send message to channel: {e}");
100 }
101 });
102 (handler, rx)
103}
104
105#[cfg_attr(
125 feature = "python",
126 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
127)]
128pub struct WebSocketConfig {
129 pub url: String,
131 pub headers: Vec<(String, String)>,
133 pub message_handler: Option<MessageHandler>,
140 pub heartbeat: Option<u64>,
142 pub heartbeat_msg: Option<String>,
144 pub ping_handler: Option<PingHandler>,
146 pub reconnect_timeout_ms: Option<u64>,
150 pub reconnect_delay_initial_ms: Option<u64>,
154 pub reconnect_delay_max_ms: Option<u64>,
158 pub reconnect_backoff_factor: Option<f64>,
162 pub reconnect_jitter_ms: Option<u64>,
166}
167
168impl Debug for WebSocketConfig {
169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170 f.debug_struct(stringify!(WebSocketConfig))
171 .field("url", &self.url)
172 .field("headers", &self.headers)
173 .field(
174 "message_handler",
175 &self.message_handler.as_ref().map(|_| "<function>"),
176 )
177 .field("heartbeat", &self.heartbeat)
178 .field("heartbeat_msg", &self.heartbeat_msg)
179 .field(
180 "ping_handler",
181 &self.ping_handler.as_ref().map(|_| "<function>"),
182 )
183 .field("reconnect_timeout_ms", &self.reconnect_timeout_ms)
184 .field(
185 "reconnect_delay_initial_ms",
186 &self.reconnect_delay_initial_ms,
187 )
188 .field("reconnect_delay_max_ms", &self.reconnect_delay_max_ms)
189 .field("reconnect_backoff_factor", &self.reconnect_backoff_factor)
190 .field("reconnect_jitter_ms", &self.reconnect_jitter_ms)
191 .finish()
192 }
193}
194
195impl Clone for WebSocketConfig {
196 fn clone(&self) -> Self {
197 Self {
198 url: self.url.clone(),
199 headers: self.headers.clone(),
200 message_handler: self.message_handler.clone(),
201 heartbeat: self.heartbeat,
202 heartbeat_msg: self.heartbeat_msg.clone(),
203 ping_handler: self.ping_handler.clone(),
204 reconnect_timeout_ms: self.reconnect_timeout_ms,
205 reconnect_delay_initial_ms: self.reconnect_delay_initial_ms,
206 reconnect_delay_max_ms: self.reconnect_delay_max_ms,
207 reconnect_backoff_factor: self.reconnect_backoff_factor,
208 reconnect_jitter_ms: self.reconnect_jitter_ms,
209 }
210 }
211}
212
213#[derive(Debug)]
215pub(crate) enum WriterCommand {
216 Update(MessageWriter),
218 Send(Message),
220}
221
222struct WebSocketClientInner {
238 config: WebSocketConfig,
239 read_task: Option<tokio::task::JoinHandle<()>>,
240 write_task: tokio::task::JoinHandle<()>,
241 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
242 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
243 connection_mode: Arc<AtomicU8>,
244 reconnect_timeout: Duration,
245 backoff: ExponentialBackoff,
246 is_stream_mode: bool,
250}
251
252impl WebSocketClientInner {
253 pub async fn new_with_writer(
255 config: WebSocketConfig,
256 writer: MessageWriter,
257 ) -> Result<Self, Error> {
258 install_cryptographic_provider();
259
260 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
261
262 let read_task = None;
264
265 let backoff = ExponentialBackoff::new(
266 Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
267 Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
268 config.reconnect_backoff_factor.unwrap_or(1.5),
269 config.reconnect_jitter_ms.unwrap_or(100),
270 true, )
272 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
273
274 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
275 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
276
277 let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
278 Some(Self::spawn_heartbeat_task(
279 connection_mode.clone(),
280 heartbeat_interval,
281 config.heartbeat_msg.clone(),
282 writer_tx.clone(),
283 ))
284 } else {
285 None
286 };
287
288 Ok(Self {
289 config: config.clone(),
290 writer_tx,
291 connection_mode,
292 reconnect_timeout: Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10000)),
293 heartbeat_task,
294 read_task,
295 write_task,
296 backoff,
297 is_stream_mode: true,
298 })
299 }
300
301 pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
303 install_cryptographic_provider();
304
305 let WebSocketConfig {
306 url,
307 message_handler,
308 heartbeat,
309 headers,
310 heartbeat_msg,
311 ping_handler,
312 reconnect_timeout_ms,
313 reconnect_delay_initial_ms,
314 reconnect_delay_max_ms,
315 reconnect_backoff_factor,
316 reconnect_jitter_ms,
317 } = &config;
318 let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
319
320 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
321
322 let read_task = if message_handler.is_some() {
323 Some(Self::spawn_message_handler_task(
324 connection_mode.clone(),
325 reader,
326 message_handler.as_ref(),
327 ping_handler.as_ref(),
328 ))
329 } else {
330 None
331 };
332
333 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
334 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
335
336 let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
338 Self::spawn_heartbeat_task(
339 connection_mode.clone(),
340 *heartbeat_secs,
341 heartbeat_msg.clone(),
342 writer_tx.clone(),
343 )
344 });
345
346 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
347 let backoff = ExponentialBackoff::new(
348 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
349 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
350 reconnect_backoff_factor.unwrap_or(1.5),
351 reconnect_jitter_ms.unwrap_or(100),
352 true, )
354 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
355
356 Ok(Self {
357 config,
358 read_task,
359 write_task,
360 writer_tx,
361 heartbeat_task,
362 connection_mode,
363 reconnect_timeout,
364 backoff,
365 is_stream_mode: false,
366 })
367 }
368
369 #[inline]
371 pub async fn connect_with_server(
372 url: &str,
373 headers: Vec<(String, String)>,
374 ) -> Result<(MessageWriter, MessageReader), Error> {
375 let mut request = url.into_client_request()?;
376 let req_headers = request.headers_mut();
377
378 let mut header_names: Vec<HeaderName> = Vec::new();
379 for (key, val) in headers {
380 let header_value = HeaderValue::from_str(&val)?;
381 let header_name: HeaderName = key.parse()?;
382 header_names.push(header_name.clone());
383 req_headers.insert(header_name, header_value);
384 }
385
386 connect_async(request).await.map(|resp| resp.0.split())
387 }
388
389 pub async fn reconnect(&mut self) -> Result<(), Error> {
398 tracing::debug!("Reconnecting");
399
400 if self.is_stream_mode {
401 tracing::warn!(
402 "Auto-reconnect disabled for stream-based WebSocket client; \
403 stream users must manually reconnect by creating a new connection"
404 );
405 self.connection_mode
407 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
408 return Ok(());
409 }
410
411 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
412 tracing::debug!("Reconnect aborted due to disconnect state");
413 return Ok(());
414 }
415
416 tokio::time::timeout(self.reconnect_timeout, async {
417 let (new_writer, reader) =
419 Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
420
421 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
422 tracing::debug!("Reconnect aborted mid-flight (after connect)");
423 return Ok(());
424 }
425
426 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
427 tracing::error!("{e}");
428 }
429
430 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
432
433 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
434 tracing::debug!("Reconnect aborted mid-flight (after delay)");
435 return Ok(());
436 }
437
438 if let Some(ref read_task) = self.read_task.take()
439 && !read_task.is_finished()
440 {
441 read_task.abort();
442 log_task_aborted("read");
443 }
444
445 if self
448 .connection_mode
449 .compare_exchange(
450 ConnectionMode::Reconnect.as_u8(),
451 ConnectionMode::Active.as_u8(),
452 Ordering::SeqCst,
453 Ordering::SeqCst,
454 )
455 .is_err()
456 {
457 tracing::debug!("Reconnect aborted (state changed during reconnect)");
458 return Ok(());
459 }
460
461 self.read_task = if self.config.message_handler.is_some() {
462 Some(Self::spawn_message_handler_task(
463 self.connection_mode.clone(),
464 reader,
465 self.config.message_handler.as_ref(),
466 self.config.ping_handler.as_ref(),
467 ))
468 } else {
469 None
470 };
471
472 tracing::debug!("Reconnect succeeded");
473 Ok(())
474 })
475 .await
476 .map_err(|_| {
477 Error::Io(std::io::Error::new(
478 std::io::ErrorKind::TimedOut,
479 format!(
480 "reconnection timed out after {}s",
481 self.reconnect_timeout.as_secs_f64()
482 ),
483 ))
484 })?
485 }
486
487 #[inline]
495 #[must_use]
496 pub fn is_alive(&self) -> bool {
497 match &self.read_task {
498 Some(read_task) => !read_task.is_finished(),
499 None => true, }
501 }
502
503 fn spawn_message_handler_task(
504 connection_state: Arc<AtomicU8>,
505 mut reader: MessageReader,
506 message_handler: Option<&MessageHandler>,
507 ping_handler: Option<&PingHandler>,
508 ) -> tokio::task::JoinHandle<()> {
509 tracing::debug!("Started message handler task 'read'");
510
511 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
512
513 let message_handler = message_handler.cloned();
515 let ping_handler = ping_handler.cloned();
516
517 tokio::task::spawn(async move {
518 loop {
519 if !ConnectionMode::from_atomic(&connection_state).is_active() {
520 break;
521 }
522
523 match tokio::time::timeout(check_interval, reader.next()).await {
524 Ok(Some(Ok(Message::Binary(data)))) => {
525 tracing::trace!("Received message <binary> {} bytes", data.len());
526 if let Some(ref handler) = message_handler {
527 handler(Message::Binary(data));
528 }
529 }
530 Ok(Some(Ok(Message::Text(data)))) => {
531 tracing::trace!("Received message: {data}");
532 if let Some(ref handler) = message_handler {
533 handler(Message::Text(data));
534 }
535 }
536 Ok(Some(Ok(Message::Ping(ping_data)))) => {
537 tracing::trace!("Received ping: {ping_data:?}");
538 if let Some(ref handler) = ping_handler {
539 handler(ping_data.to_vec());
540 }
541 }
542 Ok(Some(Ok(Message::Pong(_)))) => {
543 tracing::trace!("Received pong");
544 }
545 Ok(Some(Ok(Message::Close(_)))) => {
546 tracing::debug!("Received close message - terminating");
547 break;
548 }
549 Ok(Some(Ok(_))) => (),
550 Ok(Some(Err(e))) => {
551 tracing::error!("Received error message - terminating: {e}");
552 break;
553 }
554 Ok(None) => {
555 tracing::debug!("No message received - terminating");
556 break;
557 }
558 Err(_) => {
559 continue;
561 }
562 }
563 }
564 })
565 }
566
567 fn spawn_write_task(
568 connection_state: Arc<AtomicU8>,
569 writer: MessageWriter,
570 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
571 ) -> tokio::task::JoinHandle<()> {
572 log_task_started("write");
573
574 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
576
577 tokio::task::spawn(async move {
578 let mut active_writer = writer;
579
580 loop {
581 match ConnectionMode::from_atomic(&connection_state) {
582 ConnectionMode::Disconnect => {
583 _ = tokio::time::timeout(
586 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
587 active_writer.close(),
588 )
589 .await;
590 break;
591 }
592 ConnectionMode::Closed => break,
593 _ => {}
594 }
595
596 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
597 Ok(Some(msg)) => {
598 let mode = ConnectionMode::from_atomic(&connection_state);
600 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
601 break;
602 }
603
604 match msg {
605 WriterCommand::Update(new_writer) => {
606 tracing::debug!("Received new writer");
607
608 tokio::time::sleep(Duration::from_millis(100)).await;
610
611 _ = tokio::time::timeout(
614 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
615 active_writer.close(),
616 )
617 .await;
618
619 active_writer = new_writer;
620 tracing::debug!("Updated writer");
621 }
622 _ if mode.is_reconnect() => {
623 tracing::warn!("Skipping message while reconnecting, {msg:?}");
624 continue;
625 }
626 WriterCommand::Send(msg) => {
627 if let Err(e) = active_writer.send(msg).await {
628 tracing::error!("Failed to send message: {e}");
629 tracing::warn!("Writer triggering reconnect");
631 connection_state
632 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
633 }
634 }
635 }
636 }
637 Ok(None) => {
638 tracing::debug!("Writer channel closed, terminating writer task");
640 break;
641 }
642 Err(_) => {
643 continue;
645 }
646 }
647 }
648
649 _ = tokio::time::timeout(
652 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
653 active_writer.close(),
654 )
655 .await;
656
657 log_task_stopped("write");
658 })
659 }
660
661 fn spawn_heartbeat_task(
662 connection_state: Arc<AtomicU8>,
663 heartbeat_secs: u64,
664 message: Option<String>,
665 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
666 ) -> tokio::task::JoinHandle<()> {
667 log_task_started("heartbeat");
668
669 tokio::task::spawn(async move {
670 let interval = Duration::from_secs(heartbeat_secs);
671
672 loop {
673 tokio::time::sleep(interval).await;
674
675 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
676 ConnectionMode::Active => {
677 let msg = match &message {
678 Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
679 None => WriterCommand::Send(Message::Ping(vec![].into())),
680 };
681
682 match writer_tx.send(msg) {
683 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
684 Err(e) => {
685 tracing::error!("Failed to send heartbeat to writer task: {e}");
686 }
687 }
688 }
689 ConnectionMode::Reconnect => continue,
690 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
691 }
692 }
693
694 log_task_stopped("heartbeat");
695 })
696 }
697}
698
699impl Drop for WebSocketClientInner {
700 fn drop(&mut self) {
701 self.clean_drop();
703 }
704}
705
706impl CleanDrop for WebSocketClientInner {
708 fn clean_drop(&mut self) {
709 if let Some(ref read_task) = self.read_task.take()
710 && !read_task.is_finished()
711 {
712 read_task.abort();
713 log_task_aborted("read");
714 }
715
716 if !self.write_task.is_finished() {
717 self.write_task.abort();
718 log_task_aborted("write");
719 }
720
721 if let Some(ref handle) = self.heartbeat_task.take()
722 && !handle.is_finished()
723 {
724 handle.abort();
725 log_task_aborted("heartbeat");
726 }
727
728 self.config.message_handler = None;
730 self.config.ping_handler = None;
731 }
732}
733
734#[cfg_attr(
739 feature = "python",
740 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
741)]
742pub struct WebSocketClient {
743 pub(crate) controller_task: tokio::task::JoinHandle<()>,
744 pub(crate) connection_mode: Arc<AtomicU8>,
745 pub(crate) reconnect_timeout: Duration,
746 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
747 pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
748}
749
750impl Debug for WebSocketClient {
751 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
752 f.debug_struct(stringify!(WebSocketClient)).finish()
753 }
754}
755
756impl WebSocketClient {
757 #[allow(clippy::too_many_arguments)]
773 pub async fn connect_stream(
774 config: WebSocketConfig,
775 keyed_quotas: Vec<(String, Quota)>,
776 default_quota: Option<Quota>,
777 post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
778 ) -> Result<(MessageReader, Self), Error> {
779 install_cryptographic_provider();
780
781 let (writer, reader) =
783 WebSocketClientInner::connect_with_server(&config.url, config.headers.clone()).await?;
784
785 let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
787
788 let connection_mode = inner.connection_mode.clone();
789 let reconnect_timeout = inner.reconnect_timeout;
790 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
791 let writer_tx = inner.writer_tx.clone();
792
793 let controller_task =
794 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnect);
795
796 Ok((
797 reader,
798 Self {
799 controller_task,
800 connection_mode,
801 reconnect_timeout,
802 rate_limiter,
803 writer_tx,
804 },
805 ))
806 }
807
808 pub async fn connect(
825 config: WebSocketConfig,
826 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
827 keyed_quotas: Vec<(String, Quota)>,
828 default_quota: Option<Quota>,
829 ) -> Result<Self, Error> {
830 tracing::debug!("Connecting");
831 let inner = WebSocketClientInner::connect_url(config).await?;
832 let connection_mode = inner.connection_mode.clone();
833 let writer_tx = inner.writer_tx.clone();
834 let reconnect_timeout = inner.reconnect_timeout;
835
836 let controller_task =
837 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnection);
838
839 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
840
841 Ok(Self {
842 controller_task,
843 connection_mode,
844 reconnect_timeout,
845 rate_limiter,
846 writer_tx,
847 })
848 }
849
850 #[must_use]
852 pub fn connection_mode(&self) -> ConnectionMode {
853 ConnectionMode::from_atomic(&self.connection_mode)
854 }
855
856 #[inline]
861 #[must_use]
862 pub fn is_active(&self) -> bool {
863 self.connection_mode().is_active()
864 }
865
866 #[must_use]
868 pub fn is_disconnected(&self) -> bool {
869 self.controller_task.is_finished()
870 }
871
872 #[inline]
877 #[must_use]
878 pub fn is_reconnecting(&self) -> bool {
879 self.connection_mode().is_reconnect()
880 }
881
882 #[inline]
886 #[must_use]
887 pub fn is_disconnecting(&self) -> bool {
888 self.connection_mode().is_disconnect()
889 }
890
891 #[inline]
897 #[must_use]
898 pub fn is_closed(&self) -> bool {
899 self.connection_mode().is_closed()
900 }
901
902 async fn wait_for_active(&self) -> Result<(), SendError> {
906 if self.is_closed() {
907 return Err(SendError::Closed);
908 }
909
910 let timeout = self.reconnect_timeout;
911 let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
912
913 if !self.is_active() {
914 tracing::debug!("Waiting for client to become ACTIVE before sending...");
915
916 let inner = tokio::time::timeout(timeout, async {
917 loop {
918 if self.is_active() {
919 return Ok(());
920 }
921 if matches!(
922 self.connection_mode(),
923 ConnectionMode::Disconnect | ConnectionMode::Closed
924 ) {
925 return Err(());
926 }
927 tokio::time::sleep(check_interval).await;
928 }
929 })
930 .await
931 .map_err(|_| SendError::Timeout)?;
932 inner.map_err(|()| SendError::Closed)?;
933 }
934
935 Ok(())
936 }
937
938 pub async fn disconnect(&self) {
943 tracing::debug!("Disconnecting");
944 self.connection_mode
945 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
946
947 match tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
948 while !self.is_disconnected() {
949 tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
950 }
951
952 if !self.controller_task.is_finished() {
953 self.controller_task.abort();
954 log_task_aborted("controller");
955 }
956 })
957 .await
958 {
959 Ok(()) => {
960 tracing::debug!("Controller task finished");
961 }
962 Err(_) => {
963 tracing::error!("Timeout waiting for controller task to finish");
964 if !self.controller_task.is_finished() {
965 self.controller_task.abort();
966 log_task_aborted("controller");
967 }
968 }
969 }
970 }
971
972 #[allow(unused_variables)]
978 pub async fn send_text(
979 &self,
980 data: String,
981 keys: Option<Vec<String>>,
982 ) -> Result<(), SendError> {
983 self.rate_limiter.await_keys_ready(keys).await;
984 self.wait_for_active().await?;
985
986 tracing::trace!("Sending text: {data:?}");
987
988 let msg = Message::Text(data.into());
989 self.writer_tx
990 .send(WriterCommand::Send(msg))
991 .map_err(|e| SendError::BrokenPipe(e.to_string()))
992 }
993
994 pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1000 self.wait_for_active().await?;
1001
1002 tracing::trace!("Sending pong frame ({} bytes)", data.len());
1003
1004 let msg = Message::Pong(data.into());
1005 self.writer_tx
1006 .send(WriterCommand::Send(msg))
1007 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1008 }
1009
1010 #[allow(unused_variables)]
1016 pub async fn send_bytes(
1017 &self,
1018 data: Vec<u8>,
1019 keys: Option<Vec<String>>,
1020 ) -> Result<(), SendError> {
1021 self.rate_limiter.await_keys_ready(keys).await;
1022 self.wait_for_active().await?;
1023
1024 tracing::trace!("Sending bytes: {data:?}");
1025
1026 let msg = Message::Binary(data.into());
1027 self.writer_tx
1028 .send(WriterCommand::Send(msg))
1029 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1030 }
1031
1032 pub async fn send_close_message(&self) -> Result<(), SendError> {
1038 self.wait_for_active().await?;
1039
1040 let msg = Message::Close(None);
1041 self.writer_tx
1042 .send(WriterCommand::Send(msg))
1043 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1044 }
1045
1046 fn spawn_controller_task(
1047 mut inner: WebSocketClientInner,
1048 connection_mode: Arc<AtomicU8>,
1049 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1050 ) -> tokio::task::JoinHandle<()> {
1051 tokio::task::spawn(async move {
1052 log_task_started("controller");
1053
1054 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1055
1056 loop {
1057 tokio::time::sleep(check_interval).await;
1058 let mut mode = ConnectionMode::from_atomic(&connection_mode);
1059
1060 if mode.is_disconnect() {
1061 tracing::debug!("Disconnecting");
1062
1063 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1064 if tokio::time::timeout(timeout, async {
1065 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1067
1068 if let Some(task) = &inner.read_task
1069 && !task.is_finished()
1070 {
1071 task.abort();
1072 log_task_aborted("read");
1073 }
1074
1075 if let Some(task) = &inner.heartbeat_task
1076 && !task.is_finished()
1077 {
1078 task.abort();
1079 log_task_aborted("heartbeat");
1080 }
1081 })
1082 .await
1083 .is_err()
1084 {
1085 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
1086 }
1087
1088 tracing::debug!("Closed");
1089 break; }
1091
1092 if mode.is_active() && !inner.is_alive() {
1093 if connection_mode
1094 .compare_exchange(
1095 ConnectionMode::Active.as_u8(),
1096 ConnectionMode::Reconnect.as_u8(),
1097 Ordering::SeqCst,
1098 Ordering::SeqCst,
1099 )
1100 .is_ok()
1101 {
1102 tracing::debug!("Detected dead read task, transitioning to RECONNECT");
1103 }
1104 mode = ConnectionMode::from_atomic(&connection_mode);
1105 }
1106
1107 if mode.is_reconnect() {
1108 match inner.reconnect().await {
1109 Ok(()) => {
1110 inner.backoff.reset();
1111
1112 if ConnectionMode::from_atomic(&connection_mode).is_active() {
1114 if let Some(ref handler) = inner.config.message_handler {
1115 let reconnected_msg =
1116 Message::Text(RECONNECTED.to_string().into());
1117 handler(reconnected_msg);
1118 tracing::debug!("Sent reconnected message to handler");
1119 }
1120
1121 if let Some(ref callback) = post_reconnection {
1123 callback();
1124 tracing::debug!("Called `post_reconnection` handler");
1125 }
1126
1127 tracing::debug!("Reconnected successfully");
1128 } else {
1129 tracing::debug!(
1130 "Skipping post_reconnection handlers due to disconnect state"
1131 );
1132 }
1133 }
1134 Err(e) => {
1135 let duration = inner.backoff.next_duration();
1136 tracing::warn!("Reconnect attempt failed: {e}");
1137 if !duration.is_zero() {
1138 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
1139 }
1140 tokio::time::sleep(duration).await;
1141 }
1142 }
1143 }
1144 }
1145 inner
1146 .connection_mode
1147 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1148
1149 log_task_stopped("controller");
1150 })
1151 }
1152}
1153
1154impl Drop for WebSocketClient {
1156 fn drop(&mut self) {
1157 if !self.controller_task.is_finished() {
1158 self.controller_task.abort();
1159 log_task_aborted("controller");
1160 }
1161 }
1162}
1163
1164#[cfg(test)]
1169#[cfg(target_os = "linux")] mod tests {
1171 use std::{num::NonZeroU32, sync::Arc};
1172
1173 use futures_util::{SinkExt, StreamExt};
1174 use tokio::{
1175 net::TcpListener,
1176 task::{self, JoinHandle},
1177 };
1178 use tokio_tungstenite::{
1179 accept_hdr_async,
1180 tungstenite::{
1181 handshake::server::{self, Callback},
1182 http::HeaderValue,
1183 },
1184 };
1185
1186 use crate::{
1187 ratelimiter::quota::Quota,
1188 websocket::{WebSocketClient, WebSocketConfig},
1189 };
1190
1191 struct TestServer {
1192 task: JoinHandle<()>,
1193 port: u16,
1194 }
1195
1196 #[derive(Debug, Clone)]
1197 struct TestCallback {
1198 key: String,
1199 value: HeaderValue,
1200 }
1201
1202 impl Callback for TestCallback {
1203 fn on_request(
1204 self,
1205 request: &server::Request,
1206 response: server::Response,
1207 ) -> Result<server::Response, server::ErrorResponse> {
1208 let _ = response;
1209 let value = request.headers().get(&self.key);
1210 assert!(value.is_some());
1211
1212 if let Some(value) = request.headers().get(&self.key) {
1213 assert_eq!(value, self.value);
1214 }
1215
1216 Ok(response)
1217 }
1218 }
1219
1220 impl TestServer {
1221 async fn setup() -> Self {
1222 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1223 let port = TcpListener::local_addr(&server).unwrap().port();
1224
1225 let header_key = "test".to_string();
1226 let header_value = "test".to_string();
1227
1228 let test_call_back = TestCallback {
1229 key: header_key,
1230 value: HeaderValue::from_str(&header_value).unwrap(),
1231 };
1232
1233 let task = task::spawn(async move {
1234 loop {
1236 let (conn, _) = server.accept().await.unwrap();
1237 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1238 .await
1239 .unwrap();
1240
1241 task::spawn(async move {
1242 while let Some(Ok(msg)) = websocket.next().await {
1243 match msg {
1244 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1245 if txt == "close-now" =>
1246 {
1247 tracing::debug!("Forcibly closing from server side");
1248 let _ = websocket.close(None).await;
1250 break;
1251 }
1252 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1254 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1255 if websocket.send(msg).await.is_err() {
1256 break;
1257 }
1258 }
1259 tokio_tungstenite::tungstenite::protocol::Message::Close(
1261 _frame,
1262 ) => {
1263 let _ = websocket.close(None).await;
1264 break;
1265 }
1266 _ => {}
1268 }
1269 }
1270 });
1271 }
1272 });
1273
1274 Self { task, port }
1275 }
1276 }
1277
1278 impl Drop for TestServer {
1279 fn drop(&mut self) {
1280 self.task.abort();
1281 }
1282 }
1283
1284 async fn setup_test_client(port: u16) -> WebSocketClient {
1285 let config = WebSocketConfig {
1286 url: format!("ws://127.0.0.1:{port}"),
1287 headers: vec![("test".into(), "test".into())],
1288 message_handler: None,
1289 heartbeat: None,
1290 heartbeat_msg: None,
1291 ping_handler: None,
1292 reconnect_timeout_ms: None,
1293 reconnect_delay_initial_ms: None,
1294 reconnect_backoff_factor: None,
1295 reconnect_delay_max_ms: None,
1296 reconnect_jitter_ms: None,
1297 };
1298 WebSocketClient::connect(config, None, vec![], None)
1299 .await
1300 .expect("Failed to connect")
1301 }
1302
1303 #[tokio::test]
1304 async fn test_websocket_basic() {
1305 let server = TestServer::setup().await;
1306 let client = setup_test_client(server.port).await;
1307
1308 assert!(!client.is_disconnected());
1309
1310 client.disconnect().await;
1311 assert!(client.is_disconnected());
1312 }
1313
1314 #[tokio::test]
1315 async fn test_websocket_heartbeat() {
1316 let server = TestServer::setup().await;
1317 let client = setup_test_client(server.port).await;
1318
1319 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1321
1322 client.disconnect().await;
1324 assert!(client.is_disconnected());
1325 }
1326
1327 #[tokio::test]
1328 async fn test_websocket_reconnect_exhausted() {
1329 let config = WebSocketConfig {
1330 url: "ws://127.0.0.1:9997".into(), headers: vec![],
1332 message_handler: None,
1333 heartbeat: None,
1334 heartbeat_msg: None,
1335 ping_handler: None,
1336 reconnect_timeout_ms: None,
1337 reconnect_delay_initial_ms: None,
1338 reconnect_backoff_factor: None,
1339 reconnect_delay_max_ms: None,
1340 reconnect_jitter_ms: None,
1341 };
1342 let res = WebSocketClient::connect(config, None, vec![], None).await;
1343 assert!(res.is_err(), "Should fail quickly with no server");
1344 }
1345
1346 #[tokio::test]
1347 async fn test_websocket_forced_close_reconnect() {
1348 let server = TestServer::setup().await;
1349 let client = setup_test_client(server.port).await;
1350
1351 client.send_text("Hello".into(), None).await.unwrap();
1353
1354 client.send_text("close-now".into(), None).await.unwrap();
1356
1357 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1359
1360 assert!(!client.is_disconnected());
1362
1363 client.disconnect().await;
1365 assert!(client.is_disconnected());
1366 }
1367
1368 #[tokio::test]
1369 async fn test_rate_limiter() {
1370 let server = TestServer::setup().await;
1371 let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1372
1373 let config = WebSocketConfig {
1374 url: format!("ws://127.0.0.1:{}", server.port),
1375 headers: vec![("test".into(), "test".into())],
1376 message_handler: None,
1377 heartbeat: None,
1378 heartbeat_msg: None,
1379 ping_handler: None,
1380 reconnect_timeout_ms: None,
1381 reconnect_delay_initial_ms: None,
1382 reconnect_backoff_factor: None,
1383 reconnect_delay_max_ms: None,
1384 reconnect_jitter_ms: None,
1385 };
1386
1387 let client = WebSocketClient::connect(config, None, vec![("default".into(), quota)], None)
1388 .await
1389 .unwrap();
1390
1391 client.send_text("test1".into(), None).await.unwrap();
1393 client.send_text("test2".into(), None).await.unwrap();
1394
1395 client.send_text("test3".into(), None).await.unwrap();
1397
1398 client.disconnect().await;
1400 assert!(client.is_disconnected());
1401 }
1402
1403 #[tokio::test]
1404 async fn test_concurrent_writers() {
1405 let server = TestServer::setup().await;
1406 let client = Arc::new(setup_test_client(server.port).await);
1407
1408 let mut handles = vec![];
1409 for i in 0..10 {
1410 let client = client.clone();
1411 handles.push(task::spawn(async move {
1412 client.send_text(format!("test{i}"), None).await.unwrap();
1413 }));
1414 }
1415
1416 for handle in handles {
1417 handle.await.unwrap();
1418 }
1419
1420 client.disconnect().await;
1422 assert!(client.is_disconnected());
1423 }
1424}
1425
1426#[cfg(test)]
1431mod rust_tests {
1432 use futures_util::StreamExt;
1433 use rstest::rstest;
1434 use tokio::{
1435 net::TcpListener,
1436 task,
1437 time::{Duration, sleep},
1438 };
1439 use tokio_tungstenite::accept_async;
1440
1441 use super::*;
1442
1443 #[rstest]
1444 #[tokio::test]
1445 async fn test_reconnect_then_disconnect() {
1446 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1448 let port = listener.local_addr().unwrap().port();
1449
1450 let server = task::spawn(async move {
1452 let (stream, _) = listener.accept().await.unwrap();
1453 let ws = accept_async(stream).await.unwrap();
1454 drop(ws);
1455 sleep(Duration::from_secs(1)).await;
1457 });
1458
1459 let (handler, _rx) = channel_message_handler();
1461
1462 let config = WebSocketConfig {
1464 url: format!("ws://127.0.0.1:{port}"),
1465 headers: vec![],
1466 message_handler: Some(handler),
1467 heartbeat: None,
1468 heartbeat_msg: None,
1469 ping_handler: None,
1470 reconnect_timeout_ms: Some(1_000),
1471 reconnect_delay_initial_ms: Some(50),
1472 reconnect_delay_max_ms: Some(100),
1473 reconnect_backoff_factor: Some(1.0),
1474 reconnect_jitter_ms: Some(0),
1475 };
1476
1477 let client = WebSocketClient::connect(config, None, vec![], None)
1479 .await
1480 .unwrap();
1481
1482 sleep(Duration::from_millis(100)).await;
1484 client.disconnect().await;
1486 assert!(client.is_disconnected());
1487 server.abort();
1488 }
1489
1490 #[rstest]
1491 #[tokio::test]
1492 async fn test_reconnect_state_flips_when_reader_stops() {
1493 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1495 let port = listener.local_addr().unwrap().port();
1496
1497 let server = task::spawn(async move {
1498 if let Ok((stream, _)) = listener.accept().await
1499 && let Ok(ws) = accept_async(stream).await
1500 {
1501 drop(ws);
1502 }
1503 sleep(Duration::from_millis(50)).await;
1504 });
1505
1506 let (handler, _rx) = channel_message_handler();
1507
1508 let config = WebSocketConfig {
1509 url: format!("ws://127.0.0.1:{port}"),
1510 headers: vec![],
1511 message_handler: Some(handler),
1512 heartbeat: None,
1513 heartbeat_msg: None,
1514 ping_handler: None,
1515 reconnect_timeout_ms: Some(1_000),
1516 reconnect_delay_initial_ms: Some(50),
1517 reconnect_delay_max_ms: Some(100),
1518 reconnect_backoff_factor: Some(1.0),
1519 reconnect_jitter_ms: Some(0),
1520 };
1521
1522 let client = WebSocketClient::connect(config, None, vec![], None)
1523 .await
1524 .unwrap();
1525
1526 tokio::time::timeout(Duration::from_secs(2), async {
1527 loop {
1528 if client.is_reconnecting() {
1529 break;
1530 }
1531 tokio::time::sleep(Duration::from_millis(10)).await;
1532 }
1533 })
1534 .await
1535 .expect("client did not enter RECONNECT state");
1536
1537 client.disconnect().await;
1538 server.abort();
1539 }
1540
1541 #[rstest]
1542 #[tokio::test]
1543 async fn test_stream_mode_disables_auto_reconnect() {
1544 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1547 let port = listener.local_addr().unwrap().port();
1548
1549 let server = task::spawn(async move {
1550 if let Ok((stream, _)) = listener.accept().await
1551 && let Ok(_ws) = accept_async(stream).await
1552 {
1553 sleep(Duration::from_millis(100)).await;
1555 }
1556 });
1557
1558 let config = WebSocketConfig {
1559 url: format!("ws://127.0.0.1:{port}"),
1560 headers: vec![],
1561 message_handler: None, heartbeat: None,
1563 heartbeat_msg: None,
1564 ping_handler: None,
1565 reconnect_timeout_ms: Some(1_000),
1566 reconnect_delay_initial_ms: Some(50),
1567 reconnect_delay_max_ms: Some(100),
1568 reconnect_backoff_factor: Some(1.0),
1569 reconnect_jitter_ms: Some(0),
1570 };
1571
1572 let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
1574 .await
1575 .unwrap();
1576
1577 server.abort();
1585 }
1586
1587 #[rstest]
1588 #[tokio::test]
1589 async fn test_message_handler_mode_allows_auto_reconnect() {
1590 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1592 let port = listener.local_addr().unwrap().port();
1593
1594 let server = task::spawn(async move {
1595 if let Ok((stream, _)) = listener.accept().await
1597 && let Ok(ws) = accept_async(stream).await
1598 {
1599 drop(ws);
1600 }
1601 sleep(Duration::from_millis(50)).await;
1602 });
1603
1604 let (handler, _rx) = channel_message_handler();
1605
1606 let config = WebSocketConfig {
1607 url: format!("ws://127.0.0.1:{port}"),
1608 headers: vec![],
1609 message_handler: Some(handler), heartbeat: None,
1611 heartbeat_msg: None,
1612 ping_handler: None,
1613 reconnect_timeout_ms: Some(1_000),
1614 reconnect_delay_initial_ms: Some(50),
1615 reconnect_delay_max_ms: Some(100),
1616 reconnect_backoff_factor: Some(1.0),
1617 reconnect_jitter_ms: Some(0),
1618 };
1619
1620 let client = WebSocketClient::connect(config, None, vec![], None)
1621 .await
1622 .unwrap();
1623
1624 tokio::time::timeout(Duration::from_secs(2), async {
1626 loop {
1627 if client.is_reconnecting() || client.is_closed() {
1628 break;
1629 }
1630 tokio::time::sleep(Duration::from_millis(10)).await;
1631 }
1632 })
1633 .await
1634 .expect("client should attempt reconnection or close");
1635
1636 assert!(
1639 client.is_reconnecting() || client.is_closed(),
1640 "Client with message handler should attempt reconnection"
1641 );
1642
1643 client.disconnect().await;
1644 server.abort();
1645 }
1646
1647 #[rstest]
1648 #[tokio::test]
1649 async fn test_handler_mode_reconnect_with_new_connection() {
1650 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1652 let port = listener.local_addr().unwrap().port();
1653
1654 let server = task::spawn(async move {
1655 if let Ok((stream, _)) = listener.accept().await
1657 && let Ok(ws) = accept_async(stream).await
1658 {
1659 drop(ws);
1660 }
1661
1662 sleep(Duration::from_millis(100)).await;
1664
1665 if let Ok((stream, _)) = listener.accept().await
1667 && let Ok(mut ws) = accept_async(stream).await
1668 {
1669 use futures_util::SinkExt;
1670 let _ = ws
1671 .send(Message::Text("reconnected".to_string().into()))
1672 .await;
1673 sleep(Duration::from_secs(1)).await;
1674 }
1675 });
1676
1677 let (handler, mut rx) = channel_message_handler();
1678
1679 let config = WebSocketConfig {
1680 url: format!("ws://127.0.0.1:{port}"),
1681 headers: vec![],
1682 message_handler: Some(handler),
1683 heartbeat: None,
1684 heartbeat_msg: None,
1685 ping_handler: None,
1686 reconnect_timeout_ms: Some(2_000),
1687 reconnect_delay_initial_ms: Some(50),
1688 reconnect_delay_max_ms: Some(200),
1689 reconnect_backoff_factor: Some(1.5),
1690 reconnect_jitter_ms: Some(10),
1691 };
1692
1693 let client = WebSocketClient::connect(config, None, vec![], None)
1694 .await
1695 .unwrap();
1696
1697 let result = tokio::time::timeout(Duration::from_secs(5), async {
1699 loop {
1700 if let Ok(msg) = rx.try_recv()
1701 && matches!(msg, Message::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
1702 {
1703 return true;
1704 }
1705 tokio::time::sleep(Duration::from_millis(10)).await;
1706 }
1707 })
1708 .await;
1709
1710 assert!(
1711 result.is_ok(),
1712 "Should receive message after reconnection within timeout"
1713 );
1714
1715 client.disconnect().await;
1716 server.abort();
1717 }
1718
1719 #[rstest]
1720 #[tokio::test]
1721 async fn test_stream_mode_no_auto_reconnect() {
1722 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1725 let port = listener.local_addr().unwrap().port();
1726
1727 let server = task::spawn(async move {
1728 if let Ok((stream, _)) = listener.accept().await
1730 && let Ok(mut ws) = accept_async(stream).await
1731 {
1732 use futures_util::SinkExt;
1733 let _ = ws.send(Message::Text("hello".to_string().into())).await;
1734 sleep(Duration::from_millis(50)).await;
1735 }
1737 });
1738
1739 let config = WebSocketConfig {
1740 url: format!("ws://127.0.0.1:{port}"),
1741 headers: vec![],
1742 message_handler: None, heartbeat: None,
1744 heartbeat_msg: None,
1745 ping_handler: None,
1746 reconnect_timeout_ms: Some(1_000),
1747 reconnect_delay_initial_ms: Some(50),
1748 reconnect_delay_max_ms: Some(100),
1749 reconnect_backoff_factor: Some(1.0),
1750 reconnect_jitter_ms: Some(0),
1751 };
1752
1753 let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
1754 .await
1755 .unwrap();
1756
1757 assert!(client.is_active(), "Client should start as active");
1759
1760 let msg = reader.next().await;
1762 assert!(
1763 matches!(msg, Some(Ok(Message::Text(ref text))) if AsRef::<str>::as_ref(text) == "hello"),
1764 "Should receive initial message"
1765 );
1766
1767 while let Some(msg) = reader.next().await {
1769 if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
1770 break;
1771 }
1772 }
1773
1774 sleep(Duration::from_millis(200)).await;
1777
1778 assert!(
1781 client.is_active() || client.is_closed(),
1782 "Stream mode client stays ACTIVE (caller owns reader) or caller disconnected"
1783 );
1784 assert!(
1785 !client.is_reconnecting(),
1786 "Stream mode client should never attempt reconnection"
1787 );
1788
1789 client.disconnect().await;
1790 server.abort();
1791 }
1792
1793 #[rstest]
1794 #[tokio::test]
1795 async fn test_send_timeout_uses_configured_reconnect_timeout() {
1796 use nautilus_common::testing::wait_until_async;
1799
1800 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1801 let port = listener.local_addr().unwrap().port();
1802
1803 let server = task::spawn(async move {
1804 if let Ok((stream, _)) = listener.accept().await
1806 && let Ok(ws) = accept_async(stream).await
1807 {
1808 drop(ws);
1809 }
1810 sleep(Duration::from_secs(60)).await;
1812 });
1813
1814 let (handler, _rx) = channel_message_handler();
1815
1816 let config = WebSocketConfig {
1818 url: format!("ws://127.0.0.1:{port}"),
1819 headers: vec![],
1820 message_handler: Some(handler),
1821 heartbeat: None,
1822 heartbeat_msg: None,
1823 ping_handler: None,
1824 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(50),
1826 reconnect_delay_max_ms: Some(100),
1827 reconnect_backoff_factor: Some(1.0),
1828 reconnect_jitter_ms: Some(0),
1829 };
1830
1831 let client = WebSocketClient::connect(config, None, vec![], None)
1832 .await
1833 .unwrap();
1834
1835 wait_until_async(
1837 || async { client.is_reconnecting() },
1838 Duration::from_secs(3),
1839 )
1840 .await;
1841
1842 let start = std::time::Instant::now();
1844 let send_result = client.send_text("test".to_string(), None).await;
1845 let elapsed = start.elapsed();
1846
1847 assert!(
1848 send_result.is_err(),
1849 "Send should fail when client stuck in RECONNECT"
1850 );
1851 assert!(
1852 matches!(send_result, Err(crate::error::SendError::Timeout)),
1853 "Send should return Timeout error, got: {:?}",
1854 send_result
1855 );
1856 assert!(
1859 elapsed >= Duration::from_millis(1800),
1860 "Send should timeout after at least 2s (configured timeout), took {:?}",
1861 elapsed
1862 );
1863
1864 client.disconnect().await;
1865 server.abort();
1866 }
1867
1868 #[rstest]
1869 #[tokio::test]
1870 async fn test_send_waits_during_reconnection() {
1871 use nautilus_common::testing::wait_until_async;
1873
1874 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1875 let port = listener.local_addr().unwrap().port();
1876
1877 let server = task::spawn(async move {
1878 if let Ok((stream, _)) = listener.accept().await
1880 && let Ok(ws) = accept_async(stream).await
1881 {
1882 drop(ws);
1883 }
1884
1885 sleep(Duration::from_millis(500)).await;
1887
1888 if let Ok((stream, _)) = listener.accept().await
1890 && let Ok(mut ws) = accept_async(stream).await
1891 {
1892 while let Some(Ok(msg)) = ws.next().await {
1894 if ws.send(msg).await.is_err() {
1895 break;
1896 }
1897 }
1898 }
1899 });
1900
1901 let (handler, _rx) = channel_message_handler();
1902
1903 let config = WebSocketConfig {
1904 url: format!("ws://127.0.0.1:{port}"),
1905 headers: vec![],
1906 message_handler: Some(handler),
1907 heartbeat: None,
1908 heartbeat_msg: None,
1909 ping_handler: None,
1910 reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
1912 reconnect_delay_max_ms: Some(200),
1913 reconnect_backoff_factor: Some(1.0),
1914 reconnect_jitter_ms: Some(0),
1915 };
1916
1917 let client = WebSocketClient::connect(config, None, vec![], None)
1918 .await
1919 .unwrap();
1920
1921 wait_until_async(
1923 || async { client.is_reconnecting() },
1924 Duration::from_secs(2),
1925 )
1926 .await;
1927
1928 let send_result = tokio::time::timeout(
1930 Duration::from_secs(3),
1931 client.send_text("test_message".to_string(), None),
1932 )
1933 .await;
1934
1935 assert!(
1936 send_result.is_ok() && send_result.unwrap().is_ok(),
1937 "Send should succeed after waiting for reconnection"
1938 );
1939
1940 client.disconnect().await;
1941 server.abort();
1942 }
1943
1944 #[rstest]
1945 #[tokio::test]
1946 async fn test_rate_limiter_before_active_wait() {
1947 use std::{num::NonZeroU32, sync::Arc};
1952
1953 use nautilus_common::testing::wait_until_async;
1954
1955 use crate::ratelimiter::quota::Quota;
1956
1957 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1958 let port = listener.local_addr().unwrap().port();
1959
1960 let server = task::spawn(async move {
1961 if let Ok((stream, _)) = listener.accept().await
1963 && let Ok(mut ws) = accept_async(stream).await
1964 {
1965 if let Some(Ok(_)) = ws.next().await {
1967 drop(ws);
1968 }
1969 }
1970
1971 sleep(Duration::from_millis(500)).await;
1973
1974 if let Ok((stream, _)) = listener.accept().await
1976 && let Ok(mut ws) = accept_async(stream).await
1977 {
1978 while let Some(Ok(msg)) = ws.next().await {
1979 if ws.send(msg).await.is_err() {
1980 break;
1981 }
1982 }
1983 }
1984 });
1985
1986 let (handler, _rx) = channel_message_handler();
1987
1988 let config = WebSocketConfig {
1989 url: format!("ws://127.0.0.1:{port}"),
1990 headers: vec![],
1991 message_handler: Some(handler),
1992 heartbeat: None,
1993 heartbeat_msg: None,
1994 ping_handler: None,
1995 reconnect_timeout_ms: Some(5_000),
1996 reconnect_delay_initial_ms: Some(50),
1997 reconnect_delay_max_ms: Some(100),
1998 reconnect_backoff_factor: Some(1.0),
1999 reconnect_jitter_ms: Some(0),
2000 };
2001
2002 let quota =
2004 Quota::per_second(NonZeroU32::new(1).unwrap()).allow_burst(NonZeroU32::new(1).unwrap());
2005
2006 let client = Arc::new(
2007 WebSocketClient::connect(config, None, vec![("test_key".to_string(), quota)], None)
2008 .await
2009 .unwrap(),
2010 );
2011
2012 client
2014 .send_text("msg1".to_string(), Some(vec!["test_key".to_string()]))
2015 .await
2016 .unwrap();
2017
2018 wait_until_async(
2020 || async { client.is_reconnecting() },
2021 Duration::from_secs(2),
2022 )
2023 .await;
2024
2025 let start = std::time::Instant::now();
2027 let send_result = client
2028 .send_text("msg2".to_string(), Some(vec!["test_key".to_string()]))
2029 .await;
2030 let elapsed = start.elapsed();
2031
2032 assert!(
2034 send_result.is_ok(),
2035 "Send should succeed after rate limit + reconnection, got: {:?}",
2036 send_result
2037 );
2038 assert!(
2042 elapsed >= Duration::from_millis(850),
2043 "Should wait for rate limit (~1s), waited {:?}",
2044 elapsed
2045 );
2046
2047 client.disconnect().await;
2048 server.abort();
2049 }
2050
2051 #[rstest]
2052 #[tokio::test]
2053 async fn test_disconnect_during_reconnect_exits_cleanly() {
2054 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2057 let port = listener.local_addr().unwrap().port();
2058
2059 let server = task::spawn(async move {
2060 if let Ok((stream, _)) = listener.accept().await
2062 && let Ok(ws) = accept_async(stream).await
2063 {
2064 drop(ws);
2065 }
2066 sleep(Duration::from_secs(60)).await;
2068 });
2069
2070 let (handler, _rx) = channel_message_handler();
2071
2072 let config = WebSocketConfig {
2073 url: format!("ws://127.0.0.1:{port}"),
2074 headers: vec![],
2075 message_handler: Some(handler),
2076 heartbeat: None,
2077 heartbeat_msg: None,
2078 ping_handler: None,
2079 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(100),
2081 reconnect_delay_max_ms: Some(200),
2082 reconnect_backoff_factor: Some(1.0),
2083 reconnect_jitter_ms: Some(0),
2084 };
2085
2086 let client = WebSocketClient::connect(config, None, vec![], None)
2087 .await
2088 .unwrap();
2089
2090 tokio::time::timeout(Duration::from_secs(2), async {
2092 while !client.is_reconnecting() {
2093 sleep(Duration::from_millis(10)).await;
2094 }
2095 })
2096 .await
2097 .expect("Client should enter RECONNECT state");
2098
2099 client.disconnect().await;
2101
2102 assert!(
2104 client.is_disconnected(),
2105 "Client should be cleanly disconnected"
2106 );
2107
2108 server.abort();
2109 }
2110}