nautilus_network/python/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    sync::{Arc, atomic::Ordering},
18    time::Duration,
19};
20
21use nautilus_core::python::to_pyvalue_err;
22use pyo3::{create_exception, exceptions::PyException, prelude::*};
23use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
24
25use crate::{
26    mode::ConnectionMode,
27    ratelimiter::quota::Quota,
28    websocket::{Consumer, WebSocketClient, WebSocketConfig, WriterCommand},
29};
30
31// Python exception class for websocket errors
32create_exception!(network, WebSocketClientError, PyException);
33
34fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
35    PyErr::new::<WebSocketClientError, _>(e.to_string())
36}
37
38#[pymethods]
39impl WebSocketConfig {
40    #[new]
41    #[allow(clippy::too_many_arguments)]
42    #[pyo3(signature = (url, handler, headers, heartbeat=None, heartbeat_msg=None, ping_handler=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100))]
43    fn py_new(
44        url: String,
45        handler: PyObject,
46        headers: Vec<(String, String)>,
47        heartbeat: Option<u64>,
48        heartbeat_msg: Option<String>,
49        ping_handler: Option<PyObject>,
50        reconnect_timeout_ms: Option<u64>,
51        reconnect_delay_initial_ms: Option<u64>,
52        reconnect_delay_max_ms: Option<u64>,
53        reconnect_backoff_factor: Option<f64>,
54        reconnect_jitter_ms: Option<u64>,
55    ) -> Self {
56        Self {
57            url,
58            handler: Consumer::Python(Some(Arc::new(handler))),
59            headers,
60            heartbeat,
61            heartbeat_msg,
62            ping_handler: ping_handler.map(Arc::new),
63            reconnect_timeout_ms,
64            reconnect_delay_initial_ms,
65            reconnect_delay_max_ms,
66            reconnect_backoff_factor,
67            reconnect_jitter_ms,
68        }
69    }
70}
71
72#[pymethods]
73impl WebSocketClient {
74    /// Create a websocket client.
75    ///
76    /// # Safety
77    ///
78    /// - Throws an Exception if it is unable to make websocket connection.
79    #[staticmethod]
80    #[pyo3(name = "connect", signature = (config, post_connection= None, post_reconnection= None, post_disconnection= None, keyed_quotas = Vec::new(), default_quota = None))]
81    fn py_connect(
82        config: WebSocketConfig,
83        post_connection: Option<PyObject>,
84        post_reconnection: Option<PyObject>,
85        post_disconnection: Option<PyObject>,
86        keyed_quotas: Vec<(String, Quota)>,
87        default_quota: Option<Quota>,
88        py: Python<'_>,
89    ) -> PyResult<Bound<PyAny>> {
90        pyo3_async_runtimes::tokio::future_into_py(py, async move {
91            Self::connect(
92                config,
93                post_connection,
94                post_reconnection,
95                post_disconnection,
96                keyed_quotas,
97                default_quota,
98            )
99            .await
100            .map_err(to_websocket_pyerr)
101        })
102    }
103
104    /// Closes the client heart beat and reader task.
105    ///
106    /// The connection is not completely closed the till all references
107    /// to the client are gone and the client is dropped.
108    ///
109    /// # Safety
110    ///
111    /// - The client should not be used after closing it.
112    /// - Any auto-reconnect job should be aborted before closing the client.
113    #[pyo3(name = "disconnect")]
114    fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
115        let connection_mode = slf.connection_mode.clone();
116        let mode = ConnectionMode::from_atomic(&connection_mode);
117        tracing::debug!("Close from mode {mode}");
118
119        pyo3_async_runtimes::tokio::future_into_py(py, async move {
120            match ConnectionMode::from_atomic(&connection_mode) {
121                ConnectionMode::Closed => {
122                    tracing::warn!("WebSocket already closed");
123                }
124                ConnectionMode::Disconnect => {
125                    tracing::warn!("WebSocket already disconnecting");
126                }
127                _ => {
128                    connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
129                    while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
130                        tokio::time::sleep(Duration::from_millis(10)).await;
131                    }
132                }
133            }
134
135            Ok(())
136        })
137    }
138
139    /// Check if the client is still alive.
140    ///
141    /// Even if the connection is disconnected the client will still be alive
142    /// and trying to reconnect.
143    ///
144    /// This is particularly useful for checking why a `send` failed. It could
145    /// be because the connection disconnected and the client is still alive
146    /// and reconnecting. In such cases the send can be retried after some
147    /// delay.
148    #[pyo3(name = "is_active")]
149    fn py_is_active(slf: PyRef<'_, Self>) -> bool {
150        !slf.controller_task.is_finished()
151    }
152
153    #[pyo3(name = "is_reconnecting")]
154    fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
155        slf.is_reconnecting()
156    }
157
158    #[pyo3(name = "is_disconnecting")]
159    fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
160        slf.is_disconnecting()
161    }
162
163    #[pyo3(name = "is_closed")]
164    fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
165        slf.is_closed()
166    }
167
168    /// Send bytes data to the server.
169    ///
170    /// # Errors
171    ///
172    /// - Raises PyRuntimeError if not able to send data.
173    #[pyo3(name = "send")]
174    #[pyo3(signature = (data, keys=None))]
175    fn py_send<'py>(
176        slf: PyRef<'_, Self>,
177        data: Vec<u8>,
178        py: Python<'py>,
179        keys: Option<Vec<String>>,
180    ) -> PyResult<Bound<'py, PyAny>> {
181        let rate_limiter = slf.rate_limiter.clone();
182        let writer_tx = slf.writer_tx.clone();
183
184        pyo3_async_runtimes::tokio::future_into_py(py, async move {
185            rate_limiter.await_keys_ready(keys).await;
186            tracing::trace!("Sending binary: {data:?}");
187
188            let msg = Message::Binary(data.into());
189            if let Err(e) = writer_tx.send(WriterCommand::Send(msg)) {
190                tracing::error!("{e}");
191            }
192            Ok(())
193        })
194    }
195
196    /// Send UTF-8 encoded bytes as text data to the server, respecting rate limits.
197    ///
198    /// `data`: The byte data to be sent, which will be converted to a UTF-8 string.
199    /// `keys`: Optional list of rate limit keys. If provided, the function will wait for rate limits to be met for each key before sending the data.
200    ///
201    /// # Errors
202    /// - Raises `PyRuntimeError` if unable to send the data.
203    ///
204    /// # Example
205    ///
206    /// When a request is made the URL should be split into all relevant keys within it.
207    ///
208    /// For request /foo/bar, should pass keys ["foo/bar", "foo"] for rate limiting.
209    #[pyo3(name = "send_text")]
210    #[pyo3(signature = (data, keys=None))]
211    fn py_send_text<'py>(
212        slf: PyRef<'_, Self>,
213        data: Vec<u8>,
214        py: Python<'py>,
215        keys: Option<Vec<String>>,
216    ) -> PyResult<Bound<'py, PyAny>> {
217        let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
218        let data = Utf8Bytes::from(data_str);
219        let rate_limiter = slf.rate_limiter.clone();
220        let writer_tx = slf.writer_tx.clone();
221
222        pyo3_async_runtimes::tokio::future_into_py(py, async move {
223            rate_limiter.await_keys_ready(keys).await;
224            tracing::trace!("Sending text: {data}");
225
226            let msg = Message::Text(data);
227            if let Err(e) = writer_tx.send(WriterCommand::Send(msg)) {
228                tracing::error!("{e}");
229            }
230            Ok(())
231        })
232    }
233
234    /// Send pong bytes data to the server.
235    ///
236    /// # Errors
237    ///
238    /// - Raises PyRuntimeError if not able to send data.
239    #[pyo3(name = "send_pong")]
240    fn py_send_pong<'py>(
241        slf: PyRef<'_, Self>,
242        data: Vec<u8>,
243        py: Python<'py>,
244    ) -> PyResult<Bound<'py, PyAny>> {
245        let data_str = String::from_utf8(data.clone()).map_err(to_pyvalue_err)?;
246        let writer_tx = slf.writer_tx.clone();
247        tracing::trace!("Sending pong: {data_str}");
248
249        pyo3_async_runtimes::tokio::future_into_py(py, async move {
250            let msg = Message::Pong(data.into());
251            if let Err(e) = writer_tx.send(WriterCommand::Send(msg)) {
252                tracing::error!("{e}");
253            }
254            Ok(())
255        })
256    }
257}
258
259////////////////////////////////////////////////////////////////////////////////
260// Tests
261////////////////////////////////////////////////////////////////////////////////
262#[cfg(test)]
263#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
264mod tests {
265    use std::ffi::CString;
266
267    use futures_util::{SinkExt, StreamExt};
268    use nautilus_core::python::IntoPyObjectNautilusExt;
269    use pyo3::{prelude::*, prepare_freethreaded_python};
270    use tokio::{
271        net::TcpListener,
272        task::{self, JoinHandle},
273        time::{Duration, sleep},
274    };
275    use tokio_tungstenite::{
276        accept_hdr_async,
277        tungstenite::{
278            handshake::server::{self, Callback},
279            http::HeaderValue,
280        },
281    };
282    use tracing_test::traced_test;
283
284    use crate::websocket::{WebSocketClient, WebSocketConfig};
285
286    struct TestServer {
287        task: JoinHandle<()>,
288        port: u16,
289    }
290
291    #[derive(Debug, Clone)]
292    struct TestCallback {
293        key: String,
294        value: HeaderValue,
295    }
296
297    impl Callback for TestCallback {
298        fn on_request(
299            self,
300            request: &server::Request,
301            response: server::Response,
302        ) -> Result<server::Response, server::ErrorResponse> {
303            let _ = response;
304            let value = request.headers().get(&self.key);
305            assert!(value.is_some());
306
307            if let Some(value) = request.headers().get(&self.key) {
308                assert_eq!(value, self.value);
309            }
310
311            Ok(response)
312        }
313    }
314
315    impl TestServer {
316        async fn setup(key: String, value: String) -> Self {
317            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
318            let port = TcpListener::local_addr(&server).unwrap().port();
319
320            let test_call_back = TestCallback {
321                key,
322                value: HeaderValue::from_str(&value).unwrap(),
323            };
324
325            // Set up test server
326            let task = task::spawn(async move {
327                // Keep accepting connections
328                loop {
329                    let (conn, _) = server.accept().await.unwrap();
330                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
331                        .await
332                        .unwrap();
333
334                    task::spawn(async move {
335                        while let Some(Ok(msg)) = websocket.next().await {
336                            match msg {
337                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
338                                    if txt == "close-now" =>
339                                {
340                                    tracing::debug!("Forcibly closing from server side");
341                                    // This sends a close frame, then stops reading
342                                    let _ = websocket.close(None).await;
343                                    break;
344                                }
345                                // Echo text/binary frames
346                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
347                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
348                                    if websocket.send(msg).await.is_err() {
349                                        break;
350                                    }
351                                }
352                                // If the client closes, we also break
353                                tokio_tungstenite::tungstenite::protocol::Message::Close(
354                                    _frame,
355                                ) => {
356                                    let _ = websocket.close(None).await;
357                                    break;
358                                }
359                                // Ignore pings/pongs
360                                _ => {}
361                            }
362                        }
363                    });
364                }
365            });
366
367            Self { task, port }
368        }
369    }
370
371    impl Drop for TestServer {
372        fn drop(&mut self) {
373            self.task.abort();
374        }
375    }
376
377    fn create_test_handler() -> (PyObject, PyObject) {
378        let code_raw = r"
379class Counter:
380    def __init__(self):
381        self.count = 0
382        self.check = False
383
384    def handler(self, bytes):
385        msg = bytes.decode()
386        if msg == 'ping':
387            self.count += 1
388        elif msg == 'heartbeat message':
389            self.check = True
390
391    def get_check(self):
392        return self.check
393
394    def get_count(self):
395        return self.count
396
397counter = Counter()
398";
399
400        let code = CString::new(code_raw).unwrap();
401        let filename = CString::new("test".to_string()).unwrap();
402        let module = CString::new("test".to_string()).unwrap();
403        Python::with_gil(|py| {
404            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
405
406            let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
407            let handler = counter
408                .getattr(py, "handler")
409                .unwrap()
410                .into_py_any_unwrap(py);
411
412            (counter, handler)
413        })
414    }
415
416    #[tokio::test]
417    #[traced_test]
418    async fn basic_client_test() {
419        prepare_freethreaded_python();
420
421        const N: usize = 10;
422        let mut success_count = 0;
423        let header_key = "hello-custom-key".to_string();
424        let header_value = "hello-custom-value".to_string();
425
426        let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
427        let (counter, handler) = create_test_handler();
428
429        let config = WebSocketConfig::py_new(
430            format!("ws://127.0.0.1:{}", server.port),
431            Python::with_gil(|py| handler.clone_ref(py)),
432            vec![(header_key, header_value)],
433            None,
434            None,
435            None,
436            None,
437            None,
438            None,
439            None,
440            None,
441        );
442        let client = WebSocketClient::connect(config, None, None, None, Vec::new(), None)
443            .await
444            .unwrap();
445
446        // Send messages that increment the count
447        for _ in 0..N {
448            client.send_bytes(b"ping".to_vec(), None).await;
449            success_count += 1;
450        }
451
452        // Check count is same as number messages sent
453        sleep(Duration::from_secs(1)).await;
454        let count_value: usize = Python::with_gil(|py| {
455            counter
456                .getattr(py, "get_count")
457                .unwrap()
458                .call0(py)
459                .unwrap()
460                .extract(py)
461                .unwrap()
462        });
463        assert_eq!(count_value, success_count);
464
465        // Close the connection => client should reconnect automatically
466        client.send_close_message().await;
467
468        // Send messages that increment the count
469        sleep(Duration::from_secs(2)).await;
470        for _ in 0..N {
471            client.send_bytes(b"ping".to_vec(), None).await;
472            success_count += 1;
473        }
474
475        // Check count is same as number messages sent
476        sleep(Duration::from_secs(1)).await;
477        let count_value: usize = Python::with_gil(|py| {
478            counter
479                .getattr(py, "get_count")
480                .unwrap()
481                .call0(py)
482                .unwrap()
483                .extract(py)
484                .unwrap()
485        });
486        assert_eq!(count_value, success_count);
487        assert_eq!(success_count, N + N);
488
489        // Cleanup
490        client.disconnect().await;
491        assert!(client.is_disconnected());
492    }
493
494    #[tokio::test]
495    #[traced_test]
496    async fn message_ping_test() {
497        prepare_freethreaded_python();
498
499        let header_key = "hello-custom-key".to_string();
500        let header_value = "hello-custom-value".to_string();
501
502        let (checker, handler) = create_test_handler();
503
504        // Initialize test server and config
505        let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
506        let config = WebSocketConfig::py_new(
507            format!("ws://127.0.0.1:{}", server.port),
508            Python::with_gil(|py| handler.clone_ref(py)),
509            vec![(header_key, header_value)],
510            Some(1),
511            Some("heartbeat message".to_string()),
512            None,
513            None,
514            None,
515            None,
516            None,
517            None,
518        );
519        let client = WebSocketClient::connect(config, None, None, None, Vec::new(), None)
520            .await
521            .unwrap();
522
523        // Check if ping message has the correct message
524        sleep(Duration::from_secs(2)).await;
525        let check_value: bool = Python::with_gil(|py| {
526            checker
527                .getattr(py, "get_check")
528                .unwrap()
529                .call0(py)
530                .unwrap()
531                .extract(py)
532                .unwrap()
533        });
534        assert!(check_value);
535
536        // Cleanup
537        client.disconnect().await;
538        assert!(client.is_disconnected());
539    }
540}