1use std::{
32 sync::{
33 Arc,
34 atomic::{AtomicU8, Ordering},
35 },
36 time::Duration,
37};
38
39use futures_util::{
40 SinkExt, StreamExt,
41 stream::{SplitSink, SplitStream},
42};
43use http::HeaderName;
44use nautilus_cryptography::providers::install_cryptographic_provider;
45use pyo3::{prelude::*, types::PyBytes};
46use tokio::{
47 net::TcpStream,
48 sync::mpsc::{self, Receiver, Sender},
49};
50use tokio_tungstenite::{
51 MaybeTlsStream, WebSocketStream, connect_async,
52 tungstenite::{Error, Message, client::IntoClientRequest, http::HeaderValue},
53};
54
55use crate::{
56 backoff::ExponentialBackoff,
57 mode::ConnectionMode,
58 ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
59};
60type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
61pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
62
63#[derive(Debug, Clone)]
64pub enum Consumer {
65 Python(Option<Arc<PyObject>>),
66 Rust(Sender<Message>),
67}
68
69impl Consumer {
70 #[must_use]
71 pub fn rust_consumer() -> (Self, Receiver<Message>) {
72 let (tx, rx) = mpsc::channel(100);
73 (Self::Rust(tx), rx)
74 }
75}
76
77#[derive(Debug, Clone)]
78#[cfg_attr(
79 feature = "python",
80 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
81)]
82pub struct WebSocketConfig {
83 pub url: String,
85 pub headers: Vec<(String, String)>,
87 pub handler: Consumer,
89 pub heartbeat: Option<u64>,
91 pub heartbeat_msg: Option<String>,
93 pub ping_handler: Option<Arc<PyObject>>,
95 pub reconnect_timeout_ms: Option<u64>,
97 pub reconnect_delay_initial_ms: Option<u64>,
99 pub reconnect_delay_max_ms: Option<u64>,
101 pub reconnect_backoff_factor: Option<f64>,
103 pub reconnect_jitter_ms: Option<u64>,
105}
106
107#[derive(Debug)]
109pub(crate) enum WriterCommand {
110 Update(MessageWriter),
112 Send(Message),
114}
115
116struct WebSocketClientInner {
132 config: WebSocketConfig,
133 read_task: Option<tokio::task::JoinHandle<()>>,
134 write_task: tokio::task::JoinHandle<()>,
135 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
136 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
137 connection_mode: Arc<AtomicU8>,
138 reconnect_timeout: Duration,
139 backoff: ExponentialBackoff,
140}
141
142impl WebSocketClientInner {
143 pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
145 install_cryptographic_provider();
146
147 #[allow(unused_variables)]
148 let WebSocketConfig {
149 url,
150 handler,
151 heartbeat,
152 headers,
153 heartbeat_msg,
154 ping_handler,
155 reconnect_timeout_ms,
156 reconnect_delay_initial_ms,
157 reconnect_delay_max_ms,
158 reconnect_backoff_factor,
159 reconnect_jitter_ms,
160 } = &config;
161 let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
162
163 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
164
165 let read_task = match &handler {
166 Consumer::Python(handler) => handler.as_ref().map(|handler| {
167 Self::spawn_python_callback_task(
168 connection_mode.clone(),
169 reader,
170 handler.clone(),
171 ping_handler.clone(),
172 )
173 }),
174 Consumer::Rust(sender) => Some(Self::spawn_rust_streaming_task(
175 connection_mode.clone(),
176 reader,
177 sender.clone(),
178 )),
179 };
180
181 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
182 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
183
184 let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
186 Self::spawn_heartbeat_task(
187 connection_mode.clone(),
188 *heartbeat_secs,
189 heartbeat_msg.clone(),
190 writer_tx.clone(),
191 )
192 });
193
194 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
195 let backoff = ExponentialBackoff::new(
196 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
197 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
198 reconnect_backoff_factor.unwrap_or(1.5),
199 reconnect_jitter_ms.unwrap_or(100),
200 true, );
202
203 Ok(Self {
204 config,
205 read_task,
206 write_task,
207 writer_tx,
208 heartbeat_task,
209 connection_mode,
210 reconnect_timeout,
211 backoff,
212 })
213 }
214
215 #[inline]
217 pub async fn connect_with_server(
218 url: &str,
219 headers: Vec<(String, String)>,
220 ) -> Result<(MessageWriter, MessageReader), Error> {
221 let mut request = url.into_client_request()?;
222 let req_headers = request.headers_mut();
223
224 let mut header_names: Vec<HeaderName> = Vec::new();
225 for (key, val) in headers {
226 let header_value = HeaderValue::from_str(&val)?;
227 let header_name: HeaderName = key.parse()?;
228 header_names.push(header_name.clone());
229 req_headers.insert(header_name, header_value);
230 }
231
232 connect_async(request).await.map(|resp| resp.0.split())
233 }
234
235 pub async fn reconnect(&mut self) -> Result<(), Error> {
240 tracing::debug!("Reconnecting");
241
242 tokio::time::timeout(self.reconnect_timeout, async {
243 let (new_writer, reader) =
244 Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
245
246 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
247 tracing::error!("{e}");
248 }
249
250 tokio::time::sleep(Duration::from_millis(100)).await;
252
253 if let Some(ref read_task) = self.read_task.take() {
254 if !read_task.is_finished() {
255 read_task.abort();
256 tracing::debug!("Aborted task 'read'");
257 }
258 }
259
260 self.connection_mode
261 .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
262
263 self.read_task = match &self.config.handler {
264 Consumer::Python(handler) => handler.as_ref().map(|handler| {
265 Self::spawn_python_callback_task(
266 self.connection_mode.clone(),
267 reader,
268 handler.clone(),
269 self.config.ping_handler.clone(),
270 )
271 }),
272 Consumer::Rust(sender) => Some(Self::spawn_rust_streaming_task(
273 self.connection_mode.clone(),
274 reader,
275 sender.clone(),
276 )),
277 };
278
279 tracing::debug!("Reconnect succeeded");
280 Ok(())
281 })
282 .await
283 .map_err(|_| {
284 Error::Io(std::io::Error::new(
285 std::io::ErrorKind::TimedOut,
286 format!(
287 "reconnection timed out after {}s",
288 self.reconnect_timeout.as_secs_f64()
289 ),
290 ))
291 })?
292 }
293
294 #[inline]
302 #[must_use]
303 pub fn is_alive(&self) -> bool {
304 match &self.read_task {
305 Some(read_task) => !read_task.is_finished(),
306 None => true, }
308 }
309
310 fn spawn_rust_streaming_task(
311 connection_state: Arc<AtomicU8>,
312 mut reader: MessageReader,
313 sender: Sender<Message>,
314 ) -> tokio::task::JoinHandle<()> {
315 tracing::debug!("Started streaming task 'read'");
316
317 let check_interval = Duration::from_millis(10);
318
319 tokio::task::spawn(async move {
320 loop {
321 if !ConnectionMode::from_atomic(&connection_state).is_active() {
322 break;
323 }
324
325 match tokio::time::timeout(check_interval, reader.next()).await {
326 Ok(Some(Ok(message))) => {
327 if let Err(e) = sender.send(message).await {
328 tracing::error!("Failed to send message: {e}");
329 }
330 }
331 Ok(Some(Err(e))) => {
332 tracing::error!("Received error message - terminating: {e}");
333 break;
334 }
335 Ok(None) => {
336 tracing::debug!("No message received - terminating");
337 break;
338 }
339 Err(_) => {
340 continue;
342 }
343 }
344 }
345 })
346 }
347
348 fn spawn_python_callback_task(
349 connection_state: Arc<AtomicU8>,
350 mut reader: MessageReader,
351 handler: Arc<PyObject>,
352 ping_handler: Option<Arc<PyObject>>,
353 ) -> tokio::task::JoinHandle<()> {
354 tracing::debug!("Started task 'read'");
355
356 let check_interval = Duration::from_millis(10);
358
359 tokio::task::spawn(async move {
360 loop {
361 if !ConnectionMode::from_atomic(&connection_state).is_active() {
362 break;
363 }
364
365 match tokio::time::timeout(check_interval, reader.next()).await {
366 Ok(Some(Ok(Message::Binary(data)))) => {
367 tracing::trace!("Received message <binary> {} bytes", data.len());
368 if let Err(e) =
369 Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &data),)))
370 {
371 tracing::error!("Error calling handler: {e}");
372 break;
373 }
374 continue;
375 }
376 Ok(Some(Ok(Message::Text(data)))) => {
377 tracing::trace!("Received message: {data}");
378 if let Err(e) = Python::with_gil(|py| {
379 handler.call1(py, (PyBytes::new(py, data.as_bytes()),))
380 }) {
381 tracing::error!("Error calling handler: {e}");
382 break;
383 }
384 continue;
385 }
386 Ok(Some(Ok(Message::Ping(ping)))) => {
387 tracing::trace!("Received ping: {ping:?}");
388 if let Some(ref handler) = ping_handler {
389 if let Err(e) =
390 Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &ping),)))
391 {
392 tracing::error!("Error calling handler: {e}");
393 break;
394 }
395 }
396 continue;
397 }
398 Ok(Some(Ok(Message::Pong(_)))) => {
399 tracing::trace!("Received pong");
400 }
401 Ok(Some(Ok(Message::Close(_)))) => {
402 tracing::debug!("Received close message - terminating");
403 break;
404 }
405 Ok(Some(Ok(_))) => (),
406 Ok(Some(Err(e))) => {
407 tracing::error!("Received error message - terminating: {e}");
408 break;
409 }
410 Ok(None) => {
413 tracing::debug!("No message received - terminating");
414 break;
415 }
416 Err(_) => {
417 continue;
419 }
420 }
421 }
422 })
423 }
424
425 fn spawn_write_task(
426 connection_state: Arc<AtomicU8>,
427 writer: MessageWriter,
428 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
429 ) -> tokio::task::JoinHandle<()> {
430 tracing::debug!("Started task 'write'");
431
432 let check_interval = Duration::from_millis(10);
434
435 tokio::task::spawn(async move {
436 let mut active_writer = writer;
437
438 loop {
439 match ConnectionMode::from_atomic(&connection_state) {
440 ConnectionMode::Disconnect => {
441 _ = active_writer.close().await;
444 break;
445 }
446 ConnectionMode::Closed => break,
447 _ => {}
448 }
449
450 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
451 Ok(Some(msg)) => {
452 let mode = ConnectionMode::from_atomic(&connection_state);
454 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
455 break;
456 }
457
458 match msg {
459 WriterCommand::Update(new_writer) => {
460 tracing::debug!("Received new writer");
461
462 tokio::time::sleep(Duration::from_millis(100)).await;
464
465 _ = active_writer.close().await;
468
469 active_writer = new_writer;
470 tracing::debug!("Updated writer");
471 }
472 _ if mode.is_reconnect() => {
473 tracing::warn!("Skipping message while reconnecting, {msg:?}");
474 continue;
475 }
476 WriterCommand::Send(msg) => {
477 if let Err(e) = active_writer.send(msg).await {
478 tracing::error!("Failed to send message: {e}");
479 tracing::warn!("Writer triggering reconnect");
481 connection_state
482 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
483 }
484 }
485 }
486 }
487 Ok(None) => {
488 tracing::debug!("Writer channel closed, terminating writer task");
490 break;
491 }
492 Err(_) => {
493 continue;
495 }
496 }
497 }
498
499 _ = active_writer.close().await;
502
503 tracing::debug!("Completed task 'write'");
504 })
505 }
506
507 fn spawn_heartbeat_task(
508 connection_state: Arc<AtomicU8>,
509 heartbeat_secs: u64,
510 message: Option<String>,
511 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
512 ) -> tokio::task::JoinHandle<()> {
513 tracing::debug!("Started task 'heartbeat'");
514
515 tokio::task::spawn(async move {
516 let interval = Duration::from_secs(heartbeat_secs);
517
518 loop {
519 tokio::time::sleep(interval).await;
520
521 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
522 ConnectionMode::Active => {
523 let msg = match &message {
524 Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
525 None => WriterCommand::Send(Message::Ping(vec![].into())),
526 };
527
528 match writer_tx.send(msg) {
529 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
530 Err(e) => {
531 tracing::error!("Failed to send heartbeat to writer task: {e}");
532 }
533 }
534 }
535 ConnectionMode::Reconnect => continue,
536 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
537 }
538 }
539
540 tracing::debug!("Completed task 'heartbeat'");
541 })
542 }
543}
544
545impl Drop for WebSocketClientInner {
546 fn drop(&mut self) {
547 if let Some(ref read_task) = self.read_task.take() {
548 if !read_task.is_finished() {
549 read_task.abort();
550 tracing::debug!("Aborted task 'read'");
551 }
552 }
553
554 if !self.write_task.is_finished() {
555 self.write_task.abort();
556 tracing::debug!("Aborted task 'write'");
557 }
558
559 if let Some(ref handle) = self.heartbeat_task.take() {
560 if !handle.is_finished() {
561 handle.abort();
562 tracing::debug!("Aborted task 'heartbeat'");
563 }
564 }
565 }
566}
567
568#[cfg_attr(
573 feature = "python",
574 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
575)]
576pub struct WebSocketClient {
577 pub(crate) controller_task: tokio::task::JoinHandle<()>,
578 pub(crate) connection_mode: Arc<AtomicU8>,
579 pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
580 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
581}
582
583impl WebSocketClient {
584 #[allow(clippy::too_many_arguments)]
590 pub async fn connect_stream(
591 config: WebSocketConfig,
592 keyed_quotas: Vec<(String, Quota)>,
593 default_quota: Option<Quota>,
594 post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
595 ) -> Result<(MessageReader, Self), Error> {
596 install_cryptographic_provider();
597 let (ws_stream, _) = connect_async(config.url.clone().into_client_request()?).await?;
598 let (writer, reader) = ws_stream.split();
599 let inner = WebSocketClientInner::connect_url(config).await?;
600
601 let connection_mode = inner.connection_mode.clone();
602 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
603 let writer_tx = inner.writer_tx.clone();
604 if let Err(e) = writer_tx.send(WriterCommand::Update(writer)) {
605 tracing::error!("{e}");
606 }
607
608 let controller_task = Self::spawn_controller_task(
609 inner,
610 connection_mode.clone(),
611 post_reconnect,
612 None, None, );
615
616 Ok((
617 reader,
618 Self {
619 controller_task,
620 connection_mode,
621 writer_tx,
622 rate_limiter,
623 },
624 ))
625 }
626
627 pub async fn connect(
636 config: WebSocketConfig,
637 post_connection: Option<PyObject>,
638 post_reconnection: Option<PyObject>,
639 post_disconnection: Option<PyObject>,
640 keyed_quotas: Vec<(String, Quota)>,
641 default_quota: Option<Quota>,
642 ) -> Result<Self, Error> {
643 tracing::debug!("Connecting");
644 let inner = WebSocketClientInner::connect_url(config.clone()).await?;
645 let connection_mode = inner.connection_mode.clone();
646 let writer_tx = inner.writer_tx.clone();
647
648 let controller_task = Self::spawn_controller_task(
649 inner,
650 connection_mode.clone(),
651 None, post_reconnection, post_disconnection, );
655 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
656
657 if let Some(handler) = post_connection {
658 Python::with_gil(|py| match handler.call0(py) {
659 Ok(_) => tracing::debug!("Called `post_connection` handler"),
660 Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
661 });
662 }
663
664 Ok(Self {
665 controller_task,
666 connection_mode,
667 writer_tx,
668 rate_limiter,
669 })
670 }
671
672 #[must_use]
674 pub fn connection_mode(&self) -> ConnectionMode {
675 ConnectionMode::from_atomic(&self.connection_mode)
676 }
677
678 #[inline]
683 #[must_use]
684 pub fn is_active(&self) -> bool {
685 self.connection_mode().is_active()
686 }
687
688 #[must_use]
690 pub fn is_disconnected(&self) -> bool {
691 self.controller_task.is_finished()
692 }
693
694 #[inline]
699 #[must_use]
700 pub fn is_reconnecting(&self) -> bool {
701 self.connection_mode().is_reconnect()
702 }
703
704 #[inline]
708 #[must_use]
709 pub fn is_disconnecting(&self) -> bool {
710 self.connection_mode().is_disconnect()
711 }
712
713 #[inline]
719 #[must_use]
720 pub fn is_closed(&self) -> bool {
721 self.connection_mode().is_closed()
722 }
723
724 pub async fn disconnect(&self) {
729 tracing::debug!("Disconnecting");
730 self.connection_mode
731 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
732
733 match tokio::time::timeout(Duration::from_secs(5), async {
734 while !self.is_disconnected() {
735 tokio::time::sleep(Duration::from_millis(10)).await;
736 }
737
738 if !self.controller_task.is_finished() {
739 self.controller_task.abort();
740 tracing::debug!("Aborted task 'controller'");
741 }
742 })
743 .await
744 {
745 Ok(()) => {
746 tracing::debug!("Controller task finished");
747 }
748 Err(_) => {
749 tracing::error!("Timeout waiting for controller task to finish");
750 }
751 }
752 }
753
754 pub async fn send_text(&self, data: String, keys: Option<Vec<String>>) {
756 self.rate_limiter.await_keys_ready(keys).await;
757
758 if !self.is_active() {
759 tracing::error!("Cannot send data - connection not active");
760 return;
761 }
762
763 tracing::trace!("Sending text: {data:?}");
764
765 let msg = Message::Text(data.into());
766 if let Err(e) = self.writer_tx.send(WriterCommand::Send(msg)) {
767 tracing::error!("Error sending message: {e}");
768 }
769 }
770
771 pub async fn send_bytes(&self, data: Vec<u8>, keys: Option<Vec<String>>) {
773 self.rate_limiter.await_keys_ready(keys).await;
774
775 if !self.is_active() {
776 tracing::error!("Cannot send data - connection not active");
777 return;
778 }
779
780 tracing::trace!("Sending bytes: {data:?}");
781
782 let msg = Message::Binary(data.into());
783 if let Err(e) = self.writer_tx.send(WriterCommand::Send(msg)) {
784 tracing::error!("Error sending message: {e}");
785 }
786 }
787
788 pub async fn send_close_message(&self) {
790 if !self.is_active() {
791 tracing::error!("Cannot send close message - connection not active");
792 return;
793 }
794
795 let msg = Message::Close(None);
796 if let Err(e) = self.writer_tx.send(WriterCommand::Send(msg)) {
797 tracing::error!("Error sending close message: {e}");
798 }
799 }
800
801 fn spawn_controller_task(
802 mut inner: WebSocketClientInner,
803 connection_mode: Arc<AtomicU8>,
804 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
805 py_post_reconnection: Option<PyObject>, py_post_disconnection: Option<PyObject>, ) -> tokio::task::JoinHandle<()> {
808 tokio::task::spawn(async move {
809 tracing::debug!("Started task 'controller'");
810
811 let check_interval = Duration::from_millis(10);
812
813 loop {
814 tokio::time::sleep(check_interval).await;
815 let mode = ConnectionMode::from_atomic(&connection_mode);
816
817 if mode.is_disconnect() {
818 tracing::debug!("Disconnecting");
819
820 let timeout = Duration::from_secs(5);
821 if tokio::time::timeout(timeout, async {
822 tokio::time::sleep(Duration::from_millis(100)).await;
824
825 if let Some(task) = &inner.read_task {
826 if !task.is_finished() {
827 task.abort();
828 tracing::debug!("Aborted task 'read'");
829 }
830 }
831
832 if let Some(task) = &inner.heartbeat_task {
833 if !task.is_finished() {
834 task.abort();
835 tracing::debug!("Aborted task 'heartbeat'");
836 }
837 }
838 })
839 .await
840 .is_err()
841 {
842 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
843 }
844
845 tracing::debug!("Closed");
846
847 if let Some(ref handler) = py_post_disconnection {
848 Python::with_gil(|py| match handler.call0(py) {
849 Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
850 Err(e) => {
851 tracing::error!("Error calling `post_disconnection` handler: {e}");
852 }
853 });
854 }
855 break; }
857
858 if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
859 match inner.reconnect().await {
860 Ok(()) => {
861 tracing::debug!("Reconnected successfully");
862 inner.backoff.reset();
863
864 if let Some(ref callback) = post_reconnection {
865 callback();
866 }
867
868 if let Some(ref handler) = py_post_reconnection {
870 Python::with_gil(|py| match handler.call0(py) {
871 Ok(_) => {
872 tracing::debug!("Called `post_reconnection` handler");
873 }
874 Err(e) => tracing::error!(
875 "Error calling `post_reconnection` handler: {e}"
876 ),
877 });
878 }
879 }
880 Err(e) => {
881 let duration = inner.backoff.next_duration();
882 tracing::warn!("Reconnect attempt failed: {e}");
883 if !duration.is_zero() {
884 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
885 }
886 tokio::time::sleep(duration).await;
887 }
888 }
889 }
890 }
891 inner
892 .connection_mode
893 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
894
895 tracing::debug!("Completed task 'controller'");
896 })
897 }
898}
899
900#[cfg(test)]
904#[cfg(target_os = "linux")] mod tests {
906 use std::{num::NonZeroU32, sync::Arc};
907
908 use futures_util::{SinkExt, StreamExt};
909 use tokio::{
910 net::TcpListener,
911 task::{self, JoinHandle},
912 };
913 use tokio_tungstenite::{
914 accept_hdr_async,
915 tungstenite::{
916 handshake::server::{self, Callback},
917 http::HeaderValue,
918 },
919 };
920
921 use crate::{
922 ratelimiter::quota::Quota,
923 websocket::{Consumer, WebSocketClient, WebSocketConfig},
924 };
925
926 struct TestServer {
927 task: JoinHandle<()>,
928 port: u16,
929 }
930
931 #[derive(Debug, Clone)]
932 struct TestCallback {
933 key: String,
934 value: HeaderValue,
935 }
936
937 impl Callback for TestCallback {
938 fn on_request(
939 self,
940 request: &server::Request,
941 response: server::Response,
942 ) -> Result<server::Response, server::ErrorResponse> {
943 let _ = response;
944 let value = request.headers().get(&self.key);
945 assert!(value.is_some());
946
947 if let Some(value) = request.headers().get(&self.key) {
948 assert_eq!(value, self.value);
949 }
950
951 Ok(response)
952 }
953 }
954
955 impl TestServer {
956 async fn setup() -> Self {
957 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
958 let port = TcpListener::local_addr(&server).unwrap().port();
959
960 let header_key = "test".to_string();
961 let header_value = "test".to_string();
962
963 let test_call_back = TestCallback {
964 key: header_key,
965 value: HeaderValue::from_str(&header_value).unwrap(),
966 };
967
968 let task = task::spawn(async move {
969 loop {
971 let (conn, _) = server.accept().await.unwrap();
972 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
973 .await
974 .unwrap();
975
976 task::spawn(async move {
977 while let Some(Ok(msg)) = websocket.next().await {
978 match msg {
979 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
980 if txt == "close-now" =>
981 {
982 tracing::debug!("Forcibly closing from server side");
983 let _ = websocket.close(None).await;
985 break;
986 }
987 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
989 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
990 if websocket.send(msg).await.is_err() {
991 break;
992 }
993 }
994 tokio_tungstenite::tungstenite::protocol::Message::Close(
996 _frame,
997 ) => {
998 let _ = websocket.close(None).await;
999 break;
1000 }
1001 _ => {}
1003 }
1004 }
1005 });
1006 }
1007 });
1008
1009 Self { task, port }
1010 }
1011 }
1012
1013 impl Drop for TestServer {
1014 fn drop(&mut self) {
1015 self.task.abort();
1016 }
1017 }
1018
1019 async fn setup_test_client(port: u16) -> WebSocketClient {
1020 let config = WebSocketConfig {
1021 url: format!("ws://127.0.0.1:{port}"),
1022 headers: vec![("test".into(), "test".into())],
1023 handler: Consumer::Python(None),
1024 heartbeat: None,
1025 heartbeat_msg: None,
1026 ping_handler: None,
1027 reconnect_timeout_ms: None,
1028 reconnect_delay_initial_ms: None,
1029 reconnect_backoff_factor: None,
1030 reconnect_delay_max_ms: None,
1031 reconnect_jitter_ms: None,
1032 };
1033 WebSocketClient::connect(config, None, None, None, vec![], None)
1034 .await
1035 .expect("Failed to connect")
1036 }
1037
1038 #[tokio::test]
1039 async fn test_websocket_basic() {
1040 let server = TestServer::setup().await;
1041 let client = setup_test_client(server.port).await;
1042
1043 assert!(!client.is_disconnected());
1044
1045 client.disconnect().await;
1046 assert!(client.is_disconnected());
1047 }
1048
1049 #[tokio::test]
1050 async fn test_websocket_heartbeat() {
1051 let server = TestServer::setup().await;
1052 let client = setup_test_client(server.port).await;
1053
1054 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1056
1057 client.disconnect().await;
1059 assert!(client.is_disconnected());
1060 }
1061
1062 #[tokio::test]
1063 async fn test_websocket_reconnect_exhausted() {
1064 let config = WebSocketConfig {
1065 url: "ws://127.0.0.1:9997".into(), headers: vec![],
1067 handler: Consumer::Python(None),
1068 heartbeat: None,
1069 heartbeat_msg: None,
1070 ping_handler: None,
1071 reconnect_timeout_ms: None,
1072 reconnect_delay_initial_ms: None,
1073 reconnect_backoff_factor: None,
1074 reconnect_delay_max_ms: None,
1075 reconnect_jitter_ms: None,
1076 };
1077 let res = WebSocketClient::connect(config, None, None, None, vec![], None).await;
1078 assert!(res.is_err(), "Should fail quickly with no server");
1079 }
1080
1081 #[tokio::test]
1082 async fn test_websocket_forced_close_reconnect() {
1083 let server = TestServer::setup().await;
1084 let client = setup_test_client(server.port).await;
1085
1086 client.send_text("Hello".into(), None).await;
1088
1089 client.send_text("close-now".into(), None).await;
1091
1092 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1094
1095 assert!(!client.is_disconnected());
1097
1098 client.disconnect().await;
1100 assert!(client.is_disconnected());
1101 }
1102
1103 #[tokio::test]
1104 async fn test_rate_limiter() {
1105 let server = TestServer::setup().await;
1106 let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1107
1108 let config = WebSocketConfig {
1109 url: format!("ws://127.0.0.1:{}", server.port),
1110 headers: vec![("test".into(), "test".into())],
1111 handler: Consumer::Python(None),
1112 heartbeat: None,
1113 heartbeat_msg: None,
1114 ping_handler: None,
1115 reconnect_timeout_ms: None,
1116 reconnect_delay_initial_ms: None,
1117 reconnect_backoff_factor: None,
1118 reconnect_delay_max_ms: None,
1119 reconnect_jitter_ms: None,
1120 };
1121
1122 let client = WebSocketClient::connect(
1123 config,
1124 None,
1125 None,
1126 None,
1127 vec![("default".into(), quota)],
1128 None,
1129 )
1130 .await
1131 .unwrap();
1132
1133 client.send_text("test1".into(), None).await;
1135 client.send_text("test2".into(), None).await;
1136
1137 client.send_text("test3".into(), None).await;
1139
1140 client.disconnect().await;
1142 assert!(client.is_disconnected());
1143 }
1144
1145 #[tokio::test]
1146 async fn test_concurrent_writers() {
1147 let server = TestServer::setup().await;
1148 let client = Arc::new(setup_test_client(server.port).await);
1149
1150 let mut handles = vec![];
1151 for i in 0..10 {
1152 let client = client.clone();
1153 handles.push(task::spawn(async move {
1154 client.send_text(format!("test{i}"), None).await;
1155 }));
1156 }
1157
1158 for handle in handles {
1159 handle.await.unwrap();
1160 }
1161
1162 client.disconnect().await;
1164 assert!(client.is_disconnected());
1165 }
1166}