nautilus_network/python/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    sync::{Arc, atomic::Ordering},
18    time::Duration,
19};
20
21use nautilus_core::python::{clone_py_object, to_pyruntime_err, to_pyvalue_err};
22use pyo3::{create_exception, exceptions::PyException, prelude::*, types::PyBytes};
23use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
24
25use crate::{
26    RECONNECTED,
27    mode::ConnectionMode,
28    ratelimiter::quota::Quota,
29    websocket::{MessageHandler, PingHandler, WebSocketClient, WebSocketConfig, WriterCommand},
30};
31
32// Python exception class for websocket errors
33create_exception!(network, WebSocketClientError, PyException);
34
35fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
36    PyErr::new::<WebSocketClientError, _>(e.to_string())
37}
38
39#[pymethods]
40impl WebSocketConfig {
41    #[new]
42    #[allow(clippy::too_many_arguments)]
43    #[pyo3(signature = (url, handler, headers, heartbeat=None, heartbeat_msg=None, ping_handler=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100))]
44    fn py_new(
45        url: String,
46        handler: PyObject,
47        headers: Vec<(String, String)>,
48        heartbeat: Option<u64>,
49        heartbeat_msg: Option<String>,
50        ping_handler: Option<PyObject>,
51        reconnect_timeout_ms: Option<u64>,
52        reconnect_delay_initial_ms: Option<u64>,
53        reconnect_delay_max_ms: Option<u64>,
54        reconnect_backoff_factor: Option<f64>,
55        reconnect_jitter_ms: Option<u64>,
56    ) -> Self {
57        // Create function pointer that calls Python handler
58        let handler_clone = clone_py_object(&handler);
59        let message_handler: MessageHandler = Arc::new(move |msg: Message| {
60            Python::with_gil(|py| {
61                let data = match msg {
62                    Message::Binary(data) => data.to_vec(),
63                    Message::Text(text) => {
64                        // Disregard the RECONNECTED sentinel message used for Rust flows
65                        if text == RECONNECTED {
66                            return;
67                        }
68                        text.as_bytes().to_vec()
69                    }
70                    _ => return, // Skip other message types
71                };
72                if let Err(e) = handler_clone.call1(py, (PyBytes::new(py, &data),)) {
73                    tracing::error!("Error calling Python message handler: {e}");
74                }
75            });
76        });
77
78        // Create function pointer for ping handler if provided
79        let ping_handler_fn = ping_handler.map(|ping_handler| {
80            let ping_handler_clone = clone_py_object(&ping_handler);
81            let ping_handler_fn: PingHandler = std::sync::Arc::new(move |data: Vec<u8>| {
82                Python::with_gil(|py| {
83                    if let Err(e) = ping_handler_clone.call1(py, (PyBytes::new(py, &data),)) {
84                        tracing::error!("Error calling Python ping handler: {e}");
85                    }
86                });
87            });
88            ping_handler_fn
89        });
90
91        Self {
92            url,
93            message_handler: Some(message_handler),
94            headers,
95            heartbeat,
96            heartbeat_msg,
97            ping_handler: ping_handler_fn,
98            reconnect_timeout_ms,
99            reconnect_delay_initial_ms,
100            reconnect_delay_max_ms,
101            reconnect_backoff_factor,
102            reconnect_jitter_ms,
103        }
104    }
105}
106
107#[pymethods]
108impl WebSocketClient {
109    /// Create a websocket client.
110    ///
111    /// # Safety
112    ///
113    /// - Throws an Exception if it is unable to make websocket connection.
114    #[staticmethod]
115    #[pyo3(name = "connect", signature = (config, post_reconnection= None, keyed_quotas = Vec::new(), default_quota = None))]
116    fn py_connect(
117        config: WebSocketConfig,
118        post_reconnection: Option<PyObject>,
119        keyed_quotas: Vec<(String, Quota)>,
120        default_quota: Option<Quota>,
121        py: Python<'_>,
122    ) -> PyResult<Bound<'_, PyAny>> {
123        // Convert Python callback to function pointer
124        let post_reconnection_fn = post_reconnection.map(|callback| {
125            let callback_clone = clone_py_object(&callback);
126            Arc::new(move || {
127                Python::with_gil(|py| {
128                    if let Err(e) = callback_clone.call0(py) {
129                        tracing::error!("Error calling post_reconnection handler: {e}");
130                    }
131                });
132            }) as std::sync::Arc<dyn Fn() + Send + Sync>
133        });
134
135        pyo3_async_runtimes::tokio::future_into_py(py, async move {
136            Self::connect(config, post_reconnection_fn, keyed_quotas, default_quota)
137                .await
138                .map_err(to_websocket_pyerr)
139        })
140    }
141
142    /// Closes the client heart beat and reader task.
143    ///
144    /// The connection is not completely closed the till all references
145    /// to the client are gone and the client is dropped.
146    ///
147    /// # Safety
148    ///
149    /// - The client should not be used after closing it.
150    /// - Any auto-reconnect job should be aborted before closing the client.
151    #[pyo3(name = "disconnect")]
152    fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
153        let connection_mode = slf.connection_mode.clone();
154        let mode = ConnectionMode::from_atomic(&connection_mode);
155        tracing::debug!("Close from mode {mode}");
156
157        pyo3_async_runtimes::tokio::future_into_py(py, async move {
158            match ConnectionMode::from_atomic(&connection_mode) {
159                ConnectionMode::Closed => {
160                    tracing::debug!("WebSocket already closed");
161                }
162                ConnectionMode::Disconnect => {
163                    tracing::debug!("WebSocket already disconnecting");
164                }
165                _ => {
166                    connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
167                    while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
168                        tokio::time::sleep(Duration::from_millis(10)).await;
169                    }
170                }
171            }
172
173            Ok(())
174        })
175    }
176
177    /// Check if the client is still alive.
178    ///
179    /// Even if the connection is disconnected the client will still be alive
180    /// and trying to reconnect.
181    ///
182    /// This is particularly useful for checking why a `send` failed. It could
183    /// be because the connection disconnected and the client is still alive
184    /// and reconnecting. In such cases the send can be retried after some
185    /// delay.
186    #[pyo3(name = "is_active")]
187    fn py_is_active(slf: PyRef<'_, Self>) -> bool {
188        !slf.controller_task.is_finished()
189    }
190
191    #[pyo3(name = "is_reconnecting")]
192    fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
193        slf.is_reconnecting()
194    }
195
196    #[pyo3(name = "is_disconnecting")]
197    fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
198        slf.is_disconnecting()
199    }
200
201    #[pyo3(name = "is_closed")]
202    fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
203        slf.is_closed()
204    }
205
206    /// Send bytes data to the server.
207    ///
208    /// # Errors
209    ///
210    /// - Raises `PyRuntimeError` if not able to send data.
211    #[pyo3(name = "send")]
212    #[pyo3(signature = (data, keys=None))]
213    fn py_send<'py>(
214        slf: PyRef<'_, Self>,
215        data: Vec<u8>,
216        py: Python<'py>,
217        keys: Option<Vec<String>>,
218    ) -> PyResult<Bound<'py, PyAny>> {
219        let rate_limiter = slf.rate_limiter.clone();
220        let writer_tx = slf.writer_tx.clone();
221        let mode = slf.connection_mode.clone();
222
223        pyo3_async_runtimes::tokio::future_into_py(py, async move {
224            if !ConnectionMode::from_atomic(&mode).is_active() {
225                let msg = "Cannot send data: connection not active".to_string();
226                tracing::error!("{msg}");
227                return Err(to_pyruntime_err(std::io::Error::new(
228                    std::io::ErrorKind::NotConnected,
229                    msg,
230                )));
231            }
232            rate_limiter.await_keys_ready(keys).await;
233            tracing::trace!("Sending binary: {data:?}");
234
235            let msg = Message::Binary(data.into());
236            writer_tx
237                .send(WriterCommand::Send(msg))
238                .map_err(to_pyruntime_err)
239        })
240    }
241
242    /// Send UTF-8 encoded bytes as text data to the server, respecting rate limits.
243    ///
244    /// `data`: The byte data to be sent, which will be converted to a UTF-8 string.
245    /// `keys`: Optional list of rate limit keys. If provided, the function will wait for rate limits to be met for each key before sending the data.
246    ///
247    /// # Errors
248    /// - Raises `PyRuntimeError` if unable to send the data.
249    ///
250    /// # Example
251    ///
252    /// When a request is made the URL should be split into all relevant keys within it.
253    ///
254    /// For request /foo/bar, should pass keys ["foo/bar", "foo"] for rate limiting.
255    #[pyo3(name = "send_text")]
256    #[pyo3(signature = (data, keys=None))]
257    fn py_send_text<'py>(
258        slf: PyRef<'_, Self>,
259        data: Vec<u8>,
260        py: Python<'py>,
261        keys: Option<Vec<String>>,
262    ) -> PyResult<Bound<'py, PyAny>> {
263        let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
264        let data = Utf8Bytes::from(data_str);
265        let rate_limiter = slf.rate_limiter.clone();
266        let writer_tx = slf.writer_tx.clone();
267        let mode = slf.connection_mode.clone();
268
269        pyo3_async_runtimes::tokio::future_into_py(py, async move {
270            if !ConnectionMode::from_atomic(&mode).is_active() {
271                let err = std::io::Error::new(
272                    std::io::ErrorKind::NotConnected,
273                    "Cannot send text: connection not active",
274                );
275                return Err(to_pyruntime_err(err));
276            }
277            rate_limiter.await_keys_ready(keys).await;
278            tracing::trace!("Sending text: {data}");
279
280            let msg = Message::Text(data);
281            writer_tx
282                .send(WriterCommand::Send(msg))
283                .map_err(to_pyruntime_err)
284        })
285    }
286
287    /// Send pong bytes data to the server.
288    ///
289    /// # Errors
290    ///
291    /// - Raises `PyRuntimeError` if not able to send data.
292    #[pyo3(name = "send_pong")]
293    fn py_send_pong<'py>(
294        slf: PyRef<'_, Self>,
295        data: Vec<u8>,
296        py: Python<'py>,
297    ) -> PyResult<Bound<'py, PyAny>> {
298        let data_str = String::from_utf8(data.clone()).map_err(to_pyvalue_err)?;
299        let writer_tx = slf.writer_tx.clone();
300        let mode = slf.connection_mode.clone();
301
302        pyo3_async_runtimes::tokio::future_into_py(py, async move {
303            if !ConnectionMode::from_atomic(&mode).is_active() {
304                let err = std::io::Error::new(
305                    std::io::ErrorKind::NotConnected,
306                    "Cannot send pong: connection not active",
307                );
308                return Err(to_pyruntime_err(err));
309            }
310            tracing::trace!("Sending pong: {data_str}");
311
312            let msg = Message::Pong(data.into());
313            writer_tx
314                .send(WriterCommand::Send(msg))
315                .map_err(to_pyruntime_err)
316        })
317    }
318}
319
320////////////////////////////////////////////////////////////////////////////////
321// Tests
322////////////////////////////////////////////////////////////////////////////////
323#[cfg(test)]
324#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
325mod tests {
326    use std::ffi::CString;
327
328    use futures_util::{SinkExt, StreamExt};
329    use nautilus_core::python::IntoPyObjectNautilusExt;
330    use pyo3::{prelude::*, prepare_freethreaded_python};
331    use tokio::{
332        net::TcpListener,
333        task::{self, JoinHandle},
334        time::{Duration, sleep},
335    };
336    use tokio_tungstenite::{
337        accept_hdr_async,
338        tungstenite::{
339            handshake::server::{self, Callback},
340            http::HeaderValue,
341        },
342    };
343    use tracing_test::traced_test;
344
345    use crate::websocket::{WebSocketClient, WebSocketConfig};
346
347    struct TestServer {
348        task: JoinHandle<()>,
349        port: u16,
350    }
351
352    #[derive(Debug, Clone)]
353    struct TestCallback {
354        key: String,
355        value: HeaderValue,
356    }
357
358    impl Callback for TestCallback {
359        fn on_request(
360            self,
361            request: &server::Request,
362            response: server::Response,
363        ) -> Result<server::Response, server::ErrorResponse> {
364            let _ = response;
365            let value = request.headers().get(&self.key);
366            assert!(value.is_some());
367
368            if let Some(value) = request.headers().get(&self.key) {
369                assert_eq!(value, self.value);
370            }
371
372            Ok(response)
373        }
374    }
375
376    impl TestServer {
377        async fn setup(key: String, value: String) -> Self {
378            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
379            let port = TcpListener::local_addr(&server).unwrap().port();
380
381            let test_call_back = TestCallback {
382                key,
383                value: HeaderValue::from_str(&value).unwrap(),
384            };
385
386            // Set up test server
387            let task = task::spawn(async move {
388                // Keep accepting connections
389                loop {
390                    let (conn, _) = server.accept().await.unwrap();
391                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
392                        .await
393                        .unwrap();
394
395                    task::spawn(async move {
396                        while let Some(Ok(msg)) = websocket.next().await {
397                            match msg {
398                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
399                                    if txt == "close-now" =>
400                                {
401                                    tracing::debug!("Forcibly closing from server side");
402                                    // This sends a close frame, then stops reading
403                                    let _ = websocket.close(None).await;
404                                    break;
405                                }
406                                // Echo text/binary frames
407                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
408                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
409                                    if websocket.send(msg).await.is_err() {
410                                        break;
411                                    }
412                                }
413                                // If the client closes, we also break
414                                tokio_tungstenite::tungstenite::protocol::Message::Close(
415                                    _frame,
416                                ) => {
417                                    let _ = websocket.close(None).await;
418                                    break;
419                                }
420                                // Ignore pings/pongs
421                                _ => {}
422                            }
423                        }
424                    });
425                }
426            });
427
428            Self { task, port }
429        }
430    }
431
432    impl Drop for TestServer {
433        fn drop(&mut self) {
434            self.task.abort();
435        }
436    }
437
438    fn create_test_handler() -> (PyObject, PyObject) {
439        let code_raw = r"
440class Counter:
441    def __init__(self):
442        self.count = 0
443        self.check = False
444
445    def handler(self, bytes):
446        msg = bytes.decode()
447        if msg == 'ping':
448            self.count += 1
449        elif msg == 'heartbeat message':
450            self.check = True
451
452    def get_check(self):
453        return self.check
454
455    def get_count(self):
456        return self.count
457
458counter = Counter()
459";
460
461        let code = CString::new(code_raw).unwrap();
462        let filename = CString::new("test".to_string()).unwrap();
463        let module = CString::new("test".to_string()).unwrap();
464        Python::with_gil(|py| {
465            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
466
467            let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
468            let handler = counter
469                .getattr(py, "handler")
470                .unwrap()
471                .into_py_any_unwrap(py);
472
473            (counter, handler)
474        })
475    }
476
477    #[tokio::test]
478    #[traced_test]
479    async fn basic_client_test() {
480        prepare_freethreaded_python();
481
482        const N: usize = 10;
483        let mut success_count = 0;
484        let header_key = "hello-custom-key".to_string();
485        let header_value = "hello-custom-value".to_string();
486
487        let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
488        let (counter, handler) = create_test_handler();
489
490        let config = WebSocketConfig::py_new(
491            format!("ws://127.0.0.1:{}", server.port),
492            Python::with_gil(|py| handler.clone_ref(py)),
493            vec![(header_key, header_value)],
494            None,
495            None,
496            None,
497            None,
498            None,
499            None,
500            None,
501            None,
502        );
503        let client = WebSocketClient::connect(config, None, Vec::new(), None)
504            .await
505            .unwrap();
506
507        // Send messages that increment the count
508        for _ in 0..N {
509            client.send_bytes(b"ping".to_vec(), None).await.unwrap();
510            success_count += 1;
511        }
512
513        // Check count is same as number messages sent
514        sleep(Duration::from_secs(1)).await;
515        let count_value: usize = Python::with_gil(|py| {
516            counter
517                .getattr(py, "get_count")
518                .unwrap()
519                .call0(py)
520                .unwrap()
521                .extract(py)
522                .unwrap()
523        });
524        assert_eq!(count_value, success_count);
525
526        // Close the connection => client should reconnect automatically
527        client.send_close_message().await.unwrap();
528
529        // Send messages that increment the count
530        sleep(Duration::from_secs(2)).await;
531        for _ in 0..N {
532            client.send_bytes(b"ping".to_vec(), None).await.unwrap();
533            success_count += 1;
534        }
535
536        // Check count is same as number messages sent
537        sleep(Duration::from_secs(1)).await;
538        let count_value: usize = Python::with_gil(|py| {
539            counter
540                .getattr(py, "get_count")
541                .unwrap()
542                .call0(py)
543                .unwrap()
544                .extract(py)
545                .unwrap()
546        });
547        assert_eq!(count_value, success_count);
548        assert_eq!(success_count, N + N);
549
550        // Cleanup
551        client.disconnect().await;
552        assert!(client.is_disconnected());
553    }
554
555    #[tokio::test]
556    #[traced_test]
557    async fn message_ping_test() {
558        prepare_freethreaded_python();
559
560        let header_key = "hello-custom-key".to_string();
561        let header_value = "hello-custom-value".to_string();
562
563        let (checker, handler) = create_test_handler();
564
565        // Initialize test server and config
566        let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
567        let config = WebSocketConfig::py_new(
568            format!("ws://127.0.0.1:{}", server.port),
569            Python::with_gil(|py| handler.clone_ref(py)),
570            vec![(header_key, header_value)],
571            Some(1),
572            Some("heartbeat message".to_string()),
573            None,
574            None,
575            None,
576            None,
577            None,
578            None,
579        );
580        let client = WebSocketClient::connect(config, None, Vec::new(), None)
581            .await
582            .unwrap();
583
584        // Check if ping message has the correct message
585        sleep(Duration::from_secs(2)).await;
586        let check_value: bool = Python::with_gil(|py| {
587            checker
588                .getattr(py, "get_check")
589                .unwrap()
590                .call0(py)
591                .unwrap()
592                .extract(py)
593                .unwrap()
594        });
595        assert!(check_value);
596
597        // Cleanup
598        client.disconnect().await;
599        assert!(client.is_disconnected());
600    }
601}