1use std::{
32 fmt::Debug,
33 sync::{
34 Arc,
35 atomic::{AtomicU8, Ordering},
36 },
37 time::Duration,
38};
39
40use futures_util::{
41 SinkExt, StreamExt,
42 stream::{SplitSink, SplitStream},
43};
44use http::HeaderName;
45use nautilus_core::CleanDrop;
46use nautilus_cryptography::providers::install_cryptographic_provider;
47use tokio::net::TcpStream;
48use tokio_tungstenite::{
49 MaybeTlsStream, WebSocketStream, connect_async,
50 tungstenite::{Error, Message, client::IntoClientRequest, http::HeaderValue},
51};
52
53use crate::{
54 RECONNECTED,
55 backoff::ExponentialBackoff,
56 error::SendError,
57 logging::{log_task_aborted, log_task_started, log_task_stopped},
58 mode::ConnectionMode,
59 ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
60};
61
62const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
64const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
65const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
66
67type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
68pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
69
70pub type MessageHandler = Arc<dyn Fn(Message) + Send + Sync>;
72
73pub type PingHandler = Arc<dyn Fn(Vec<u8>) + Send + Sync>;
75
76#[must_use]
80pub fn channel_message_handler() -> (
81 MessageHandler,
82 tokio::sync::mpsc::UnboundedReceiver<Message>,
83) {
84 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
85 let handler = Arc::new(move |msg: Message| {
86 if let Err(e) = tx.send(msg) {
87 tracing::error!("Failed to send message to channel: {e}");
88 }
89 });
90 (handler, rx)
91}
92
93#[cfg_attr(
94 feature = "python",
95 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
96)]
97pub struct WebSocketConfig {
98 pub url: String,
100 pub headers: Vec<(String, String)>,
102 pub message_handler: Option<MessageHandler>,
104 pub heartbeat: Option<u64>,
106 pub heartbeat_msg: Option<String>,
108 pub ping_handler: Option<PingHandler>,
110 pub reconnect_timeout_ms: Option<u64>,
112 pub reconnect_delay_initial_ms: Option<u64>,
114 pub reconnect_delay_max_ms: Option<u64>,
116 pub reconnect_backoff_factor: Option<f64>,
118 pub reconnect_jitter_ms: Option<u64>,
120}
121
122impl Debug for WebSocketConfig {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 f.debug_struct("WebSocketConfig")
125 .field("url", &self.url)
126 .field("headers", &self.headers)
127 .field(
128 "message_handler",
129 &self.message_handler.as_ref().map(|_| "<function>"),
130 )
131 .field("heartbeat", &self.heartbeat)
132 .field("heartbeat_msg", &self.heartbeat_msg)
133 .field(
134 "ping_handler",
135 &self.ping_handler.as_ref().map(|_| "<function>"),
136 )
137 .field("reconnect_timeout_ms", &self.reconnect_timeout_ms)
138 .field(
139 "reconnect_delay_initial_ms",
140 &self.reconnect_delay_initial_ms,
141 )
142 .field("reconnect_delay_max_ms", &self.reconnect_delay_max_ms)
143 .field("reconnect_backoff_factor", &self.reconnect_backoff_factor)
144 .field("reconnect_jitter_ms", &self.reconnect_jitter_ms)
145 .finish()
146 }
147}
148
149impl Clone for WebSocketConfig {
150 fn clone(&self) -> Self {
151 Self {
152 url: self.url.clone(),
153 headers: self.headers.clone(),
154 message_handler: self.message_handler.clone(),
155 heartbeat: self.heartbeat,
156 heartbeat_msg: self.heartbeat_msg.clone(),
157 ping_handler: self.ping_handler.clone(),
158 reconnect_timeout_ms: self.reconnect_timeout_ms,
159 reconnect_delay_initial_ms: self.reconnect_delay_initial_ms,
160 reconnect_delay_max_ms: self.reconnect_delay_max_ms,
161 reconnect_backoff_factor: self.reconnect_backoff_factor,
162 reconnect_jitter_ms: self.reconnect_jitter_ms,
163 }
164 }
165}
166
167#[derive(Debug)]
169pub(crate) enum WriterCommand {
170 Update(MessageWriter),
172 Send(Message),
174}
175
176struct WebSocketClientInner {
192 config: WebSocketConfig,
193 read_task: Option<tokio::task::JoinHandle<()>>,
194 write_task: tokio::task::JoinHandle<()>,
195 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
196 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
197 connection_mode: Arc<AtomicU8>,
198 reconnect_timeout: Duration,
199 backoff: ExponentialBackoff,
200}
201
202impl WebSocketClientInner {
203 pub async fn new_with_writer(
205 config: WebSocketConfig,
206 writer: MessageWriter,
207 ) -> Result<Self, Error> {
208 install_cryptographic_provider();
209
210 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
211
212 let read_task = None;
214
215 let backoff = ExponentialBackoff::new(
216 Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
217 Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
218 config.reconnect_backoff_factor.unwrap_or(1.5),
219 config.reconnect_jitter_ms.unwrap_or(100),
220 true, )
222 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
223
224 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
225 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
226
227 let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
228 Some(Self::spawn_heartbeat_task(
229 connection_mode.clone(),
230 heartbeat_interval,
231 config.heartbeat_msg.clone(),
232 writer_tx.clone(),
233 ))
234 } else {
235 None
236 };
237
238 Ok(Self {
239 config: config.clone(),
240 writer_tx,
241 connection_mode,
242 reconnect_timeout: Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10000)),
243 heartbeat_task,
244 read_task,
245 write_task,
246 backoff,
247 })
248 }
249
250 pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
252 install_cryptographic_provider();
253
254 let WebSocketConfig {
255 url,
256 message_handler,
257 heartbeat,
258 headers,
259 heartbeat_msg,
260 ping_handler,
261 reconnect_timeout_ms,
262 reconnect_delay_initial_ms,
263 reconnect_delay_max_ms,
264 reconnect_backoff_factor,
265 reconnect_jitter_ms,
266 } = &config;
267 let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
268
269 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
270
271 let read_task = if message_handler.is_some() {
272 Some(Self::spawn_message_handler_task(
273 connection_mode.clone(),
274 reader,
275 message_handler.as_ref(),
276 ping_handler.as_ref(),
277 ))
278 } else {
279 None
280 };
281
282 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
283 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
284
285 let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
287 Self::spawn_heartbeat_task(
288 connection_mode.clone(),
289 *heartbeat_secs,
290 heartbeat_msg.clone(),
291 writer_tx.clone(),
292 )
293 });
294
295 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
296 let backoff = ExponentialBackoff::new(
297 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
298 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
299 reconnect_backoff_factor.unwrap_or(1.5),
300 reconnect_jitter_ms.unwrap_or(100),
301 true, )
303 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
304
305 Ok(Self {
306 config,
307 read_task,
308 write_task,
309 writer_tx,
310 heartbeat_task,
311 connection_mode,
312 reconnect_timeout,
313 backoff,
314 })
315 }
316
317 #[inline]
319 pub async fn connect_with_server(
320 url: &str,
321 headers: Vec<(String, String)>,
322 ) -> Result<(MessageWriter, MessageReader), Error> {
323 let mut request = url.into_client_request()?;
324 let req_headers = request.headers_mut();
325
326 let mut header_names: Vec<HeaderName> = Vec::new();
327 for (key, val) in headers {
328 let header_value = HeaderValue::from_str(&val)?;
329 let header_name: HeaderName = key.parse()?;
330 header_names.push(header_name.clone());
331 req_headers.insert(header_name, header_value);
332 }
333
334 connect_async(request).await.map(|resp| resp.0.split())
335 }
336
337 pub async fn reconnect(&mut self) -> Result<(), Error> {
342 tracing::debug!("Reconnecting");
343
344 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
345 tracing::debug!("Reconnect aborted due to disconnect state");
346 return Ok(());
347 }
348
349 tokio::time::timeout(self.reconnect_timeout, async {
350 let (new_writer, reader) =
352 Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
353
354 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
355 tracing::debug!("Reconnect aborted mid-flight (after connect)");
356 return Ok(());
357 }
358
359 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
360 tracing::error!("{e}");
361 }
362
363 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
365
366 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
367 tracing::debug!("Reconnect aborted mid-flight (after delay)");
368 return Ok(());
369 }
370
371 if let Some(ref read_task) = self.read_task.take()
372 && !read_task.is_finished()
373 {
374 read_task.abort();
375 log_task_aborted("read");
376 }
377
378 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
380 tracing::debug!("Reconnect aborted mid-flight (before spawn read)");
381 return Ok(());
382 }
383
384 self.connection_mode
386 .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
387
388 self.read_task = if self.config.message_handler.is_some() {
389 Some(Self::spawn_message_handler_task(
390 self.connection_mode.clone(),
391 reader,
392 self.config.message_handler.as_ref(),
393 self.config.ping_handler.as_ref(),
394 ))
395 } else {
396 None
397 };
398
399 tracing::debug!("Reconnect succeeded");
400 Ok(())
401 })
402 .await
403 .map_err(|_| {
404 Error::Io(std::io::Error::new(
405 std::io::ErrorKind::TimedOut,
406 format!(
407 "reconnection timed out after {}s",
408 self.reconnect_timeout.as_secs_f64()
409 ),
410 ))
411 })?
412 }
413
414 #[inline]
422 #[must_use]
423 pub fn is_alive(&self) -> bool {
424 match &self.read_task {
425 Some(read_task) => !read_task.is_finished(),
426 None => true, }
428 }
429
430 fn spawn_message_handler_task(
431 connection_state: Arc<AtomicU8>,
432 mut reader: MessageReader,
433 message_handler: Option<&MessageHandler>,
434 ping_handler: Option<&PingHandler>,
435 ) -> tokio::task::JoinHandle<()> {
436 tracing::debug!("Started message handler task 'read'");
437
438 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
439
440 let message_handler = message_handler.cloned();
442 let ping_handler = ping_handler.cloned();
443
444 tokio::task::spawn(async move {
445 loop {
446 if !ConnectionMode::from_atomic(&connection_state).is_active() {
447 break;
448 }
449
450 match tokio::time::timeout(check_interval, reader.next()).await {
451 Ok(Some(Ok(Message::Binary(data)))) => {
452 tracing::trace!("Received message <binary> {} bytes", data.len());
453 if let Some(ref handler) = message_handler {
454 handler(Message::Binary(data));
455 }
456 }
457 Ok(Some(Ok(Message::Text(data)))) => {
458 tracing::trace!("Received message: {data}");
459 if let Some(ref handler) = message_handler {
460 handler(Message::Text(data));
461 }
462 }
463 Ok(Some(Ok(Message::Ping(ping_data)))) => {
464 tracing::trace!("Received ping: {ping_data:?}");
465 if let Some(ref handler) = ping_handler {
466 handler(ping_data.to_vec());
467 }
468 }
469 Ok(Some(Ok(Message::Pong(_)))) => {
470 tracing::trace!("Received pong");
471 }
472 Ok(Some(Ok(Message::Close(_)))) => {
473 tracing::debug!("Received close message - terminating");
474 break;
475 }
476 Ok(Some(Ok(_))) => (),
477 Ok(Some(Err(e))) => {
478 tracing::error!("Received error message - terminating: {e}");
479 break;
480 }
481 Ok(None) => {
482 tracing::debug!("No message received - terminating");
483 break;
484 }
485 Err(_) => {
486 continue;
488 }
489 }
490 }
491 })
492 }
493
494 fn spawn_write_task(
495 connection_state: Arc<AtomicU8>,
496 writer: MessageWriter,
497 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
498 ) -> tokio::task::JoinHandle<()> {
499 log_task_started("write");
500
501 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
503
504 tokio::task::spawn(async move {
505 let mut active_writer = writer;
506
507 loop {
508 match ConnectionMode::from_atomic(&connection_state) {
509 ConnectionMode::Disconnect => {
510 _ = tokio::time::timeout(
513 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
514 active_writer.close(),
515 )
516 .await;
517 break;
518 }
519 ConnectionMode::Closed => break,
520 _ => {}
521 }
522
523 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
524 Ok(Some(msg)) => {
525 let mode = ConnectionMode::from_atomic(&connection_state);
527 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
528 break;
529 }
530
531 match msg {
532 WriterCommand::Update(new_writer) => {
533 tracing::debug!("Received new writer");
534
535 tokio::time::sleep(Duration::from_millis(100)).await;
537
538 _ = tokio::time::timeout(
541 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
542 active_writer.close(),
543 )
544 .await;
545
546 active_writer = new_writer;
547 tracing::debug!("Updated writer");
548 }
549 _ if mode.is_reconnect() => {
550 tracing::warn!("Skipping message while reconnecting, {msg:?}");
551 continue;
552 }
553 WriterCommand::Send(msg) => {
554 if let Err(e) = active_writer.send(msg).await {
555 tracing::error!("Failed to send message: {e}");
556 tracing::warn!("Writer triggering reconnect");
558 connection_state
559 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
560 }
561 }
562 }
563 }
564 Ok(None) => {
565 tracing::debug!("Writer channel closed, terminating writer task");
567 break;
568 }
569 Err(_) => {
570 continue;
572 }
573 }
574 }
575
576 _ = tokio::time::timeout(
579 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
580 active_writer.close(),
581 )
582 .await;
583
584 log_task_stopped("write");
585 })
586 }
587
588 fn spawn_heartbeat_task(
589 connection_state: Arc<AtomicU8>,
590 heartbeat_secs: u64,
591 message: Option<String>,
592 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
593 ) -> tokio::task::JoinHandle<()> {
594 log_task_started("heartbeat");
595
596 tokio::task::spawn(async move {
597 let interval = Duration::from_secs(heartbeat_secs);
598
599 loop {
600 tokio::time::sleep(interval).await;
601
602 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
603 ConnectionMode::Active => {
604 let msg = match &message {
605 Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
606 None => WriterCommand::Send(Message::Ping(vec![].into())),
607 };
608
609 match writer_tx.send(msg) {
610 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
611 Err(e) => {
612 tracing::error!("Failed to send heartbeat to writer task: {e}");
613 }
614 }
615 }
616 ConnectionMode::Reconnect => continue,
617 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
618 }
619 }
620
621 log_task_stopped("heartbeat");
622 })
623 }
624}
625
626impl Drop for WebSocketClientInner {
627 fn drop(&mut self) {
628 self.clean_drop();
630 }
631}
632
633impl CleanDrop for WebSocketClientInner {
634 fn clean_drop(&mut self) {
635 if let Some(ref read_task) = self.read_task.take()
636 && !read_task.is_finished()
637 {
638 read_task.abort();
639 log_task_aborted("read");
640 }
641
642 if !self.write_task.is_finished() {
643 self.write_task.abort();
644 log_task_aborted("write");
645 }
646
647 if let Some(ref handle) = self.heartbeat_task.take()
648 && !handle.is_finished()
649 {
650 handle.abort();
651 log_task_aborted("heartbeat");
652 }
653
654 self.config.message_handler = None;
656 self.config.ping_handler = None;
657 }
658}
659
660#[cfg_attr(
665 feature = "python",
666 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
667)]
668pub struct WebSocketClient {
669 pub(crate) controller_task: tokio::task::JoinHandle<()>,
670 pub(crate) connection_mode: Arc<AtomicU8>,
671 pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
672 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
673}
674
675impl Debug for WebSocketClient {
676 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
677 f.debug_struct(stringify!(WebSocketClient)).finish()
678 }
679}
680
681impl WebSocketClient {
682 #[allow(clippy::too_many_arguments)]
688 pub async fn connect_stream(
689 config: WebSocketConfig,
690 keyed_quotas: Vec<(String, Quota)>,
691 default_quota: Option<Quota>,
692 post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
693 ) -> Result<(MessageReader, Self), Error> {
694 install_cryptographic_provider();
695
696 let (ws_stream, _) = connect_async(config.url.clone().into_client_request()?).await?;
698 let (writer, reader) = ws_stream.split();
699
700 let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
702
703 let connection_mode = inner.connection_mode.clone();
704 let writer_tx = inner.writer_tx.clone();
705
706 let controller_task =
707 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnect);
708
709 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
710
711 Ok((
712 reader,
713 Self {
714 controller_task,
715 connection_mode,
716 writer_tx,
717 rate_limiter,
718 },
719 ))
720 }
721
722 pub async fn connect(
731 config: WebSocketConfig,
732 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
733 keyed_quotas: Vec<(String, Quota)>,
734 default_quota: Option<Quota>,
735 ) -> Result<Self, Error> {
736 tracing::debug!("Connecting");
737 let inner = WebSocketClientInner::connect_url(config).await?;
738 let connection_mode = inner.connection_mode.clone();
739 let writer_tx = inner.writer_tx.clone();
740
741 let controller_task =
742 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnection);
743
744 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
745
746 Ok(Self {
747 controller_task,
748 connection_mode,
749 writer_tx,
750 rate_limiter,
751 })
752 }
753
754 #[must_use]
756 pub fn connection_mode(&self) -> ConnectionMode {
757 ConnectionMode::from_atomic(&self.connection_mode)
758 }
759
760 #[inline]
765 #[must_use]
766 pub fn is_active(&self) -> bool {
767 self.connection_mode().is_active()
768 }
769
770 #[must_use]
772 pub fn is_disconnected(&self) -> bool {
773 self.controller_task.is_finished()
774 }
775
776 #[inline]
781 #[must_use]
782 pub fn is_reconnecting(&self) -> bool {
783 self.connection_mode().is_reconnect()
784 }
785
786 #[inline]
790 #[must_use]
791 pub fn is_disconnecting(&self) -> bool {
792 self.connection_mode().is_disconnect()
793 }
794
795 #[inline]
801 #[must_use]
802 pub fn is_closed(&self) -> bool {
803 self.connection_mode().is_closed()
804 }
805
806 pub async fn disconnect(&self) {
811 tracing::debug!("Disconnecting");
812 self.connection_mode
813 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
814
815 match tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
816 while !self.is_disconnected() {
817 tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
818 }
819
820 if !self.controller_task.is_finished() {
821 self.controller_task.abort();
822 log_task_aborted("controller");
823 }
824 })
825 .await
826 {
827 Ok(()) => {
828 tracing::debug!("Controller task finished");
829 }
830 Err(_) => {
831 tracing::error!("Timeout waiting for controller task to finish");
832 }
833 }
834 }
835
836 #[allow(unused_variables)]
842 pub async fn send_text(
843 &self,
844 data: String,
845 keys: Option<Vec<String>>,
846 ) -> std::result::Result<(), SendError> {
847 self.rate_limiter.await_keys_ready(keys).await;
848
849 if !self.is_active() {
850 return Err(SendError::Closed);
851 }
852
853 tracing::trace!("Sending text: {data:?}");
854
855 let msg = Message::Text(data.into());
856 self.writer_tx
857 .send(WriterCommand::Send(msg))
858 .map_err(|e| SendError::BrokenPipe(e.to_string()))
859 }
860
861 #[allow(unused_variables)]
867 pub async fn send_bytes(
868 &self,
869 data: Vec<u8>,
870 keys: Option<Vec<String>>,
871 ) -> std::result::Result<(), SendError> {
872 self.rate_limiter.await_keys_ready(keys).await;
873
874 if !self.is_active() {
875 return Err(SendError::Closed);
876 }
877
878 tracing::trace!("Sending bytes: {data:?}");
879
880 let msg = Message::Binary(data.into());
881 self.writer_tx
882 .send(WriterCommand::Send(msg))
883 .map_err(|e| SendError::BrokenPipe(e.to_string()))
884 }
885
886 pub async fn send_close_message(&self) -> std::result::Result<(), SendError> {
892 if !self.is_active() {
893 return Err(SendError::Closed);
894 }
895
896 let msg = Message::Close(None);
897 self.writer_tx
898 .send(WriterCommand::Send(msg))
899 .map_err(|e| SendError::BrokenPipe(e.to_string()))
900 }
901
902 fn spawn_controller_task(
903 mut inner: WebSocketClientInner,
904 connection_mode: Arc<AtomicU8>,
905 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
906 ) -> tokio::task::JoinHandle<()> {
907 tokio::task::spawn(async move {
908 log_task_started("controller");
909
910 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
911
912 loop {
913 tokio::time::sleep(check_interval).await;
914 let mode = ConnectionMode::from_atomic(&connection_mode);
915
916 if mode.is_disconnect() {
917 tracing::debug!("Disconnecting");
918
919 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
920 if tokio::time::timeout(timeout, async {
921 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
923
924 if let Some(task) = &inner.read_task
925 && !task.is_finished()
926 {
927 task.abort();
928 log_task_aborted("read");
929 }
930
931 if let Some(task) = &inner.heartbeat_task
932 && !task.is_finished()
933 {
934 task.abort();
935 log_task_aborted("heartbeat");
936 }
937 })
938 .await
939 .is_err()
940 {
941 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
942 }
943
944 tracing::debug!("Closed");
945 break; }
947
948 if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
949 match inner.reconnect().await {
950 Ok(()) => {
951 inner.backoff.reset();
952
953 if ConnectionMode::from_atomic(&connection_mode).is_active() {
955 if let Some(ref handler) = inner.config.message_handler {
956 let reconnected_msg =
957 Message::Text(RECONNECTED.to_string().into());
958 handler(reconnected_msg);
959 tracing::debug!("Sent reconnected message to handler");
960 }
961
962 if let Some(ref callback) = post_reconnection {
964 callback();
965 tracing::debug!("Called `post_reconnection` handler");
966 }
967
968 tracing::debug!("Reconnected successfully");
969 } else {
970 tracing::debug!(
971 "Skipping post_reconnection handlers due to disconnect state"
972 );
973 }
974 }
975 Err(e) => {
976 let duration = inner.backoff.next_duration();
977 tracing::warn!("Reconnect attempt failed: {e}");
978 if !duration.is_zero() {
979 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
980 }
981 tokio::time::sleep(duration).await;
982 }
983 }
984 }
985 }
986 inner
987 .connection_mode
988 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
989
990 log_task_stopped("controller");
991 })
992 }
993}
994
995impl Drop for WebSocketClient {
997 fn drop(&mut self) {
998 if !self.controller_task.is_finished() {
999 self.controller_task.abort();
1000 log_task_aborted("controller");
1001 }
1002 }
1003}
1004
1005#[cfg(test)]
1010#[cfg(target_os = "linux")] mod tests {
1012 use std::{num::NonZeroU32, sync::Arc};
1013
1014 use futures_util::{SinkExt, StreamExt};
1015 use tokio::{
1016 net::TcpListener,
1017 task::{self, JoinHandle},
1018 };
1019 use tokio_tungstenite::{
1020 accept_hdr_async,
1021 tungstenite::{
1022 handshake::server::{self, Callback},
1023 http::HeaderValue,
1024 },
1025 };
1026
1027 use crate::{
1028 ratelimiter::quota::Quota,
1029 websocket::{WebSocketClient, WebSocketConfig},
1030 };
1031
1032 struct TestServer {
1033 task: JoinHandle<()>,
1034 port: u16,
1035 }
1036
1037 #[derive(Debug, Clone)]
1038 struct TestCallback {
1039 key: String,
1040 value: HeaderValue,
1041 }
1042
1043 impl Callback for TestCallback {
1044 fn on_request(
1045 self,
1046 request: &server::Request,
1047 response: server::Response,
1048 ) -> Result<server::Response, server::ErrorResponse> {
1049 let _ = response;
1050 let value = request.headers().get(&self.key);
1051 assert!(value.is_some());
1052
1053 if let Some(value) = request.headers().get(&self.key) {
1054 assert_eq!(value, self.value);
1055 }
1056
1057 Ok(response)
1058 }
1059 }
1060
1061 impl TestServer {
1062 async fn setup() -> Self {
1063 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1064 let port = TcpListener::local_addr(&server).unwrap().port();
1065
1066 let header_key = "test".to_string();
1067 let header_value = "test".to_string();
1068
1069 let test_call_back = TestCallback {
1070 key: header_key,
1071 value: HeaderValue::from_str(&header_value).unwrap(),
1072 };
1073
1074 let task = task::spawn(async move {
1075 loop {
1077 let (conn, _) = server.accept().await.unwrap();
1078 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1079 .await
1080 .unwrap();
1081
1082 task::spawn(async move {
1083 while let Some(Ok(msg)) = websocket.next().await {
1084 match msg {
1085 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1086 if txt == "close-now" =>
1087 {
1088 tracing::debug!("Forcibly closing from server side");
1089 let _ = websocket.close(None).await;
1091 break;
1092 }
1093 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1095 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1096 if websocket.send(msg).await.is_err() {
1097 break;
1098 }
1099 }
1100 tokio_tungstenite::tungstenite::protocol::Message::Close(
1102 _frame,
1103 ) => {
1104 let _ = websocket.close(None).await;
1105 break;
1106 }
1107 _ => {}
1109 }
1110 }
1111 });
1112 }
1113 });
1114
1115 Self { task, port }
1116 }
1117 }
1118
1119 impl Drop for TestServer {
1120 fn drop(&mut self) {
1121 self.task.abort();
1122 }
1123 }
1124
1125 async fn setup_test_client(port: u16) -> WebSocketClient {
1126 let config = WebSocketConfig {
1127 url: format!("ws://127.0.0.1:{port}"),
1128 headers: vec![("test".into(), "test".into())],
1129 message_handler: None,
1130 heartbeat: None,
1131 heartbeat_msg: None,
1132 ping_handler: None,
1133 reconnect_timeout_ms: None,
1134 reconnect_delay_initial_ms: None,
1135 reconnect_backoff_factor: None,
1136 reconnect_delay_max_ms: None,
1137 reconnect_jitter_ms: None,
1138 };
1139 WebSocketClient::connect(config, None, vec![], None)
1140 .await
1141 .expect("Failed to connect")
1142 }
1143
1144 #[tokio::test]
1145 async fn test_websocket_basic() {
1146 let server = TestServer::setup().await;
1147 let client = setup_test_client(server.port).await;
1148
1149 assert!(!client.is_disconnected());
1150
1151 client.disconnect().await;
1152 assert!(client.is_disconnected());
1153 }
1154
1155 #[tokio::test]
1156 async fn test_websocket_heartbeat() {
1157 let server = TestServer::setup().await;
1158 let client = setup_test_client(server.port).await;
1159
1160 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1162
1163 client.disconnect().await;
1165 assert!(client.is_disconnected());
1166 }
1167
1168 #[tokio::test]
1169 async fn test_websocket_reconnect_exhausted() {
1170 let config = WebSocketConfig {
1171 url: "ws://127.0.0.1:9997".into(), headers: vec![],
1173 message_handler: None,
1174 heartbeat: None,
1175 heartbeat_msg: None,
1176 ping_handler: None,
1177 reconnect_timeout_ms: None,
1178 reconnect_delay_initial_ms: None,
1179 reconnect_backoff_factor: None,
1180 reconnect_delay_max_ms: None,
1181 reconnect_jitter_ms: None,
1182 };
1183 let res = WebSocketClient::connect(config, None, vec![], None).await;
1184 assert!(res.is_err(), "Should fail quickly with no server");
1185 }
1186
1187 #[tokio::test]
1188 async fn test_websocket_forced_close_reconnect() {
1189 let server = TestServer::setup().await;
1190 let client = setup_test_client(server.port).await;
1191
1192 client.send_text("Hello".into(), None).await.unwrap();
1194
1195 client.send_text("close-now".into(), None).await.unwrap();
1197
1198 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1200
1201 assert!(!client.is_disconnected());
1203
1204 client.disconnect().await;
1206 assert!(client.is_disconnected());
1207 }
1208
1209 #[tokio::test]
1210 async fn test_rate_limiter() {
1211 let server = TestServer::setup().await;
1212 let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1213
1214 let config = WebSocketConfig {
1215 url: format!("ws://127.0.0.1:{}", server.port),
1216 headers: vec![("test".into(), "test".into())],
1217 message_handler: None,
1218 heartbeat: None,
1219 heartbeat_msg: None,
1220 ping_handler: None,
1221 reconnect_timeout_ms: None,
1222 reconnect_delay_initial_ms: None,
1223 reconnect_backoff_factor: None,
1224 reconnect_delay_max_ms: None,
1225 reconnect_jitter_ms: None,
1226 };
1227
1228 let client = WebSocketClient::connect(config, None, vec![("default".into(), quota)], None)
1229 .await
1230 .unwrap();
1231
1232 client.send_text("test1".into(), None).await.unwrap();
1234 client.send_text("test2".into(), None).await.unwrap();
1235
1236 client.send_text("test3".into(), None).await.unwrap();
1238
1239 client.disconnect().await;
1241 assert!(client.is_disconnected());
1242 }
1243
1244 #[tokio::test]
1245 async fn test_concurrent_writers() {
1246 let server = TestServer::setup().await;
1247 let client = Arc::new(setup_test_client(server.port).await);
1248
1249 let mut handles = vec![];
1250 for i in 0..10 {
1251 let client = client.clone();
1252 handles.push(task::spawn(async move {
1253 client.send_text(format!("test{i}"), None).await.unwrap();
1254 }));
1255 }
1256
1257 for handle in handles {
1258 handle.await.unwrap();
1259 }
1260
1261 client.disconnect().await;
1263 assert!(client.is_disconnected());
1264 }
1265}
1266
1267#[cfg(test)]
1268mod rust_tests {
1269 use tokio::{
1270 net::TcpListener,
1271 task,
1272 time::{Duration, sleep},
1273 };
1274 use tokio_tungstenite::accept_async;
1275
1276 use super::*;
1277
1278 #[tokio::test]
1279 async fn test_reconnect_then_disconnect() {
1280 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1282 let port = listener.local_addr().unwrap().port();
1283
1284 let server = task::spawn(async move {
1286 let (stream, _) = listener.accept().await.unwrap();
1287 let ws = accept_async(stream).await.unwrap();
1288 drop(ws);
1289 sleep(Duration::from_secs(1)).await;
1291 });
1292
1293 let (handler, _rx) = channel_message_handler();
1295
1296 let config = WebSocketConfig {
1298 url: format!("ws://127.0.0.1:{port}"),
1299 headers: vec![],
1300 message_handler: Some(handler),
1301 heartbeat: None,
1302 heartbeat_msg: None,
1303 ping_handler: None,
1304 reconnect_timeout_ms: Some(1_000),
1305 reconnect_delay_initial_ms: Some(50),
1306 reconnect_delay_max_ms: Some(100),
1307 reconnect_backoff_factor: Some(1.0),
1308 reconnect_jitter_ms: Some(0),
1309 };
1310
1311 let client = WebSocketClient::connect(config, None, vec![], None)
1313 .await
1314 .unwrap();
1315
1316 sleep(Duration::from_millis(100)).await;
1318 client.disconnect().await;
1320 assert!(client.is_disconnected());
1321 server.abort();
1322 }
1323}