nautilus_network/python/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    sync::{Arc, atomic::Ordering},
18    time::Duration,
19};
20
21use nautilus_core::python::{clone_py_object, to_pyruntime_err, to_pyvalue_err};
22use pyo3::{Py, create_exception, exceptions::PyException, prelude::*, types::PyBytes};
23use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
24
25use crate::{
26    RECONNECTED,
27    mode::ConnectionMode,
28    ratelimiter::quota::Quota,
29    websocket::{
30        WebSocketClient, WebSocketConfig,
31        types::{MessageHandler, PingHandler, WriterCommand},
32    },
33};
34
35// Python exception class for websocket errors
36create_exception!(network, WebSocketClientError, PyException);
37
38fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
39    PyErr::new::<WebSocketClientError, _>(e.to_string())
40}
41
42#[pymethods]
43impl WebSocketConfig {
44    #[new]
45    #[allow(clippy::too_many_arguments)]
46    #[pyo3(signature = (
47        url,
48        handler,
49        headers,
50        heartbeat=None,
51        heartbeat_msg=None,
52        ping_handler=None,
53        reconnect_timeout_ms=10_000,
54        reconnect_delay_initial_ms=2_000,
55        reconnect_delay_max_ms=30_000,
56        reconnect_backoff_factor=1.5,
57        reconnect_jitter_ms=100,
58        reconnect_max_attempts=None,
59    ))]
60    fn py_new(
61        url: String,
62        handler: Py<PyAny>,
63        headers: Vec<(String, String)>,
64        heartbeat: Option<u64>,
65        heartbeat_msg: Option<String>,
66        ping_handler: Option<Py<PyAny>>,
67        reconnect_timeout_ms: Option<u64>,
68        reconnect_delay_initial_ms: Option<u64>,
69        reconnect_delay_max_ms: Option<u64>,
70        reconnect_backoff_factor: Option<f64>,
71        reconnect_jitter_ms: Option<u64>,
72        reconnect_max_attempts: Option<u32>,
73    ) -> Self {
74        // Create function pointer that calls Python handler
75        let handler_clone = clone_py_object(&handler);
76        let message_handler: MessageHandler = Arc::new(move |msg: Message| {
77            Python::attach(|py| {
78                let data = match msg {
79                    Message::Binary(data) => data.to_vec(),
80                    Message::Text(text) => {
81                        if text == RECONNECTED {
82                            return;
83                        }
84                        text.as_bytes().to_vec()
85                    }
86                    _ => return,
87                };
88                if let Err(e) = handler_clone.call1(py, (PyBytes::new(py, &data),)) {
89                    tracing::error!("Error calling Python message handler: {e}");
90                }
91            });
92        });
93
94        // Create function pointer for ping handler if provided
95        let ping_handler_fn = ping_handler.map(|ping_handler| {
96            let ping_handler_clone = clone_py_object(&ping_handler);
97            let ping_handler_fn: PingHandler = std::sync::Arc::new(move |data: Vec<u8>| {
98                Python::attach(|py| {
99                    if let Err(e) = ping_handler_clone.call1(py, (PyBytes::new(py, &data),)) {
100                        tracing::error!("Error calling Python ping handler: {e}");
101                    }
102                });
103            });
104            ping_handler_fn
105        });
106
107        Self {
108            url,
109            message_handler: Some(message_handler),
110            headers,
111            heartbeat,
112            heartbeat_msg,
113            ping_handler: ping_handler_fn,
114            reconnect_timeout_ms,
115            reconnect_delay_initial_ms,
116            reconnect_delay_max_ms,
117            reconnect_backoff_factor,
118            reconnect_jitter_ms,
119            reconnect_max_attempts,
120        }
121    }
122}
123
124#[pymethods]
125impl WebSocketClient {
126    /// Create a websocket client.
127    ///
128    /// # Safety
129    ///
130    /// - Throws an Exception if it is unable to make websocket connection.
131    #[staticmethod]
132    #[pyo3(name = "connect", signature = (config, post_reconnection= None, keyed_quotas = Vec::new(), default_quota = None))]
133    fn py_connect(
134        config: WebSocketConfig,
135        post_reconnection: Option<Py<PyAny>>,
136        keyed_quotas: Vec<(String, Quota)>,
137        default_quota: Option<Quota>,
138        py: Python<'_>,
139    ) -> PyResult<Bound<'_, PyAny>> {
140        let post_reconnection_fn = post_reconnection.map(|callback| {
141            let callback_clone = clone_py_object(&callback);
142            Arc::new(move || {
143                Python::attach(|py| {
144                    if let Err(e) = callback_clone.call0(py) {
145                        tracing::error!("Error calling post_reconnection handler: {e}");
146                    }
147                });
148            }) as std::sync::Arc<dyn Fn() + Send + Sync>
149        });
150
151        pyo3_async_runtimes::tokio::future_into_py(py, async move {
152            Self::connect(config, post_reconnection_fn, keyed_quotas, default_quota)
153                .await
154                .map_err(to_websocket_pyerr)
155        })
156    }
157
158    /// Closes the client heart beat and reader task.
159    ///
160    /// The connection is not completely closed the till all references
161    /// to the client are gone and the client is dropped.
162    ///
163    /// # Safety
164    ///
165    /// - The client should not be used after closing it.
166    /// - Any auto-reconnect job should be aborted before closing the client.
167    #[pyo3(name = "disconnect")]
168    fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
169        let connection_mode = slf.connection_mode.clone();
170        let mode = ConnectionMode::from_atomic(&connection_mode);
171        tracing::debug!("Close from mode {mode}");
172
173        pyo3_async_runtimes::tokio::future_into_py(py, async move {
174            match ConnectionMode::from_atomic(&connection_mode) {
175                ConnectionMode::Closed => {
176                    tracing::debug!("WebSocket already closed");
177                }
178                ConnectionMode::Disconnect => {
179                    tracing::debug!("WebSocket already disconnecting");
180                }
181                _ => {
182                    connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
183                    while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
184                        tokio::time::sleep(Duration::from_millis(10)).await;
185                    }
186                }
187            }
188
189            Ok(())
190        })
191    }
192
193    /// Check if the client is still alive.
194    ///
195    /// Even if the connection is disconnected the client will still be alive
196    /// and trying to reconnect.
197    ///
198    /// This is particularly useful for checking why a `send` failed. It could
199    /// be because the connection disconnected and the client is still alive
200    /// and reconnecting. In such cases the send can be retried after some
201    /// delay.
202    #[pyo3(name = "is_active")]
203    fn py_is_active(slf: PyRef<'_, Self>) -> bool {
204        !slf.controller_task.is_finished()
205    }
206
207    #[pyo3(name = "is_reconnecting")]
208    fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
209        slf.is_reconnecting()
210    }
211
212    #[pyo3(name = "is_disconnecting")]
213    fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
214        slf.is_disconnecting()
215    }
216
217    #[pyo3(name = "is_closed")]
218    fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
219        slf.is_closed()
220    }
221
222    /// Send bytes data to the server.
223    ///
224    /// # Errors
225    ///
226    /// - Raises `PyRuntimeError` if not able to send data.
227    #[pyo3(name = "send")]
228    #[pyo3(signature = (data, keys=None))]
229    fn py_send<'py>(
230        slf: PyRef<'_, Self>,
231        data: Vec<u8>,
232        py: Python<'py>,
233        keys: Option<Vec<String>>,
234    ) -> PyResult<Bound<'py, PyAny>> {
235        let rate_limiter = slf.rate_limiter.clone();
236        let writer_tx = slf.writer_tx.clone();
237        let mode = slf.connection_mode.clone();
238
239        pyo3_async_runtimes::tokio::future_into_py(py, async move {
240            if !ConnectionMode::from_atomic(&mode).is_active() {
241                let msg = "Cannot send data: connection not active".to_string();
242                tracing::error!("{msg}");
243                return Err(to_pyruntime_err(std::io::Error::new(
244                    std::io::ErrorKind::NotConnected,
245                    msg,
246                )));
247            }
248            rate_limiter.await_keys_ready(keys).await;
249            tracing::trace!("Sending binary: {data:?}");
250
251            let msg = Message::Binary(data.into());
252            writer_tx
253                .send(WriterCommand::Send(msg))
254                .map_err(to_pyruntime_err)
255        })
256    }
257
258    /// Send UTF-8 encoded bytes as text data to the server, respecting rate limits.
259    ///
260    /// `data`: The byte data to be sent, which will be converted to a UTF-8 string.
261    /// `keys`: Optional list of rate limit keys. If provided, the function will wait for rate limits to be met for each key before sending the data.
262    ///
263    /// # Errors
264    /// - Raises `PyRuntimeError` if unable to send the data.
265    ///
266    /// # Example
267    ///
268    /// When a request is made the URL should be split into all relevant keys within it.
269    ///
270    /// For request /foo/bar, should pass keys ["foo/bar", "foo"] for rate limiting.
271    #[pyo3(name = "send_text")]
272    #[pyo3(signature = (data, keys=None))]
273    fn py_send_text<'py>(
274        slf: PyRef<'_, Self>,
275        data: Vec<u8>,
276        py: Python<'py>,
277        keys: Option<Vec<String>>,
278    ) -> PyResult<Bound<'py, PyAny>> {
279        let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
280        let data = Utf8Bytes::from(data_str);
281        let rate_limiter = slf.rate_limiter.clone();
282        let writer_tx = slf.writer_tx.clone();
283        let mode = slf.connection_mode.clone();
284
285        pyo3_async_runtimes::tokio::future_into_py(py, async move {
286            if !ConnectionMode::from_atomic(&mode).is_active() {
287                let e = std::io::Error::new(
288                    std::io::ErrorKind::NotConnected,
289                    "Cannot send text: connection not active",
290                );
291                return Err(to_pyruntime_err(e));
292            }
293            rate_limiter.await_keys_ready(keys).await;
294            tracing::trace!("Sending text: {data}");
295
296            let msg = Message::Text(data);
297            writer_tx
298                .send(WriterCommand::Send(msg))
299                .map_err(to_pyruntime_err)
300        })
301    }
302
303    /// Send pong bytes data to the server.
304    ///
305    /// # Errors
306    ///
307    /// - Raises `PyRuntimeError` if not able to send data.
308    #[pyo3(name = "send_pong")]
309    fn py_send_pong<'py>(
310        slf: PyRef<'_, Self>,
311        data: Vec<u8>,
312        py: Python<'py>,
313    ) -> PyResult<Bound<'py, PyAny>> {
314        let writer_tx = slf.writer_tx.clone();
315        let mode = slf.connection_mode.clone();
316        let data_len = data.len();
317
318        pyo3_async_runtimes::tokio::future_into_py(py, async move {
319            if !ConnectionMode::from_atomic(&mode).is_active() {
320                let e = std::io::Error::new(
321                    std::io::ErrorKind::NotConnected,
322                    "Cannot send pong: connection not active",
323                );
324                return Err(to_pyruntime_err(e));
325            }
326            tracing::trace!("Sending pong frame ({data_len} bytes)");
327
328            let msg = Message::Pong(data.into());
329            writer_tx
330                .send(WriterCommand::Send(msg))
331                .map_err(to_pyruntime_err)
332        })
333    }
334}
335
336////////////////////////////////////////////////////////////////////////////////
337// Tests
338////////////////////////////////////////////////////////////////////////////////
339#[cfg(test)]
340#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
341mod tests {
342    use std::ffi::CString;
343
344    use futures_util::{SinkExt, StreamExt};
345    use nautilus_core::python::IntoPyObjectNautilusExt;
346    use pyo3::prelude::*;
347    use tokio::{
348        net::TcpListener,
349        task::{self, JoinHandle},
350        time::{Duration, sleep},
351    };
352    use tokio_tungstenite::{
353        accept_hdr_async,
354        tungstenite::{
355            handshake::server::{self, Callback},
356            http::HeaderValue,
357        },
358    };
359    use tracing_test::traced_test;
360
361    use crate::websocket::{WebSocketClient, WebSocketConfig};
362
363    struct TestServer {
364        task: JoinHandle<()>,
365        port: u16,
366    }
367
368    #[derive(Debug, Clone)]
369    struct TestCallback {
370        key: String,
371        value: HeaderValue,
372    }
373
374    impl Callback for TestCallback {
375        #[allow(clippy::panic_in_result_fn)]
376        fn on_request(
377            self,
378            request: &server::Request,
379            response: server::Response,
380        ) -> Result<server::Response, server::ErrorResponse> {
381            let _ = response;
382            let value = request.headers().get(&self.key);
383            assert!(value.is_some());
384
385            if let Some(value) = request.headers().get(&self.key) {
386                assert_eq!(value, self.value);
387            }
388
389            Ok(response)
390        }
391    }
392
393    impl TestServer {
394        async fn setup(key: String, value: String) -> Self {
395            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
396            let port = TcpListener::local_addr(&server).unwrap().port();
397
398            let test_call_back = TestCallback {
399                key,
400                value: HeaderValue::from_str(&value).unwrap(),
401            };
402
403            // Set up test server
404            let task = task::spawn(async move {
405                // Keep accepting connections
406                loop {
407                    let (conn, _) = server.accept().await.unwrap();
408                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
409                        .await
410                        .unwrap();
411
412                    task::spawn(async move {
413                        while let Some(Ok(msg)) = websocket.next().await {
414                            match msg {
415                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
416                                    if txt == "close-now" =>
417                                {
418                                    tracing::debug!("Forcibly closing from server side");
419                                    // This sends a close frame, then stops reading
420                                    let _ = websocket.close(None).await;
421                                    break;
422                                }
423                                // Echo text/binary frames
424                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
425                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
426                                    if websocket.send(msg).await.is_err() {
427                                        break;
428                                    }
429                                }
430                                // If the client closes, we also break
431                                tokio_tungstenite::tungstenite::protocol::Message::Close(
432                                    _frame,
433                                ) => {
434                                    let _ = websocket.close(None).await;
435                                    break;
436                                }
437                                // Ignore pings/pongs
438                                _ => {}
439                            }
440                        }
441                    });
442                }
443            });
444
445            Self { task, port }
446        }
447    }
448
449    impl Drop for TestServer {
450        fn drop(&mut self) {
451            self.task.abort();
452        }
453    }
454
455    fn create_test_handler() -> (Py<PyAny>, Py<PyAny>) {
456        let code_raw = r"
457class Counter:
458    def __init__(self):
459        self.count = 0
460        self.check = False
461
462    def handler(self, bytes):
463        msg = bytes.decode()
464        if msg == 'ping':
465            self.count += 1
466        elif msg == 'heartbeat message':
467            self.check = True
468
469    def get_check(self):
470        return self.check
471
472    def get_count(self):
473        return self.count
474
475counter = Counter()
476";
477
478        let code = CString::new(code_raw).unwrap();
479        let filename = CString::new("test".to_string()).unwrap();
480        let module = CString::new("test".to_string()).unwrap();
481        Python::attach(|py| {
482            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
483
484            let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
485            let handler = counter
486                .getattr(py, "handler")
487                .unwrap()
488                .into_py_any_unwrap(py);
489
490            (counter, handler)
491        })
492    }
493
494    #[tokio::test]
495    #[traced_test]
496    async fn basic_client_test() {
497        Python::initialize();
498
499        const N: usize = 10;
500        let mut success_count = 0;
501        let header_key = "hello-custom-key".to_string();
502        let header_value = "hello-custom-value".to_string();
503
504        let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
505        let (counter, handler) = create_test_handler();
506
507        let config = WebSocketConfig::py_new(
508            format!("ws://127.0.0.1:{}", server.port),
509            Python::attach(|py| handler.clone_ref(py)),
510            vec![(header_key, header_value)],
511            None,
512            None,
513            None,
514            None,
515            None,
516            None,
517            None,
518            None,
519            None,
520        );
521        let client = WebSocketClient::connect(config, None, Vec::new(), None)
522            .await
523            .unwrap();
524
525        // Send messages that increment the count
526        for _ in 0..N {
527            client.send_bytes(b"ping".to_vec(), None).await.unwrap();
528            success_count += 1;
529        }
530
531        // Check count is same as number messages sent
532        sleep(Duration::from_secs(1)).await;
533        let count_value: usize = Python::attach(|py| {
534            counter
535                .getattr(py, "get_count")
536                .unwrap()
537                .call0(py)
538                .unwrap()
539                .extract(py)
540                .unwrap()
541        });
542        assert_eq!(count_value, success_count);
543
544        // Close the connection => client should reconnect automatically
545        client.send_close_message().await.unwrap();
546
547        // Send messages that increment the count
548        sleep(Duration::from_secs(2)).await;
549        for _ in 0..N {
550            client.send_bytes(b"ping".to_vec(), None).await.unwrap();
551            success_count += 1;
552        }
553
554        // Check count is same as number messages sent
555        sleep(Duration::from_secs(1)).await;
556        let count_value: usize = Python::attach(|py| {
557            counter
558                .getattr(py, "get_count")
559                .unwrap()
560                .call0(py)
561                .unwrap()
562                .extract(py)
563                .unwrap()
564        });
565        assert_eq!(count_value, success_count);
566        assert_eq!(success_count, N + N);
567
568        // Cleanup
569        client.disconnect().await;
570        assert!(client.is_disconnected());
571    }
572
573    #[tokio::test]
574    #[traced_test]
575    async fn message_ping_test() {
576        Python::initialize();
577
578        let header_key = "hello-custom-key".to_string();
579        let header_value = "hello-custom-value".to_string();
580
581        let (checker, handler) = create_test_handler();
582
583        // Initialize test server and config
584        let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
585        let config = WebSocketConfig::py_new(
586            format!("ws://127.0.0.1:{}", server.port),
587            Python::attach(|py| handler.clone_ref(py)),
588            vec![(header_key, header_value)],
589            Some(1),
590            Some("heartbeat message".to_string()),
591            None,
592            None,
593            None,
594            None,
595            None,
596            None,
597            None,
598        );
599        let client = WebSocketClient::connect(config, None, Vec::new(), None)
600            .await
601            .unwrap();
602
603        // Check if ping message has the correct message
604        sleep(Duration::from_secs(2)).await;
605        let check_value: bool = Python::attach(|py| {
606            checker
607                .getattr(py, "get_check")
608                .unwrap()
609                .call0(py)
610                .unwrap()
611                .extract(py)
612                .unwrap()
613        });
614        assert!(check_value);
615
616        // Cleanup
617        client.disconnect().await;
618        assert!(client.is_disconnected());
619    }
620}