1use std::{
17 sync::{Arc, atomic::Ordering},
18 time::Duration,
19};
20
21use nautilus_core::python::{clone_py_object, to_pyruntime_err, to_pyvalue_err};
22use pyo3::{create_exception, exceptions::PyException, prelude::*, types::PyBytes};
23use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
24
25use crate::{
26 RECONNECTED,
27 mode::ConnectionMode,
28 ratelimiter::quota::Quota,
29 websocket::{MessageHandler, PingHandler, WebSocketClient, WebSocketConfig, WriterCommand},
30};
31
32create_exception!(network, WebSocketClientError, PyException);
34
35fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
36 PyErr::new::<WebSocketClientError, _>(e.to_string())
37}
38
39#[pymethods]
40impl WebSocketConfig {
41 #[new]
42 #[allow(clippy::too_many_arguments)]
43 #[pyo3(signature = (url, handler, headers, heartbeat=None, heartbeat_msg=None, ping_handler=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100))]
44 fn py_new(
45 url: String,
46 handler: PyObject,
47 headers: Vec<(String, String)>,
48 heartbeat: Option<u64>,
49 heartbeat_msg: Option<String>,
50 ping_handler: Option<PyObject>,
51 reconnect_timeout_ms: Option<u64>,
52 reconnect_delay_initial_ms: Option<u64>,
53 reconnect_delay_max_ms: Option<u64>,
54 reconnect_backoff_factor: Option<f64>,
55 reconnect_jitter_ms: Option<u64>,
56 ) -> Self {
57 let handler_clone = clone_py_object(&handler);
59 let message_handler: MessageHandler = Arc::new(move |msg: Message| {
60 Python::with_gil(|py| {
61 let data = match msg {
62 Message::Binary(data) => data.to_vec(),
63 Message::Text(text) => {
64 if text == RECONNECTED {
66 return;
67 }
68 text.as_bytes().to_vec()
69 }
70 _ => return, };
72 if let Err(e) = handler_clone.call1(py, (PyBytes::new(py, &data),)) {
73 tracing::error!("Error calling Python message handler: {e}");
74 }
75 });
76 });
77
78 let ping_handler_fn = ping_handler.map(|ping_handler| {
80 let ping_handler_clone = clone_py_object(&ping_handler);
81 let ping_handler_fn: PingHandler = std::sync::Arc::new(move |data: Vec<u8>| {
82 Python::with_gil(|py| {
83 if let Err(e) = ping_handler_clone.call1(py, (PyBytes::new(py, &data),)) {
84 tracing::error!("Error calling Python ping handler: {e}");
85 }
86 });
87 });
88 ping_handler_fn
89 });
90
91 Self {
92 url,
93 message_handler: Some(message_handler),
94 headers,
95 heartbeat,
96 heartbeat_msg,
97 ping_handler: ping_handler_fn,
98 reconnect_timeout_ms,
99 reconnect_delay_initial_ms,
100 reconnect_delay_max_ms,
101 reconnect_backoff_factor,
102 reconnect_jitter_ms,
103 }
104 }
105}
106
107#[pymethods]
108impl WebSocketClient {
109 #[staticmethod]
115 #[pyo3(name = "connect", signature = (config, post_reconnection= None, keyed_quotas = Vec::new(), default_quota = None))]
116 fn py_connect(
117 config: WebSocketConfig,
118 post_reconnection: Option<PyObject>,
119 keyed_quotas: Vec<(String, Quota)>,
120 default_quota: Option<Quota>,
121 py: Python<'_>,
122 ) -> PyResult<Bound<'_, PyAny>> {
123 let post_reconnection_fn = post_reconnection.map(|callback| {
125 let callback_clone = clone_py_object(&callback);
126 Arc::new(move || {
127 Python::with_gil(|py| {
128 if let Err(e) = callback_clone.call0(py) {
129 tracing::error!("Error calling post_reconnection handler: {e}");
130 }
131 });
132 }) as std::sync::Arc<dyn Fn() + Send + Sync>
133 });
134
135 pyo3_async_runtimes::tokio::future_into_py(py, async move {
136 Self::connect(config, post_reconnection_fn, keyed_quotas, default_quota)
137 .await
138 .map_err(to_websocket_pyerr)
139 })
140 }
141
142 #[pyo3(name = "disconnect")]
152 fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
153 let connection_mode = slf.connection_mode.clone();
154 let mode = ConnectionMode::from_atomic(&connection_mode);
155 tracing::debug!("Close from mode {mode}");
156
157 pyo3_async_runtimes::tokio::future_into_py(py, async move {
158 match ConnectionMode::from_atomic(&connection_mode) {
159 ConnectionMode::Closed => {
160 tracing::debug!("WebSocket already closed");
161 }
162 ConnectionMode::Disconnect => {
163 tracing::debug!("WebSocket already disconnecting");
164 }
165 _ => {
166 connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
167 while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
168 tokio::time::sleep(Duration::from_millis(10)).await;
169 }
170 }
171 }
172
173 Ok(())
174 })
175 }
176
177 #[pyo3(name = "is_active")]
187 fn py_is_active(slf: PyRef<'_, Self>) -> bool {
188 !slf.controller_task.is_finished()
189 }
190
191 #[pyo3(name = "is_reconnecting")]
192 fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
193 slf.is_reconnecting()
194 }
195
196 #[pyo3(name = "is_disconnecting")]
197 fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
198 slf.is_disconnecting()
199 }
200
201 #[pyo3(name = "is_closed")]
202 fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
203 slf.is_closed()
204 }
205
206 #[pyo3(name = "send")]
212 #[pyo3(signature = (data, keys=None))]
213 fn py_send<'py>(
214 slf: PyRef<'_, Self>,
215 data: Vec<u8>,
216 py: Python<'py>,
217 keys: Option<Vec<String>>,
218 ) -> PyResult<Bound<'py, PyAny>> {
219 let rate_limiter = slf.rate_limiter.clone();
220 let writer_tx = slf.writer_tx.clone();
221 let mode = slf.connection_mode.clone();
222
223 pyo3_async_runtimes::tokio::future_into_py(py, async move {
224 if !ConnectionMode::from_atomic(&mode).is_active() {
225 let msg = "Cannot send data: connection not active".to_string();
226 tracing::error!("{msg}");
227 return Err(to_pyruntime_err(std::io::Error::new(
228 std::io::ErrorKind::NotConnected,
229 msg,
230 )));
231 }
232 rate_limiter.await_keys_ready(keys).await;
233 tracing::trace!("Sending binary: {data:?}");
234
235 let msg = Message::Binary(data.into());
236 writer_tx
237 .send(WriterCommand::Send(msg))
238 .map_err(to_pyruntime_err)
239 })
240 }
241
242 #[pyo3(name = "send_text")]
256 #[pyo3(signature = (data, keys=None))]
257 fn py_send_text<'py>(
258 slf: PyRef<'_, Self>,
259 data: Vec<u8>,
260 py: Python<'py>,
261 keys: Option<Vec<String>>,
262 ) -> PyResult<Bound<'py, PyAny>> {
263 let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
264 let data = Utf8Bytes::from(data_str);
265 let rate_limiter = slf.rate_limiter.clone();
266 let writer_tx = slf.writer_tx.clone();
267 let mode = slf.connection_mode.clone();
268
269 pyo3_async_runtimes::tokio::future_into_py(py, async move {
270 if !ConnectionMode::from_atomic(&mode).is_active() {
271 let err = std::io::Error::new(
272 std::io::ErrorKind::NotConnected,
273 "Cannot send text: connection not active",
274 );
275 return Err(to_pyruntime_err(err));
276 }
277 rate_limiter.await_keys_ready(keys).await;
278 tracing::trace!("Sending text: {data}");
279
280 let msg = Message::Text(data);
281 writer_tx
282 .send(WriterCommand::Send(msg))
283 .map_err(to_pyruntime_err)
284 })
285 }
286
287 #[pyo3(name = "send_pong")]
293 fn py_send_pong<'py>(
294 slf: PyRef<'_, Self>,
295 data: Vec<u8>,
296 py: Python<'py>,
297 ) -> PyResult<Bound<'py, PyAny>> {
298 let data_str = String::from_utf8(data.clone()).map_err(to_pyvalue_err)?;
299 let writer_tx = slf.writer_tx.clone();
300 let mode = slf.connection_mode.clone();
301
302 pyo3_async_runtimes::tokio::future_into_py(py, async move {
303 if !ConnectionMode::from_atomic(&mode).is_active() {
304 let err = std::io::Error::new(
305 std::io::ErrorKind::NotConnected,
306 "Cannot send pong: connection not active",
307 );
308 return Err(to_pyruntime_err(err));
309 }
310 tracing::trace!("Sending pong: {data_str}");
311
312 let msg = Message::Pong(data.into());
313 writer_tx
314 .send(WriterCommand::Send(msg))
315 .map_err(to_pyruntime_err)
316 })
317 }
318}
319
320#[cfg(test)]
324#[cfg(target_os = "linux")] mod tests {
326 use std::ffi::CString;
327
328 use futures_util::{SinkExt, StreamExt};
329 use nautilus_core::python::IntoPyObjectNautilusExt;
330 use pyo3::{prelude::*, prepare_freethreaded_python};
331 use tokio::{
332 net::TcpListener,
333 task::{self, JoinHandle},
334 time::{Duration, sleep},
335 };
336 use tokio_tungstenite::{
337 accept_hdr_async,
338 tungstenite::{
339 handshake::server::{self, Callback},
340 http::HeaderValue,
341 },
342 };
343 use tracing_test::traced_test;
344
345 use crate::websocket::{WebSocketClient, WebSocketConfig};
346
347 struct TestServer {
348 task: JoinHandle<()>,
349 port: u16,
350 }
351
352 #[derive(Debug, Clone)]
353 struct TestCallback {
354 key: String,
355 value: HeaderValue,
356 }
357
358 impl Callback for TestCallback {
359 fn on_request(
360 self,
361 request: &server::Request,
362 response: server::Response,
363 ) -> Result<server::Response, server::ErrorResponse> {
364 let _ = response;
365 let value = request.headers().get(&self.key);
366 assert!(value.is_some());
367
368 if let Some(value) = request.headers().get(&self.key) {
369 assert_eq!(value, self.value);
370 }
371
372 Ok(response)
373 }
374 }
375
376 impl TestServer {
377 async fn setup(key: String, value: String) -> Self {
378 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
379 let port = TcpListener::local_addr(&server).unwrap().port();
380
381 let test_call_back = TestCallback {
382 key,
383 value: HeaderValue::from_str(&value).unwrap(),
384 };
385
386 let task = task::spawn(async move {
388 loop {
390 let (conn, _) = server.accept().await.unwrap();
391 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
392 .await
393 .unwrap();
394
395 task::spawn(async move {
396 while let Some(Ok(msg)) = websocket.next().await {
397 match msg {
398 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
399 if txt == "close-now" =>
400 {
401 tracing::debug!("Forcibly closing from server side");
402 let _ = websocket.close(None).await;
404 break;
405 }
406 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
408 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
409 if websocket.send(msg).await.is_err() {
410 break;
411 }
412 }
413 tokio_tungstenite::tungstenite::protocol::Message::Close(
415 _frame,
416 ) => {
417 let _ = websocket.close(None).await;
418 break;
419 }
420 _ => {}
422 }
423 }
424 });
425 }
426 });
427
428 Self { task, port }
429 }
430 }
431
432 impl Drop for TestServer {
433 fn drop(&mut self) {
434 self.task.abort();
435 }
436 }
437
438 fn create_test_handler() -> (PyObject, PyObject) {
439 let code_raw = r"
440class Counter:
441 def __init__(self):
442 self.count = 0
443 self.check = False
444
445 def handler(self, bytes):
446 msg = bytes.decode()
447 if msg == 'ping':
448 self.count += 1
449 elif msg == 'heartbeat message':
450 self.check = True
451
452 def get_check(self):
453 return self.check
454
455 def get_count(self):
456 return self.count
457
458counter = Counter()
459";
460
461 let code = CString::new(code_raw).unwrap();
462 let filename = CString::new("test".to_string()).unwrap();
463 let module = CString::new("test".to_string()).unwrap();
464 Python::with_gil(|py| {
465 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
466
467 let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
468 let handler = counter
469 .getattr(py, "handler")
470 .unwrap()
471 .into_py_any_unwrap(py);
472
473 (counter, handler)
474 })
475 }
476
477 #[tokio::test]
478 #[traced_test]
479 async fn basic_client_test() {
480 prepare_freethreaded_python();
481
482 const N: usize = 10;
483 let mut success_count = 0;
484 let header_key = "hello-custom-key".to_string();
485 let header_value = "hello-custom-value".to_string();
486
487 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
488 let (counter, handler) = create_test_handler();
489
490 let config = WebSocketConfig::py_new(
491 format!("ws://127.0.0.1:{}", server.port),
492 Python::with_gil(|py| handler.clone_ref(py)),
493 vec![(header_key, header_value)],
494 None,
495 None,
496 None,
497 None,
498 None,
499 None,
500 None,
501 None,
502 );
503 let client = WebSocketClient::connect(config, None, Vec::new(), None)
504 .await
505 .unwrap();
506
507 for _ in 0..N {
509 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
510 success_count += 1;
511 }
512
513 sleep(Duration::from_secs(1)).await;
515 let count_value: usize = Python::with_gil(|py| {
516 counter
517 .getattr(py, "get_count")
518 .unwrap()
519 .call0(py)
520 .unwrap()
521 .extract(py)
522 .unwrap()
523 });
524 assert_eq!(count_value, success_count);
525
526 client.send_close_message().await.unwrap();
528
529 sleep(Duration::from_secs(2)).await;
531 for _ in 0..N {
532 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
533 success_count += 1;
534 }
535
536 sleep(Duration::from_secs(1)).await;
538 let count_value: usize = Python::with_gil(|py| {
539 counter
540 .getattr(py, "get_count")
541 .unwrap()
542 .call0(py)
543 .unwrap()
544 .extract(py)
545 .unwrap()
546 });
547 assert_eq!(count_value, success_count);
548 assert_eq!(success_count, N + N);
549
550 client.disconnect().await;
552 assert!(client.is_disconnected());
553 }
554
555 #[tokio::test]
556 #[traced_test]
557 async fn message_ping_test() {
558 prepare_freethreaded_python();
559
560 let header_key = "hello-custom-key".to_string();
561 let header_value = "hello-custom-value".to_string();
562
563 let (checker, handler) = create_test_handler();
564
565 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
567 let config = WebSocketConfig::py_new(
568 format!("ws://127.0.0.1:{}", server.port),
569 Python::with_gil(|py| handler.clone_ref(py)),
570 vec![(header_key, header_value)],
571 Some(1),
572 Some("heartbeat message".to_string()),
573 None,
574 None,
575 None,
576 None,
577 None,
578 None,
579 );
580 let client = WebSocketClient::connect(config, None, Vec::new(), None)
581 .await
582 .unwrap();
583
584 sleep(Duration::from_secs(2)).await;
586 let check_value: bool = Python::with_gil(|py| {
587 checker
588 .getattr(py, "get_check")
589 .unwrap()
590 .call0(py)
591 .unwrap()
592 .extract(py)
593 .unwrap()
594 });
595 assert!(check_value);
596
597 client.disconnect().await;
599 assert!(client.is_disconnected());
600 }
601}