1use std::{
17 sync::{Arc, atomic::Ordering},
18 time::Duration,
19};
20
21use nautilus_core::python::{clone_py_object, to_pyruntime_err, to_pyvalue_err};
22use pyo3::{Py, create_exception, exceptions::PyException, prelude::*, types::PyBytes};
23use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
24
25use crate::{
26 RECONNECTED,
27 mode::ConnectionMode,
28 ratelimiter::quota::Quota,
29 websocket::{
30 WebSocketClient, WebSocketConfig,
31 types::{MessageHandler, PingHandler, WriterCommand},
32 },
33};
34
35create_exception!(network, WebSocketClientError, PyException);
37
38fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
39 PyErr::new::<WebSocketClientError, _>(e.to_string())
40}
41
42#[pymethods]
43impl WebSocketConfig {
44 #[new]
45 #[allow(clippy::too_many_arguments)]
46 #[pyo3(signature = (
47 url,
48 handler,
49 headers,
50 heartbeat=None,
51 heartbeat_msg=None,
52 ping_handler=None,
53 reconnect_timeout_ms=10_000,
54 reconnect_delay_initial_ms=2_000,
55 reconnect_delay_max_ms=30_000,
56 reconnect_backoff_factor=1.5,
57 reconnect_jitter_ms=100,
58 reconnect_max_attempts=None,
59 ))]
60 fn py_new(
61 url: String,
62 handler: Py<PyAny>,
63 headers: Vec<(String, String)>,
64 heartbeat: Option<u64>,
65 heartbeat_msg: Option<String>,
66 ping_handler: Option<Py<PyAny>>,
67 reconnect_timeout_ms: Option<u64>,
68 reconnect_delay_initial_ms: Option<u64>,
69 reconnect_delay_max_ms: Option<u64>,
70 reconnect_backoff_factor: Option<f64>,
71 reconnect_jitter_ms: Option<u64>,
72 reconnect_max_attempts: Option<u32>,
73 ) -> Self {
74 let handler_clone = clone_py_object(&handler);
76 let message_handler: MessageHandler = Arc::new(move |msg: Message| {
77 Python::attach(|py| {
78 let data = match msg {
79 Message::Binary(data) => data.to_vec(),
80 Message::Text(text) => {
81 if text == RECONNECTED {
82 return;
83 }
84 text.as_bytes().to_vec()
85 }
86 _ => return,
87 };
88 if let Err(e) = handler_clone.call1(py, (PyBytes::new(py, &data),)) {
89 tracing::error!("Error calling Python message handler: {e}");
90 }
91 });
92 });
93
94 let ping_handler_fn = ping_handler.map(|ping_handler| {
96 let ping_handler_clone = clone_py_object(&ping_handler);
97 let ping_handler_fn: PingHandler = std::sync::Arc::new(move |data: Vec<u8>| {
98 Python::attach(|py| {
99 if let Err(e) = ping_handler_clone.call1(py, (PyBytes::new(py, &data),)) {
100 tracing::error!("Error calling Python ping handler: {e}");
101 }
102 });
103 });
104 ping_handler_fn
105 });
106
107 Self {
108 url,
109 message_handler: Some(message_handler),
110 headers,
111 heartbeat,
112 heartbeat_msg,
113 ping_handler: ping_handler_fn,
114 reconnect_timeout_ms,
115 reconnect_delay_initial_ms,
116 reconnect_delay_max_ms,
117 reconnect_backoff_factor,
118 reconnect_jitter_ms,
119 reconnect_max_attempts,
120 }
121 }
122}
123
124#[pymethods]
125impl WebSocketClient {
126 #[staticmethod]
132 #[pyo3(name = "connect", signature = (config, post_reconnection= None, keyed_quotas = Vec::new(), default_quota = None))]
133 fn py_connect(
134 config: WebSocketConfig,
135 post_reconnection: Option<Py<PyAny>>,
136 keyed_quotas: Vec<(String, Quota)>,
137 default_quota: Option<Quota>,
138 py: Python<'_>,
139 ) -> PyResult<Bound<'_, PyAny>> {
140 let post_reconnection_fn = post_reconnection.map(|callback| {
141 let callback_clone = clone_py_object(&callback);
142 Arc::new(move || {
143 Python::attach(|py| {
144 if let Err(e) = callback_clone.call0(py) {
145 tracing::error!("Error calling post_reconnection handler: {e}");
146 }
147 });
148 }) as std::sync::Arc<dyn Fn() + Send + Sync>
149 });
150
151 pyo3_async_runtimes::tokio::future_into_py(py, async move {
152 Self::connect(config, post_reconnection_fn, keyed_quotas, default_quota)
153 .await
154 .map_err(to_websocket_pyerr)
155 })
156 }
157
158 #[pyo3(name = "disconnect")]
168 fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
169 let connection_mode = slf.connection_mode.clone();
170 let mode = ConnectionMode::from_atomic(&connection_mode);
171 tracing::debug!("Close from mode {mode}");
172
173 pyo3_async_runtimes::tokio::future_into_py(py, async move {
174 match ConnectionMode::from_atomic(&connection_mode) {
175 ConnectionMode::Closed => {
176 tracing::debug!("WebSocket already closed");
177 }
178 ConnectionMode::Disconnect => {
179 tracing::debug!("WebSocket already disconnecting");
180 }
181 _ => {
182 connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
183 while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
184 tokio::time::sleep(Duration::from_millis(10)).await;
185 }
186 }
187 }
188
189 Ok(())
190 })
191 }
192
193 #[pyo3(name = "is_active")]
203 fn py_is_active(slf: PyRef<'_, Self>) -> bool {
204 !slf.controller_task.is_finished()
205 }
206
207 #[pyo3(name = "is_reconnecting")]
208 fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
209 slf.is_reconnecting()
210 }
211
212 #[pyo3(name = "is_disconnecting")]
213 fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
214 slf.is_disconnecting()
215 }
216
217 #[pyo3(name = "is_closed")]
218 fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
219 slf.is_closed()
220 }
221
222 #[pyo3(name = "send")]
228 #[pyo3(signature = (data, keys=None))]
229 fn py_send<'py>(
230 slf: PyRef<'_, Self>,
231 data: Vec<u8>,
232 py: Python<'py>,
233 keys: Option<Vec<String>>,
234 ) -> PyResult<Bound<'py, PyAny>> {
235 let rate_limiter = slf.rate_limiter.clone();
236 let writer_tx = slf.writer_tx.clone();
237 let mode = slf.connection_mode.clone();
238
239 pyo3_async_runtimes::tokio::future_into_py(py, async move {
240 if !ConnectionMode::from_atomic(&mode).is_active() {
241 let msg = "Cannot send data: connection not active".to_string();
242 tracing::error!("{msg}");
243 return Err(to_pyruntime_err(std::io::Error::new(
244 std::io::ErrorKind::NotConnected,
245 msg,
246 )));
247 }
248 rate_limiter.await_keys_ready(keys).await;
249 tracing::trace!("Sending binary: {data:?}");
250
251 let msg = Message::Binary(data.into());
252 writer_tx
253 .send(WriterCommand::Send(msg))
254 .map_err(to_pyruntime_err)
255 })
256 }
257
258 #[pyo3(name = "send_text")]
272 #[pyo3(signature = (data, keys=None))]
273 fn py_send_text<'py>(
274 slf: PyRef<'_, Self>,
275 data: Vec<u8>,
276 py: Python<'py>,
277 keys: Option<Vec<String>>,
278 ) -> PyResult<Bound<'py, PyAny>> {
279 let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
280 let data = Utf8Bytes::from(data_str);
281 let rate_limiter = slf.rate_limiter.clone();
282 let writer_tx = slf.writer_tx.clone();
283 let mode = slf.connection_mode.clone();
284
285 pyo3_async_runtimes::tokio::future_into_py(py, async move {
286 if !ConnectionMode::from_atomic(&mode).is_active() {
287 let e = std::io::Error::new(
288 std::io::ErrorKind::NotConnected,
289 "Cannot send text: connection not active",
290 );
291 return Err(to_pyruntime_err(e));
292 }
293 rate_limiter.await_keys_ready(keys).await;
294 tracing::trace!("Sending text: {data}");
295
296 let msg = Message::Text(data);
297 writer_tx
298 .send(WriterCommand::Send(msg))
299 .map_err(to_pyruntime_err)
300 })
301 }
302
303 #[pyo3(name = "send_pong")]
309 fn py_send_pong<'py>(
310 slf: PyRef<'_, Self>,
311 data: Vec<u8>,
312 py: Python<'py>,
313 ) -> PyResult<Bound<'py, PyAny>> {
314 let writer_tx = slf.writer_tx.clone();
315 let mode = slf.connection_mode.clone();
316 let data_len = data.len();
317
318 pyo3_async_runtimes::tokio::future_into_py(py, async move {
319 if !ConnectionMode::from_atomic(&mode).is_active() {
320 let e = std::io::Error::new(
321 std::io::ErrorKind::NotConnected,
322 "Cannot send pong: connection not active",
323 );
324 return Err(to_pyruntime_err(e));
325 }
326 tracing::trace!("Sending pong frame ({data_len} bytes)");
327
328 let msg = Message::Pong(data.into());
329 writer_tx
330 .send(WriterCommand::Send(msg))
331 .map_err(to_pyruntime_err)
332 })
333 }
334}
335
336#[cfg(test)]
337#[cfg(target_os = "linux")] mod tests {
339 use std::ffi::CString;
340
341 use futures_util::{SinkExt, StreamExt};
342 use nautilus_core::python::IntoPyObjectNautilusExt;
343 use pyo3::prelude::*;
344 use tokio::{
345 net::TcpListener,
346 task::{self, JoinHandle},
347 time::{Duration, sleep},
348 };
349 use tokio_tungstenite::{
350 accept_hdr_async,
351 tungstenite::{
352 handshake::server::{self, Callback},
353 http::HeaderValue,
354 },
355 };
356 use tracing_test::traced_test;
357
358 use crate::websocket::{WebSocketClient, WebSocketConfig};
359
360 struct TestServer {
361 task: JoinHandle<()>,
362 port: u16,
363 }
364
365 #[derive(Debug, Clone)]
366 struct TestCallback {
367 key: String,
368 value: HeaderValue,
369 }
370
371 impl Callback for TestCallback {
372 #[allow(clippy::panic_in_result_fn)]
373 fn on_request(
374 self,
375 request: &server::Request,
376 response: server::Response,
377 ) -> Result<server::Response, server::ErrorResponse> {
378 let _ = response;
379 let value = request.headers().get(&self.key);
380 assert!(value.is_some());
381
382 if let Some(value) = request.headers().get(&self.key) {
383 assert_eq!(value, self.value);
384 }
385
386 Ok(response)
387 }
388 }
389
390 impl TestServer {
391 async fn setup(key: String, value: String) -> Self {
392 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
393 let port = TcpListener::local_addr(&server).unwrap().port();
394
395 let test_call_back = TestCallback {
396 key,
397 value: HeaderValue::from_str(&value).unwrap(),
398 };
399
400 let task = task::spawn(async move {
402 loop {
404 let (conn, _) = server.accept().await.unwrap();
405 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
406 .await
407 .unwrap();
408
409 task::spawn(async move {
410 while let Some(Ok(msg)) = websocket.next().await {
411 match msg {
412 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
413 if txt == "close-now" =>
414 {
415 tracing::debug!("Forcibly closing from server side");
416 let _ = websocket.close(None).await;
418 break;
419 }
420 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
422 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
423 if websocket.send(msg).await.is_err() {
424 break;
425 }
426 }
427 tokio_tungstenite::tungstenite::protocol::Message::Close(
429 _frame,
430 ) => {
431 let _ = websocket.close(None).await;
432 break;
433 }
434 _ => {}
436 }
437 }
438 });
439 }
440 });
441
442 Self { task, port }
443 }
444 }
445
446 impl Drop for TestServer {
447 fn drop(&mut self) {
448 self.task.abort();
449 }
450 }
451
452 fn create_test_handler() -> (Py<PyAny>, Py<PyAny>) {
453 let code_raw = r"
454class Counter:
455 def __init__(self):
456 self.count = 0
457 self.check = False
458
459 def handler(self, bytes):
460 msg = bytes.decode()
461 if msg == 'ping':
462 self.count += 1
463 elif msg == 'heartbeat message':
464 self.check = True
465
466 def get_check(self):
467 return self.check
468
469 def get_count(self):
470 return self.count
471
472counter = Counter()
473";
474
475 let code = CString::new(code_raw).unwrap();
476 let filename = CString::new("test".to_string()).unwrap();
477 let module = CString::new("test".to_string()).unwrap();
478 Python::attach(|py| {
479 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
480
481 let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
482 let handler = counter
483 .getattr(py, "handler")
484 .unwrap()
485 .into_py_any_unwrap(py);
486
487 (counter, handler)
488 })
489 }
490
491 #[tokio::test]
492 #[traced_test]
493 async fn basic_client_test() {
494 Python::initialize();
495
496 const N: usize = 10;
497 let mut success_count = 0;
498 let header_key = "hello-custom-key".to_string();
499 let header_value = "hello-custom-value".to_string();
500
501 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
502 let (counter, handler) = create_test_handler();
503
504 let config = WebSocketConfig::py_new(
505 format!("ws://127.0.0.1:{}", server.port),
506 Python::attach(|py| handler.clone_ref(py)),
507 vec![(header_key, header_value)],
508 None,
509 None,
510 None,
511 None,
512 None,
513 None,
514 None,
515 None,
516 None,
517 );
518 let client = WebSocketClient::connect(config, None, Vec::new(), None)
519 .await
520 .unwrap();
521
522 for _ in 0..N {
524 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
525 success_count += 1;
526 }
527
528 sleep(Duration::from_secs(1)).await;
530 let count_value: usize = Python::attach(|py| {
531 counter
532 .getattr(py, "get_count")
533 .unwrap()
534 .call0(py)
535 .unwrap()
536 .extract(py)
537 .unwrap()
538 });
539 assert_eq!(count_value, success_count);
540
541 client.send_close_message().await.unwrap();
543
544 sleep(Duration::from_secs(2)).await;
546 for _ in 0..N {
547 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
548 success_count += 1;
549 }
550
551 sleep(Duration::from_secs(1)).await;
553 let count_value: usize = Python::attach(|py| {
554 counter
555 .getattr(py, "get_count")
556 .unwrap()
557 .call0(py)
558 .unwrap()
559 .extract(py)
560 .unwrap()
561 });
562 assert_eq!(count_value, success_count);
563 assert_eq!(success_count, N + N);
564
565 client.disconnect().await;
567 assert!(client.is_disconnected());
568 }
569
570 #[tokio::test]
571 #[traced_test]
572 async fn message_ping_test() {
573 Python::initialize();
574
575 let header_key = "hello-custom-key".to_string();
576 let header_value = "hello-custom-value".to_string();
577
578 let (checker, handler) = create_test_handler();
579
580 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
582 let config = WebSocketConfig::py_new(
583 format!("ws://127.0.0.1:{}", server.port),
584 Python::attach(|py| handler.clone_ref(py)),
585 vec![(header_key, header_value)],
586 Some(1),
587 Some("heartbeat message".to_string()),
588 None,
589 None,
590 None,
591 None,
592 None,
593 None,
594 None,
595 );
596 let client = WebSocketClient::connect(config, None, Vec::new(), None)
597 .await
598 .unwrap();
599
600 sleep(Duration::from_secs(2)).await;
602 let check_value: bool = Python::attach(|py| {
603 checker
604 .getattr(py, "get_check")
605 .unwrap()
606 .call0(py)
607 .unwrap()
608 .extract(py)
609 .unwrap()
610 });
611 assert!(check_value);
612
613 client.disconnect().await;
615 assert!(client.is_disconnected());
616 }
617}