1use std::{
33 sync::{
34 Arc,
35 atomic::{AtomicU8, Ordering},
36 },
37 time::Duration,
38};
39
40use futures_util::{
41 SinkExt, StreamExt,
42 stream::{SplitSink, SplitStream},
43};
44use http::HeaderName;
45use nautilus_cryptography::providers::install_cryptographic_provider;
46use pyo3::{prelude::*, types::PyBytes};
47use tokio::{net::TcpStream, sync::Mutex};
48use tokio_tungstenite::{
49 MaybeTlsStream, WebSocketStream, connect_async,
50 tungstenite::{Error, Message, client::IntoClientRequest, http::HeaderValue},
51};
52
53use crate::{
54 backoff::ExponentialBackoff,
55 mode::ConnectionMode,
56 ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
57};
58type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
59type SharedMessageWriter =
60 Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>;
61pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
62
63#[derive(Debug, Clone)]
64#[cfg_attr(
65 feature = "python",
66 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
67)]
68pub struct WebSocketConfig {
69 pub url: String,
71 pub headers: Vec<(String, String)>,
73 pub handler: Option<Arc<PyObject>>,
75 pub heartbeat: Option<u64>,
77 pub heartbeat_msg: Option<String>,
79 pub ping_handler: Option<Arc<PyObject>>,
81 pub reconnect_timeout_ms: Option<u64>,
83 pub reconnect_delay_initial_ms: Option<u64>,
85 pub reconnect_delay_max_ms: Option<u64>,
87 pub reconnect_backoff_factor: Option<f64>,
89 pub reconnect_jitter_ms: Option<u64>,
91}
92
93struct WebSocketClientInner {
109 config: WebSocketConfig,
110 read_task: Option<tokio::task::JoinHandle<()>>,
111 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
112 writer: SharedMessageWriter,
113 connection_mode: Arc<AtomicU8>,
114 reconnect_timeout: Duration,
115 backoff: ExponentialBackoff,
116}
117
118impl WebSocketClientInner {
119 pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
121 install_cryptographic_provider();
122
123 #[allow(unused_variables)]
124 let WebSocketConfig {
125 url,
126 handler,
127 heartbeat,
128 headers,
129 heartbeat_msg,
130 ping_handler,
131 reconnect_timeout_ms,
132 reconnect_delay_initial_ms,
133 reconnect_delay_max_ms,
134 reconnect_backoff_factor,
135 reconnect_jitter_ms,
136 } = &config;
137 let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
138 let writer = Arc::new(Mutex::new(writer));
139
140 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
141
142 let read_task = handler
144 .as_ref()
145 .map(|handler| Self::spawn_read_task(reader, handler.clone(), ping_handler.clone()));
146
147 let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
149 Self::spawn_heartbeat_task(
150 connection_mode.clone(),
151 *heartbeat_secs,
152 heartbeat_msg.clone(),
153 writer.clone(),
154 )
155 });
156
157 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
158 let backoff = ExponentialBackoff::new(
159 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
160 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
161 reconnect_backoff_factor.unwrap_or(1.5),
162 reconnect_jitter_ms.unwrap_or(100),
163 true, );
165
166 Ok(Self {
167 config,
168 read_task,
169 heartbeat_task,
170 writer,
171 connection_mode,
172 reconnect_timeout,
173 backoff,
174 })
175 }
176
177 #[inline]
179 pub async fn connect_with_server(
180 url: &str,
181 headers: Vec<(String, String)>,
182 ) -> Result<(MessageWriter, MessageReader), Error> {
183 let mut request = url.into_client_request()?;
184 let req_headers = request.headers_mut();
185
186 let mut header_names: Vec<HeaderName> = Vec::new();
187 for (key, val) in headers {
188 let header_value = HeaderValue::from_str(&val)?;
189 let header_name: HeaderName = key.parse()?;
190 header_names.push(header_name.clone());
191 req_headers.insert(header_name, header_value);
192 }
193
194 connect_async(request).await.map(|resp| resp.0.split())
195 }
196
197 pub async fn reconnect(&mut self) -> Result<(), Error> {
202 tracing::debug!("Reconnecting");
203
204 tokio::time::timeout(self.reconnect_timeout, async {
205 shutdown(
206 self.read_task.take(),
207 self.heartbeat_task.take(),
208 self.writer.clone(),
209 )
210 .await;
211
212 let (new_writer, reader) =
213 Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
214
215 {
216 let mut guard = self.writer.lock().await;
217 *guard = new_writer;
218 drop(guard);
219 }
220
221 if let Some(ref handler) = self.config.handler {
223 self.read_task = Some(Self::spawn_read_task(
224 reader,
225 handler.clone(),
226 self.config.ping_handler.clone(),
227 ));
228 }
229
230 self.heartbeat_task = self.config.heartbeat.as_ref().map(|heartbeat_secs| {
232 Self::spawn_heartbeat_task(
233 self.connection_mode.clone(),
234 *heartbeat_secs,
235 self.config.heartbeat_msg.clone(),
236 self.writer.clone(),
237 )
238 });
239
240 self.connection_mode
241 .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
242
243 tracing::debug!("Reconnect succeeded");
244 Ok(())
245 })
246 .await
247 .map_err(|_| {
248 Error::Io(std::io::Error::new(
249 std::io::ErrorKind::TimedOut,
250 format!(
251 "reconnection timed out after {}s",
252 self.reconnect_timeout.as_secs_f64()
253 ),
254 ))
255 })?
256 }
257
258 #[inline]
266 #[must_use]
267 pub fn is_alive(&self) -> bool {
268 match &self.read_task {
269 Some(read_task) => !read_task.is_finished(),
270 None => true, }
272 }
273
274 fn spawn_read_task(
275 mut reader: MessageReader,
276 handler: Arc<PyObject>,
277 ping_handler: Option<Arc<PyObject>>,
278 ) -> tokio::task::JoinHandle<()> {
279 tracing::debug!("Started task 'read'");
280
281 tokio::task::spawn(async move {
282 loop {
283 match reader.next().await {
284 Some(Ok(Message::Binary(data))) => {
285 tracing::trace!("Received message <binary> {} bytes", data.len());
286 if let Err(e) =
287 Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &data),)))
288 {
289 tracing::error!("Error calling handler: {e}");
290 break;
291 }
292 continue;
293 }
294 Some(Ok(Message::Text(data))) => {
295 tracing::trace!("Received message: {data}");
296 if let Err(e) = Python::with_gil(|py| {
297 handler.call1(py, (PyBytes::new(py, data.as_bytes()),))
298 }) {
299 tracing::error!("Error calling handler: {e}");
300 break;
301 }
302 continue;
303 }
304 Some(Ok(Message::Ping(ping))) => {
305 tracing::trace!("Received ping: {ping:?}",);
306 if let Some(ref handler) = ping_handler {
307 if let Err(e) =
308 Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &ping),)))
309 {
310 tracing::error!("Error calling handler: {e}");
311 break;
312 }
313 }
314 continue;
315 }
316 Some(Ok(Message::Pong(_))) => {
317 tracing::trace!("Received pong");
318 }
319 Some(Ok(Message::Close(_))) => {
320 tracing::debug!("Received close message - terminating");
321 break;
322 }
323 Some(Ok(_)) => (),
324 Some(Err(e)) => {
325 tracing::error!("Received error message - terminating: {e}");
326 break;
327 }
328 None => {
331 tracing::debug!("No message received - terminating");
332 break;
333 }
334 }
335 }
336 })
337 }
338
339 fn spawn_heartbeat_task(
340 connection_state: Arc<AtomicU8>,
341 heartbeat_secs: u64,
342 message: Option<String>,
343 writer: SharedMessageWriter,
344 ) -> tokio::task::JoinHandle<()> {
345 tracing::debug!("Started task 'heartbeat'");
346
347 tokio::task::spawn(async move {
348 let interval = Duration::from_secs(heartbeat_secs);
349 loop {
350 tokio::time::sleep(interval).await;
351
352 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
353 ConnectionMode::Active => {
354 let mut guard = writer.lock().await;
355 let guard_send_response = match message.clone() {
356 Some(msg) => guard.send(Message::Text(msg.into())).await,
357 None => guard.send(Message::Ping(vec![].into())).await,
358 };
359
360 match guard_send_response {
361 Ok(()) => tracing::trace!("Sent ping"),
362 Err(e) => tracing::error!("Error sending ping: {e}"),
363 }
364 }
365 ConnectionMode::Reconnect => continue,
366 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
367 }
368 }
369
370 tracing::debug!("Completed task 'heartbeat'");
371 })
372 }
373}
374
375async fn shutdown(
385 read_task: Option<tokio::task::JoinHandle<()>>,
386 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
387 writer: SharedMessageWriter,
388) {
389 tracing::debug!("Closing");
390
391 let timeout = Duration::from_secs(5);
392 if tokio::time::timeout(timeout, async {
393 let mut write_half = writer.lock().await;
395 if let Err(e) = write_half.send(Message::Close(None)).await {
396 tracing::debug!("Error sending close frame: {e}");
398 }
399 drop(write_half);
400
401 tokio::time::sleep(Duration::from_millis(100)).await;
402
403 if let Some(task) = read_task {
405 if !task.is_finished() {
406 task.abort();
407 tracing::debug!("Aborted read task");
408 }
409 }
410 if let Some(task) = heartbeat_task {
411 if !task.is_finished() {
412 task.abort();
413 tracing::debug!("Aborted heartbeat task");
414 }
415 }
416
417 let mut write_half = writer.lock().await;
419 if let Err(e) = write_half.close().await {
420 tracing::error!("Error closing writer: {e}");
421 }
422 })
423 .await
424 .is_err()
425 {
426 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
427 }
428
429 tracing::debug!("Closed");
430}
431
432impl Drop for WebSocketClientInner {
433 fn drop(&mut self) {
434 if let Some(ref read_task) = self.read_task.take() {
435 if !read_task.is_finished() {
436 read_task.abort();
437 }
438 }
439
440 if let Some(ref handle) = self.heartbeat_task.take() {
442 if !handle.is_finished() {
443 handle.abort();
444 }
445 }
446 }
447}
448
449#[cfg_attr(
454 feature = "python",
455 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
456)]
457pub struct WebSocketClient {
458 pub(crate) writer: SharedMessageWriter,
459 pub(crate) controller_task: tokio::task::JoinHandle<()>,
460 pub(crate) connection_mode: Arc<AtomicU8>,
461 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
462}
463
464impl WebSocketClient {
465 #[allow(clippy::too_many_arguments)]
471 pub async fn connect_stream(
472 config: WebSocketConfig,
473 keyed_quotas: Vec<(String, Quota)>,
474 default_quota: Option<Quota>,
475 ) -> Result<(MessageReader, Self), Error> {
476 let (ws_stream, _) = connect_async(config.url.clone().into_client_request()?).await?;
477 let (writer, reader) = ws_stream.split();
478 let writer = Arc::new(Mutex::new(writer));
479
480 let inner = WebSocketClientInner::connect_url(config).await?;
481 let connection_mode = inner.connection_mode.clone();
482 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
483
484 let controller_task = Self::spawn_controller_task(
485 inner,
486 connection_mode.clone(),
487 None, None, );
490
491 Ok((
492 reader,
493 Self {
494 writer: writer.clone(),
495 controller_task,
496 connection_mode,
497 rate_limiter,
498 },
499 ))
500 }
501
502 pub async fn connect(
511 config: WebSocketConfig,
512 post_connection: Option<PyObject>,
513 post_reconnection: Option<PyObject>,
514 post_disconnection: Option<PyObject>,
515 keyed_quotas: Vec<(String, Quota)>,
516 default_quota: Option<Quota>,
517 ) -> Result<Self, Error> {
518 tracing::debug!("Connecting");
519 let inner = WebSocketClientInner::connect_url(config.clone()).await?;
520 let writer = inner.writer.clone();
521 let connection_mode = inner.connection_mode.clone();
522
523 let controller_task = Self::spawn_controller_task(
524 inner,
525 connection_mode.clone(),
526 post_reconnection,
527 post_disconnection,
528 );
529 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
530
531 if let Some(handler) = post_connection {
532 Python::with_gil(|py| match handler.call0(py) {
533 Ok(_) => tracing::debug!("Called `post_connection` handler"),
534 Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
535 });
536 };
537
538 Ok(Self {
539 writer,
540 controller_task,
541 connection_mode,
542 rate_limiter,
543 })
544 }
545
546 #[must_use]
548 pub fn connection_mode(&self) -> ConnectionMode {
549 ConnectionMode::from_atomic(&self.connection_mode)
550 }
551
552 #[inline]
557 #[must_use]
558 pub fn is_active(&self) -> bool {
559 self.connection_mode().is_active()
560 }
561
562 #[must_use]
564 pub fn is_disconnected(&self) -> bool {
565 self.controller_task.is_finished()
566 }
567
568 #[inline]
573 #[must_use]
574 pub fn is_reconnecting(&self) -> bool {
575 self.connection_mode().is_reconnect()
576 }
577
578 #[inline]
582 #[must_use]
583 pub fn is_disconnecting(&self) -> bool {
584 self.connection_mode().is_disconnect()
585 }
586
587 #[inline]
593 #[must_use]
594 pub fn is_closed(&self) -> bool {
595 self.connection_mode().is_closed()
596 }
597
598 pub async fn disconnect(&self) {
603 tracing::debug!("Disconnecting");
604 self.connection_mode
605 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
606
607 match tokio::time::timeout(Duration::from_secs(5), async {
608 while !self.is_disconnected() {
609 tokio::time::sleep(Duration::from_millis(10)).await;
610 }
611
612 if !self.controller_task.is_finished() {
613 self.controller_task.abort();
614 tracing::debug!("Aborted controller task");
615 }
616 })
617 .await
618 {
619 Ok(()) => {
620 tracing::debug!("Controller task finished");
621 }
622 Err(_) => {
623 tracing::error!("Timeout waiting for controller task to finish");
624 }
625 }
626 }
627
628 pub async fn send_text(&self, data: String, keys: Option<Vec<String>>) {
630 self.rate_limiter.await_keys_ready(keys).await;
631 tracing::trace!("Sending text: {data:?}");
632 let mut guard = self.writer.lock().await;
633 if let Err(e) = guard.send(Message::Text(data.into())).await {
634 tracing::error!("Error sending message: {e}");
635 }
636 }
637
638 pub async fn send_bytes(&self, data: Vec<u8>, keys: Option<Vec<String>>) {
640 self.rate_limiter.await_keys_ready(keys).await;
641 tracing::trace!("Sending bytes: {data:?}");
642 let mut guard = self.writer.lock().await;
643 if let Err(e) = guard.send(Message::Binary(data.into())).await {
644 tracing::error!("Error sending message: {e}");
645 }
646 }
647
648 pub async fn send_close_message(&self) {
650 let mut guard = self.writer.lock().await;
651 match guard.send(Message::Close(None)).await {
652 Ok(()) => tracing::debug!("Sent close message"),
653 Err(e) => tracing::error!("Error sending close message: {e}"),
654 }
655 }
656
657 fn spawn_controller_task(
658 mut inner: WebSocketClientInner,
659 connection_mode: Arc<AtomicU8>,
660 post_reconnection: Option<PyObject>,
661 post_disconnection: Option<PyObject>,
662 ) -> tokio::task::JoinHandle<()> {
663 tokio::task::spawn(async move {
664 tracing::debug!("Started task 'controller'");
665
666 let check_interval = Duration::from_millis(10);
667
668 loop {
669 tokio::time::sleep(check_interval).await;
670 let mode = ConnectionMode::from_atomic(&connection_mode);
671
672 if mode.is_disconnect() {
673 tracing::debug!("Disconnecting");
674 shutdown(
675 inner.read_task.take(),
676 inner.heartbeat_task.take(),
677 inner.writer.clone(),
678 )
679 .await;
680
681 if let Some(ref handler) = post_disconnection {
682 Python::with_gil(|py| match handler.call0(py) {
683 Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
684 Err(e) => {
685 tracing::error!("Error calling `post_disconnection` handler: {e}");
686 }
687 });
688 }
689 break; }
691
692 if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
693 match inner.reconnect().await {
694 Ok(()) => {
695 tracing::debug!("Reconnected successfully");
696 inner.backoff.reset();
697
698 if let Some(ref handler) = post_reconnection {
699 Python::with_gil(|py| match handler.call0(py) {
700 Ok(_) => tracing::debug!("Called `post_reconnection` handler"),
701 Err(e) => tracing::error!(
702 "Error calling `post_reconnection` handler: {e}"
703 ),
704 });
705 }
706 }
707 Err(e) => {
708 let duration = inner.backoff.next_duration();
709 tracing::warn!("Reconnect attempt failed: {e}");
710 if !duration.is_zero() {
711 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
712 }
713 tokio::time::sleep(duration).await;
714 }
715 }
716 }
717 }
718 inner
719 .connection_mode
720 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
721 })
722 }
723}
724
725#[cfg(test)]
729#[cfg(target_os = "linux")] mod tests {
731 use std::{num::NonZeroU32, sync::Arc};
732
733 use futures_util::{SinkExt, StreamExt};
734 use tokio::{
735 net::TcpListener,
736 task::{self, JoinHandle},
737 };
738 use tokio_tungstenite::{
739 accept_hdr_async,
740 tungstenite::{
741 handshake::server::{self, Callback},
742 http::HeaderValue,
743 },
744 };
745
746 use crate::{
747 ratelimiter::quota::Quota,
748 websocket::{WebSocketClient, WebSocketConfig},
749 };
750
751 struct TestServer {
752 task: JoinHandle<()>,
753 port: u16,
754 }
755
756 #[derive(Debug, Clone)]
757 struct TestCallback {
758 key: String,
759 value: HeaderValue,
760 }
761
762 impl Callback for TestCallback {
763 fn on_request(
764 self,
765 request: &server::Request,
766 response: server::Response,
767 ) -> Result<server::Response, server::ErrorResponse> {
768 let _ = response;
769 let value = request.headers().get(&self.key);
770 assert!(value.is_some());
771
772 if let Some(value) = request.headers().get(&self.key) {
773 assert_eq!(value, self.value);
774 }
775
776 Ok(response)
777 }
778 }
779
780 impl TestServer {
781 async fn setup() -> Self {
782 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
783 let port = TcpListener::local_addr(&server).unwrap().port();
784
785 let header_key = "test".to_string();
786 let header_value = "test".to_string();
787
788 let test_call_back = TestCallback {
789 key: header_key,
790 value: HeaderValue::from_str(&header_value).unwrap(),
791 };
792
793 let task = task::spawn(async move {
794 loop {
796 let (conn, _) = server.accept().await.unwrap();
797 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
798 .await
799 .unwrap();
800
801 task::spawn(async move {
802 while let Some(Ok(msg)) = websocket.next().await {
803 match msg {
804 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
805 if txt == "close-now" =>
806 {
807 tracing::debug!("Forcibly closing from server side");
808 let _ = websocket.close(None).await;
810 break;
811 }
812 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
814 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
815 if websocket.send(msg).await.is_err() {
816 break;
817 }
818 }
819 tokio_tungstenite::tungstenite::protocol::Message::Close(
821 _frame,
822 ) => {
823 let _ = websocket.close(None).await;
824 break;
825 }
826 _ => {}
828 }
829 }
830 });
831 }
832 });
833
834 Self { task, port }
835 }
836 }
837
838 impl Drop for TestServer {
839 fn drop(&mut self) {
840 self.task.abort();
841 }
842 }
843
844 async fn setup_test_client(port: u16) -> WebSocketClient {
845 let config = WebSocketConfig {
846 url: format!("ws://127.0.0.1:{port}"),
847 headers: vec![("test".into(), "test".into())],
848 handler: None,
849 heartbeat: None,
850 heartbeat_msg: None,
851 ping_handler: None,
852 reconnect_timeout_ms: None,
853 reconnect_delay_initial_ms: None,
854 reconnect_backoff_factor: None,
855 reconnect_delay_max_ms: None,
856 reconnect_jitter_ms: None,
857 };
858 WebSocketClient::connect(config, None, None, None, vec![], None)
859 .await
860 .expect("Failed to connect")
861 }
862
863 #[tokio::test]
864 async fn test_websocket_basic() {
865 let server = TestServer::setup().await;
866 let client = setup_test_client(server.port).await;
867
868 assert!(!client.is_disconnected());
869
870 client.disconnect().await;
871 assert!(client.is_disconnected());
872 }
873
874 #[tokio::test]
875 async fn test_websocket_heartbeat() {
876 let server = TestServer::setup().await;
877 let client = setup_test_client(server.port).await;
878
879 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
881
882 client.disconnect().await;
884 assert!(client.is_disconnected());
885 }
886
887 #[tokio::test]
888 async fn test_websocket_reconnect_exhausted() {
889 let config = WebSocketConfig {
890 url: "ws://127.0.0.1:9997".into(), headers: vec![],
892 handler: None,
893 heartbeat: None,
894 heartbeat_msg: None,
895 ping_handler: None,
896 reconnect_timeout_ms: None,
897 reconnect_delay_initial_ms: None,
898 reconnect_backoff_factor: None,
899 reconnect_delay_max_ms: None,
900 reconnect_jitter_ms: None,
901 };
902 let res = WebSocketClient::connect(config, None, None, None, vec![], None).await;
903 assert!(res.is_err(), "Should fail quickly with no server");
904 }
905
906 #[tokio::test]
907 async fn test_websocket_forced_close_reconnect() {
908 let server = TestServer::setup().await;
909 let client = setup_test_client(server.port).await;
910
911 client.send_text("Hello".into(), None).await;
913
914 client.send_text("close-now".into(), None).await;
916
917 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
919
920 assert!(!client.is_disconnected());
922
923 client.disconnect().await;
925 assert!(client.is_disconnected());
926 }
927
928 #[tokio::test]
929 async fn test_rate_limiter() {
930 let server = TestServer::setup().await;
931 let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
932
933 let config = WebSocketConfig {
934 url: format!("ws://127.0.0.1:{}", server.port),
935 headers: vec![("test".into(), "test".into())],
936 handler: None,
937 heartbeat: None,
938 heartbeat_msg: None,
939 ping_handler: None,
940 reconnect_timeout_ms: None,
941 reconnect_delay_initial_ms: None,
942 reconnect_backoff_factor: None,
943 reconnect_delay_max_ms: None,
944 reconnect_jitter_ms: None,
945 };
946
947 let client = WebSocketClient::connect(
948 config,
949 None,
950 None,
951 None,
952 vec![("default".into(), quota)],
953 None,
954 )
955 .await
956 .unwrap();
957
958 client.send_text("test1".into(), None).await;
960 client.send_text("test2".into(), None).await;
961
962 client.send_text("test3".into(), None).await;
964
965 client.disconnect().await;
967 assert!(client.is_disconnected());
968 }
969
970 #[tokio::test]
971 async fn test_concurrent_writers() {
972 let server = TestServer::setup().await;
973 let client = Arc::new(setup_test_client(server.port).await);
974
975 let mut handles = vec![];
976 for i in 0..10 {
977 let client = client.clone();
978 handles.push(task::spawn(async move {
979 client.send_text(format!("test{i}"), None).await;
980 }));
981 }
982
983 for handle in handles {
984 handle.await.unwrap();
985 }
986
987 client.disconnect().await;
989 assert!(client.is_disconnected());
990 }
991}