1use std::{
33 sync::{
34 atomic::{AtomicU8, Ordering},
35 Arc,
36 },
37 time::Duration,
38};
39
40use futures_util::{
41 stream::{SplitSink, SplitStream},
42 SinkExt, StreamExt,
43};
44use http::HeaderName;
45use nautilus_cryptography::providers::install_cryptographic_provider;
46use pyo3::{prelude::*, types::PyBytes};
47use tokio::{net::TcpStream, sync::Mutex};
48use tokio_tungstenite::{
49 connect_async,
50 tungstenite::{client::IntoClientRequest, http::HeaderValue, Error, Message},
51 MaybeTlsStream, WebSocketStream,
52};
53
54use crate::{
55 backoff::ExponentialBackoff,
56 mode::ConnectionMode,
57 ratelimiter::{clock::MonotonicClock, quota::Quota, RateLimiter},
58};
59type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
60type SharedMessageWriter =
61 Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>;
62pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
63
64#[derive(Debug, Clone)]
65#[cfg_attr(
66 feature = "python",
67 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
68)]
69pub struct WebSocketConfig {
70 pub url: String,
72 pub headers: Vec<(String, String)>,
74 pub handler: Option<Arc<PyObject>>,
76 pub heartbeat: Option<u64>,
78 pub heartbeat_msg: Option<String>,
80 pub ping_handler: Option<Arc<PyObject>>,
82 pub reconnect_timeout_ms: Option<u64>,
84 pub reconnect_delay_initial_ms: Option<u64>,
86 pub reconnect_delay_max_ms: Option<u64>,
88 pub reconnect_backoff_factor: Option<f64>,
90 pub reconnect_jitter_ms: Option<u64>,
92}
93
94struct WebSocketClientInner {
110 config: WebSocketConfig,
111 read_task: Option<tokio::task::JoinHandle<()>>,
112 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
113 writer: SharedMessageWriter,
114 connection_mode: Arc<AtomicU8>,
115 reconnect_timeout: Duration,
116 backoff: ExponentialBackoff,
117}
118
119impl WebSocketClientInner {
120 pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
122 install_cryptographic_provider();
123
124 #[allow(unused_variables)]
125 let WebSocketConfig {
126 url,
127 handler,
128 heartbeat,
129 headers,
130 heartbeat_msg,
131 ping_handler,
132 reconnect_timeout_ms,
133 reconnect_delay_initial_ms,
134 reconnect_delay_max_ms,
135 reconnect_backoff_factor,
136 reconnect_jitter_ms,
137 } = &config;
138 let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
139 let writer = Arc::new(Mutex::new(writer));
140
141 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
142
143 let read_task = handler
145 .as_ref()
146 .map(|handler| Self::spawn_read_task(reader, handler.clone(), ping_handler.clone()));
147
148 let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
150 Self::spawn_heartbeat_task(
151 connection_mode.clone(),
152 *heartbeat_secs,
153 heartbeat_msg.clone(),
154 writer.clone(),
155 )
156 });
157
158 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
159 let backoff = ExponentialBackoff::new(
160 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
161 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
162 reconnect_backoff_factor.unwrap_or(1.5),
163 reconnect_jitter_ms.unwrap_or(100),
164 true, );
166
167 Ok(Self {
168 config,
169 read_task,
170 heartbeat_task,
171 writer,
172 connection_mode,
173 reconnect_timeout,
174 backoff,
175 })
176 }
177
178 #[inline]
180 pub async fn connect_with_server(
181 url: &str,
182 headers: Vec<(String, String)>,
183 ) -> Result<(MessageWriter, MessageReader), Error> {
184 let mut request = url.into_client_request()?;
185 let req_headers = request.headers_mut();
186
187 let mut header_names: Vec<HeaderName> = Vec::new();
188 for (key, val) in headers {
189 let header_value = HeaderValue::from_str(&val)?;
190 let header_name: HeaderName = key.parse()?;
191 header_names.push(header_name.clone());
192 req_headers.insert(header_name, header_value);
193 }
194
195 connect_async(request).await.map(|resp| resp.0.split())
196 }
197
198 pub async fn reconnect(&mut self) -> Result<(), Error> {
203 tracing::debug!("Reconnecting");
204
205 tokio::time::timeout(self.reconnect_timeout, async {
206 shutdown(
207 self.read_task.take(),
208 self.heartbeat_task.take(),
209 self.writer.clone(),
210 )
211 .await;
212
213 let (new_writer, reader) =
214 Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
215
216 {
217 let mut guard = self.writer.lock().await;
218 *guard = new_writer;
219 drop(guard);
220 }
221
222 if let Some(ref handler) = self.config.handler {
224 self.read_task = Some(Self::spawn_read_task(
225 reader,
226 handler.clone(),
227 self.config.ping_handler.clone(),
228 ));
229 }
230
231 self.heartbeat_task = self.config.heartbeat.as_ref().map(|heartbeat_secs| {
233 Self::spawn_heartbeat_task(
234 self.connection_mode.clone(),
235 *heartbeat_secs,
236 self.config.heartbeat_msg.clone(),
237 self.writer.clone(),
238 )
239 });
240
241 self.connection_mode
242 .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
243
244 tracing::debug!("Reconnect succeeded");
245 Ok(())
246 })
247 .await
248 .map_err(|_| {
249 Error::Io(std::io::Error::new(
250 std::io::ErrorKind::TimedOut,
251 format!(
252 "reconnection timed out after {}s",
253 self.reconnect_timeout.as_secs_f64()
254 ),
255 ))
256 })?
257 }
258
259 #[inline]
267 #[must_use]
268 pub fn is_alive(&self) -> bool {
269 match &self.read_task {
270 Some(read_task) => !read_task.is_finished(),
271 None => true, }
273 }
274
275 fn spawn_read_task(
276 mut reader: MessageReader,
277 handler: Arc<PyObject>,
278 ping_handler: Option<Arc<PyObject>>,
279 ) -> tokio::task::JoinHandle<()> {
280 tracing::debug!("Started task 'read'");
281
282 tokio::task::spawn(async move {
283 loop {
284 match reader.next().await {
285 Some(Ok(Message::Binary(data))) => {
286 tracing::trace!("Received message <binary> {} bytes", data.len());
287 if let Err(e) =
288 Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &data),)))
289 {
290 tracing::error!("Error calling handler: {e}");
291 break;
292 }
293 continue;
294 }
295 Some(Ok(Message::Text(data))) => {
296 tracing::trace!("Received message: {data}");
297 if let Err(e) = Python::with_gil(|py| {
298 handler.call1(py, (PyBytes::new(py, data.as_bytes()),))
299 }) {
300 tracing::error!("Error calling handler: {e}");
301 break;
302 }
303 continue;
304 }
305 Some(Ok(Message::Ping(ping))) => {
306 tracing::trace!("Received ping: {ping:?}",);
307 if let Some(ref handler) = ping_handler {
308 if let Err(e) =
309 Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &ping),)))
310 {
311 tracing::error!("Error calling handler: {e}");
312 break;
313 }
314 }
315 continue;
316 }
317 Some(Ok(Message::Pong(_))) => {
318 tracing::trace!("Received pong");
319 }
320 Some(Ok(Message::Close(_))) => {
321 tracing::debug!("Received close message - terminating");
322 break;
323 }
324 Some(Ok(_)) => (),
325 Some(Err(e)) => {
326 tracing::error!("Received error message - terminating: {e}");
327 break;
328 }
329 None => {
332 tracing::debug!("No message received - terminating");
333 break;
334 }
335 }
336 }
337 })
338 }
339
340 fn spawn_heartbeat_task(
341 connection_state: Arc<AtomicU8>,
342 heartbeat_secs: u64,
343 message: Option<String>,
344 writer: SharedMessageWriter,
345 ) -> tokio::task::JoinHandle<()> {
346 tracing::debug!("Started task 'heartbeat'");
347
348 tokio::task::spawn(async move {
349 let interval = Duration::from_secs(heartbeat_secs);
350 loop {
351 tokio::time::sleep(interval).await;
352
353 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
354 ConnectionMode::Active => {
355 let mut guard = writer.lock().await;
356 let guard_send_response = match message.clone() {
357 Some(msg) => guard.send(Message::Text(msg.into())).await,
358 None => guard.send(Message::Ping(vec![].into())).await,
359 };
360
361 match guard_send_response {
362 Ok(()) => tracing::trace!("Sent ping"),
363 Err(e) => tracing::error!("Error sending ping: {e}"),
364 }
365 }
366 ConnectionMode::Reconnect => continue,
367 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
368 }
369 }
370
371 tracing::debug!("Completed task 'heartbeat'");
372 })
373 }
374}
375
376async fn shutdown(
386 read_task: Option<tokio::task::JoinHandle<()>>,
387 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
388 writer: SharedMessageWriter,
389) {
390 tracing::debug!("Closing");
391
392 let timeout = Duration::from_secs(5);
393 if tokio::time::timeout(timeout, async {
394 let mut write_half = writer.lock().await;
396 if let Err(e) = write_half.send(Message::Close(None)).await {
397 tracing::debug!("Error sending close frame: {e}");
399 }
400 drop(write_half);
401
402 tokio::time::sleep(Duration::from_millis(100)).await;
403
404 if let Some(task) = read_task {
406 if !task.is_finished() {
407 task.abort();
408 tracing::debug!("Aborted read task");
409 }
410 }
411 if let Some(task) = heartbeat_task {
412 if !task.is_finished() {
413 task.abort();
414 tracing::debug!("Aborted heartbeat task");
415 }
416 }
417
418 let mut write_half = writer.lock().await;
420 if let Err(e) = write_half.close().await {
421 tracing::error!("Error closing writer: {e}");
422 }
423 })
424 .await
425 .is_err()
426 {
427 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
428 }
429
430 tracing::debug!("Closed");
431}
432
433impl Drop for WebSocketClientInner {
434 fn drop(&mut self) {
435 if let Some(ref read_task) = self.read_task.take() {
436 if !read_task.is_finished() {
437 read_task.abort();
438 }
439 }
440
441 if let Some(ref handle) = self.heartbeat_task.take() {
443 if !handle.is_finished() {
444 handle.abort();
445 }
446 }
447 }
448}
449
450#[cfg_attr(
455 feature = "python",
456 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
457)]
458pub struct WebSocketClient {
459 pub(crate) writer: SharedMessageWriter,
460 pub(crate) controller_task: tokio::task::JoinHandle<()>,
461 pub(crate) connection_mode: Arc<AtomicU8>,
462 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
463}
464
465impl WebSocketClient {
466 #[allow(clippy::too_many_arguments)]
468 pub async fn connect_stream(
469 config: WebSocketConfig,
470 keyed_quotas: Vec<(String, Quota)>,
471 default_quota: Option<Quota>,
472 ) -> Result<(MessageReader, Self), Error> {
473 let (ws_stream, _) = connect_async(config.url.clone().into_client_request()?).await?;
474 let (writer, reader) = ws_stream.split();
475 let writer = Arc::new(Mutex::new(writer));
476
477 let inner = WebSocketClientInner::connect_url(config).await?;
478 let connection_mode = inner.connection_mode.clone();
479 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
480
481 let controller_task = Self::spawn_controller_task(
482 inner,
483 connection_mode.clone(),
484 None, None, );
487
488 Ok((
489 reader,
490 Self {
491 writer: writer.clone(),
492 controller_task,
493 connection_mode,
494 rate_limiter,
495 },
496 ))
497 }
498
499 pub async fn connect(
504 config: WebSocketConfig,
505 post_connection: Option<PyObject>,
506 post_reconnection: Option<PyObject>,
507 post_disconnection: Option<PyObject>,
508 keyed_quotas: Vec<(String, Quota)>,
509 default_quota: Option<Quota>,
510 ) -> Result<Self, Error> {
511 tracing::debug!("Connecting");
512 let inner = WebSocketClientInner::connect_url(config.clone()).await?;
513 let writer = inner.writer.clone();
514 let connection_mode = inner.connection_mode.clone();
515
516 let controller_task = Self::spawn_controller_task(
517 inner,
518 connection_mode.clone(),
519 post_reconnection,
520 post_disconnection,
521 );
522 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
523
524 if let Some(handler) = post_connection {
525 Python::with_gil(|py| match handler.call0(py) {
526 Ok(_) => tracing::debug!("Called `post_connection` handler"),
527 Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
528 });
529 };
530
531 Ok(Self {
532 writer,
533 controller_task,
534 connection_mode: connection_mode.clone(),
535 rate_limiter,
536 })
537 }
538
539 pub fn connection_mode(&self) -> ConnectionMode {
541 ConnectionMode::from_atomic(&self.connection_mode)
542 }
543
544 #[inline]
549 #[must_use]
550 pub fn is_active(&self) -> bool {
551 self.connection_mode().is_active()
552 }
553
554 #[must_use]
556 pub fn is_disconnected(&self) -> bool {
557 self.controller_task.is_finished()
558 }
559
560 #[inline]
565 #[must_use]
566 pub fn is_reconnecting(&self) -> bool {
567 self.connection_mode().is_reconnect()
568 }
569
570 #[inline]
574 #[must_use]
575 pub fn is_disconnecting(&self) -> bool {
576 self.connection_mode().is_disconnect()
577 }
578
579 #[inline]
585 #[must_use]
586 pub fn is_closed(&self) -> bool {
587 self.connection_mode().is_closed()
588 }
589
590 pub async fn disconnect(&self) {
595 tracing::debug!("Disconnecting");
596 self.connection_mode
597 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
598
599 match tokio::time::timeout(Duration::from_secs(5), async {
600 while !self.is_disconnected() {
601 tokio::time::sleep(Duration::from_millis(10)).await;
602 }
603
604 if !self.controller_task.is_finished() {
605 self.controller_task.abort();
606 tracing::debug!("Aborted controller task");
607 }
608 })
609 .await
610 {
611 Ok(()) => {
612 tracing::debug!("Controller task finished");
613 }
614 Err(_) => {
615 tracing::error!("Timeout waiting for controller task to finish");
616 }
617 }
618 }
619
620 pub async fn send_text(&self, data: String, keys: Option<Vec<String>>) -> Result<(), Error> {
621 self.rate_limiter.await_keys_ready(keys).await;
622 tracing::trace!("Sending text: {data:?}");
623 let mut guard = self.writer.lock().await;
624 guard.send(Message::Text(data.into())).await
625 }
626
627 pub async fn send_bytes(&self, data: Vec<u8>, keys: Option<Vec<String>>) -> Result<(), Error> {
628 self.rate_limiter.await_keys_ready(keys).await;
629 tracing::trace!("Sending bytes: {data:?}");
630 let mut guard = self.writer.lock().await;
631 guard.send(Message::Binary(data.into())).await
632 }
633
634 pub async fn send_close_message(&self) {
635 let mut guard = self.writer.lock().await;
636 match guard.send(Message::Close(None)).await {
637 Ok(()) => tracing::debug!("Sent close message"),
638 Err(e) => tracing::error!("Error sending close message: {e}"),
639 }
640 }
641
642 fn spawn_controller_task(
643 mut inner: WebSocketClientInner,
644 connection_mode: Arc<AtomicU8>,
645 post_reconnection: Option<PyObject>,
646 post_disconnection: Option<PyObject>,
647 ) -> tokio::task::JoinHandle<()> {
648 tokio::task::spawn(async move {
649 tracing::debug!("Starting task 'controller'");
650
651 let check_interval = Duration::from_millis(10);
652
653 loop {
654 tokio::time::sleep(check_interval).await;
655 let mode = ConnectionMode::from_atomic(&connection_mode);
656
657 if mode.is_disconnect() {
658 tracing::debug!("Disconnecting");
659 shutdown(
660 inner.read_task.take(),
661 inner.heartbeat_task.take(),
662 inner.writer.clone(),
663 )
664 .await;
665
666 if let Some(ref handler) = post_disconnection {
667 Python::with_gil(|py| match handler.call0(py) {
668 Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
669 Err(e) => {
670 tracing::error!("Error calling `post_disconnection` handler: {e}")
671 }
672 });
673 }
674 break; }
676
677 if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
678 match inner.reconnect().await {
679 Ok(()) => {
680 tracing::debug!("Reconnected successfully");
681 inner.backoff.reset();
682
683 if let Some(ref handler) = post_reconnection {
684 Python::with_gil(|py| match handler.call0(py) {
685 Ok(_) => tracing::debug!("Called `post_reconnection` handler"),
686 Err(e) => tracing::error!(
687 "Error calling `post_reconnection` handler: {e}"
688 ),
689 });
690 }
691 }
692 Err(e) => {
693 let duration = inner.backoff.next_duration();
694 tracing::warn!("Reconnect attempt failed: {e}");
695 if !duration.is_zero() {
696 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
697 }
698 tokio::time::sleep(duration).await;
699 }
700 }
701 }
702 }
703 inner
704 .connection_mode
705 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
706 })
707 }
708}
709
710#[cfg(test)]
714#[cfg(target_os = "linux")] mod tests {
716 use std::{num::NonZeroU32, sync::Arc};
717
718 use futures_util::{SinkExt, StreamExt};
719 use tokio::{
720 net::TcpListener,
721 task::{self, JoinHandle},
722 };
723 use tokio_tungstenite::{
724 accept_hdr_async,
725 tungstenite::{
726 handshake::server::{self, Callback},
727 http::HeaderValue,
728 },
729 };
730
731 use crate::{
732 ratelimiter::quota::Quota,
733 websocket::{WebSocketClient, WebSocketConfig},
734 };
735
736 struct TestServer {
737 task: JoinHandle<()>,
738 port: u16,
739 }
740
741 #[derive(Debug, Clone)]
742 struct TestCallback {
743 key: String,
744 value: HeaderValue,
745 }
746
747 impl Callback for TestCallback {
748 fn on_request(
749 self,
750 request: &server::Request,
751 response: server::Response,
752 ) -> Result<server::Response, server::ErrorResponse> {
753 let _ = response;
754 let value = request.headers().get(&self.key);
755 assert!(value.is_some());
756
757 if let Some(value) = request.headers().get(&self.key) {
758 assert_eq!(value, self.value);
759 }
760
761 Ok(response)
762 }
763 }
764
765 impl TestServer {
766 async fn setup() -> Self {
767 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
768 let port = TcpListener::local_addr(&server).unwrap().port();
769
770 let header_key = "test".to_string();
771 let header_value = "test".to_string();
772
773 let test_call_back = TestCallback {
774 key: header_key,
775 value: HeaderValue::from_str(&header_value).unwrap(),
776 };
777
778 let task = task::spawn(async move {
779 loop {
781 let (conn, _) = server.accept().await.unwrap();
782 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
783 .await
784 .unwrap();
785
786 task::spawn(async move {
787 while let Some(Ok(msg)) = websocket.next().await {
788 match msg {
789 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
790 if txt == "close-now" =>
791 {
792 tracing::debug!("Forcibly closing from server side");
793 let _ = websocket.close(None).await;
795 break;
796 }
797 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
799 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
800 if websocket.send(msg).await.is_err() {
801 break;
802 }
803 }
804 tokio_tungstenite::tungstenite::protocol::Message::Close(
806 _frame,
807 ) => {
808 let _ = websocket.close(None).await;
809 break;
810 }
811 _ => {}
813 }
814 }
815 });
816 }
817 });
818
819 Self { task, port }
820 }
821 }
822
823 impl Drop for TestServer {
824 fn drop(&mut self) {
825 self.task.abort();
826 }
827 }
828
829 async fn setup_test_client(port: u16) -> WebSocketClient {
830 let config = WebSocketConfig {
831 url: format!("ws://127.0.0.1:{port}"),
832 headers: vec![("test".into(), "test".into())],
833 handler: None,
834 heartbeat: None,
835 heartbeat_msg: None,
836 ping_handler: None,
837 reconnect_timeout_ms: None,
838 reconnect_delay_initial_ms: None,
839 reconnect_backoff_factor: None,
840 reconnect_delay_max_ms: None,
841 reconnect_jitter_ms: None,
842 };
843 WebSocketClient::connect(config, None, None, None, vec![], None)
844 .await
845 .expect("Failed to connect")
846 }
847
848 #[tokio::test]
849 async fn test_websocket_basic() {
850 let server = TestServer::setup().await;
851 let client = setup_test_client(server.port).await;
852
853 assert!(!client.is_disconnected());
854
855 client.disconnect().await;
856 assert!(client.is_disconnected());
857 }
858
859 #[tokio::test]
860 async fn test_websocket_heartbeat() {
861 let server = TestServer::setup().await;
862 let client = setup_test_client(server.port).await;
863
864 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
866
867 client.disconnect().await;
869 assert!(client.is_disconnected());
870 }
871
872 #[tokio::test]
873 async fn test_websocket_reconnect_exhausted() {
874 let config = WebSocketConfig {
875 url: "ws://127.0.0.1:9997".into(), headers: vec![],
877 handler: None,
878 heartbeat: None,
879 heartbeat_msg: None,
880 ping_handler: None,
881 reconnect_timeout_ms: None,
882 reconnect_delay_initial_ms: None,
883 reconnect_backoff_factor: None,
884 reconnect_delay_max_ms: None,
885 reconnect_jitter_ms: None,
886 };
887 let res = WebSocketClient::connect(config, None, None, None, vec![], None).await;
888 assert!(res.is_err(), "Should fail quickly with no server");
889 }
890
891 #[tokio::test]
892 async fn test_websocket_forced_close_reconnect() {
893 let server = TestServer::setup().await;
894 let client = setup_test_client(server.port).await;
895
896 client.send_text("Hello".into(), None).await.unwrap();
898
899 client.send_text("close-now".into(), None).await.unwrap();
901
902 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
904
905 assert!(!client.is_disconnected());
907
908 client.disconnect().await;
910 assert!(client.is_disconnected());
911 }
912
913 #[tokio::test]
914 async fn test_rate_limiter() {
915 let server = TestServer::setup().await;
916 let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
917
918 let config = WebSocketConfig {
919 url: format!("ws://127.0.0.1:{}", server.port),
920 headers: vec![("test".into(), "test".into())],
921 handler: None,
922 heartbeat: None,
923 heartbeat_msg: None,
924 ping_handler: None,
925 reconnect_timeout_ms: None,
926 reconnect_delay_initial_ms: None,
927 reconnect_backoff_factor: None,
928 reconnect_delay_max_ms: None,
929 reconnect_jitter_ms: None,
930 };
931
932 let client = WebSocketClient::connect(
933 config,
934 None,
935 None,
936 None,
937 vec![("default".into(), quota)],
938 None,
939 )
940 .await
941 .unwrap();
942
943 client.send_text("test1".into(), None).await.unwrap();
945 client.send_text("test2".into(), None).await.unwrap();
946
947 client.send_text("test3".into(), None).await.unwrap();
949
950 client.disconnect().await;
952 assert!(client.is_disconnected());
953 }
954
955 #[tokio::test]
956 async fn test_concurrent_writers() {
957 let server = TestServer::setup().await;
958 let client = Arc::new(setup_test_client(server.port).await);
959
960 let mut handles = vec![];
961 for i in 0..10 {
962 let client = client.clone();
963 handles.push(task::spawn(async move {
964 client.send_text(format!("test{i}"), None).await.unwrap();
965 }));
966 }
967
968 for handle in handles {
969 handle.await.unwrap();
970 }
971
972 client.disconnect().await;
974 assert!(client.is_disconnected());
975 }
976}