1use std::{
17 sync::{Arc, atomic::Ordering},
18 time::Duration,
19};
20
21use nautilus_core::python::{clone_py_object, to_pyruntime_err, to_pyvalue_err};
22use pyo3::{Py, create_exception, exceptions::PyException, prelude::*, types::PyBytes};
23use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
24
25use crate::{
26 RECONNECTED,
27 mode::ConnectionMode,
28 ratelimiter::quota::Quota,
29 websocket::{
30 WebSocketClient, WebSocketConfig,
31 types::{MessageHandler, PingHandler, WriterCommand},
32 },
33};
34
35create_exception!(network, WebSocketClientError, PyException);
36
37fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
38 PyErr::new::<WebSocketClientError, _>(e.to_string())
39}
40
41#[pymethods]
42impl WebSocketConfig {
43 #[new]
45 #[allow(clippy::too_many_arguments)]
46 #[pyo3(signature = (
47 url,
48 headers,
49 heartbeat=None,
50 heartbeat_msg=None,
51 reconnect_timeout_ms=10_000,
52 reconnect_delay_initial_ms=2_000,
53 reconnect_delay_max_ms=30_000,
54 reconnect_backoff_factor=1.5,
55 reconnect_jitter_ms=100,
56 reconnect_max_attempts=None,
57 ))]
58 fn py_new(
59 url: String,
60 headers: Vec<(String, String)>,
61 heartbeat: Option<u64>,
62 heartbeat_msg: Option<String>,
63 reconnect_timeout_ms: Option<u64>,
64 reconnect_delay_initial_ms: Option<u64>,
65 reconnect_delay_max_ms: Option<u64>,
66 reconnect_backoff_factor: Option<f64>,
67 reconnect_jitter_ms: Option<u64>,
68 reconnect_max_attempts: Option<u32>,
69 ) -> Self {
70 Self {
71 url,
72 headers,
73 heartbeat,
74 heartbeat_msg,
75 reconnect_timeout_ms,
76 reconnect_delay_initial_ms,
77 reconnect_delay_max_ms,
78 reconnect_backoff_factor,
79 reconnect_jitter_ms,
80 reconnect_max_attempts,
81 }
82 }
83}
84
85#[pymethods]
86impl WebSocketClient {
87 #[staticmethod]
99 #[pyo3(name = "connect", signature = (loop_, config, handler, ping_handler = None, post_reconnection = None, keyed_quotas = Vec::new(), default_quota = None))]
100 #[allow(clippy::too_many_arguments)]
101 fn py_connect(
102 loop_: Py<PyAny>,
103 config: WebSocketConfig,
104 handler: Py<PyAny>,
105 ping_handler: Option<Py<PyAny>>,
106 post_reconnection: Option<Py<PyAny>>,
107 keyed_quotas: Vec<(String, Quota)>,
108 default_quota: Option<Quota>,
109 py: Python<'_>,
110 ) -> PyResult<Bound<'_, PyAny>> {
111 let call_soon_threadsafe: Py<PyAny> = loop_.getattr(py, "call_soon_threadsafe")?;
112 let call_soon_clone = clone_py_object(&call_soon_threadsafe);
113 let handler_clone = clone_py_object(&handler);
114
115 let message_handler: MessageHandler = Arc::new(move |msg: Message| {
116 if matches!(msg, Message::Text(ref text) if text.as_str() == RECONNECTED) {
117 return;
118 }
119
120 Python::attach(|py| {
121 let py_bytes = match &msg {
122 Message::Binary(data) => PyBytes::new(py, data),
123 Message::Text(text) => PyBytes::new(py, text.as_bytes()),
124 _ => return,
125 };
126
127 if let Err(e) = call_soon_clone.call1(py, (&handler_clone, py_bytes)) {
128 tracing::error!("Error scheduling message handler on event loop: {e}");
129 }
130 });
131 });
132
133 let ping_handler_fn = ping_handler.map(|ping_handler| {
134 let ping_handler_clone = clone_py_object(&ping_handler);
135 let call_soon_clone = clone_py_object(&call_soon_threadsafe);
136
137 let ping_handler_fn: PingHandler = Arc::new(move |data: Vec<u8>| {
138 Python::attach(|py| {
139 let py_bytes = PyBytes::new(py, &data);
140 if let Err(e) = call_soon_clone.call1(py, (&ping_handler_clone, py_bytes)) {
141 tracing::error!("Error scheduling ping handler on event loop: {e}");
142 }
143 });
144 });
145 ping_handler_fn
146 });
147
148 let post_reconnection_fn = post_reconnection.map(|callback| {
149 let callback_clone = clone_py_object(&callback);
150 Arc::new(move || {
151 Python::attach(|py| {
152 if let Err(e) = callback_clone.call0(py) {
153 tracing::error!("Error calling post_reconnection handler: {e}");
154 }
155 });
156 }) as std::sync::Arc<dyn Fn() + Send + Sync>
157 });
158
159 pyo3_async_runtimes::tokio::future_into_py(py, async move {
160 Self::connect(
161 config,
162 Some(message_handler),
163 ping_handler_fn,
164 post_reconnection_fn,
165 keyed_quotas,
166 default_quota,
167 )
168 .await
169 .map_err(to_websocket_pyerr)
170 })
171 }
172
173 #[pyo3(name = "disconnect")]
183 fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
184 let connection_mode = slf.connection_mode.clone();
185 let mode = ConnectionMode::from_atomic(&connection_mode);
186 tracing::debug!("Close from mode {mode}");
187
188 pyo3_async_runtimes::tokio::future_into_py(py, async move {
189 match ConnectionMode::from_atomic(&connection_mode) {
190 ConnectionMode::Closed => {
191 tracing::debug!("WebSocket already closed");
192 }
193 ConnectionMode::Disconnect => {
194 tracing::debug!("WebSocket already disconnecting");
195 }
196 _ => {
197 connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
198 while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
199 tokio::time::sleep(Duration::from_millis(10)).await;
200 }
201 }
202 }
203
204 Ok(())
205 })
206 }
207
208 #[pyo3(name = "is_active")]
218 fn py_is_active(slf: PyRef<'_, Self>) -> bool {
219 !slf.controller_task.is_finished()
220 }
221
222 #[pyo3(name = "is_reconnecting")]
223 fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
224 slf.is_reconnecting()
225 }
226
227 #[pyo3(name = "is_disconnecting")]
228 fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
229 slf.is_disconnecting()
230 }
231
232 #[pyo3(name = "is_closed")]
233 fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
234 slf.is_closed()
235 }
236
237 #[pyo3(name = "send")]
243 #[pyo3(signature = (data, keys=None))]
244 fn py_send<'py>(
245 slf: PyRef<'_, Self>,
246 data: Vec<u8>,
247 py: Python<'py>,
248 keys: Option<Vec<String>>,
249 ) -> PyResult<Bound<'py, PyAny>> {
250 let rate_limiter = slf.rate_limiter.clone();
251 let writer_tx = slf.writer_tx.clone();
252 let mode = slf.connection_mode.clone();
253
254 pyo3_async_runtimes::tokio::future_into_py(py, async move {
255 if !ConnectionMode::from_atomic(&mode).is_active() {
256 let msg = "Cannot send data: connection not active".to_string();
257 tracing::error!("{msg}");
258 return Err(to_pyruntime_err(std::io::Error::new(
259 std::io::ErrorKind::NotConnected,
260 msg,
261 )));
262 }
263 rate_limiter.await_keys_ready(keys).await;
264 tracing::trace!("Sending binary: {data:?}");
265
266 let msg = Message::Binary(data.into());
267 writer_tx
268 .send(WriterCommand::Send(msg))
269 .map_err(to_pyruntime_err)
270 })
271 }
272
273 #[pyo3(name = "send_text")]
287 #[pyo3(signature = (data, keys=None))]
288 fn py_send_text<'py>(
289 slf: PyRef<'_, Self>,
290 data: Vec<u8>,
291 py: Python<'py>,
292 keys: Option<Vec<String>>,
293 ) -> PyResult<Bound<'py, PyAny>> {
294 let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
295 let data = Utf8Bytes::from(data_str);
296 let rate_limiter = slf.rate_limiter.clone();
297 let writer_tx = slf.writer_tx.clone();
298 let mode = slf.connection_mode.clone();
299
300 pyo3_async_runtimes::tokio::future_into_py(py, async move {
301 if !ConnectionMode::from_atomic(&mode).is_active() {
302 let e = std::io::Error::new(
303 std::io::ErrorKind::NotConnected,
304 "Cannot send text: connection not active",
305 );
306 return Err(to_pyruntime_err(e));
307 }
308 rate_limiter.await_keys_ready(keys).await;
309 tracing::trace!("Sending text: {data}");
310
311 let msg = Message::Text(data);
312 writer_tx
313 .send(WriterCommand::Send(msg))
314 .map_err(to_pyruntime_err)
315 })
316 }
317
318 #[pyo3(name = "send_pong")]
324 fn py_send_pong<'py>(
325 slf: PyRef<'_, Self>,
326 data: Vec<u8>,
327 py: Python<'py>,
328 ) -> PyResult<Bound<'py, PyAny>> {
329 let writer_tx = slf.writer_tx.clone();
330 let mode = slf.connection_mode.clone();
331 let data_len = data.len();
332
333 pyo3_async_runtimes::tokio::future_into_py(py, async move {
334 if !ConnectionMode::from_atomic(&mode).is_active() {
335 let e = std::io::Error::new(
336 std::io::ErrorKind::NotConnected,
337 "Cannot send pong: connection not active",
338 );
339 return Err(to_pyruntime_err(e));
340 }
341 tracing::trace!("Sending pong frame ({data_len} bytes)");
342
343 let msg = Message::Pong(data.into());
344 writer_tx
345 .send(WriterCommand::Send(msg))
346 .map_err(to_pyruntime_err)
347 })
348 }
349}
350
351#[cfg(test)]
352#[cfg(target_os = "linux")] mod tests {
354 use std::ffi::CString;
355
356 use futures_util::{SinkExt, StreamExt};
357 use nautilus_core::python::IntoPyObjectNautilusExt;
358 use pyo3::{prelude::*, types::PyBytes};
359 use tokio::{
360 net::TcpListener,
361 task::{self, JoinHandle},
362 time::{Duration, sleep},
363 };
364 use tokio_tungstenite::{
365 accept_hdr_async,
366 tungstenite::{
367 Message,
368 handshake::server::{self, Callback},
369 http::HeaderValue,
370 },
371 };
372 use tracing_test::traced_test;
373
374 use crate::websocket::{MessageHandler, WebSocketClient, WebSocketConfig};
375
376 struct TestServer {
377 task: JoinHandle<()>,
378 port: u16,
379 }
380
381 #[derive(Debug, Clone)]
382 struct TestCallback {
383 key: String,
384 value: HeaderValue,
385 }
386
387 impl Callback for TestCallback {
388 #[allow(clippy::panic_in_result_fn)]
389 fn on_request(
390 self,
391 request: &server::Request,
392 response: server::Response,
393 ) -> Result<server::Response, server::ErrorResponse> {
394 let _ = response;
395 let value = request.headers().get(&self.key);
396 assert!(value.is_some());
397
398 if let Some(value) = request.headers().get(&self.key) {
399 assert_eq!(value, self.value);
400 }
401
402 Ok(response)
403 }
404 }
405
406 impl TestServer {
407 async fn setup(key: String, value: String) -> Self {
408 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
409 let port = TcpListener::local_addr(&server).unwrap().port();
410
411 let test_call_back = TestCallback {
412 key,
413 value: HeaderValue::from_str(&value).unwrap(),
414 };
415
416 let task = task::spawn(async move {
418 loop {
420 let (conn, _) = server.accept().await.unwrap();
421 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
422 .await
423 .unwrap();
424
425 task::spawn(async move {
426 while let Some(Ok(msg)) = websocket.next().await {
427 match msg {
428 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
429 if txt == "close-now" =>
430 {
431 tracing::debug!("Forcibly closing from server side");
432 let _ = websocket.close(None).await;
434 break;
435 }
436 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
438 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
439 if websocket.send(msg).await.is_err() {
440 break;
441 }
442 }
443 tokio_tungstenite::tungstenite::protocol::Message::Close(
445 _frame,
446 ) => {
447 let _ = websocket.close(None).await;
448 break;
449 }
450 _ => {}
452 }
453 }
454 });
455 }
456 });
457
458 Self { task, port }
459 }
460 }
461
462 impl Drop for TestServer {
463 fn drop(&mut self) {
464 self.task.abort();
465 }
466 }
467
468 fn create_test_handler() -> (Py<PyAny>, Py<PyAny>) {
469 let code_raw = r"
470class Counter:
471 def __init__(self):
472 self.count = 0
473 self.check = False
474
475 def handler(self, bytes):
476 msg = bytes.decode()
477 if msg == 'ping':
478 self.count += 1
479 elif msg == 'heartbeat message':
480 self.check = True
481
482 def get_check(self):
483 return self.check
484
485 def get_count(self):
486 return self.count
487
488counter = Counter()
489";
490
491 let code = CString::new(code_raw).unwrap();
492 let filename = CString::new("test".to_string()).unwrap();
493 let module = CString::new("test".to_string()).unwrap();
494 Python::attach(|py| {
495 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
496
497 let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
498 let handler = counter
499 .getattr(py, "handler")
500 .unwrap()
501 .into_py_any_unwrap(py);
502
503 (counter, handler)
504 })
505 }
506
507 #[tokio::test]
508 #[traced_test]
509 async fn basic_client_test() {
510 Python::initialize();
511
512 const N: usize = 10;
513 let mut success_count = 0;
514 let header_key = "hello-custom-key".to_string();
515 let header_value = "hello-custom-value".to_string();
516
517 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
518 let (counter, handler) = create_test_handler();
519
520 let config = WebSocketConfig::py_new(
521 format!("ws://127.0.0.1:{}", server.port),
522 vec![(header_key, header_value)],
523 None,
524 None,
525 None,
526 None,
527 None,
528 None,
529 None,
530 None,
531 );
532
533 let handler_clone = Python::attach(|py| handler.clone_ref(py));
534
535 let message_handler: MessageHandler = std::sync::Arc::new(move |msg: Message| {
536 Python::attach(|py| {
537 let data = match msg {
538 Message::Binary(data) => data.to_vec(),
539 Message::Text(text) => text.as_bytes().to_vec(),
540 _ => return,
541 };
542 let py_bytes = PyBytes::new(py, &data);
543 if let Err(e) = handler_clone.call1(py, (py_bytes,)) {
544 tracing::error!("Error calling handler: {e}");
545 }
546 });
547 });
548
549 let client =
550 WebSocketClient::connect(config, Some(message_handler), None, None, vec![], None)
551 .await
552 .unwrap();
553
554 for _ in 0..N {
555 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
556 success_count += 1;
557 }
558
559 sleep(Duration::from_secs(1)).await;
560 let count_value: usize = Python::attach(|py| {
561 counter
562 .getattr(py, "get_count")
563 .unwrap()
564 .call0(py)
565 .unwrap()
566 .extract(py)
567 .unwrap()
568 });
569 assert_eq!(count_value, success_count);
570
571 client.send_close_message().await.unwrap();
573
574 sleep(Duration::from_secs(2)).await;
576 for _ in 0..N {
577 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
578 success_count += 1;
579 }
580
581 sleep(Duration::from_secs(1)).await;
582 let count_value: usize = Python::attach(|py| {
583 counter
584 .getattr(py, "get_count")
585 .unwrap()
586 .call0(py)
587 .unwrap()
588 .extract(py)
589 .unwrap()
590 });
591 assert_eq!(count_value, success_count);
592 assert_eq!(success_count, N + N);
593
594 client.disconnect().await;
595 assert!(client.is_disconnected());
596 }
597
598 #[tokio::test]
599 #[traced_test]
600 async fn message_ping_test() {
601 Python::initialize();
602
603 let header_key = "hello-custom-key".to_string();
604 let header_value = "hello-custom-value".to_string();
605
606 let (checker, handler) = create_test_handler();
607
608 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
609 let config = WebSocketConfig::py_new(
610 format!("ws://127.0.0.1:{}", server.port),
611 vec![(header_key, header_value)],
612 Some(1),
613 Some("heartbeat message".to_string()),
614 None,
615 None,
616 None,
617 None,
618 None,
619 None,
620 );
621
622 let handler_clone = Python::attach(|py| handler.clone_ref(py));
623
624 let message_handler: MessageHandler = std::sync::Arc::new(move |msg: Message| {
625 Python::attach(|py| {
626 let data = match msg {
627 Message::Binary(data) => data.to_vec(),
628 Message::Text(text) => text.as_bytes().to_vec(),
629 _ => return,
630 };
631 let py_bytes = PyBytes::new(py, &data);
632 if let Err(e) = handler_clone.call1(py, (py_bytes,)) {
633 tracing::error!("Error calling handler: {e}");
634 }
635 });
636 });
637
638 let client =
639 WebSocketClient::connect(config, Some(message_handler), None, None, vec![], None)
640 .await
641 .unwrap();
642
643 sleep(Duration::from_secs(2)).await;
644 let check_value: bool = Python::attach(|py| {
645 checker
646 .getattr(py, "get_check")
647 .unwrap()
648 .call0(py)
649 .unwrap()
650 .extract(py)
651 .unwrap()
652 });
653 assert!(check_value);
654
655 client.disconnect().await;
656 assert!(client.is_disconnected());
657 }
658}