1use std::{
17 sync::{Arc, atomic::Ordering},
18 time::Duration,
19};
20
21use nautilus_core::python::{clone_py_object, to_pyruntime_err, to_pyvalue_err};
22use pyo3::{Py, create_exception, exceptions::PyException, prelude::*, types::PyBytes};
23use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
24
25use crate::{
26 RECONNECTED,
27 mode::ConnectionMode,
28 ratelimiter::quota::Quota,
29 websocket::{
30 WebSocketClient, WebSocketConfig,
31 types::{MessageHandler, PingHandler, WriterCommand},
32 },
33};
34
35create_exception!(network, WebSocketClientError, PyException);
36
37fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
38 PyErr::new::<WebSocketClientError, _>(e.to_string())
39}
40
41#[pymethods]
42impl WebSocketConfig {
43 #[new]
45 #[allow(clippy::too_many_arguments)]
46 #[pyo3(signature = (
47 url,
48 headers,
49 heartbeat=None,
50 heartbeat_msg=None,
51 reconnect_timeout_ms=10_000,
52 reconnect_delay_initial_ms=2_000,
53 reconnect_delay_max_ms=30_000,
54 reconnect_backoff_factor=1.5,
55 reconnect_jitter_ms=100,
56 reconnect_max_attempts=None,
57 ))]
58 fn py_new(
59 url: String,
60 headers: Vec<(String, String)>,
61 heartbeat: Option<u64>,
62 heartbeat_msg: Option<String>,
63 reconnect_timeout_ms: Option<u64>,
64 reconnect_delay_initial_ms: Option<u64>,
65 reconnect_delay_max_ms: Option<u64>,
66 reconnect_backoff_factor: Option<f64>,
67 reconnect_jitter_ms: Option<u64>,
68 reconnect_max_attempts: Option<u32>,
69 ) -> Self {
70 Self {
71 url,
72 headers,
73 heartbeat,
74 heartbeat_msg,
75 reconnect_timeout_ms,
76 reconnect_delay_initial_ms,
77 reconnect_delay_max_ms,
78 reconnect_backoff_factor,
79 reconnect_jitter_ms,
80 reconnect_max_attempts,
81 }
82 }
83}
84
85#[pymethods]
86impl WebSocketClient {
87 #[staticmethod]
99 #[pyo3(name = "connect", signature = (loop_, config, handler, ping_handler = None, post_reconnection = None, keyed_quotas = Vec::new(), default_quota = None))]
100 #[allow(clippy::too_many_arguments)]
101 fn py_connect(
102 loop_: Py<PyAny>,
103 config: WebSocketConfig,
104 handler: Py<PyAny>,
105 ping_handler: Option<Py<PyAny>>,
106 post_reconnection: Option<Py<PyAny>>,
107 keyed_quotas: Vec<(String, Quota)>,
108 default_quota: Option<Quota>,
109 py: Python<'_>,
110 ) -> PyResult<Bound<'_, PyAny>> {
111 let call_soon_threadsafe: Py<PyAny> = loop_.getattr(py, "call_soon_threadsafe")?;
112 let call_soon_clone = clone_py_object(&call_soon_threadsafe);
113 let handler_clone = clone_py_object(&handler);
114
115 let message_handler: MessageHandler = Arc::new(move |msg: Message| {
116 if matches!(msg, Message::Text(ref text) if text.as_str() == RECONNECTED) {
117 return;
118 }
119
120 Python::attach(|py| {
121 let py_bytes = match &msg {
122 Message::Binary(data) => PyBytes::new(py, data),
123 Message::Text(text) => PyBytes::new(py, text.as_bytes()),
124 _ => return,
125 };
126
127 if let Err(e) = call_soon_clone.call1(py, (&handler_clone, py_bytes)) {
128 log::error!("Error scheduling message handler on event loop: {e}");
129 }
130 });
131 });
132
133 let ping_handler_fn = ping_handler.map(|ping_handler| {
134 let ping_handler_clone = clone_py_object(&ping_handler);
135 let call_soon_clone = clone_py_object(&call_soon_threadsafe);
136
137 let ping_handler_fn: PingHandler = Arc::new(move |data: Vec<u8>| {
138 Python::attach(|py| {
139 let py_bytes = PyBytes::new(py, &data);
140 if let Err(e) = call_soon_clone.call1(py, (&ping_handler_clone, py_bytes)) {
141 log::error!("Error scheduling ping handler on event loop: {e}");
142 }
143 });
144 });
145 ping_handler_fn
146 });
147
148 let post_reconnection_fn = post_reconnection.map(|callback| {
149 let callback_clone = clone_py_object(&callback);
150 Arc::new(move || {
151 Python::attach(|py| {
152 if let Err(e) = callback_clone.call0(py) {
153 log::error!("Error calling post_reconnection handler: {e}");
154 }
155 });
156 }) as std::sync::Arc<dyn Fn() + Send + Sync>
157 });
158
159 pyo3_async_runtimes::tokio::future_into_py(py, async move {
160 Self::connect(
161 config,
162 Some(message_handler),
163 ping_handler_fn,
164 post_reconnection_fn,
165 keyed_quotas,
166 default_quota,
167 )
168 .await
169 .map_err(to_websocket_pyerr)
170 })
171 }
172
173 #[pyo3(name = "disconnect")]
183 fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
184 let connection_mode = slf.connection_mode.clone();
185 let mode = ConnectionMode::from_atomic(&connection_mode);
186 log::debug!("Close from mode {mode}");
187
188 pyo3_async_runtimes::tokio::future_into_py(py, async move {
189 match ConnectionMode::from_atomic(&connection_mode) {
190 ConnectionMode::Closed => {
191 log::debug!("WebSocket already closed");
192 }
193 ConnectionMode::Disconnect => {
194 log::debug!("WebSocket already disconnecting");
195 }
196 _ => {
197 connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
198 while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
199 tokio::time::sleep(Duration::from_millis(10)).await;
200 }
201 }
202 }
203
204 Ok(())
205 })
206 }
207
208 #[pyo3(name = "is_active")]
218 fn py_is_active(slf: PyRef<'_, Self>) -> bool {
219 !slf.controller_task.is_finished()
220 }
221
222 #[pyo3(name = "is_reconnecting")]
223 fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
224 slf.is_reconnecting()
225 }
226
227 #[pyo3(name = "is_disconnecting")]
228 fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
229 slf.is_disconnecting()
230 }
231
232 #[pyo3(name = "is_closed")]
233 fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
234 slf.is_closed()
235 }
236
237 #[pyo3(name = "send")]
243 #[pyo3(signature = (data, keys=None))]
244 fn py_send<'py>(
245 slf: PyRef<'_, Self>,
246 data: Vec<u8>,
247 py: Python<'py>,
248 keys: Option<Vec<String>>,
249 ) -> PyResult<Bound<'py, PyAny>> {
250 let rate_limiter = slf.rate_limiter.clone();
251 let writer_tx = slf.writer_tx.clone();
252 let mode = slf.connection_mode.clone();
253
254 pyo3_async_runtimes::tokio::future_into_py(py, async move {
255 if !ConnectionMode::from_atomic(&mode).is_active() {
256 let msg = "Cannot send data: connection not active".to_string();
257 log::error!("{msg}");
258 return Err(to_pyruntime_err(std::io::Error::new(
259 std::io::ErrorKind::NotConnected,
260 msg,
261 )));
262 }
263 rate_limiter.await_keys_ready(keys).await;
264 log::trace!("Sending binary: {data:?}");
265
266 let msg = Message::Binary(data.into());
267 writer_tx
268 .send(WriterCommand::Send(msg))
269 .map_err(to_pyruntime_err)
270 })
271 }
272
273 #[pyo3(name = "send_text")]
287 #[pyo3(signature = (data, keys=None))]
288 fn py_send_text<'py>(
289 slf: PyRef<'_, Self>,
290 data: Vec<u8>,
291 py: Python<'py>,
292 keys: Option<Vec<String>>,
293 ) -> PyResult<Bound<'py, PyAny>> {
294 let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
295 let data = Utf8Bytes::from(data_str);
296 let rate_limiter = slf.rate_limiter.clone();
297 let writer_tx = slf.writer_tx.clone();
298 let mode = slf.connection_mode.clone();
299
300 pyo3_async_runtimes::tokio::future_into_py(py, async move {
301 if !ConnectionMode::from_atomic(&mode).is_active() {
302 let e = std::io::Error::new(
303 std::io::ErrorKind::NotConnected,
304 "Cannot send text: connection not active",
305 );
306 return Err(to_pyruntime_err(e));
307 }
308 rate_limiter.await_keys_ready(keys).await;
309 log::trace!("Sending text: {data}");
310
311 let msg = Message::Text(data);
312 writer_tx
313 .send(WriterCommand::Send(msg))
314 .map_err(to_pyruntime_err)
315 })
316 }
317
318 #[pyo3(name = "send_pong")]
324 fn py_send_pong<'py>(
325 slf: PyRef<'_, Self>,
326 data: Vec<u8>,
327 py: Python<'py>,
328 ) -> PyResult<Bound<'py, PyAny>> {
329 let writer_tx = slf.writer_tx.clone();
330 let mode = slf.connection_mode.clone();
331 let data_len = data.len();
332
333 pyo3_async_runtimes::tokio::future_into_py(py, async move {
334 if !ConnectionMode::from_atomic(&mode).is_active() {
335 let e = std::io::Error::new(
336 std::io::ErrorKind::NotConnected,
337 "Cannot send pong: connection not active",
338 );
339 return Err(to_pyruntime_err(e));
340 }
341 log::trace!("Sending pong frame ({data_len} bytes)");
342
343 let msg = Message::Pong(data.into());
344 writer_tx
345 .send(WriterCommand::Send(msg))
346 .map_err(to_pyruntime_err)
347 })
348 }
349}
350
351#[cfg(test)]
352#[cfg(target_os = "linux")] mod tests {
354 use std::ffi::CString;
355
356 use futures_util::{SinkExt, StreamExt};
357 use nautilus_core::python::IntoPyObjectNautilusExt;
358 use pyo3::{prelude::*, types::PyBytes};
359 use tokio::{
360 net::TcpListener,
361 task::{self, JoinHandle},
362 time::{Duration, sleep},
363 };
364 use tokio_tungstenite::{
365 accept_hdr_async,
366 tungstenite::{
367 Message,
368 handshake::server::{self, Callback},
369 http::HeaderValue,
370 },
371 };
372
373 use crate::websocket::{MessageHandler, WebSocketClient, WebSocketConfig};
374
375 struct TestServer {
376 task: JoinHandle<()>,
377 port: u16,
378 }
379
380 #[derive(Debug, Clone)]
381 struct TestCallback {
382 key: String,
383 value: HeaderValue,
384 }
385
386 impl Callback for TestCallback {
387 #[allow(clippy::panic_in_result_fn)]
388 fn on_request(
389 self,
390 request: &server::Request,
391 response: server::Response,
392 ) -> Result<server::Response, server::ErrorResponse> {
393 let _ = response;
394 let value = request.headers().get(&self.key);
395 assert!(value.is_some());
396
397 if let Some(value) = request.headers().get(&self.key) {
398 assert_eq!(value, self.value);
399 }
400
401 Ok(response)
402 }
403 }
404
405 impl TestServer {
406 async fn setup(key: String, value: String) -> Self {
407 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
408 let port = TcpListener::local_addr(&server).unwrap().port();
409
410 let test_call_back = TestCallback {
411 key,
412 value: HeaderValue::from_str(&value).unwrap(),
413 };
414
415 let task = task::spawn(async move {
417 loop {
419 let (conn, _) = server.accept().await.unwrap();
420 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
421 .await
422 .unwrap();
423
424 task::spawn(async move {
425 while let Some(Ok(msg)) = websocket.next().await {
426 match msg {
427 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
428 if txt == "close-now" =>
429 {
430 log::debug!("Forcibly closing from server side");
431 let _ = websocket.close(None).await;
433 break;
434 }
435 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
437 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
438 if websocket.send(msg).await.is_err() {
439 break;
440 }
441 }
442 tokio_tungstenite::tungstenite::protocol::Message::Close(
444 _frame,
445 ) => {
446 let _ = websocket.close(None).await;
447 break;
448 }
449 _ => {}
451 }
452 }
453 });
454 }
455 });
456
457 Self { task, port }
458 }
459 }
460
461 impl Drop for TestServer {
462 fn drop(&mut self) {
463 self.task.abort();
464 }
465 }
466
467 fn create_test_handler() -> (Py<PyAny>, Py<PyAny>) {
468 let code_raw = r"
469class Counter:
470 def __init__(self):
471 self.count = 0
472 self.check = False
473
474 def handler(self, bytes):
475 msg = bytes.decode()
476 if msg == 'ping':
477 self.count += 1
478 elif msg == 'heartbeat message':
479 self.check = True
480
481 def get_check(self):
482 return self.check
483
484 def get_count(self):
485 return self.count
486
487counter = Counter()
488";
489
490 let code = CString::new(code_raw).unwrap();
491 let filename = CString::new("test".to_string()).unwrap();
492 let module = CString::new("test".to_string()).unwrap();
493 Python::attach(|py| {
494 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
495
496 let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
497 let handler = counter
498 .getattr(py, "handler")
499 .unwrap()
500 .into_py_any_unwrap(py);
501
502 (counter, handler)
503 })
504 }
505
506 #[tokio::test]
507 async fn basic_client_test() {
508 const N: usize = 10;
509
510 Python::initialize();
511
512 let mut success_count = 0;
513 let header_key = "hello-custom-key".to_string();
514 let header_value = "hello-custom-value".to_string();
515
516 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
517 let (counter, handler) = create_test_handler();
518
519 let config = WebSocketConfig::py_new(
520 format!("ws://127.0.0.1:{}", server.port),
521 vec![(header_key, header_value)],
522 None,
523 None,
524 None,
525 None,
526 None,
527 None,
528 None,
529 None,
530 );
531
532 let handler_clone = Python::attach(|py| handler.clone_ref(py));
533
534 let message_handler: MessageHandler = std::sync::Arc::new(move |msg: Message| {
535 Python::attach(|py| {
536 let data = match msg {
537 Message::Binary(data) => data.to_vec(),
538 Message::Text(text) => text.as_bytes().to_vec(),
539 _ => return,
540 };
541 let py_bytes = PyBytes::new(py, &data);
542 if let Err(e) = handler_clone.call1(py, (py_bytes,)) {
543 log::error!("Error calling handler: {e}");
544 }
545 });
546 });
547
548 let client =
549 WebSocketClient::connect(config, Some(message_handler), None, None, vec![], None)
550 .await
551 .unwrap();
552
553 for _ in 0..N {
554 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
555 success_count += 1;
556 }
557
558 sleep(Duration::from_secs(1)).await;
559 let count_value: usize = Python::attach(|py| {
560 counter
561 .getattr(py, "get_count")
562 .unwrap()
563 .call0(py)
564 .unwrap()
565 .extract(py)
566 .unwrap()
567 });
568 assert_eq!(count_value, success_count);
569
570 client.send_close_message().await.unwrap();
572
573 sleep(Duration::from_secs(2)).await;
575 for _ in 0..N {
576 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
577 success_count += 1;
578 }
579
580 sleep(Duration::from_secs(1)).await;
581 let count_value: usize = Python::attach(|py| {
582 counter
583 .getattr(py, "get_count")
584 .unwrap()
585 .call0(py)
586 .unwrap()
587 .extract(py)
588 .unwrap()
589 });
590 assert_eq!(count_value, success_count);
591 assert_eq!(success_count, N + N);
592
593 client.disconnect().await;
594 assert!(client.is_disconnected());
595 }
596
597 #[tokio::test]
598 async fn message_ping_test() {
599 Python::initialize();
600
601 let header_key = "hello-custom-key".to_string();
602 let header_value = "hello-custom-value".to_string();
603
604 let (checker, handler) = create_test_handler();
605
606 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
607 let config = WebSocketConfig::py_new(
608 format!("ws://127.0.0.1:{}", server.port),
609 vec![(header_key, header_value)],
610 Some(1),
611 Some("heartbeat message".to_string()),
612 None,
613 None,
614 None,
615 None,
616 None,
617 None,
618 );
619
620 let handler_clone = Python::attach(|py| handler.clone_ref(py));
621
622 let message_handler: MessageHandler = std::sync::Arc::new(move |msg: Message| {
623 Python::attach(|py| {
624 let data = match msg {
625 Message::Binary(data) => data.to_vec(),
626 Message::Text(text) => text.as_bytes().to_vec(),
627 _ => return,
628 };
629 let py_bytes = PyBytes::new(py, &data);
630 if let Err(e) = handler_clone.call1(py, (py_bytes,)) {
631 log::error!("Error calling handler: {e}");
632 }
633 });
634 });
635
636 let client =
637 WebSocketClient::connect(config, Some(message_handler), None, None, vec![], None)
638 .await
639 .unwrap();
640
641 sleep(Duration::from_secs(2)).await;
642 let check_value: bool = Python::attach(|py| {
643 checker
644 .getattr(py, "get_check")
645 .unwrap()
646 .call0(py)
647 .unwrap()
648 .extract(py)
649 .unwrap()
650 });
651 assert!(check_value);
652
653 client.disconnect().await;
654 assert!(client.is_disconnected());
655 }
656}