1use std::{
17 sync::{Arc, atomic::Ordering},
18 time::Duration,
19};
20
21use nautilus_core::python::to_pyvalue_err;
22use pyo3::{create_exception, exceptions::PyException, prelude::*};
23use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
24
25use crate::{
26 mode::ConnectionMode,
27 ratelimiter::quota::Quota,
28 websocket::{Consumer, WebSocketClient, WebSocketConfig, WriterCommand},
29};
30
31create_exception!(network, WebSocketClientError, PyException);
33
34fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
35 PyErr::new::<WebSocketClientError, _>(e.to_string())
36}
37
38#[pymethods]
39impl WebSocketConfig {
40 #[new]
41 #[allow(clippy::too_many_arguments)]
42 #[pyo3(signature = (url, handler, headers, heartbeat=None, heartbeat_msg=None, ping_handler=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100))]
43 fn py_new(
44 url: String,
45 handler: PyObject,
46 headers: Vec<(String, String)>,
47 heartbeat: Option<u64>,
48 heartbeat_msg: Option<String>,
49 ping_handler: Option<PyObject>,
50 reconnect_timeout_ms: Option<u64>,
51 reconnect_delay_initial_ms: Option<u64>,
52 reconnect_delay_max_ms: Option<u64>,
53 reconnect_backoff_factor: Option<f64>,
54 reconnect_jitter_ms: Option<u64>,
55 ) -> Self {
56 Self {
57 url,
58 handler: Consumer::Python(Some(Arc::new(handler))),
59 headers,
60 heartbeat,
61 heartbeat_msg,
62 ping_handler: ping_handler.map(Arc::new),
63 reconnect_timeout_ms,
64 reconnect_delay_initial_ms,
65 reconnect_delay_max_ms,
66 reconnect_backoff_factor,
67 reconnect_jitter_ms,
68 }
69 }
70}
71
72#[pymethods]
73impl WebSocketClient {
74 #[staticmethod]
80 #[pyo3(name = "connect", signature = (config, post_connection= None, post_reconnection= None, post_disconnection= None, keyed_quotas = Vec::new(), default_quota = None))]
81 fn py_connect(
82 config: WebSocketConfig,
83 post_connection: Option<PyObject>,
84 post_reconnection: Option<PyObject>,
85 post_disconnection: Option<PyObject>,
86 keyed_quotas: Vec<(String, Quota)>,
87 default_quota: Option<Quota>,
88 py: Python<'_>,
89 ) -> PyResult<Bound<PyAny>> {
90 pyo3_async_runtimes::tokio::future_into_py(py, async move {
91 Self::connect(
92 config,
93 post_connection,
94 post_reconnection,
95 post_disconnection,
96 keyed_quotas,
97 default_quota,
98 )
99 .await
100 .map_err(to_websocket_pyerr)
101 })
102 }
103
104 #[pyo3(name = "disconnect")]
114 fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
115 let connection_mode = slf.connection_mode.clone();
116 let mode = ConnectionMode::from_atomic(&connection_mode);
117 tracing::debug!("Close from mode {mode}");
118
119 pyo3_async_runtimes::tokio::future_into_py(py, async move {
120 match ConnectionMode::from_atomic(&connection_mode) {
121 ConnectionMode::Closed => {
122 tracing::warn!("WebSocket already closed");
123 }
124 ConnectionMode::Disconnect => {
125 tracing::warn!("WebSocket already disconnecting");
126 }
127 _ => {
128 connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
129 while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
130 tokio::time::sleep(Duration::from_millis(10)).await;
131 }
132 }
133 }
134
135 Ok(())
136 })
137 }
138
139 #[pyo3(name = "is_active")]
149 fn py_is_active(slf: PyRef<'_, Self>) -> bool {
150 !slf.controller_task.is_finished()
151 }
152
153 #[pyo3(name = "is_reconnecting")]
154 fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
155 slf.is_reconnecting()
156 }
157
158 #[pyo3(name = "is_disconnecting")]
159 fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
160 slf.is_disconnecting()
161 }
162
163 #[pyo3(name = "is_closed")]
164 fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
165 slf.is_closed()
166 }
167
168 #[pyo3(name = "send")]
174 #[pyo3(signature = (data, keys=None))]
175 fn py_send<'py>(
176 slf: PyRef<'_, Self>,
177 data: Vec<u8>,
178 py: Python<'py>,
179 keys: Option<Vec<String>>,
180 ) -> PyResult<Bound<'py, PyAny>> {
181 let rate_limiter = slf.rate_limiter.clone();
182 let writer_tx = slf.writer_tx.clone();
183
184 pyo3_async_runtimes::tokio::future_into_py(py, async move {
185 rate_limiter.await_keys_ready(keys).await;
186 tracing::trace!("Sending binary: {data:?}");
187
188 let msg = Message::Binary(data.into());
189 if let Err(e) = writer_tx.send(WriterCommand::Send(msg)) {
190 tracing::error!("{e}");
191 }
192 Ok(())
193 })
194 }
195
196 #[pyo3(name = "send_text")]
210 #[pyo3(signature = (data, keys=None))]
211 fn py_send_text<'py>(
212 slf: PyRef<'_, Self>,
213 data: Vec<u8>,
214 py: Python<'py>,
215 keys: Option<Vec<String>>,
216 ) -> PyResult<Bound<'py, PyAny>> {
217 let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
218 let data = Utf8Bytes::from(data_str);
219 let rate_limiter = slf.rate_limiter.clone();
220 let writer_tx = slf.writer_tx.clone();
221
222 pyo3_async_runtimes::tokio::future_into_py(py, async move {
223 rate_limiter.await_keys_ready(keys).await;
224 tracing::trace!("Sending text: {data}");
225
226 let msg = Message::Text(data);
227 if let Err(e) = writer_tx.send(WriterCommand::Send(msg)) {
228 tracing::error!("{e}");
229 }
230 Ok(())
231 })
232 }
233
234 #[pyo3(name = "send_pong")]
240 fn py_send_pong<'py>(
241 slf: PyRef<'_, Self>,
242 data: Vec<u8>,
243 py: Python<'py>,
244 ) -> PyResult<Bound<'py, PyAny>> {
245 let data_str = String::from_utf8(data.clone()).map_err(to_pyvalue_err)?;
246 let writer_tx = slf.writer_tx.clone();
247 tracing::trace!("Sending pong: {data_str}");
248
249 pyo3_async_runtimes::tokio::future_into_py(py, async move {
250 let msg = Message::Pong(data.into());
251 if let Err(e) = writer_tx.send(WriterCommand::Send(msg)) {
252 tracing::error!("{e}");
253 }
254 Ok(())
255 })
256 }
257}
258
259#[cfg(test)]
263#[cfg(target_os = "linux")] mod tests {
265 use std::ffi::CString;
266
267 use futures_util::{SinkExt, StreamExt};
268 use nautilus_core::python::IntoPyObjectNautilusExt;
269 use pyo3::{prelude::*, prepare_freethreaded_python};
270 use tokio::{
271 net::TcpListener,
272 task::{self, JoinHandle},
273 time::{Duration, sleep},
274 };
275 use tokio_tungstenite::{
276 accept_hdr_async,
277 tungstenite::{
278 handshake::server::{self, Callback},
279 http::HeaderValue,
280 },
281 };
282 use tracing_test::traced_test;
283
284 use crate::websocket::{WebSocketClient, WebSocketConfig};
285
286 struct TestServer {
287 task: JoinHandle<()>,
288 port: u16,
289 }
290
291 #[derive(Debug, Clone)]
292 struct TestCallback {
293 key: String,
294 value: HeaderValue,
295 }
296
297 impl Callback for TestCallback {
298 fn on_request(
299 self,
300 request: &server::Request,
301 response: server::Response,
302 ) -> Result<server::Response, server::ErrorResponse> {
303 let _ = response;
304 let value = request.headers().get(&self.key);
305 assert!(value.is_some());
306
307 if let Some(value) = request.headers().get(&self.key) {
308 assert_eq!(value, self.value);
309 }
310
311 Ok(response)
312 }
313 }
314
315 impl TestServer {
316 async fn setup(key: String, value: String) -> Self {
317 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
318 let port = TcpListener::local_addr(&server).unwrap().port();
319
320 let test_call_back = TestCallback {
321 key,
322 value: HeaderValue::from_str(&value).unwrap(),
323 };
324
325 let task = task::spawn(async move {
327 loop {
329 let (conn, _) = server.accept().await.unwrap();
330 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
331 .await
332 .unwrap();
333
334 task::spawn(async move {
335 while let Some(Ok(msg)) = websocket.next().await {
336 match msg {
337 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
338 if txt == "close-now" =>
339 {
340 tracing::debug!("Forcibly closing from server side");
341 let _ = websocket.close(None).await;
343 break;
344 }
345 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
347 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
348 if websocket.send(msg).await.is_err() {
349 break;
350 }
351 }
352 tokio_tungstenite::tungstenite::protocol::Message::Close(
354 _frame,
355 ) => {
356 let _ = websocket.close(None).await;
357 break;
358 }
359 _ => {}
361 }
362 }
363 });
364 }
365 });
366
367 Self { task, port }
368 }
369 }
370
371 impl Drop for TestServer {
372 fn drop(&mut self) {
373 self.task.abort();
374 }
375 }
376
377 fn create_test_handler() -> (PyObject, PyObject) {
378 let code_raw = r"
379class Counter:
380 def __init__(self):
381 self.count = 0
382 self.check = False
383
384 def handler(self, bytes):
385 msg = bytes.decode()
386 if msg == 'ping':
387 self.count += 1
388 elif msg == 'heartbeat message':
389 self.check = True
390
391 def get_check(self):
392 return self.check
393
394 def get_count(self):
395 return self.count
396
397counter = Counter()
398";
399
400 let code = CString::new(code_raw).unwrap();
401 let filename = CString::new("test".to_string()).unwrap();
402 let module = CString::new("test".to_string()).unwrap();
403 Python::with_gil(|py| {
404 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
405
406 let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
407 let handler = counter
408 .getattr(py, "handler")
409 .unwrap()
410 .into_py_any_unwrap(py);
411
412 (counter, handler)
413 })
414 }
415
416 #[tokio::test]
417 #[traced_test]
418 async fn basic_client_test() {
419 prepare_freethreaded_python();
420
421 const N: usize = 10;
422 let mut success_count = 0;
423 let header_key = "hello-custom-key".to_string();
424 let header_value = "hello-custom-value".to_string();
425
426 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
427 let (counter, handler) = create_test_handler();
428
429 let config = WebSocketConfig::py_new(
430 format!("ws://127.0.0.1:{}", server.port),
431 Python::with_gil(|py| handler.clone_ref(py)),
432 vec![(header_key, header_value)],
433 None,
434 None,
435 None,
436 None,
437 None,
438 None,
439 None,
440 None,
441 );
442 let client = WebSocketClient::connect(config, None, None, None, Vec::new(), None)
443 .await
444 .unwrap();
445
446 for _ in 0..N {
448 client.send_bytes(b"ping".to_vec(), None).await;
449 success_count += 1;
450 }
451
452 sleep(Duration::from_secs(1)).await;
454 let count_value: usize = Python::with_gil(|py| {
455 counter
456 .getattr(py, "get_count")
457 .unwrap()
458 .call0(py)
459 .unwrap()
460 .extract(py)
461 .unwrap()
462 });
463 assert_eq!(count_value, success_count);
464
465 client.send_close_message().await;
467
468 sleep(Duration::from_secs(2)).await;
470 for _ in 0..N {
471 client.send_bytes(b"ping".to_vec(), None).await;
472 success_count += 1;
473 }
474
475 sleep(Duration::from_secs(1)).await;
477 let count_value: usize = Python::with_gil(|py| {
478 counter
479 .getattr(py, "get_count")
480 .unwrap()
481 .call0(py)
482 .unwrap()
483 .extract(py)
484 .unwrap()
485 });
486 assert_eq!(count_value, success_count);
487 assert_eq!(success_count, N + N);
488
489 client.disconnect().await;
491 assert!(client.is_disconnected());
492 }
493
494 #[tokio::test]
495 #[traced_test]
496 async fn message_ping_test() {
497 prepare_freethreaded_python();
498
499 let header_key = "hello-custom-key".to_string();
500 let header_value = "hello-custom-value".to_string();
501
502 let (checker, handler) = create_test_handler();
503
504 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
506 let config = WebSocketConfig::py_new(
507 format!("ws://127.0.0.1:{}", server.port),
508 Python::with_gil(|py| handler.clone_ref(py)),
509 vec![(header_key, header_value)],
510 Some(1),
511 Some("heartbeat message".to_string()),
512 None,
513 None,
514 None,
515 None,
516 None,
517 None,
518 );
519 let client = WebSocketClient::connect(config, None, None, None, Vec::new(), None)
520 .await
521 .unwrap();
522
523 sleep(Duration::from_secs(2)).await;
525 let check_value: bool = Python::with_gil(|py| {
526 checker
527 .getattr(py, "get_check")
528 .unwrap()
529 .call0(py)
530 .unwrap()
531 .extract(py)
532 .unwrap()
533 });
534 assert!(check_value);
535
536 client.disconnect().await;
538 assert!(client.is_disconnected());
539 }
540}