1use std::{
17 sync::{Arc, atomic::Ordering},
18 time::Duration,
19};
20
21use futures::SinkExt;
22use nautilus_core::python::to_pyvalue_err;
23use pyo3::{create_exception, exceptions::PyException, prelude::*};
24use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
25
26use crate::{
27 mode::ConnectionMode,
28 ratelimiter::quota::Quota,
29 websocket::{WebSocketClient, WebSocketConfig},
30};
31
32create_exception!(network, WebSocketClientError, PyException);
34
35fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
36 PyErr::new::<WebSocketClientError, _>(e.to_string())
37}
38
39#[pymethods]
40impl WebSocketConfig {
41 #[new]
42 #[pyo3(signature = (url, handler, headers, heartbeat=None, heartbeat_msg=None, ping_handler=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100))]
43 #[allow(clippy::too_many_arguments)]
44 fn py_new(
45 url: String,
46 handler: PyObject,
47 headers: Vec<(String, String)>,
48 heartbeat: Option<u64>,
49 heartbeat_msg: Option<String>,
50 ping_handler: Option<PyObject>,
51 reconnect_timeout_ms: Option<u64>,
52 reconnect_delay_initial_ms: Option<u64>,
53 reconnect_delay_max_ms: Option<u64>,
54 reconnect_backoff_factor: Option<f64>,
55 reconnect_jitter_ms: Option<u64>,
56 ) -> Self {
57 Self {
58 url,
59 handler: Some(Arc::new(handler)),
60 headers,
61 heartbeat,
62 heartbeat_msg,
63 ping_handler: ping_handler.map(Arc::new),
64 reconnect_timeout_ms,
65 reconnect_delay_initial_ms,
66 reconnect_delay_max_ms,
67 reconnect_backoff_factor,
68 reconnect_jitter_ms,
69 }
70 }
71}
72
73#[pymethods]
74impl WebSocketClient {
75 #[staticmethod]
81 #[pyo3(name = "connect", signature = (config, post_connection= None, post_reconnection= None, post_disconnection= None, keyed_quotas = Vec::new(), default_quota = None))]
82 fn py_connect(
83 config: WebSocketConfig,
84 post_connection: Option<PyObject>,
85 post_reconnection: Option<PyObject>,
86 post_disconnection: Option<PyObject>,
87 keyed_quotas: Vec<(String, Quota)>,
88 default_quota: Option<Quota>,
89 py: Python<'_>,
90 ) -> PyResult<Bound<PyAny>> {
91 pyo3_async_runtimes::tokio::future_into_py(py, async move {
92 Self::connect(
93 config,
94 post_connection,
95 post_reconnection,
96 post_disconnection,
97 keyed_quotas,
98 default_quota,
99 )
100 .await
101 .map_err(to_websocket_pyerr)
102 })
103 }
104
105 #[pyo3(name = "disconnect")]
115 fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
116 let connection_mode = slf.connection_mode.clone();
117 let mode = ConnectionMode::from_atomic(&connection_mode);
118 tracing::debug!("Close from mode {mode}");
119
120 pyo3_async_runtimes::tokio::future_into_py(py, async move {
121 match ConnectionMode::from_atomic(&connection_mode) {
122 ConnectionMode::Closed => {
123 tracing::warn!("WebSocket already closed");
124 }
125 ConnectionMode::Disconnect => {
126 tracing::warn!("WebSocket already disconnecting");
127 }
128 _ => {
129 connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
130 while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
131 tokio::time::sleep(Duration::from_millis(10)).await;
132 }
133 }
134 }
135
136 Ok(())
137 })
138 }
139
140 #[pyo3(name = "is_active")]
150 fn py_is_active(slf: PyRef<'_, Self>) -> bool {
151 !slf.controller_task.is_finished()
152 }
153
154 #[pyo3(name = "is_reconnecting")]
155 fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
156 slf.is_reconnecting()
157 }
158
159 #[pyo3(name = "is_disconnecting")]
160 fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
161 slf.is_disconnecting()
162 }
163
164 #[pyo3(name = "is_closed")]
165 fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
166 slf.is_closed()
167 }
168
169 #[pyo3(name = "send")]
175 #[pyo3(signature = (data, keys=None))]
176 fn py_send<'py>(
177 slf: PyRef<'_, Self>,
178 data: Vec<u8>,
179 py: Python<'py>,
180 keys: Option<Vec<String>>,
181 ) -> PyResult<Bound<'py, PyAny>> {
182 let writer = slf.writer.clone();
183 let rate_limiter = slf.rate_limiter.clone();
184
185 pyo3_async_runtimes::tokio::future_into_py(py, async move {
186 rate_limiter.await_keys_ready(keys).await;
187 tracing::trace!("Sending binary: {data:?}");
188
189 let mut guard = writer.lock().await;
190 guard
191 .send(Message::Binary(data.into()))
192 .await
193 .map_err(to_websocket_pyerr)
194 })
195 }
196
197 #[pyo3(name = "send_text")]
211 #[pyo3(signature = (data, keys=None))]
212 fn py_send_text<'py>(
213 slf: PyRef<'_, Self>,
214 data: Vec<u8>,
215 py: Python<'py>,
216 keys: Option<Vec<String>>,
217 ) -> PyResult<Bound<'py, PyAny>> {
218 let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
219 let data = Utf8Bytes::from(data_str);
220 let writer = slf.writer.clone();
221 let rate_limiter = slf.rate_limiter.clone();
222
223 pyo3_async_runtimes::tokio::future_into_py(py, async move {
224 rate_limiter.await_keys_ready(keys).await;
225 tracing::trace!("Sending text: {data}");
226
227 let mut guard = writer.lock().await;
228 guard
229 .send(Message::Text(data))
230 .await
231 .map_err(to_websocket_pyerr)
232 })
233 }
234
235 #[pyo3(name = "send_pong")]
241 fn py_send_pong<'py>(
242 slf: PyRef<'_, Self>,
243 data: Vec<u8>,
244 py: Python<'py>,
245 ) -> PyResult<Bound<'py, PyAny>> {
246 let data_str = String::from_utf8(data.clone()).map_err(to_pyvalue_err)?;
247 let writer = slf.writer.clone();
248 tracing::trace!("Sending pong: {data_str}");
249
250 pyo3_async_runtimes::tokio::future_into_py(py, async move {
251 let mut guard = writer.lock().await;
252 guard
253 .send(Message::Pong(data.into()))
254 .await
255 .map_err(to_websocket_pyerr)
256 })
257 }
258}
259
260#[cfg(test)]
264#[cfg(target_os = "linux")] mod tests {
266 use std::ffi::CString;
267
268 use futures_util::{SinkExt, StreamExt};
269 use nautilus_core::python::IntoPyObjectNautilusExt;
270 use pyo3::{prelude::*, prepare_freethreaded_python};
271 use tokio::{
272 net::TcpListener,
273 task::{self, JoinHandle},
274 time::{Duration, sleep},
275 };
276 use tokio_tungstenite::{
277 accept_hdr_async,
278 tungstenite::{
279 handshake::server::{self, Callback},
280 http::HeaderValue,
281 },
282 };
283 use tracing_test::traced_test;
284
285 use crate::websocket::{WebSocketClient, WebSocketConfig};
286
287 struct TestServer {
288 task: JoinHandle<()>,
289 port: u16,
290 }
291
292 #[derive(Debug, Clone)]
293 struct TestCallback {
294 key: String,
295 value: HeaderValue,
296 }
297
298 impl Callback for TestCallback {
299 fn on_request(
300 self,
301 request: &server::Request,
302 response: server::Response,
303 ) -> Result<server::Response, server::ErrorResponse> {
304 let _ = response;
305 let value = request.headers().get(&self.key);
306 assert!(value.is_some());
307
308 if let Some(value) = request.headers().get(&self.key) {
309 assert_eq!(value, self.value);
310 }
311
312 Ok(response)
313 }
314 }
315
316 impl TestServer {
317 async fn setup(key: String, value: String) -> Self {
318 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
319 let port = TcpListener::local_addr(&server).unwrap().port();
320
321 let test_call_back = TestCallback {
322 key,
323 value: HeaderValue::from_str(&value).unwrap(),
324 };
325
326 let task = task::spawn(async move {
328 loop {
330 let (conn, _) = server.accept().await.unwrap();
331 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
332 .await
333 .unwrap();
334
335 task::spawn(async move {
336 while let Some(Ok(msg)) = websocket.next().await {
337 match msg {
338 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
339 if txt == "close-now" =>
340 {
341 tracing::debug!("Forcibly closing from server side");
342 let _ = websocket.close(None).await;
344 break;
345 }
346 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
348 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
349 if websocket.send(msg).await.is_err() {
350 break;
351 }
352 }
353 tokio_tungstenite::tungstenite::protocol::Message::Close(
355 _frame,
356 ) => {
357 let _ = websocket.close(None).await;
358 break;
359 }
360 _ => {}
362 }
363 }
364 });
365 }
366 });
367
368 Self { task, port }
369 }
370 }
371
372 impl Drop for TestServer {
373 fn drop(&mut self) {
374 self.task.abort();
375 }
376 }
377
378 fn create_test_handler() -> (PyObject, PyObject) {
379 let code_raw = r"
380class Counter:
381 def __init__(self):
382 self.count = 0
383 self.check = False
384
385 def handler(self, bytes):
386 msg = bytes.decode()
387 if msg == 'ping':
388 self.count += 1
389 elif msg == 'heartbeat message':
390 self.check = True
391
392 def get_check(self):
393 return self.check
394
395 def get_count(self):
396 return self.count
397
398counter = Counter()
399";
400
401 let code = CString::new(code_raw).unwrap();
402 let filename = CString::new("test".to_string()).unwrap();
403 let module = CString::new("test".to_string()).unwrap();
404 Python::with_gil(|py| {
405 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
406
407 let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
408 let handler = counter
409 .getattr(py, "handler")
410 .unwrap()
411 .into_py_any_unwrap(py);
412
413 (counter, handler)
414 })
415 }
416
417 #[tokio::test]
418 #[traced_test]
419 async fn basic_client_test() {
420 prepare_freethreaded_python();
421
422 const N: usize = 10;
423 let mut success_count = 0;
424 let header_key = "hello-custom-key".to_string();
425 let header_value = "hello-custom-value".to_string();
426
427 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
428 let (counter, handler) = create_test_handler();
429
430 let config = WebSocketConfig::py_new(
431 format!("ws://127.0.0.1:{}", server.port),
432 Python::with_gil(|py| handler.clone_ref(py)),
433 vec![(header_key, header_value)],
434 None,
435 None,
436 None,
437 None,
438 None,
439 None,
440 None,
441 None,
442 );
443 let client = WebSocketClient::connect(config, None, None, None, Vec::new(), None)
444 .await
445 .unwrap();
446
447 for _ in 0..N {
449 client.send_bytes(b"ping".to_vec(), None).await;
450 success_count += 1;
451 }
452
453 sleep(Duration::from_secs(1)).await;
455 let count_value: usize = Python::with_gil(|py| {
456 counter
457 .getattr(py, "get_count")
458 .unwrap()
459 .call0(py)
460 .unwrap()
461 .extract(py)
462 .unwrap()
463 });
464 assert_eq!(count_value, success_count);
465
466 client.send_close_message().await;
468
469 sleep(Duration::from_secs(2)).await;
471 for _ in 0..N {
472 client.send_bytes(b"ping".to_vec(), None).await;
473 success_count += 1;
474 }
475
476 sleep(Duration::from_secs(1)).await;
478 let count_value: usize = Python::with_gil(|py| {
479 counter
480 .getattr(py, "get_count")
481 .unwrap()
482 .call0(py)
483 .unwrap()
484 .extract(py)
485 .unwrap()
486 });
487 assert_eq!(count_value, success_count);
488 assert_eq!(success_count, N + N);
489
490 client.disconnect().await;
492 assert!(client.is_disconnected());
493 }
494
495 #[tokio::test]
496 #[traced_test]
497 async fn message_ping_test() {
498 prepare_freethreaded_python();
499
500 let header_key = "hello-custom-key".to_string();
501 let header_value = "hello-custom-value".to_string();
502
503 let (checker, handler) = create_test_handler();
504
505 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
507 let config = WebSocketConfig::py_new(
508 format!("ws://127.0.0.1:{}", server.port),
509 Python::with_gil(|py| handler.clone_ref(py)),
510 vec![(header_key, header_value)],
511 Some(1),
512 Some("heartbeat message".to_string()),
513 None,
514 None,
515 None,
516 None,
517 None,
518 None,
519 );
520 let client = WebSocketClient::connect(config, None, None, None, Vec::new(), None)
521 .await
522 .unwrap();
523
524 sleep(Duration::from_secs(2)).await;
526 let check_value: bool = Python::with_gil(|py| {
527 checker
528 .getattr(py, "get_check")
529 .unwrap()
530 .call0(py)
531 .unwrap()
532 .extract(py)
533 .unwrap()
534 });
535 assert!(check_value);
536
537 client.disconnect().await;
539 assert!(client.is_disconnected());
540 }
541}