Skip to main content

nautilus_network/python/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    sync::{Arc, atomic::Ordering},
18    time::Duration,
19};
20
21use nautilus_core::{
22    collections::into_ustr_vec,
23    python::{clone_py_object, to_pyruntime_err, to_pyvalue_err},
24};
25use pyo3::{Py, create_exception, exceptions::PyException, prelude::*, types::PyBytes};
26use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
27
28use crate::{
29    RECONNECTED,
30    mode::ConnectionMode,
31    ratelimiter::quota::Quota,
32    websocket::{
33        WebSocketClient, WebSocketConfig,
34        types::{MessageHandler, PingHandler, WriterCommand},
35    },
36};
37
38create_exception!(network, WebSocketClientError, PyException);
39
40fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
41    PyErr::new::<WebSocketClientError, _>(e.to_string())
42}
43
44#[pymethods]
45impl WebSocketConfig {
46    /// Create a new WebSocket configuration.
47    #[new]
48    #[allow(clippy::too_many_arguments)]
49    #[pyo3(signature = (
50        url,
51        headers,
52        heartbeat=None,
53        heartbeat_msg=None,
54        reconnect_timeout_ms=10_000,
55        reconnect_delay_initial_ms=2_000,
56        reconnect_delay_max_ms=30_000,
57        reconnect_backoff_factor=1.5,
58        reconnect_jitter_ms=100,
59        reconnect_max_attempts=None,
60    ))]
61    fn py_new(
62        url: String,
63        headers: Vec<(String, String)>,
64        heartbeat: Option<u64>,
65        heartbeat_msg: Option<String>,
66        reconnect_timeout_ms: Option<u64>,
67        reconnect_delay_initial_ms: Option<u64>,
68        reconnect_delay_max_ms: Option<u64>,
69        reconnect_backoff_factor: Option<f64>,
70        reconnect_jitter_ms: Option<u64>,
71        reconnect_max_attempts: Option<u32>,
72    ) -> Self {
73        Self {
74            url,
75            headers,
76            heartbeat,
77            heartbeat_msg,
78            reconnect_timeout_ms,
79            reconnect_delay_initial_ms,
80            reconnect_delay_max_ms,
81            reconnect_backoff_factor,
82            reconnect_jitter_ms,
83            reconnect_max_attempts,
84        }
85    }
86}
87
88#[pymethods]
89impl WebSocketClient {
90    /// Create a websocket client.
91    ///
92    /// The handler and ping_handler callbacks are scheduled on the provided event loop
93    /// using `call_soon_threadsafe` to ensure they execute on the correct thread.
94    /// This is critical for thread safety since WebSocket messages arrive on
95    /// a Tokio worker thread, but Python callbacks (like those entering the
96    /// kernel via MessageBus) must run on the asyncio event loop thread.
97    ///
98    /// # Safety
99    ///
100    /// - Throws an Exception if it is unable to make websocket connection.
101    #[staticmethod]
102    #[pyo3(name = "connect", signature = (loop_, config, handler, ping_handler = None, post_reconnection = None, keyed_quotas = Vec::new(), default_quota = None))]
103    #[allow(clippy::too_many_arguments)]
104    fn py_connect(
105        loop_: Py<PyAny>,
106        config: WebSocketConfig,
107        handler: Py<PyAny>,
108        ping_handler: Option<Py<PyAny>>,
109        post_reconnection: Option<Py<PyAny>>,
110        keyed_quotas: Vec<(String, Quota)>,
111        default_quota: Option<Quota>,
112        py: Python<'_>,
113    ) -> PyResult<Bound<'_, PyAny>> {
114        let call_soon_threadsafe: Py<PyAny> = loop_.getattr(py, "call_soon_threadsafe")?;
115        let call_soon_clone = clone_py_object(&call_soon_threadsafe);
116        let handler_clone = clone_py_object(&handler);
117
118        let message_handler: MessageHandler = Arc::new(move |msg: Message| {
119            if matches!(msg, Message::Text(ref text) if text.as_str() == RECONNECTED) {
120                return;
121            }
122
123            Python::attach(|py| {
124                let py_bytes = match &msg {
125                    Message::Binary(data) => PyBytes::new(py, data),
126                    Message::Text(text) => PyBytes::new(py, text.as_bytes()),
127                    _ => return,
128                };
129
130                if let Err(e) = call_soon_clone.call1(py, (&handler_clone, py_bytes)) {
131                    log::error!("Error scheduling message handler on event loop: {e}");
132                }
133            });
134        });
135
136        let ping_handler_fn = ping_handler.map(|ping_handler| {
137            let ping_handler_clone = clone_py_object(&ping_handler);
138            let call_soon_clone = clone_py_object(&call_soon_threadsafe);
139
140            let ping_handler_fn: PingHandler = Arc::new(move |data: Vec<u8>| {
141                Python::attach(|py| {
142                    let py_bytes = PyBytes::new(py, &data);
143                    if let Err(e) = call_soon_clone.call1(py, (&ping_handler_clone, py_bytes)) {
144                        log::error!("Error scheduling ping handler on event loop: {e}");
145                    }
146                });
147            });
148            ping_handler_fn
149        });
150
151        let post_reconnection_fn = post_reconnection.map(|callback| {
152            let callback_clone = clone_py_object(&callback);
153            Arc::new(move || {
154                Python::attach(|py| {
155                    if let Err(e) = callback_clone.call0(py) {
156                        log::error!("Error calling post_reconnection handler: {e}");
157                    }
158                });
159            }) as std::sync::Arc<dyn Fn() + Send + Sync>
160        });
161
162        pyo3_async_runtimes::tokio::future_into_py(py, async move {
163            Self::connect(
164                config,
165                Some(message_handler),
166                ping_handler_fn,
167                post_reconnection_fn,
168                keyed_quotas,
169                default_quota,
170            )
171            .await
172            .map_err(to_websocket_pyerr)
173        })
174    }
175
176    /// Closes the client heart beat and reader task.
177    ///
178    /// The connection is not completely closed the till all references
179    /// to the client are gone and the client is dropped.
180    ///
181    /// # Safety
182    ///
183    /// - The client should not be used after closing it.
184    /// - Any auto-reconnect job should be aborted before closing the client.
185    #[pyo3(name = "disconnect")]
186    fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
187        let connection_mode = slf.connection_mode.clone();
188        let mode = ConnectionMode::from_atomic(&connection_mode);
189        log::debug!("Close from mode {mode}");
190
191        pyo3_async_runtimes::tokio::future_into_py(py, async move {
192            match ConnectionMode::from_atomic(&connection_mode) {
193                ConnectionMode::Closed => {
194                    log::debug!("WebSocket already closed");
195                }
196                ConnectionMode::Disconnect => {
197                    log::debug!("WebSocket already disconnecting");
198                }
199                _ => {
200                    connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
201                    while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
202                        tokio::time::sleep(Duration::from_millis(10)).await;
203                    }
204                }
205            }
206
207            Ok(())
208        })
209    }
210
211    /// Check if the client is still alive.
212    ///
213    /// Even if the connection is disconnected the client will still be alive
214    /// and trying to reconnect.
215    ///
216    /// This is particularly useful for checking why a `send` failed. It could
217    /// be because the connection disconnected and the client is still alive
218    /// and reconnecting. In such cases the send can be retried after some
219    /// delay.
220    #[pyo3(name = "is_active")]
221    fn py_is_active(slf: PyRef<'_, Self>) -> bool {
222        !slf.controller_task.is_finished()
223    }
224
225    #[pyo3(name = "is_reconnecting")]
226    fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
227        slf.is_reconnecting()
228    }
229
230    #[pyo3(name = "is_disconnecting")]
231    fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
232        slf.is_disconnecting()
233    }
234
235    #[pyo3(name = "is_closed")]
236    fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
237        slf.is_closed()
238    }
239
240    /// Send bytes data to the server.
241    ///
242    /// # Errors
243    ///
244    /// - Raises `PyRuntimeError` if not able to send data.
245    #[pyo3(name = "send")]
246    #[pyo3(signature = (data, keys=None))]
247    fn py_send<'py>(
248        slf: PyRef<'_, Self>,
249        data: Vec<u8>,
250        py: Python<'py>,
251        keys: Option<Vec<String>>,
252    ) -> PyResult<Bound<'py, PyAny>> {
253        let rate_limiter = slf.rate_limiter.clone();
254        let writer_tx = slf.writer_tx.clone();
255        let mode = slf.connection_mode.clone();
256        let keys = keys.map(into_ustr_vec);
257
258        pyo3_async_runtimes::tokio::future_into_py(py, async move {
259            if !ConnectionMode::from_atomic(&mode).is_active() {
260                let msg = "Cannot send data: connection not active".to_string();
261                log::error!("{msg}");
262                return Err(to_pyruntime_err(std::io::Error::new(
263                    std::io::ErrorKind::NotConnected,
264                    msg,
265                )));
266            }
267            rate_limiter.await_keys_ready(keys.as_deref()).await;
268            log::trace!("Sending binary: {data:?}");
269
270            let msg = Message::Binary(data.into());
271            writer_tx
272                .send(WriterCommand::Send(msg))
273                .map_err(to_pyruntime_err)
274        })
275    }
276
277    /// Send UTF-8 encoded bytes as text data to the server, respecting rate limits.
278    ///
279    /// `data`: The byte data to be sent, which will be converted to a UTF-8 string.
280    /// `keys`: Optional list of rate limit keys. If provided, the function will wait for rate limits to be met for each key before sending the data.
281    ///
282    /// # Errors
283    /// - Raises `PyRuntimeError` if unable to send the data.
284    ///
285    /// # Example
286    ///
287    /// When a request is made the URL should be split into all relevant keys within it.
288    ///
289    /// For request /foo/bar, should pass keys ["foo/bar", "foo"] for rate limiting.
290    #[pyo3(name = "send_text")]
291    #[pyo3(signature = (data, keys=None))]
292    fn py_send_text<'py>(
293        slf: PyRef<'_, Self>,
294        data: Vec<u8>,
295        py: Python<'py>,
296        keys: Option<Vec<String>>,
297    ) -> PyResult<Bound<'py, PyAny>> {
298        let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
299        let data = Utf8Bytes::from(data_str);
300        let rate_limiter = slf.rate_limiter.clone();
301        let writer_tx = slf.writer_tx.clone();
302        let mode = slf.connection_mode.clone();
303        let keys = keys.map(into_ustr_vec);
304
305        pyo3_async_runtimes::tokio::future_into_py(py, async move {
306            if !ConnectionMode::from_atomic(&mode).is_active() {
307                let e = std::io::Error::new(
308                    std::io::ErrorKind::NotConnected,
309                    "Cannot send text: connection not active",
310                );
311                return Err(to_pyruntime_err(e));
312            }
313            rate_limiter.await_keys_ready(keys.as_deref()).await;
314            log::trace!("Sending text: {data}");
315
316            let msg = Message::Text(data);
317            writer_tx
318                .send(WriterCommand::Send(msg))
319                .map_err(to_pyruntime_err)
320        })
321    }
322
323    /// Send pong bytes data to the server.
324    ///
325    /// # Errors
326    ///
327    /// - Raises `PyRuntimeError` if not able to send data.
328    #[pyo3(name = "send_pong")]
329    fn py_send_pong<'py>(
330        slf: PyRef<'_, Self>,
331        data: Vec<u8>,
332        py: Python<'py>,
333    ) -> PyResult<Bound<'py, PyAny>> {
334        let writer_tx = slf.writer_tx.clone();
335        let mode = slf.connection_mode.clone();
336        let data_len = data.len();
337
338        pyo3_async_runtimes::tokio::future_into_py(py, async move {
339            if !ConnectionMode::from_atomic(&mode).is_active() {
340                let e = std::io::Error::new(
341                    std::io::ErrorKind::NotConnected,
342                    "Cannot send pong: connection not active",
343                );
344                return Err(to_pyruntime_err(e));
345            }
346            log::trace!("Sending pong frame ({data_len} bytes)");
347
348            let msg = Message::Pong(data.into());
349            writer_tx
350                .send(WriterCommand::Send(msg))
351                .map_err(to_pyruntime_err)
352        })
353    }
354}
355
356#[cfg(test)]
357#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
358mod tests {
359    use std::ffi::CString;
360
361    use futures_util::{SinkExt, StreamExt};
362    use nautilus_core::python::IntoPyObjectNautilusExt;
363    use pyo3::{prelude::*, types::PyBytes};
364    use tokio::{
365        net::TcpListener,
366        task::{self, JoinHandle},
367        time::{Duration, sleep},
368    };
369    use tokio_tungstenite::{
370        accept_hdr_async,
371        tungstenite::{
372            Message,
373            handshake::server::{self, Callback},
374            http::HeaderValue,
375        },
376    };
377
378    use crate::websocket::{MessageHandler, WebSocketClient, WebSocketConfig};
379
380    struct TestServer {
381        task: JoinHandle<()>,
382        port: u16,
383    }
384
385    #[derive(Debug, Clone)]
386    struct TestCallback {
387        key: String,
388        value: HeaderValue,
389    }
390
391    impl Callback for TestCallback {
392        #[allow(clippy::panic_in_result_fn)]
393        fn on_request(
394            self,
395            request: &server::Request,
396            response: server::Response,
397        ) -> Result<server::Response, server::ErrorResponse> {
398            let _ = response;
399            let value = request.headers().get(&self.key);
400            assert!(value.is_some());
401
402            if let Some(value) = request.headers().get(&self.key) {
403                assert_eq!(value, self.value);
404            }
405
406            Ok(response)
407        }
408    }
409
410    impl TestServer {
411        async fn setup(key: String, value: String) -> Self {
412            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
413            let port = TcpListener::local_addr(&server).unwrap().port();
414
415            let test_call_back = TestCallback {
416                key,
417                value: HeaderValue::from_str(&value).unwrap(),
418            };
419
420            // Set up test server
421            let task = task::spawn(async move {
422                // Keep accepting connections
423                loop {
424                    let (conn, _) = server.accept().await.unwrap();
425                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
426                        .await
427                        .unwrap();
428
429                    task::spawn(async move {
430                        while let Some(Ok(msg)) = websocket.next().await {
431                            match msg {
432                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
433                                    if txt == "close-now" =>
434                                {
435                                    log::debug!("Forcibly closing from server side");
436                                    // This sends a close frame, then stops reading
437                                    let _ = websocket.close(None).await;
438                                    break;
439                                }
440                                // Echo text/binary frames
441                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
442                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
443                                    if websocket.send(msg).await.is_err() {
444                                        break;
445                                    }
446                                }
447                                // If the client closes, we also break
448                                tokio_tungstenite::tungstenite::protocol::Message::Close(
449                                    _frame,
450                                ) => {
451                                    let _ = websocket.close(None).await;
452                                    break;
453                                }
454                                // Ignore pings/pongs
455                                _ => {}
456                            }
457                        }
458                    });
459                }
460            });
461
462            Self { task, port }
463        }
464    }
465
466    impl Drop for TestServer {
467        fn drop(&mut self) {
468            self.task.abort();
469        }
470    }
471
472    fn create_test_handler() -> (Py<PyAny>, Py<PyAny>) {
473        let code_raw = r"
474class Counter:
475    def __init__(self):
476        self.count = 0
477        self.check = False
478
479    def handler(self, bytes):
480        msg = bytes.decode()
481        if msg == 'ping':
482            self.count += 1
483        elif msg == 'heartbeat message':
484            self.check = True
485
486    def get_check(self):
487        return self.check
488
489    def get_count(self):
490        return self.count
491
492counter = Counter()
493";
494
495        let code = CString::new(code_raw).unwrap();
496        let filename = CString::new("test".to_string()).unwrap();
497        let module = CString::new("test".to_string()).unwrap();
498        Python::attach(|py| {
499            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
500
501            let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
502            let handler = counter
503                .getattr(py, "handler")
504                .unwrap()
505                .into_py_any_unwrap(py);
506
507            (counter, handler)
508        })
509    }
510
511    #[tokio::test]
512    async fn basic_client_test() {
513        const N: usize = 10;
514
515        Python::initialize();
516
517        let mut success_count = 0;
518        let header_key = "hello-custom-key".to_string();
519        let header_value = "hello-custom-value".to_string();
520
521        let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
522        let (counter, handler) = create_test_handler();
523
524        let config = WebSocketConfig::py_new(
525            format!("ws://127.0.0.1:{}", server.port),
526            vec![(header_key, header_value)],
527            None,
528            None,
529            None,
530            None,
531            None,
532            None,
533            None,
534            None,
535        );
536
537        let handler_clone = Python::attach(|py| handler.clone_ref(py));
538
539        let message_handler: MessageHandler = std::sync::Arc::new(move |msg: Message| {
540            Python::attach(|py| {
541                let data = match msg {
542                    Message::Binary(data) => data.to_vec(),
543                    Message::Text(text) => text.as_bytes().to_vec(),
544                    _ => return,
545                };
546                let py_bytes = PyBytes::new(py, &data);
547                if let Err(e) = handler_clone.call1(py, (py_bytes,)) {
548                    log::error!("Error calling handler: {e}");
549                }
550            });
551        });
552
553        let client =
554            WebSocketClient::connect(config, Some(message_handler), None, None, vec![], None)
555                .await
556                .unwrap();
557
558        for _ in 0..N {
559            client.send_bytes(b"ping".to_vec(), None).await.unwrap();
560            success_count += 1;
561        }
562
563        sleep(Duration::from_secs(1)).await;
564        let count_value: usize = Python::attach(|py| {
565            counter
566                .getattr(py, "get_count")
567                .unwrap()
568                .call0(py)
569                .unwrap()
570                .extract(py)
571                .unwrap()
572        });
573        assert_eq!(count_value, success_count);
574
575        // Close the connection => client should reconnect automatically
576        client.send_close_message().await.unwrap();
577
578        // Send messages that increment the count
579        sleep(Duration::from_secs(2)).await;
580        for _ in 0..N {
581            client.send_bytes(b"ping".to_vec(), None).await.unwrap();
582            success_count += 1;
583        }
584
585        sleep(Duration::from_secs(1)).await;
586        let count_value: usize = Python::attach(|py| {
587            counter
588                .getattr(py, "get_count")
589                .unwrap()
590                .call0(py)
591                .unwrap()
592                .extract(py)
593                .unwrap()
594        });
595        assert_eq!(count_value, success_count);
596        assert_eq!(success_count, N + N);
597
598        client.disconnect().await;
599        assert!(client.is_disconnected());
600    }
601
602    #[tokio::test]
603    async fn message_ping_test() {
604        Python::initialize();
605
606        let header_key = "hello-custom-key".to_string();
607        let header_value = "hello-custom-value".to_string();
608
609        let (checker, handler) = create_test_handler();
610
611        let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
612        let config = WebSocketConfig::py_new(
613            format!("ws://127.0.0.1:{}", server.port),
614            vec![(header_key, header_value)],
615            Some(1),
616            Some("heartbeat message".to_string()),
617            None,
618            None,
619            None,
620            None,
621            None,
622            None,
623        );
624
625        let handler_clone = Python::attach(|py| handler.clone_ref(py));
626
627        let message_handler: MessageHandler = std::sync::Arc::new(move |msg: Message| {
628            Python::attach(|py| {
629                let data = match msg {
630                    Message::Binary(data) => data.to_vec(),
631                    Message::Text(text) => text.as_bytes().to_vec(),
632                    _ => return,
633                };
634                let py_bytes = PyBytes::new(py, &data);
635                if let Err(e) = handler_clone.call1(py, (py_bytes,)) {
636                    log::error!("Error calling handler: {e}");
637                }
638            });
639        });
640
641        let client =
642            WebSocketClient::connect(config, Some(message_handler), None, None, vec![], None)
643                .await
644                .unwrap();
645
646        sleep(Duration::from_secs(2)).await;
647        let check_value: bool = Python::attach(|py| {
648            checker
649                .getattr(py, "get_check")
650                .unwrap()
651                .call0(py)
652                .unwrap()
653                .extract(py)
654                .unwrap()
655        });
656        assert!(check_value);
657
658        client.disconnect().await;
659        assert!(client.is_disconnected());
660    }
661}