1use std::{
17 sync::{atomic::Ordering, Arc},
18 time::Duration,
19};
20
21use futures::SinkExt;
22use nautilus_core::python::to_pyvalue_err;
23use pyo3::{create_exception, exceptions::PyException, prelude::*};
24use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
25
26use crate::{
27 mode::ConnectionMode,
28 ratelimiter::quota::Quota,
29 websocket::{WebSocketClient, WebSocketConfig},
30};
31
32create_exception!(network, WebSocketClientError, PyException);
34
35fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
36 PyErr::new::<WebSocketClientError, _>(e.to_string())
37}
38
39#[pymethods]
40impl WebSocketConfig {
41 #[new]
42 #[pyo3(signature = (url, handler, headers, heartbeat=None, heartbeat_msg=None, ping_handler=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100))]
43 #[allow(clippy::too_many_arguments)]
44 fn py_new(
45 url: String,
46 handler: PyObject,
47 headers: Vec<(String, String)>,
48 heartbeat: Option<u64>,
49 heartbeat_msg: Option<String>,
50 ping_handler: Option<PyObject>,
51 reconnect_timeout_ms: Option<u64>,
52 reconnect_delay_initial_ms: Option<u64>,
53 reconnect_delay_max_ms: Option<u64>,
54 reconnect_backoff_factor: Option<f64>,
55 reconnect_jitter_ms: Option<u64>,
56 ) -> Self {
57 Self {
58 url,
59 handler: Some(Arc::new(handler)),
60 headers,
61 heartbeat,
62 heartbeat_msg,
63 ping_handler: ping_handler.map(Arc::new),
64 reconnect_timeout_ms,
65 reconnect_delay_initial_ms,
66 reconnect_delay_max_ms,
67 reconnect_backoff_factor,
68 reconnect_jitter_ms,
69 }
70 }
71}
72
73#[pymethods]
74impl WebSocketClient {
75 #[staticmethod]
81 #[pyo3(name = "connect", signature = (config, post_connection= None, post_reconnection= None, post_disconnection= None, keyed_quotas = Vec::new(), default_quota = None))]
82 fn py_connect(
83 config: WebSocketConfig,
84 post_connection: Option<PyObject>,
85 post_reconnection: Option<PyObject>,
86 post_disconnection: Option<PyObject>,
87 keyed_quotas: Vec<(String, Quota)>,
88 default_quota: Option<Quota>,
89 py: Python<'_>,
90 ) -> PyResult<Bound<PyAny>> {
91 pyo3_async_runtimes::tokio::future_into_py(py, async move {
92 Self::connect(
93 config,
94 post_connection,
95 post_reconnection,
96 post_disconnection,
97 keyed_quotas,
98 default_quota,
99 )
100 .await
101 .map_err(to_websocket_pyerr)
102 })
103 }
104
105 #[pyo3(name = "disconnect")]
115 fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
116 let connection_mode = slf.connection_mode.clone();
117 let mode = ConnectionMode::from_atomic(&connection_mode);
118 tracing::debug!("Close from mode {mode}");
119
120 pyo3_async_runtimes::tokio::future_into_py(py, async move {
121 match ConnectionMode::from_atomic(&connection_mode) {
122 ConnectionMode::Closed => {
123 tracing::warn!("WebSocket already closed");
124 }
125 ConnectionMode::Disconnect => {
126 tracing::warn!("WebSocket already disconnecting");
127 }
128 _ => {
129 connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
130 while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
131 tokio::time::sleep(Duration::from_millis(10)).await;
132 }
133 }
134 }
135
136 Ok(())
137 })
138 }
139
140 #[pyo3(name = "is_active")]
150 fn py_is_active(slf: PyRef<'_, Self>) -> bool {
151 !slf.controller_task.is_finished()
152 }
153
154 #[pyo3(name = "is_reconnecting")]
155 fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
156 slf.is_reconnecting()
157 }
158
159 #[pyo3(name = "is_disconnecting")]
160 fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
161 slf.is_disconnecting()
162 }
163
164 #[pyo3(name = "is_closed")]
165 fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
166 slf.is_closed()
167 }
168
169 #[pyo3(name = "send")]
175 #[pyo3(signature = (data, keys=None))]
176 fn py_send<'py>(
177 slf: PyRef<'_, Self>,
178 data: Vec<u8>,
179 py: Python<'py>,
180 keys: Option<Vec<String>>,
181 ) -> PyResult<Bound<'py, PyAny>> {
182 let writer = slf.writer.clone();
183 let rate_limiter = slf.rate_limiter.clone();
184
185 pyo3_async_runtimes::tokio::future_into_py(py, async move {
186 rate_limiter.await_keys_ready(keys).await;
187 tracing::trace!("Sending binary: {data:?}");
188
189 let mut guard = writer.lock().await;
190 guard
191 .send(Message::Binary(data.into()))
192 .await
193 .map_err(to_websocket_pyerr)
194 })
195 }
196
197 #[pyo3(name = "send_text")]
211 #[pyo3(signature = (data, keys=None))]
212 fn py_send_text<'py>(
213 slf: PyRef<'_, Self>,
214 data: Vec<u8>,
215 py: Python<'py>,
216 keys: Option<Vec<String>>,
217 ) -> PyResult<Bound<'py, PyAny>> {
218 let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
219 let data = Utf8Bytes::from(data_str);
220 let writer = slf.writer.clone();
221 let rate_limiter = slf.rate_limiter.clone();
222
223 pyo3_async_runtimes::tokio::future_into_py(py, async move {
224 rate_limiter.await_keys_ready(keys).await;
225 tracing::trace!("Sending text: {data}");
226
227 let mut guard = writer.lock().await;
228 guard
229 .send(Message::Text(data))
230 .await
231 .map_err(to_websocket_pyerr)
232 })
233 }
234
235 #[pyo3(name = "send_pong")]
241 fn py_send_pong<'py>(
242 slf: PyRef<'_, Self>,
243 data: Vec<u8>,
244 py: Python<'py>,
245 ) -> PyResult<Bound<'py, PyAny>> {
246 let data_str = String::from_utf8(data.clone()).map_err(to_pyvalue_err)?;
247 let writer = slf.writer.clone();
248 tracing::trace!("Sending pong: {data_str}");
249
250 pyo3_async_runtimes::tokio::future_into_py(py, async move {
251 let mut guard = writer.lock().await;
252 guard
253 .send(Message::Pong(data.into()))
254 .await
255 .map_err(to_websocket_pyerr)
256 })
257 }
258}
259
260#[cfg(test)]
264#[cfg(target_os = "linux")] mod tests {
266 use std::ffi::CString;
267
268 use futures_util::{SinkExt, StreamExt};
269 use pyo3::{prelude::*, prepare_freethreaded_python};
270 use tokio::{
271 net::TcpListener,
272 task::{self, JoinHandle},
273 time::{sleep, Duration},
274 };
275 use tokio_tungstenite::{
276 accept_hdr_async,
277 tungstenite::{
278 handshake::server::{self, Callback},
279 http::HeaderValue,
280 },
281 };
282 use tracing_test::traced_test;
283
284 use crate::websocket::{WebSocketClient, WebSocketConfig};
285
286 struct TestServer {
287 task: JoinHandle<()>,
288 port: u16,
289 }
290
291 #[derive(Debug, Clone)]
292 struct TestCallback {
293 key: String,
294 value: HeaderValue,
295 }
296
297 impl Callback for TestCallback {
298 fn on_request(
299 self,
300 request: &server::Request,
301 response: server::Response,
302 ) -> Result<server::Response, server::ErrorResponse> {
303 let _ = response;
304 let value = request.headers().get(&self.key);
305 assert!(value.is_some());
306
307 if let Some(value) = request.headers().get(&self.key) {
308 assert_eq!(value, self.value);
309 }
310
311 Ok(response)
312 }
313 }
314
315 impl TestServer {
316 async fn setup(key: String, value: String) -> Self {
317 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
318 let port = TcpListener::local_addr(&server).unwrap().port();
319
320 let test_call_back = TestCallback {
321 key,
322 value: HeaderValue::from_str(&value).unwrap(),
323 };
324
325 let task = task::spawn(async move {
327 loop {
329 let (conn, _) = server.accept().await.unwrap();
330 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
331 .await
332 .unwrap();
333
334 task::spawn(async move {
335 while let Some(Ok(msg)) = websocket.next().await {
336 match msg {
337 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
338 if txt == "close-now" =>
339 {
340 tracing::debug!("Forcibly closing from server side");
341 let _ = websocket.close(None).await;
343 break;
344 }
345 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
347 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
348 if websocket.send(msg).await.is_err() {
349 break;
350 }
351 }
352 tokio_tungstenite::tungstenite::protocol::Message::Close(
354 _frame,
355 ) => {
356 let _ = websocket.close(None).await;
357 break;
358 }
359 _ => {}
361 }
362 }
363 });
364 }
365 });
366
367 Self { task, port }
368 }
369 }
370
371 impl Drop for TestServer {
372 fn drop(&mut self) {
373 self.task.abort();
374 }
375 }
376
377 fn create_test_handler() -> (PyObject, PyObject) {
378 let code_raw = r#"
379class Counter:
380 def __init__(self):
381 self.count = 0
382 self.check = False
383
384 def handler(self, bytes):
385 msg = bytes.decode()
386 if msg == 'ping':
387 self.count += 1
388 elif msg == 'heartbeat message':
389 self.check = True
390
391 def get_check(self):
392 return self.check
393
394 def get_count(self):
395 return self.count
396
397counter = Counter()
398"#;
399
400 let code = CString::new(code_raw).unwrap();
401 let filename = CString::new("test".to_string()).unwrap();
402 let module = CString::new("test".to_string()).unwrap();
403 Python::with_gil(|py| {
404 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
405
406 let counter = pymod.getattr("counter").unwrap().into_py(py);
407 let handler = counter.getattr(py, "handler").unwrap().into_py(py);
408
409 (counter, handler)
410 })
411 }
412
413 #[tokio::test]
414 #[traced_test]
415 async fn basic_client_test() {
416 prepare_freethreaded_python();
417
418 const N: usize = 10;
419 let mut success_count = 0;
420 let header_key = "hello-custom-key".to_string();
421 let header_value = "hello-custom-value".to_string();
422
423 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
424 let (counter, handler) = create_test_handler();
425
426 let config = WebSocketConfig::py_new(
427 format!("ws://127.0.0.1:{}", server.port),
428 Python::with_gil(|py| handler.clone_ref(py)),
429 vec![(header_key, header_value)],
430 None,
431 None,
432 None,
433 None,
434 None,
435 None,
436 None,
437 None,
438 );
439 let client = WebSocketClient::connect(config, None, None, None, Vec::new(), None)
440 .await
441 .unwrap();
442
443 for _ in 0..N {
445 if client.send_bytes(b"ping".to_vec(), None).await.is_ok() {
446 success_count += 1;
447 };
448 }
449
450 sleep(Duration::from_secs(1)).await;
452 let count_value: usize = Python::with_gil(|py| {
453 counter
454 .getattr(py, "get_count")
455 .unwrap()
456 .call0(py)
457 .unwrap()
458 .extract(py)
459 .unwrap()
460 });
461 assert_eq!(count_value, success_count);
462
463 client.send_close_message().await;
465
466 sleep(Duration::from_secs(2)).await;
468 for _ in 0..N {
469 if client.send_bytes(b"ping".to_vec(), None).await.is_ok() {
470 success_count += 1;
471 };
472 }
473
474 sleep(Duration::from_secs(1)).await;
476 let count_value: usize = Python::with_gil(|py| {
477 counter
478 .getattr(py, "get_count")
479 .unwrap()
480 .call0(py)
481 .unwrap()
482 .extract(py)
483 .unwrap()
484 });
485 assert_eq!(count_value, success_count);
486 assert_eq!(success_count, N + N);
487
488 client.disconnect().await;
490 assert!(client.is_disconnected());
491 }
492
493 #[tokio::test]
494 #[traced_test]
495 async fn message_ping_test() {
496 prepare_freethreaded_python();
497
498 let header_key = "hello-custom-key".to_string();
499 let header_value = "hello-custom-value".to_string();
500
501 let (checker, handler) = create_test_handler();
502
503 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
505 let config = WebSocketConfig::py_new(
506 format!("ws://127.0.0.1:{}", server.port),
507 Python::with_gil(|py| handler.clone_ref(py)),
508 vec![(header_key, header_value)],
509 Some(1),
510 Some("heartbeat message".to_string()),
511 None,
512 None,
513 None,
514 None,
515 None,
516 None,
517 );
518 let client = WebSocketClient::connect(config, None, None, None, Vec::new(), None)
519 .await
520 .unwrap();
521
522 sleep(Duration::from_secs(2)).await;
524 let check_value: bool = Python::with_gil(|py| {
525 checker
526 .getattr(py, "get_check")
527 .unwrap()
528 .call0(py)
529 .unwrap()
530 .extract(py)
531 .unwrap()
532 });
533 assert!(check_value);
534
535 client.disconnect().await;
537 assert!(client.is_disconnected());
538 }
539}