1use std::{
17 sync::{Arc, atomic::Ordering},
18 time::Duration,
19};
20
21use nautilus_core::python::{clone_py_object, to_pyruntime_err, to_pyvalue_err};
22use pyo3::{Py, create_exception, exceptions::PyException, prelude::*, types::PyBytes};
23use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
24
25use crate::{
26 RECONNECTED,
27 mode::ConnectionMode,
28 ratelimiter::quota::Quota,
29 websocket::{
30 WebSocketClient, WebSocketConfig,
31 types::{MessageHandler, PingHandler, WriterCommand},
32 },
33};
34
35create_exception!(network, WebSocketClientError, PyException);
37
38fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
39 PyErr::new::<WebSocketClientError, _>(e.to_string())
40}
41
42#[pymethods]
43impl WebSocketConfig {
44 #[new]
45 #[allow(clippy::too_many_arguments)]
46 #[pyo3(signature = (
47 url,
48 handler,
49 headers,
50 heartbeat=None,
51 heartbeat_msg=None,
52 ping_handler=None,
53 reconnect_timeout_ms=10_000,
54 reconnect_delay_initial_ms=2_000,
55 reconnect_delay_max_ms=30_000,
56 reconnect_backoff_factor=1.5,
57 reconnect_jitter_ms=100,
58 reconnect_max_attempts=None,
59 ))]
60 fn py_new(
61 url: String,
62 handler: Py<PyAny>,
63 headers: Vec<(String, String)>,
64 heartbeat: Option<u64>,
65 heartbeat_msg: Option<String>,
66 ping_handler: Option<Py<PyAny>>,
67 reconnect_timeout_ms: Option<u64>,
68 reconnect_delay_initial_ms: Option<u64>,
69 reconnect_delay_max_ms: Option<u64>,
70 reconnect_backoff_factor: Option<f64>,
71 reconnect_jitter_ms: Option<u64>,
72 reconnect_max_attempts: Option<u32>,
73 ) -> Self {
74 let handler_clone = clone_py_object(&handler);
76 let message_handler: MessageHandler = Arc::new(move |msg: Message| {
77 Python::attach(|py| {
78 let data = match msg {
79 Message::Binary(data) => data.to_vec(),
80 Message::Text(text) => {
81 if text == RECONNECTED {
82 return;
83 }
84 text.as_bytes().to_vec()
85 }
86 _ => return,
87 };
88 if let Err(e) = handler_clone.call1(py, (PyBytes::new(py, &data),)) {
89 tracing::error!("Error calling Python message handler: {e}");
90 }
91 });
92 });
93
94 let ping_handler_fn = ping_handler.map(|ping_handler| {
96 let ping_handler_clone = clone_py_object(&ping_handler);
97 let ping_handler_fn: PingHandler = std::sync::Arc::new(move |data: Vec<u8>| {
98 Python::attach(|py| {
99 if let Err(e) = ping_handler_clone.call1(py, (PyBytes::new(py, &data),)) {
100 tracing::error!("Error calling Python ping handler: {e}");
101 }
102 });
103 });
104 ping_handler_fn
105 });
106
107 Self {
108 url,
109 message_handler: Some(message_handler),
110 headers,
111 heartbeat,
112 heartbeat_msg,
113 ping_handler: ping_handler_fn,
114 reconnect_timeout_ms,
115 reconnect_delay_initial_ms,
116 reconnect_delay_max_ms,
117 reconnect_backoff_factor,
118 reconnect_jitter_ms,
119 reconnect_max_attempts,
120 }
121 }
122}
123
124#[pymethods]
125impl WebSocketClient {
126 #[staticmethod]
132 #[pyo3(name = "connect", signature = (config, post_reconnection= None, keyed_quotas = Vec::new(), default_quota = None))]
133 fn py_connect(
134 config: WebSocketConfig,
135 post_reconnection: Option<Py<PyAny>>,
136 keyed_quotas: Vec<(String, Quota)>,
137 default_quota: Option<Quota>,
138 py: Python<'_>,
139 ) -> PyResult<Bound<'_, PyAny>> {
140 let post_reconnection_fn = post_reconnection.map(|callback| {
141 let callback_clone = clone_py_object(&callback);
142 Arc::new(move || {
143 Python::attach(|py| {
144 if let Err(e) = callback_clone.call0(py) {
145 tracing::error!("Error calling post_reconnection handler: {e}");
146 }
147 });
148 }) as std::sync::Arc<dyn Fn() + Send + Sync>
149 });
150
151 pyo3_async_runtimes::tokio::future_into_py(py, async move {
152 Self::connect(config, post_reconnection_fn, keyed_quotas, default_quota)
153 .await
154 .map_err(to_websocket_pyerr)
155 })
156 }
157
158 #[pyo3(name = "disconnect")]
168 fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
169 let connection_mode = slf.connection_mode.clone();
170 let mode = ConnectionMode::from_atomic(&connection_mode);
171 tracing::debug!("Close from mode {mode}");
172
173 pyo3_async_runtimes::tokio::future_into_py(py, async move {
174 match ConnectionMode::from_atomic(&connection_mode) {
175 ConnectionMode::Closed => {
176 tracing::debug!("WebSocket already closed");
177 }
178 ConnectionMode::Disconnect => {
179 tracing::debug!("WebSocket already disconnecting");
180 }
181 _ => {
182 connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
183 while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
184 tokio::time::sleep(Duration::from_millis(10)).await;
185 }
186 }
187 }
188
189 Ok(())
190 })
191 }
192
193 #[pyo3(name = "is_active")]
203 fn py_is_active(slf: PyRef<'_, Self>) -> bool {
204 !slf.controller_task.is_finished()
205 }
206
207 #[pyo3(name = "is_reconnecting")]
208 fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
209 slf.is_reconnecting()
210 }
211
212 #[pyo3(name = "is_disconnecting")]
213 fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
214 slf.is_disconnecting()
215 }
216
217 #[pyo3(name = "is_closed")]
218 fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
219 slf.is_closed()
220 }
221
222 #[pyo3(name = "send")]
228 #[pyo3(signature = (data, keys=None))]
229 fn py_send<'py>(
230 slf: PyRef<'_, Self>,
231 data: Vec<u8>,
232 py: Python<'py>,
233 keys: Option<Vec<String>>,
234 ) -> PyResult<Bound<'py, PyAny>> {
235 let rate_limiter = slf.rate_limiter.clone();
236 let writer_tx = slf.writer_tx.clone();
237 let mode = slf.connection_mode.clone();
238
239 pyo3_async_runtimes::tokio::future_into_py(py, async move {
240 if !ConnectionMode::from_atomic(&mode).is_active() {
241 let msg = "Cannot send data: connection not active".to_string();
242 tracing::error!("{msg}");
243 return Err(to_pyruntime_err(std::io::Error::new(
244 std::io::ErrorKind::NotConnected,
245 msg,
246 )));
247 }
248 rate_limiter.await_keys_ready(keys).await;
249 tracing::trace!("Sending binary: {data:?}");
250
251 let msg = Message::Binary(data.into());
252 writer_tx
253 .send(WriterCommand::Send(msg))
254 .map_err(to_pyruntime_err)
255 })
256 }
257
258 #[pyo3(name = "send_text")]
272 #[pyo3(signature = (data, keys=None))]
273 fn py_send_text<'py>(
274 slf: PyRef<'_, Self>,
275 data: Vec<u8>,
276 py: Python<'py>,
277 keys: Option<Vec<String>>,
278 ) -> PyResult<Bound<'py, PyAny>> {
279 let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
280 let data = Utf8Bytes::from(data_str);
281 let rate_limiter = slf.rate_limiter.clone();
282 let writer_tx = slf.writer_tx.clone();
283 let mode = slf.connection_mode.clone();
284
285 pyo3_async_runtimes::tokio::future_into_py(py, async move {
286 if !ConnectionMode::from_atomic(&mode).is_active() {
287 let e = std::io::Error::new(
288 std::io::ErrorKind::NotConnected,
289 "Cannot send text: connection not active",
290 );
291 return Err(to_pyruntime_err(e));
292 }
293 rate_limiter.await_keys_ready(keys).await;
294 tracing::trace!("Sending text: {data}");
295
296 let msg = Message::Text(data);
297 writer_tx
298 .send(WriterCommand::Send(msg))
299 .map_err(to_pyruntime_err)
300 })
301 }
302
303 #[pyo3(name = "send_pong")]
309 fn py_send_pong<'py>(
310 slf: PyRef<'_, Self>,
311 data: Vec<u8>,
312 py: Python<'py>,
313 ) -> PyResult<Bound<'py, PyAny>> {
314 let writer_tx = slf.writer_tx.clone();
315 let mode = slf.connection_mode.clone();
316 let data_len = data.len();
317
318 pyo3_async_runtimes::tokio::future_into_py(py, async move {
319 if !ConnectionMode::from_atomic(&mode).is_active() {
320 let e = std::io::Error::new(
321 std::io::ErrorKind::NotConnected,
322 "Cannot send pong: connection not active",
323 );
324 return Err(to_pyruntime_err(e));
325 }
326 tracing::trace!("Sending pong frame ({data_len} bytes)");
327
328 let msg = Message::Pong(data.into());
329 writer_tx
330 .send(WriterCommand::Send(msg))
331 .map_err(to_pyruntime_err)
332 })
333 }
334}
335
336#[cfg(test)]
340#[cfg(target_os = "linux")] mod tests {
342 use std::ffi::CString;
343
344 use futures_util::{SinkExt, StreamExt};
345 use nautilus_core::python::IntoPyObjectNautilusExt;
346 use pyo3::prelude::*;
347 use tokio::{
348 net::TcpListener,
349 task::{self, JoinHandle},
350 time::{Duration, sleep},
351 };
352 use tokio_tungstenite::{
353 accept_hdr_async,
354 tungstenite::{
355 handshake::server::{self, Callback},
356 http::HeaderValue,
357 },
358 };
359 use tracing_test::traced_test;
360
361 use crate::websocket::{WebSocketClient, WebSocketConfig};
362
363 struct TestServer {
364 task: JoinHandle<()>,
365 port: u16,
366 }
367
368 #[derive(Debug, Clone)]
369 struct TestCallback {
370 key: String,
371 value: HeaderValue,
372 }
373
374 impl Callback for TestCallback {
375 #[allow(clippy::panic_in_result_fn)]
376 fn on_request(
377 self,
378 request: &server::Request,
379 response: server::Response,
380 ) -> Result<server::Response, server::ErrorResponse> {
381 let _ = response;
382 let value = request.headers().get(&self.key);
383 assert!(value.is_some());
384
385 if let Some(value) = request.headers().get(&self.key) {
386 assert_eq!(value, self.value);
387 }
388
389 Ok(response)
390 }
391 }
392
393 impl TestServer {
394 async fn setup(key: String, value: String) -> Self {
395 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
396 let port = TcpListener::local_addr(&server).unwrap().port();
397
398 let test_call_back = TestCallback {
399 key,
400 value: HeaderValue::from_str(&value).unwrap(),
401 };
402
403 let task = task::spawn(async move {
405 loop {
407 let (conn, _) = server.accept().await.unwrap();
408 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
409 .await
410 .unwrap();
411
412 task::spawn(async move {
413 while let Some(Ok(msg)) = websocket.next().await {
414 match msg {
415 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
416 if txt == "close-now" =>
417 {
418 tracing::debug!("Forcibly closing from server side");
419 let _ = websocket.close(None).await;
421 break;
422 }
423 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
425 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
426 if websocket.send(msg).await.is_err() {
427 break;
428 }
429 }
430 tokio_tungstenite::tungstenite::protocol::Message::Close(
432 _frame,
433 ) => {
434 let _ = websocket.close(None).await;
435 break;
436 }
437 _ => {}
439 }
440 }
441 });
442 }
443 });
444
445 Self { task, port }
446 }
447 }
448
449 impl Drop for TestServer {
450 fn drop(&mut self) {
451 self.task.abort();
452 }
453 }
454
455 fn create_test_handler() -> (Py<PyAny>, Py<PyAny>) {
456 let code_raw = r"
457class Counter:
458 def __init__(self):
459 self.count = 0
460 self.check = False
461
462 def handler(self, bytes):
463 msg = bytes.decode()
464 if msg == 'ping':
465 self.count += 1
466 elif msg == 'heartbeat message':
467 self.check = True
468
469 def get_check(self):
470 return self.check
471
472 def get_count(self):
473 return self.count
474
475counter = Counter()
476";
477
478 let code = CString::new(code_raw).unwrap();
479 let filename = CString::new("test".to_string()).unwrap();
480 let module = CString::new("test".to_string()).unwrap();
481 Python::attach(|py| {
482 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
483
484 let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
485 let handler = counter
486 .getattr(py, "handler")
487 .unwrap()
488 .into_py_any_unwrap(py);
489
490 (counter, handler)
491 })
492 }
493
494 #[tokio::test]
495 #[traced_test]
496 async fn basic_client_test() {
497 Python::initialize();
498
499 const N: usize = 10;
500 let mut success_count = 0;
501 let header_key = "hello-custom-key".to_string();
502 let header_value = "hello-custom-value".to_string();
503
504 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
505 let (counter, handler) = create_test_handler();
506
507 let config = WebSocketConfig::py_new(
508 format!("ws://127.0.0.1:{}", server.port),
509 Python::attach(|py| handler.clone_ref(py)),
510 vec![(header_key, header_value)],
511 None,
512 None,
513 None,
514 None,
515 None,
516 None,
517 None,
518 None,
519 None,
520 );
521 let client = WebSocketClient::connect(config, None, Vec::new(), None)
522 .await
523 .unwrap();
524
525 for _ in 0..N {
527 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
528 success_count += 1;
529 }
530
531 sleep(Duration::from_secs(1)).await;
533 let count_value: usize = Python::attach(|py| {
534 counter
535 .getattr(py, "get_count")
536 .unwrap()
537 .call0(py)
538 .unwrap()
539 .extract(py)
540 .unwrap()
541 });
542 assert_eq!(count_value, success_count);
543
544 client.send_close_message().await.unwrap();
546
547 sleep(Duration::from_secs(2)).await;
549 for _ in 0..N {
550 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
551 success_count += 1;
552 }
553
554 sleep(Duration::from_secs(1)).await;
556 let count_value: usize = Python::attach(|py| {
557 counter
558 .getattr(py, "get_count")
559 .unwrap()
560 .call0(py)
561 .unwrap()
562 .extract(py)
563 .unwrap()
564 });
565 assert_eq!(count_value, success_count);
566 assert_eq!(success_count, N + N);
567
568 client.disconnect().await;
570 assert!(client.is_disconnected());
571 }
572
573 #[tokio::test]
574 #[traced_test]
575 async fn message_ping_test() {
576 Python::initialize();
577
578 let header_key = "hello-custom-key".to_string();
579 let header_value = "hello-custom-value".to_string();
580
581 let (checker, handler) = create_test_handler();
582
583 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
585 let config = WebSocketConfig::py_new(
586 format!("ws://127.0.0.1:{}", server.port),
587 Python::attach(|py| handler.clone_ref(py)),
588 vec![(header_key, header_value)],
589 Some(1),
590 Some("heartbeat message".to_string()),
591 None,
592 None,
593 None,
594 None,
595 None,
596 None,
597 None,
598 );
599 let client = WebSocketClient::connect(config, None, Vec::new(), None)
600 .await
601 .unwrap();
602
603 sleep(Duration::from_secs(2)).await;
605 let check_value: bool = Python::attach(|py| {
606 checker
607 .getattr(py, "get_check")
608 .unwrap()
609 .call0(py)
610 .unwrap()
611 .extract(py)
612 .unwrap()
613 });
614 assert!(check_value);
615
616 client.disconnect().await;
618 assert!(client.is_disconnected());
619 }
620}