nautilus_network/python/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    sync::{atomic::Ordering, Arc},
18    time::Duration,
19};
20
21use futures::SinkExt;
22use nautilus_core::python::to_pyvalue_err;
23use pyo3::{create_exception, exceptions::PyException, prelude::*};
24use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
25
26use crate::{
27    mode::ConnectionMode,
28    ratelimiter::quota::Quota,
29    websocket::{WebSocketClient, WebSocketConfig},
30};
31
32// Python exception class for websocket errors
33create_exception!(network, WebSocketClientError, PyException);
34
35fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
36    PyErr::new::<WebSocketClientError, _>(e.to_string())
37}
38
39#[pymethods]
40impl WebSocketConfig {
41    #[new]
42    #[pyo3(signature = (url, handler, headers, heartbeat=None, heartbeat_msg=None, ping_handler=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100))]
43    #[allow(clippy::too_many_arguments)]
44    fn py_new(
45        url: String,
46        handler: PyObject,
47        headers: Vec<(String, String)>,
48        heartbeat: Option<u64>,
49        heartbeat_msg: Option<String>,
50        ping_handler: Option<PyObject>,
51        reconnect_timeout_ms: Option<u64>,
52        reconnect_delay_initial_ms: Option<u64>,
53        reconnect_delay_max_ms: Option<u64>,
54        reconnect_backoff_factor: Option<f64>,
55        reconnect_jitter_ms: Option<u64>,
56    ) -> Self {
57        Self {
58            url,
59            handler: Some(Arc::new(handler)),
60            headers,
61            heartbeat,
62            heartbeat_msg,
63            ping_handler: ping_handler.map(Arc::new),
64            reconnect_timeout_ms,
65            reconnect_delay_initial_ms,
66            reconnect_delay_max_ms,
67            reconnect_backoff_factor,
68            reconnect_jitter_ms,
69        }
70    }
71}
72
73#[pymethods]
74impl WebSocketClient {
75    /// Create a websocket client.
76    ///
77    /// # Safety
78    ///
79    /// - Throws an Exception if it is unable to make websocket connection.
80    #[staticmethod]
81    #[pyo3(name = "connect", signature = (config, post_connection= None, post_reconnection= None, post_disconnection= None, keyed_quotas = Vec::new(), default_quota = None))]
82    fn py_connect(
83        config: WebSocketConfig,
84        post_connection: Option<PyObject>,
85        post_reconnection: Option<PyObject>,
86        post_disconnection: Option<PyObject>,
87        keyed_quotas: Vec<(String, Quota)>,
88        default_quota: Option<Quota>,
89        py: Python<'_>,
90    ) -> PyResult<Bound<PyAny>> {
91        pyo3_async_runtimes::tokio::future_into_py(py, async move {
92            Self::connect(
93                config,
94                post_connection,
95                post_reconnection,
96                post_disconnection,
97                keyed_quotas,
98                default_quota,
99            )
100            .await
101            .map_err(to_websocket_pyerr)
102        })
103    }
104
105    /// Closes the client heart beat and reader task.
106    ///
107    /// The connection is not completely closed the till all references
108    /// to the client are gone and the client is dropped.
109    ///
110    /// # Safety
111    ///
112    /// - The client should not be used after closing it.
113    /// - Any auto-reconnect job should be aborted before closing the client.
114    #[pyo3(name = "disconnect")]
115    fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
116        let connection_mode = slf.connection_mode.clone();
117        let mode = ConnectionMode::from_atomic(&connection_mode);
118        tracing::debug!("Close from mode {mode}");
119
120        pyo3_async_runtimes::tokio::future_into_py(py, async move {
121            match ConnectionMode::from_atomic(&connection_mode) {
122                ConnectionMode::Closed => {
123                    tracing::warn!("WebSocket already closed");
124                }
125                ConnectionMode::Disconnect => {
126                    tracing::warn!("WebSocket already disconnecting");
127                }
128                _ => {
129                    connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
130                    while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
131                        tokio::time::sleep(Duration::from_millis(10)).await;
132                    }
133                }
134            }
135
136            Ok(())
137        })
138    }
139
140    /// Check if the client is still alive.
141    ///
142    /// Even if the connection is disconnected the client will still be alive
143    /// and trying to reconnect.
144    ///
145    /// This is particularly useful for checking why a `send` failed. It could
146    /// be because the connection disconnected and the client is still alive
147    /// and reconnecting. In such cases the send can be retried after some
148    /// delay.
149    #[pyo3(name = "is_active")]
150    fn py_is_active(slf: PyRef<'_, Self>) -> bool {
151        !slf.controller_task.is_finished()
152    }
153
154    #[pyo3(name = "is_reconnecting")]
155    fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
156        slf.is_reconnecting()
157    }
158
159    #[pyo3(name = "is_disconnecting")]
160    fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
161        slf.is_disconnecting()
162    }
163
164    #[pyo3(name = "is_closed")]
165    fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
166        slf.is_closed()
167    }
168
169    /// Send bytes data to the server.
170    ///
171    /// # Errors
172    ///
173    /// - Raises PyRuntimeError if not able to send data.
174    #[pyo3(name = "send")]
175    #[pyo3(signature = (data, keys=None))]
176    fn py_send<'py>(
177        slf: PyRef<'_, Self>,
178        data: Vec<u8>,
179        py: Python<'py>,
180        keys: Option<Vec<String>>,
181    ) -> PyResult<Bound<'py, PyAny>> {
182        let writer = slf.writer.clone();
183        let rate_limiter = slf.rate_limiter.clone();
184
185        pyo3_async_runtimes::tokio::future_into_py(py, async move {
186            rate_limiter.await_keys_ready(keys).await;
187            tracing::trace!("Sending binary: {data:?}");
188
189            let mut guard = writer.lock().await;
190            guard
191                .send(Message::Binary(data.into()))
192                .await
193                .map_err(to_websocket_pyerr)
194        })
195    }
196
197    /// Send UTF-8 encoded bytes as text data to the server, respecting rate limits.
198    ///
199    /// `data`: The byte data to be sent, which will be converted to a UTF-8 string.
200    /// `keys`: Optional list of rate limit keys. If provided, the function will wait for rate limits to be met for each key before sending the data.
201    ///
202    /// # Errors
203    /// - Raises `PyRuntimeError` if unable to send the data.
204    ///
205    /// # Example
206    ///
207    /// When a request is made the URL should be split into all relevant keys within it.
208    ///
209    /// For request /foo/bar, should pass keys ["foo/bar", "foo"] for rate limiting.
210    #[pyo3(name = "send_text")]
211    #[pyo3(signature = (data, keys=None))]
212    fn py_send_text<'py>(
213        slf: PyRef<'_, Self>,
214        data: Vec<u8>,
215        py: Python<'py>,
216        keys: Option<Vec<String>>,
217    ) -> PyResult<Bound<'py, PyAny>> {
218        let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
219        let data = Utf8Bytes::from(data_str);
220        let writer = slf.writer.clone();
221        let rate_limiter = slf.rate_limiter.clone();
222
223        pyo3_async_runtimes::tokio::future_into_py(py, async move {
224            rate_limiter.await_keys_ready(keys).await;
225            tracing::trace!("Sending text: {data}");
226
227            let mut guard = writer.lock().await;
228            guard
229                .send(Message::Text(data))
230                .await
231                .map_err(to_websocket_pyerr)
232        })
233    }
234
235    /// Send pong bytes data to the server.
236    ///
237    /// # Errors
238    ///
239    /// - Raises PyRuntimeError if not able to send data.
240    #[pyo3(name = "send_pong")]
241    fn py_send_pong<'py>(
242        slf: PyRef<'_, Self>,
243        data: Vec<u8>,
244        py: Python<'py>,
245    ) -> PyResult<Bound<'py, PyAny>> {
246        let data_str = String::from_utf8(data.clone()).map_err(to_pyvalue_err)?;
247        let writer = slf.writer.clone();
248        tracing::trace!("Sending pong: {data_str}");
249
250        pyo3_async_runtimes::tokio::future_into_py(py, async move {
251            let mut guard = writer.lock().await;
252            guard
253                .send(Message::Pong(data.into()))
254                .await
255                .map_err(to_websocket_pyerr)
256        })
257    }
258}
259
260////////////////////////////////////////////////////////////////////////////////
261// Tests
262////////////////////////////////////////////////////////////////////////////////
263#[cfg(test)]
264#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
265mod tests {
266    use std::ffi::CString;
267
268    use futures_util::{SinkExt, StreamExt};
269    use pyo3::{prelude::*, prepare_freethreaded_python};
270    use tokio::{
271        net::TcpListener,
272        task::{self, JoinHandle},
273        time::{sleep, Duration},
274    };
275    use tokio_tungstenite::{
276        accept_hdr_async,
277        tungstenite::{
278            handshake::server::{self, Callback},
279            http::HeaderValue,
280        },
281    };
282    use tracing_test::traced_test;
283
284    use crate::websocket::{WebSocketClient, WebSocketConfig};
285
286    struct TestServer {
287        task: JoinHandle<()>,
288        port: u16,
289    }
290
291    #[derive(Debug, Clone)]
292    struct TestCallback {
293        key: String,
294        value: HeaderValue,
295    }
296
297    impl Callback for TestCallback {
298        fn on_request(
299            self,
300            request: &server::Request,
301            response: server::Response,
302        ) -> Result<server::Response, server::ErrorResponse> {
303            let _ = response;
304            let value = request.headers().get(&self.key);
305            assert!(value.is_some());
306
307            if let Some(value) = request.headers().get(&self.key) {
308                assert_eq!(value, self.value);
309            }
310
311            Ok(response)
312        }
313    }
314
315    impl TestServer {
316        async fn setup(key: String, value: String) -> Self {
317            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
318            let port = TcpListener::local_addr(&server).unwrap().port();
319
320            let test_call_back = TestCallback {
321                key,
322                value: HeaderValue::from_str(&value).unwrap(),
323            };
324
325            // Set up test server
326            let task = task::spawn(async move {
327                // Keep accepting connections
328                loop {
329                    let (conn, _) = server.accept().await.unwrap();
330                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
331                        .await
332                        .unwrap();
333
334                    task::spawn(async move {
335                        while let Some(Ok(msg)) = websocket.next().await {
336                            match msg {
337                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
338                                    if txt == "close-now" =>
339                                {
340                                    tracing::debug!("Forcibly closing from server side");
341                                    // This sends a close frame, then stops reading
342                                    let _ = websocket.close(None).await;
343                                    break;
344                                }
345                                // Echo text/binary frames
346                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
347                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
348                                    if websocket.send(msg).await.is_err() {
349                                        break;
350                                    }
351                                }
352                                // If the client closes, we also break
353                                tokio_tungstenite::tungstenite::protocol::Message::Close(
354                                    _frame,
355                                ) => {
356                                    let _ = websocket.close(None).await;
357                                    break;
358                                }
359                                // Ignore pings/pongs
360                                _ => {}
361                            }
362                        }
363                    });
364                }
365            });
366
367            Self { task, port }
368        }
369    }
370
371    impl Drop for TestServer {
372        fn drop(&mut self) {
373            self.task.abort();
374        }
375    }
376
377    fn create_test_handler() -> (PyObject, PyObject) {
378        let code_raw = r#"
379class Counter:
380    def __init__(self):
381        self.count = 0
382        self.check = False
383
384    def handler(self, bytes):
385        msg = bytes.decode()
386        if msg == 'ping':
387            self.count += 1
388        elif msg == 'heartbeat message':
389            self.check = True
390
391    def get_check(self):
392        return self.check
393
394    def get_count(self):
395        return self.count
396
397counter = Counter()
398"#;
399
400        let code = CString::new(code_raw).unwrap();
401        let filename = CString::new("test".to_string()).unwrap();
402        let module = CString::new("test".to_string()).unwrap();
403        Python::with_gil(|py| {
404            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
405
406            let counter = pymod.getattr("counter").unwrap().into_py(py);
407            let handler = counter.getattr(py, "handler").unwrap().into_py(py);
408
409            (counter, handler)
410        })
411    }
412
413    #[tokio::test]
414    #[traced_test]
415    async fn basic_client_test() {
416        prepare_freethreaded_python();
417
418        const N: usize = 10;
419        let mut success_count = 0;
420        let header_key = "hello-custom-key".to_string();
421        let header_value = "hello-custom-value".to_string();
422
423        let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
424        let (counter, handler) = create_test_handler();
425
426        let config = WebSocketConfig::py_new(
427            format!("ws://127.0.0.1:{}", server.port),
428            Python::with_gil(|py| handler.clone_ref(py)),
429            vec![(header_key, header_value)],
430            None,
431            None,
432            None,
433            None,
434            None,
435            None,
436            None,
437            None,
438        );
439        let client = WebSocketClient::connect(config, None, None, None, Vec::new(), None)
440            .await
441            .unwrap();
442
443        // Send messages that increment the count
444        for _ in 0..N {
445            if client.send_bytes(b"ping".to_vec(), None).await.is_ok() {
446                success_count += 1;
447            };
448        }
449
450        // Check count is same as number messages sent
451        sleep(Duration::from_secs(1)).await;
452        let count_value: usize = Python::with_gil(|py| {
453            counter
454                .getattr(py, "get_count")
455                .unwrap()
456                .call0(py)
457                .unwrap()
458                .extract(py)
459                .unwrap()
460        });
461        assert_eq!(count_value, success_count);
462
463        // Close the connection => client should reconnect automatically
464        client.send_close_message().await;
465
466        // Send messages that increment the count
467        sleep(Duration::from_secs(2)).await;
468        for _ in 0..N {
469            if client.send_bytes(b"ping".to_vec(), None).await.is_ok() {
470                success_count += 1;
471            };
472        }
473
474        // Check count is same as number messages sent
475        sleep(Duration::from_secs(1)).await;
476        let count_value: usize = Python::with_gil(|py| {
477            counter
478                .getattr(py, "get_count")
479                .unwrap()
480                .call0(py)
481                .unwrap()
482                .extract(py)
483                .unwrap()
484        });
485        assert_eq!(count_value, success_count);
486        assert_eq!(success_count, N + N);
487
488        // Cleanup
489        client.disconnect().await;
490        assert!(client.is_disconnected());
491    }
492
493    #[tokio::test]
494    #[traced_test]
495    async fn message_ping_test() {
496        prepare_freethreaded_python();
497
498        let header_key = "hello-custom-key".to_string();
499        let header_value = "hello-custom-value".to_string();
500
501        let (checker, handler) = create_test_handler();
502
503        // Initialize test server and config
504        let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
505        let config = WebSocketConfig::py_new(
506            format!("ws://127.0.0.1:{}", server.port),
507            Python::with_gil(|py| handler.clone_ref(py)),
508            vec![(header_key, header_value)],
509            Some(1),
510            Some("heartbeat message".to_string()),
511            None,
512            None,
513            None,
514            None,
515            None,
516            None,
517        );
518        let client = WebSocketClient::connect(config, None, None, None, Vec::new(), None)
519            .await
520            .unwrap();
521
522        // Check if ping message has the correct message
523        sleep(Duration::from_secs(2)).await;
524        let check_value: bool = Python::with_gil(|py| {
525            checker
526                .getattr(py, "get_check")
527                .unwrap()
528                .call0(py)
529                .unwrap()
530                .extract(py)
531                .unwrap()
532        });
533        assert!(check_value);
534
535        // Cleanup
536        client.disconnect().await;
537        assert!(client.is_disconnected());
538    }
539}