nautilus_network/python/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    sync::{Arc, atomic::Ordering},
18    time::Duration,
19};
20
21use futures::SinkExt;
22use nautilus_core::python::to_pyvalue_err;
23use pyo3::{create_exception, exceptions::PyException, prelude::*};
24use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
25
26use crate::{
27    mode::ConnectionMode,
28    ratelimiter::quota::Quota,
29    websocket::{WebSocketClient, WebSocketConfig},
30};
31
32// Python exception class for websocket errors
33create_exception!(network, WebSocketClientError, PyException);
34
35fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
36    PyErr::new::<WebSocketClientError, _>(e.to_string())
37}
38
39#[pymethods]
40impl WebSocketConfig {
41    #[new]
42    #[pyo3(signature = (url, handler, headers, heartbeat=None, heartbeat_msg=None, ping_handler=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100))]
43    #[allow(clippy::too_many_arguments)]
44    fn py_new(
45        url: String,
46        handler: PyObject,
47        headers: Vec<(String, String)>,
48        heartbeat: Option<u64>,
49        heartbeat_msg: Option<String>,
50        ping_handler: Option<PyObject>,
51        reconnect_timeout_ms: Option<u64>,
52        reconnect_delay_initial_ms: Option<u64>,
53        reconnect_delay_max_ms: Option<u64>,
54        reconnect_backoff_factor: Option<f64>,
55        reconnect_jitter_ms: Option<u64>,
56    ) -> Self {
57        Self {
58            url,
59            handler: Some(Arc::new(handler)),
60            headers,
61            heartbeat,
62            heartbeat_msg,
63            ping_handler: ping_handler.map(Arc::new),
64            reconnect_timeout_ms,
65            reconnect_delay_initial_ms,
66            reconnect_delay_max_ms,
67            reconnect_backoff_factor,
68            reconnect_jitter_ms,
69        }
70    }
71}
72
73#[pymethods]
74impl WebSocketClient {
75    /// Create a websocket client.
76    ///
77    /// # Safety
78    ///
79    /// - Throws an Exception if it is unable to make websocket connection.
80    #[staticmethod]
81    #[pyo3(name = "connect", signature = (config, post_connection= None, post_reconnection= None, post_disconnection= None, keyed_quotas = Vec::new(), default_quota = None))]
82    fn py_connect(
83        config: WebSocketConfig,
84        post_connection: Option<PyObject>,
85        post_reconnection: Option<PyObject>,
86        post_disconnection: Option<PyObject>,
87        keyed_quotas: Vec<(String, Quota)>,
88        default_quota: Option<Quota>,
89        py: Python<'_>,
90    ) -> PyResult<Bound<PyAny>> {
91        pyo3_async_runtimes::tokio::future_into_py(py, async move {
92            Self::connect(
93                config,
94                post_connection,
95                post_reconnection,
96                post_disconnection,
97                keyed_quotas,
98                default_quota,
99            )
100            .await
101            .map_err(to_websocket_pyerr)
102        })
103    }
104
105    /// Closes the client heart beat and reader task.
106    ///
107    /// The connection is not completely closed the till all references
108    /// to the client are gone and the client is dropped.
109    ///
110    /// # Safety
111    ///
112    /// - The client should not be used after closing it.
113    /// - Any auto-reconnect job should be aborted before closing the client.
114    #[pyo3(name = "disconnect")]
115    fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
116        let connection_mode = slf.connection_mode.clone();
117        let mode = ConnectionMode::from_atomic(&connection_mode);
118        tracing::debug!("Close from mode {mode}");
119
120        pyo3_async_runtimes::tokio::future_into_py(py, async move {
121            match ConnectionMode::from_atomic(&connection_mode) {
122                ConnectionMode::Closed => {
123                    tracing::warn!("WebSocket already closed");
124                }
125                ConnectionMode::Disconnect => {
126                    tracing::warn!("WebSocket already disconnecting");
127                }
128                _ => {
129                    connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
130                    while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
131                        tokio::time::sleep(Duration::from_millis(10)).await;
132                    }
133                }
134            }
135
136            Ok(())
137        })
138    }
139
140    /// Check if the client is still alive.
141    ///
142    /// Even if the connection is disconnected the client will still be alive
143    /// and trying to reconnect.
144    ///
145    /// This is particularly useful for checking why a `send` failed. It could
146    /// be because the connection disconnected and the client is still alive
147    /// and reconnecting. In such cases the send can be retried after some
148    /// delay.
149    #[pyo3(name = "is_active")]
150    fn py_is_active(slf: PyRef<'_, Self>) -> bool {
151        !slf.controller_task.is_finished()
152    }
153
154    #[pyo3(name = "is_reconnecting")]
155    fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
156        slf.is_reconnecting()
157    }
158
159    #[pyo3(name = "is_disconnecting")]
160    fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
161        slf.is_disconnecting()
162    }
163
164    #[pyo3(name = "is_closed")]
165    fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
166        slf.is_closed()
167    }
168
169    /// Send bytes data to the server.
170    ///
171    /// # Errors
172    ///
173    /// - Raises PyRuntimeError if not able to send data.
174    #[pyo3(name = "send")]
175    #[pyo3(signature = (data, keys=None))]
176    fn py_send<'py>(
177        slf: PyRef<'_, Self>,
178        data: Vec<u8>,
179        py: Python<'py>,
180        keys: Option<Vec<String>>,
181    ) -> PyResult<Bound<'py, PyAny>> {
182        let writer = slf.writer.clone();
183        let rate_limiter = slf.rate_limiter.clone();
184
185        pyo3_async_runtimes::tokio::future_into_py(py, async move {
186            rate_limiter.await_keys_ready(keys).await;
187            tracing::trace!("Sending binary: {data:?}");
188
189            let mut guard = writer.lock().await;
190            guard
191                .send(Message::Binary(data.into()))
192                .await
193                .map_err(to_websocket_pyerr)
194        })
195    }
196
197    /// Send UTF-8 encoded bytes as text data to the server, respecting rate limits.
198    ///
199    /// `data`: The byte data to be sent, which will be converted to a UTF-8 string.
200    /// `keys`: Optional list of rate limit keys. If provided, the function will wait for rate limits to be met for each key before sending the data.
201    ///
202    /// # Errors
203    /// - Raises `PyRuntimeError` if unable to send the data.
204    ///
205    /// # Example
206    ///
207    /// When a request is made the URL should be split into all relevant keys within it.
208    ///
209    /// For request /foo/bar, should pass keys ["foo/bar", "foo"] for rate limiting.
210    #[pyo3(name = "send_text")]
211    #[pyo3(signature = (data, keys=None))]
212    fn py_send_text<'py>(
213        slf: PyRef<'_, Self>,
214        data: Vec<u8>,
215        py: Python<'py>,
216        keys: Option<Vec<String>>,
217    ) -> PyResult<Bound<'py, PyAny>> {
218        let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
219        let data = Utf8Bytes::from(data_str);
220        let writer = slf.writer.clone();
221        let rate_limiter = slf.rate_limiter.clone();
222
223        pyo3_async_runtimes::tokio::future_into_py(py, async move {
224            rate_limiter.await_keys_ready(keys).await;
225            tracing::trace!("Sending text: {data}");
226
227            let mut guard = writer.lock().await;
228            guard
229                .send(Message::Text(data))
230                .await
231                .map_err(to_websocket_pyerr)
232        })
233    }
234
235    /// Send pong bytes data to the server.
236    ///
237    /// # Errors
238    ///
239    /// - Raises PyRuntimeError if not able to send data.
240    #[pyo3(name = "send_pong")]
241    fn py_send_pong<'py>(
242        slf: PyRef<'_, Self>,
243        data: Vec<u8>,
244        py: Python<'py>,
245    ) -> PyResult<Bound<'py, PyAny>> {
246        let data_str = String::from_utf8(data.clone()).map_err(to_pyvalue_err)?;
247        let writer = slf.writer.clone();
248        tracing::trace!("Sending pong: {data_str}");
249
250        pyo3_async_runtimes::tokio::future_into_py(py, async move {
251            let mut guard = writer.lock().await;
252            guard
253                .send(Message::Pong(data.into()))
254                .await
255                .map_err(to_websocket_pyerr)
256        })
257    }
258}
259
260////////////////////////////////////////////////////////////////////////////////
261// Tests
262////////////////////////////////////////////////////////////////////////////////
263#[cfg(test)]
264#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
265mod tests {
266    use std::ffi::CString;
267
268    use futures_util::{SinkExt, StreamExt};
269    use nautilus_core::python::IntoPyObjectNautilusExt;
270    use pyo3::{prelude::*, prepare_freethreaded_python};
271    use tokio::{
272        net::TcpListener,
273        task::{self, JoinHandle},
274        time::{Duration, sleep},
275    };
276    use tokio_tungstenite::{
277        accept_hdr_async,
278        tungstenite::{
279            handshake::server::{self, Callback},
280            http::HeaderValue,
281        },
282    };
283    use tracing_test::traced_test;
284
285    use crate::websocket::{WebSocketClient, WebSocketConfig};
286
287    struct TestServer {
288        task: JoinHandle<()>,
289        port: u16,
290    }
291
292    #[derive(Debug, Clone)]
293    struct TestCallback {
294        key: String,
295        value: HeaderValue,
296    }
297
298    impl Callback for TestCallback {
299        fn on_request(
300            self,
301            request: &server::Request,
302            response: server::Response,
303        ) -> Result<server::Response, server::ErrorResponse> {
304            let _ = response;
305            let value = request.headers().get(&self.key);
306            assert!(value.is_some());
307
308            if let Some(value) = request.headers().get(&self.key) {
309                assert_eq!(value, self.value);
310            }
311
312            Ok(response)
313        }
314    }
315
316    impl TestServer {
317        async fn setup(key: String, value: String) -> Self {
318            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
319            let port = TcpListener::local_addr(&server).unwrap().port();
320
321            let test_call_back = TestCallback {
322                key,
323                value: HeaderValue::from_str(&value).unwrap(),
324            };
325
326            // Set up test server
327            let task = task::spawn(async move {
328                // Keep accepting connections
329                loop {
330                    let (conn, _) = server.accept().await.unwrap();
331                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
332                        .await
333                        .unwrap();
334
335                    task::spawn(async move {
336                        while let Some(Ok(msg)) = websocket.next().await {
337                            match msg {
338                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
339                                    if txt == "close-now" =>
340                                {
341                                    tracing::debug!("Forcibly closing from server side");
342                                    // This sends a close frame, then stops reading
343                                    let _ = websocket.close(None).await;
344                                    break;
345                                }
346                                // Echo text/binary frames
347                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
348                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
349                                    if websocket.send(msg).await.is_err() {
350                                        break;
351                                    }
352                                }
353                                // If the client closes, we also break
354                                tokio_tungstenite::tungstenite::protocol::Message::Close(
355                                    _frame,
356                                ) => {
357                                    let _ = websocket.close(None).await;
358                                    break;
359                                }
360                                // Ignore pings/pongs
361                                _ => {}
362                            }
363                        }
364                    });
365                }
366            });
367
368            Self { task, port }
369        }
370    }
371
372    impl Drop for TestServer {
373        fn drop(&mut self) {
374            self.task.abort();
375        }
376    }
377
378    fn create_test_handler() -> (PyObject, PyObject) {
379        let code_raw = r"
380class Counter:
381    def __init__(self):
382        self.count = 0
383        self.check = False
384
385    def handler(self, bytes):
386        msg = bytes.decode()
387        if msg == 'ping':
388            self.count += 1
389        elif msg == 'heartbeat message':
390            self.check = True
391
392    def get_check(self):
393        return self.check
394
395    def get_count(self):
396        return self.count
397
398counter = Counter()
399";
400
401        let code = CString::new(code_raw).unwrap();
402        let filename = CString::new("test".to_string()).unwrap();
403        let module = CString::new("test".to_string()).unwrap();
404        Python::with_gil(|py| {
405            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
406
407            let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
408            let handler = counter
409                .getattr(py, "handler")
410                .unwrap()
411                .into_py_any_unwrap(py);
412
413            (counter, handler)
414        })
415    }
416
417    #[tokio::test]
418    #[traced_test]
419    async fn basic_client_test() {
420        prepare_freethreaded_python();
421
422        const N: usize = 10;
423        let mut success_count = 0;
424        let header_key = "hello-custom-key".to_string();
425        let header_value = "hello-custom-value".to_string();
426
427        let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
428        let (counter, handler) = create_test_handler();
429
430        let config = WebSocketConfig::py_new(
431            format!("ws://127.0.0.1:{}", server.port),
432            Python::with_gil(|py| handler.clone_ref(py)),
433            vec![(header_key, header_value)],
434            None,
435            None,
436            None,
437            None,
438            None,
439            None,
440            None,
441            None,
442        );
443        let client = WebSocketClient::connect(config, None, None, None, Vec::new(), None)
444            .await
445            .unwrap();
446
447        // Send messages that increment the count
448        for _ in 0..N {
449            client.send_bytes(b"ping".to_vec(), None).await;
450            success_count += 1;
451        }
452
453        // Check count is same as number messages sent
454        sleep(Duration::from_secs(1)).await;
455        let count_value: usize = Python::with_gil(|py| {
456            counter
457                .getattr(py, "get_count")
458                .unwrap()
459                .call0(py)
460                .unwrap()
461                .extract(py)
462                .unwrap()
463        });
464        assert_eq!(count_value, success_count);
465
466        // Close the connection => client should reconnect automatically
467        client.send_close_message().await;
468
469        // Send messages that increment the count
470        sleep(Duration::from_secs(2)).await;
471        for _ in 0..N {
472            client.send_bytes(b"ping".to_vec(), None).await;
473            success_count += 1;
474        }
475
476        // Check count is same as number messages sent
477        sleep(Duration::from_secs(1)).await;
478        let count_value: usize = Python::with_gil(|py| {
479            counter
480                .getattr(py, "get_count")
481                .unwrap()
482                .call0(py)
483                .unwrap()
484                .extract(py)
485                .unwrap()
486        });
487        assert_eq!(count_value, success_count);
488        assert_eq!(success_count, N + N);
489
490        // Cleanup
491        client.disconnect().await;
492        assert!(client.is_disconnected());
493    }
494
495    #[tokio::test]
496    #[traced_test]
497    async fn message_ping_test() {
498        prepare_freethreaded_python();
499
500        let header_key = "hello-custom-key".to_string();
501        let header_value = "hello-custom-value".to_string();
502
503        let (checker, handler) = create_test_handler();
504
505        // Initialize test server and config
506        let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
507        let config = WebSocketConfig::py_new(
508            format!("ws://127.0.0.1:{}", server.port),
509            Python::with_gil(|py| handler.clone_ref(py)),
510            vec![(header_key, header_value)],
511            Some(1),
512            Some("heartbeat message".to_string()),
513            None,
514            None,
515            None,
516            None,
517            None,
518            None,
519        );
520        let client = WebSocketClient::connect(config, None, None, None, Vec::new(), None)
521            .await
522            .unwrap();
523
524        // Check if ping message has the correct message
525        sleep(Duration::from_secs(2)).await;
526        let check_value: bool = Python::with_gil(|py| {
527            checker
528                .getattr(py, "get_check")
529                .unwrap()
530                .call0(py)
531                .unwrap()
532                .extract(py)
533                .unwrap()
534        });
535        assert!(check_value);
536
537        // Cleanup
538        client.disconnect().await;
539        assert!(client.is_disconnected());
540    }
541}