1use std::{
17 sync::{Arc, atomic::Ordering},
18 time::Duration,
19};
20
21use nautilus_core::{
22 collections::into_ustr_vec,
23 python::{clone_py_object, to_pyruntime_err, to_pyvalue_err},
24};
25use pyo3::{Py, create_exception, exceptions::PyException, prelude::*, types::PyBytes};
26use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
27
28use crate::{
29 RECONNECTED,
30 mode::ConnectionMode,
31 ratelimiter::quota::Quota,
32 websocket::{
33 WebSocketClient, WebSocketConfig,
34 types::{MessageHandler, PingHandler, WriterCommand},
35 },
36};
37
38create_exception!(network, WebSocketClientError, PyException);
39
40fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
41 PyErr::new::<WebSocketClientError, _>(e.to_string())
42}
43
44#[pymethods]
45impl WebSocketConfig {
46 #[new]
48 #[allow(clippy::too_many_arguments)]
49 #[pyo3(signature = (
50 url,
51 headers,
52 heartbeat=None,
53 heartbeat_msg=None,
54 reconnect_timeout_ms=10_000,
55 reconnect_delay_initial_ms=2_000,
56 reconnect_delay_max_ms=30_000,
57 reconnect_backoff_factor=1.5,
58 reconnect_jitter_ms=100,
59 reconnect_max_attempts=None,
60 ))]
61 fn py_new(
62 url: String,
63 headers: Vec<(String, String)>,
64 heartbeat: Option<u64>,
65 heartbeat_msg: Option<String>,
66 reconnect_timeout_ms: Option<u64>,
67 reconnect_delay_initial_ms: Option<u64>,
68 reconnect_delay_max_ms: Option<u64>,
69 reconnect_backoff_factor: Option<f64>,
70 reconnect_jitter_ms: Option<u64>,
71 reconnect_max_attempts: Option<u32>,
72 ) -> Self {
73 Self {
74 url,
75 headers,
76 heartbeat,
77 heartbeat_msg,
78 reconnect_timeout_ms,
79 reconnect_delay_initial_ms,
80 reconnect_delay_max_ms,
81 reconnect_backoff_factor,
82 reconnect_jitter_ms,
83 reconnect_max_attempts,
84 }
85 }
86}
87
88#[pymethods]
89impl WebSocketClient {
90 #[staticmethod]
102 #[pyo3(name = "connect", signature = (loop_, config, handler, ping_handler = None, post_reconnection = None, keyed_quotas = Vec::new(), default_quota = None))]
103 #[allow(clippy::too_many_arguments)]
104 fn py_connect(
105 loop_: Py<PyAny>,
106 config: WebSocketConfig,
107 handler: Py<PyAny>,
108 ping_handler: Option<Py<PyAny>>,
109 post_reconnection: Option<Py<PyAny>>,
110 keyed_quotas: Vec<(String, Quota)>,
111 default_quota: Option<Quota>,
112 py: Python<'_>,
113 ) -> PyResult<Bound<'_, PyAny>> {
114 let call_soon_threadsafe: Py<PyAny> = loop_.getattr(py, "call_soon_threadsafe")?;
115 let call_soon_clone = clone_py_object(&call_soon_threadsafe);
116 let handler_clone = clone_py_object(&handler);
117
118 let message_handler: MessageHandler = Arc::new(move |msg: Message| {
119 if matches!(msg, Message::Text(ref text) if text.as_str() == RECONNECTED) {
120 return;
121 }
122
123 Python::attach(|py| {
124 let py_bytes = match &msg {
125 Message::Binary(data) => PyBytes::new(py, data),
126 Message::Text(text) => PyBytes::new(py, text.as_bytes()),
127 _ => return,
128 };
129
130 if let Err(e) = call_soon_clone.call1(py, (&handler_clone, py_bytes)) {
131 log::error!("Error scheduling message handler on event loop: {e}");
132 }
133 });
134 });
135
136 let ping_handler_fn = ping_handler.map(|ping_handler| {
137 let ping_handler_clone = clone_py_object(&ping_handler);
138 let call_soon_clone = clone_py_object(&call_soon_threadsafe);
139
140 let ping_handler_fn: PingHandler = Arc::new(move |data: Vec<u8>| {
141 Python::attach(|py| {
142 let py_bytes = PyBytes::new(py, &data);
143 if let Err(e) = call_soon_clone.call1(py, (&ping_handler_clone, py_bytes)) {
144 log::error!("Error scheduling ping handler on event loop: {e}");
145 }
146 });
147 });
148 ping_handler_fn
149 });
150
151 let post_reconnection_fn = post_reconnection.map(|callback| {
152 let callback_clone = clone_py_object(&callback);
153 Arc::new(move || {
154 Python::attach(|py| {
155 if let Err(e) = callback_clone.call0(py) {
156 log::error!("Error calling post_reconnection handler: {e}");
157 }
158 });
159 }) as std::sync::Arc<dyn Fn() + Send + Sync>
160 });
161
162 pyo3_async_runtimes::tokio::future_into_py(py, async move {
163 Self::connect(
164 config,
165 Some(message_handler),
166 ping_handler_fn,
167 post_reconnection_fn,
168 keyed_quotas,
169 default_quota,
170 )
171 .await
172 .map_err(to_websocket_pyerr)
173 })
174 }
175
176 #[pyo3(name = "disconnect")]
186 fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
187 let connection_mode = slf.connection_mode.clone();
188 let mode = ConnectionMode::from_atomic(&connection_mode);
189 log::debug!("Close from mode {mode}");
190
191 pyo3_async_runtimes::tokio::future_into_py(py, async move {
192 match ConnectionMode::from_atomic(&connection_mode) {
193 ConnectionMode::Closed => {
194 log::debug!("WebSocket already closed");
195 }
196 ConnectionMode::Disconnect => {
197 log::debug!("WebSocket already disconnecting");
198 }
199 _ => {
200 connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
201 while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
202 tokio::time::sleep(Duration::from_millis(10)).await;
203 }
204 }
205 }
206
207 Ok(())
208 })
209 }
210
211 #[pyo3(name = "is_active")]
221 fn py_is_active(slf: PyRef<'_, Self>) -> bool {
222 !slf.controller_task.is_finished()
223 }
224
225 #[pyo3(name = "is_reconnecting")]
226 fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
227 slf.is_reconnecting()
228 }
229
230 #[pyo3(name = "is_disconnecting")]
231 fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
232 slf.is_disconnecting()
233 }
234
235 #[pyo3(name = "is_closed")]
236 fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
237 slf.is_closed()
238 }
239
240 #[pyo3(name = "send")]
246 #[pyo3(signature = (data, keys=None))]
247 fn py_send<'py>(
248 slf: PyRef<'_, Self>,
249 data: Vec<u8>,
250 py: Python<'py>,
251 keys: Option<Vec<String>>,
252 ) -> PyResult<Bound<'py, PyAny>> {
253 let rate_limiter = slf.rate_limiter.clone();
254 let writer_tx = slf.writer_tx.clone();
255 let mode = slf.connection_mode.clone();
256 let keys = keys.map(into_ustr_vec);
257
258 pyo3_async_runtimes::tokio::future_into_py(py, async move {
259 if !ConnectionMode::from_atomic(&mode).is_active() {
260 let msg = "Cannot send data: connection not active".to_string();
261 log::error!("{msg}");
262 return Err(to_pyruntime_err(std::io::Error::new(
263 std::io::ErrorKind::NotConnected,
264 msg,
265 )));
266 }
267 rate_limiter.await_keys_ready(keys.as_deref()).await;
268 log::trace!("Sending binary: {data:?}");
269
270 let msg = Message::Binary(data.into());
271 writer_tx
272 .send(WriterCommand::Send(msg))
273 .map_err(to_pyruntime_err)
274 })
275 }
276
277 #[pyo3(name = "send_text")]
291 #[pyo3(signature = (data, keys=None))]
292 fn py_send_text<'py>(
293 slf: PyRef<'_, Self>,
294 data: Vec<u8>,
295 py: Python<'py>,
296 keys: Option<Vec<String>>,
297 ) -> PyResult<Bound<'py, PyAny>> {
298 let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
299 let data = Utf8Bytes::from(data_str);
300 let rate_limiter = slf.rate_limiter.clone();
301 let writer_tx = slf.writer_tx.clone();
302 let mode = slf.connection_mode.clone();
303 let keys = keys.map(into_ustr_vec);
304
305 pyo3_async_runtimes::tokio::future_into_py(py, async move {
306 if !ConnectionMode::from_atomic(&mode).is_active() {
307 let e = std::io::Error::new(
308 std::io::ErrorKind::NotConnected,
309 "Cannot send text: connection not active",
310 );
311 return Err(to_pyruntime_err(e));
312 }
313 rate_limiter.await_keys_ready(keys.as_deref()).await;
314 log::trace!("Sending text: {data}");
315
316 let msg = Message::Text(data);
317 writer_tx
318 .send(WriterCommand::Send(msg))
319 .map_err(to_pyruntime_err)
320 })
321 }
322
323 #[pyo3(name = "send_pong")]
329 fn py_send_pong<'py>(
330 slf: PyRef<'_, Self>,
331 data: Vec<u8>,
332 py: Python<'py>,
333 ) -> PyResult<Bound<'py, PyAny>> {
334 let writer_tx = slf.writer_tx.clone();
335 let mode = slf.connection_mode.clone();
336 let data_len = data.len();
337
338 pyo3_async_runtimes::tokio::future_into_py(py, async move {
339 if !ConnectionMode::from_atomic(&mode).is_active() {
340 let e = std::io::Error::new(
341 std::io::ErrorKind::NotConnected,
342 "Cannot send pong: connection not active",
343 );
344 return Err(to_pyruntime_err(e));
345 }
346 log::trace!("Sending pong frame ({data_len} bytes)");
347
348 let msg = Message::Pong(data.into());
349 writer_tx
350 .send(WriterCommand::Send(msg))
351 .map_err(to_pyruntime_err)
352 })
353 }
354}
355
356#[cfg(test)]
357#[cfg(target_os = "linux")] mod tests {
359 use std::ffi::CString;
360
361 use futures_util::{SinkExt, StreamExt};
362 use nautilus_core::python::IntoPyObjectNautilusExt;
363 use pyo3::{prelude::*, types::PyBytes};
364 use tokio::{
365 net::TcpListener,
366 task::{self, JoinHandle},
367 time::{Duration, sleep},
368 };
369 use tokio_tungstenite::{
370 accept_hdr_async,
371 tungstenite::{
372 Message,
373 handshake::server::{self, Callback},
374 http::HeaderValue,
375 },
376 };
377
378 use crate::websocket::{MessageHandler, WebSocketClient, WebSocketConfig};
379
380 struct TestServer {
381 task: JoinHandle<()>,
382 port: u16,
383 }
384
385 #[derive(Debug, Clone)]
386 struct TestCallback {
387 key: String,
388 value: HeaderValue,
389 }
390
391 impl Callback for TestCallback {
392 #[allow(clippy::panic_in_result_fn)]
393 fn on_request(
394 self,
395 request: &server::Request,
396 response: server::Response,
397 ) -> Result<server::Response, server::ErrorResponse> {
398 let _ = response;
399 let value = request.headers().get(&self.key);
400 assert!(value.is_some());
401
402 if let Some(value) = request.headers().get(&self.key) {
403 assert_eq!(value, self.value);
404 }
405
406 Ok(response)
407 }
408 }
409
410 impl TestServer {
411 async fn setup(key: String, value: String) -> Self {
412 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
413 let port = TcpListener::local_addr(&server).unwrap().port();
414
415 let test_call_back = TestCallback {
416 key,
417 value: HeaderValue::from_str(&value).unwrap(),
418 };
419
420 let task = task::spawn(async move {
422 loop {
424 let (conn, _) = server.accept().await.unwrap();
425 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
426 .await
427 .unwrap();
428
429 task::spawn(async move {
430 while let Some(Ok(msg)) = websocket.next().await {
431 match msg {
432 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
433 if txt == "close-now" =>
434 {
435 log::debug!("Forcibly closing from server side");
436 let _ = websocket.close(None).await;
438 break;
439 }
440 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
442 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
443 if websocket.send(msg).await.is_err() {
444 break;
445 }
446 }
447 tokio_tungstenite::tungstenite::protocol::Message::Close(
449 _frame,
450 ) => {
451 let _ = websocket.close(None).await;
452 break;
453 }
454 _ => {}
456 }
457 }
458 });
459 }
460 });
461
462 Self { task, port }
463 }
464 }
465
466 impl Drop for TestServer {
467 fn drop(&mut self) {
468 self.task.abort();
469 }
470 }
471
472 fn create_test_handler() -> (Py<PyAny>, Py<PyAny>) {
473 let code_raw = r"
474class Counter:
475 def __init__(self):
476 self.count = 0
477 self.check = False
478
479 def handler(self, bytes):
480 msg = bytes.decode()
481 if msg == 'ping':
482 self.count += 1
483 elif msg == 'heartbeat message':
484 self.check = True
485
486 def get_check(self):
487 return self.check
488
489 def get_count(self):
490 return self.count
491
492counter = Counter()
493";
494
495 let code = CString::new(code_raw).unwrap();
496 let filename = CString::new("test".to_string()).unwrap();
497 let module = CString::new("test".to_string()).unwrap();
498 Python::attach(|py| {
499 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
500
501 let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
502 let handler = counter
503 .getattr(py, "handler")
504 .unwrap()
505 .into_py_any_unwrap(py);
506
507 (counter, handler)
508 })
509 }
510
511 #[tokio::test]
512 async fn basic_client_test() {
513 const N: usize = 10;
514
515 Python::initialize();
516
517 let mut success_count = 0;
518 let header_key = "hello-custom-key".to_string();
519 let header_value = "hello-custom-value".to_string();
520
521 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
522 let (counter, handler) = create_test_handler();
523
524 let config = WebSocketConfig::py_new(
525 format!("ws://127.0.0.1:{}", server.port),
526 vec![(header_key, header_value)],
527 None,
528 None,
529 None,
530 None,
531 None,
532 None,
533 None,
534 None,
535 );
536
537 let handler_clone = Python::attach(|py| handler.clone_ref(py));
538
539 let message_handler: MessageHandler = std::sync::Arc::new(move |msg: Message| {
540 Python::attach(|py| {
541 let data = match msg {
542 Message::Binary(data) => data.to_vec(),
543 Message::Text(text) => text.as_bytes().to_vec(),
544 _ => return,
545 };
546 let py_bytes = PyBytes::new(py, &data);
547 if let Err(e) = handler_clone.call1(py, (py_bytes,)) {
548 log::error!("Error calling handler: {e}");
549 }
550 });
551 });
552
553 let client =
554 WebSocketClient::connect(config, Some(message_handler), None, None, vec![], None)
555 .await
556 .unwrap();
557
558 for _ in 0..N {
559 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
560 success_count += 1;
561 }
562
563 sleep(Duration::from_secs(1)).await;
564 let count_value: usize = Python::attach(|py| {
565 counter
566 .getattr(py, "get_count")
567 .unwrap()
568 .call0(py)
569 .unwrap()
570 .extract(py)
571 .unwrap()
572 });
573 assert_eq!(count_value, success_count);
574
575 client.send_close_message().await.unwrap();
577
578 sleep(Duration::from_secs(2)).await;
580 for _ in 0..N {
581 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
582 success_count += 1;
583 }
584
585 sleep(Duration::from_secs(1)).await;
586 let count_value: usize = Python::attach(|py| {
587 counter
588 .getattr(py, "get_count")
589 .unwrap()
590 .call0(py)
591 .unwrap()
592 .extract(py)
593 .unwrap()
594 });
595 assert_eq!(count_value, success_count);
596 assert_eq!(success_count, N + N);
597
598 client.disconnect().await;
599 assert!(client.is_disconnected());
600 }
601
602 #[tokio::test]
603 async fn message_ping_test() {
604 Python::initialize();
605
606 let header_key = "hello-custom-key".to_string();
607 let header_value = "hello-custom-value".to_string();
608
609 let (checker, handler) = create_test_handler();
610
611 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
612 let config = WebSocketConfig::py_new(
613 format!("ws://127.0.0.1:{}", server.port),
614 vec![(header_key, header_value)],
615 Some(1),
616 Some("heartbeat message".to_string()),
617 None,
618 None,
619 None,
620 None,
621 None,
622 None,
623 );
624
625 let handler_clone = Python::attach(|py| handler.clone_ref(py));
626
627 let message_handler: MessageHandler = std::sync::Arc::new(move |msg: Message| {
628 Python::attach(|py| {
629 let data = match msg {
630 Message::Binary(data) => data.to_vec(),
631 Message::Text(text) => text.as_bytes().to_vec(),
632 _ => return,
633 };
634 let py_bytes = PyBytes::new(py, &data);
635 if let Err(e) = handler_clone.call1(py, (py_bytes,)) {
636 log::error!("Error calling handler: {e}");
637 }
638 });
639 });
640
641 let client =
642 WebSocketClient::connect(config, Some(message_handler), None, None, vec![], None)
643 .await
644 .unwrap();
645
646 sleep(Duration::from_secs(2)).await;
647 let check_value: bool = Python::attach(|py| {
648 checker
649 .getattr(py, "get_check")
650 .unwrap()
651 .call0(py)
652 .unwrap()
653 .extract(py)
654 .unwrap()
655 });
656 assert!(check_value);
657
658 client.disconnect().await;
659 assert!(client.is_disconnected());
660 }
661}