nautilus_network/python/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    sync::{Arc, atomic::Ordering},
18    time::Duration,
19};
20
21use nautilus_core::python::{clone_py_object, to_pyruntime_err, to_pyvalue_err};
22use pyo3::{Py, create_exception, exceptions::PyException, prelude::*, types::PyBytes};
23use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
24
25use crate::{
26    RECONNECTED,
27    mode::ConnectionMode,
28    ratelimiter::quota::Quota,
29    websocket::{
30        WebSocketClient, WebSocketConfig,
31        types::{MessageHandler, PingHandler, WriterCommand},
32    },
33};
34
35create_exception!(network, WebSocketClientError, PyException);
36
37fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
38    PyErr::new::<WebSocketClientError, _>(e.to_string())
39}
40
41#[pymethods]
42impl WebSocketConfig {
43    /// Create a new WebSocket configuration.
44    #[new]
45    #[allow(clippy::too_many_arguments)]
46    #[pyo3(signature = (
47        url,
48        headers,
49        heartbeat=None,
50        heartbeat_msg=None,
51        reconnect_timeout_ms=10_000,
52        reconnect_delay_initial_ms=2_000,
53        reconnect_delay_max_ms=30_000,
54        reconnect_backoff_factor=1.5,
55        reconnect_jitter_ms=100,
56        reconnect_max_attempts=None,
57    ))]
58    fn py_new(
59        url: String,
60        headers: Vec<(String, String)>,
61        heartbeat: Option<u64>,
62        heartbeat_msg: Option<String>,
63        reconnect_timeout_ms: Option<u64>,
64        reconnect_delay_initial_ms: Option<u64>,
65        reconnect_delay_max_ms: Option<u64>,
66        reconnect_backoff_factor: Option<f64>,
67        reconnect_jitter_ms: Option<u64>,
68        reconnect_max_attempts: Option<u32>,
69    ) -> Self {
70        Self {
71            url,
72            headers,
73            heartbeat,
74            heartbeat_msg,
75            reconnect_timeout_ms,
76            reconnect_delay_initial_ms,
77            reconnect_delay_max_ms,
78            reconnect_backoff_factor,
79            reconnect_jitter_ms,
80            reconnect_max_attempts,
81        }
82    }
83}
84
85#[pymethods]
86impl WebSocketClient {
87    /// Create a websocket client.
88    ///
89    /// The handler and ping_handler callbacks are scheduled on the provided event loop
90    /// using `call_soon_threadsafe` to ensure they execute on the correct thread.
91    /// This is critical for thread safety since WebSocket messages arrive on
92    /// a Tokio worker thread, but Python callbacks (like those entering the
93    /// kernel via MessageBus) must run on the asyncio event loop thread.
94    ///
95    /// # Safety
96    ///
97    /// - Throws an Exception if it is unable to make websocket connection.
98    #[staticmethod]
99    #[pyo3(name = "connect", signature = (loop_, config, handler, ping_handler = None, post_reconnection = None, keyed_quotas = Vec::new(), default_quota = None))]
100    #[allow(clippy::too_many_arguments)]
101    fn py_connect(
102        loop_: Py<PyAny>,
103        config: WebSocketConfig,
104        handler: Py<PyAny>,
105        ping_handler: Option<Py<PyAny>>,
106        post_reconnection: Option<Py<PyAny>>,
107        keyed_quotas: Vec<(String, Quota)>,
108        default_quota: Option<Quota>,
109        py: Python<'_>,
110    ) -> PyResult<Bound<'_, PyAny>> {
111        let call_soon_threadsafe: Py<PyAny> = loop_.getattr(py, "call_soon_threadsafe")?;
112        let call_soon_clone = clone_py_object(&call_soon_threadsafe);
113        let handler_clone = clone_py_object(&handler);
114
115        let message_handler: MessageHandler = Arc::new(move |msg: Message| {
116            if matches!(msg, Message::Text(ref text) if text.as_str() == RECONNECTED) {
117                return;
118            }
119
120            Python::attach(|py| {
121                let py_bytes = match &msg {
122                    Message::Binary(data) => PyBytes::new(py, data),
123                    Message::Text(text) => PyBytes::new(py, text.as_bytes()),
124                    _ => return,
125                };
126
127                if let Err(e) = call_soon_clone.call1(py, (&handler_clone, py_bytes)) {
128                    log::error!("Error scheduling message handler on event loop: {e}");
129                }
130            });
131        });
132
133        let ping_handler_fn = ping_handler.map(|ping_handler| {
134            let ping_handler_clone = clone_py_object(&ping_handler);
135            let call_soon_clone = clone_py_object(&call_soon_threadsafe);
136
137            let ping_handler_fn: PingHandler = Arc::new(move |data: Vec<u8>| {
138                Python::attach(|py| {
139                    let py_bytes = PyBytes::new(py, &data);
140                    if let Err(e) = call_soon_clone.call1(py, (&ping_handler_clone, py_bytes)) {
141                        log::error!("Error scheduling ping handler on event loop: {e}");
142                    }
143                });
144            });
145            ping_handler_fn
146        });
147
148        let post_reconnection_fn = post_reconnection.map(|callback| {
149            let callback_clone = clone_py_object(&callback);
150            Arc::new(move || {
151                Python::attach(|py| {
152                    if let Err(e) = callback_clone.call0(py) {
153                        log::error!("Error calling post_reconnection handler: {e}");
154                    }
155                });
156            }) as std::sync::Arc<dyn Fn() + Send + Sync>
157        });
158
159        pyo3_async_runtimes::tokio::future_into_py(py, async move {
160            Self::connect(
161                config,
162                Some(message_handler),
163                ping_handler_fn,
164                post_reconnection_fn,
165                keyed_quotas,
166                default_quota,
167            )
168            .await
169            .map_err(to_websocket_pyerr)
170        })
171    }
172
173    /// Closes the client heart beat and reader task.
174    ///
175    /// The connection is not completely closed the till all references
176    /// to the client are gone and the client is dropped.
177    ///
178    /// # Safety
179    ///
180    /// - The client should not be used after closing it.
181    /// - Any auto-reconnect job should be aborted before closing the client.
182    #[pyo3(name = "disconnect")]
183    fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
184        let connection_mode = slf.connection_mode.clone();
185        let mode = ConnectionMode::from_atomic(&connection_mode);
186        log::debug!("Close from mode {mode}");
187
188        pyo3_async_runtimes::tokio::future_into_py(py, async move {
189            match ConnectionMode::from_atomic(&connection_mode) {
190                ConnectionMode::Closed => {
191                    log::debug!("WebSocket already closed");
192                }
193                ConnectionMode::Disconnect => {
194                    log::debug!("WebSocket already disconnecting");
195                }
196                _ => {
197                    connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
198                    while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
199                        tokio::time::sleep(Duration::from_millis(10)).await;
200                    }
201                }
202            }
203
204            Ok(())
205        })
206    }
207
208    /// Check if the client is still alive.
209    ///
210    /// Even if the connection is disconnected the client will still be alive
211    /// and trying to reconnect.
212    ///
213    /// This is particularly useful for checking why a `send` failed. It could
214    /// be because the connection disconnected and the client is still alive
215    /// and reconnecting. In such cases the send can be retried after some
216    /// delay.
217    #[pyo3(name = "is_active")]
218    fn py_is_active(slf: PyRef<'_, Self>) -> bool {
219        !slf.controller_task.is_finished()
220    }
221
222    #[pyo3(name = "is_reconnecting")]
223    fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
224        slf.is_reconnecting()
225    }
226
227    #[pyo3(name = "is_disconnecting")]
228    fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
229        slf.is_disconnecting()
230    }
231
232    #[pyo3(name = "is_closed")]
233    fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
234        slf.is_closed()
235    }
236
237    /// Send bytes data to the server.
238    ///
239    /// # Errors
240    ///
241    /// - Raises `PyRuntimeError` if not able to send data.
242    #[pyo3(name = "send")]
243    #[pyo3(signature = (data, keys=None))]
244    fn py_send<'py>(
245        slf: PyRef<'_, Self>,
246        data: Vec<u8>,
247        py: Python<'py>,
248        keys: Option<Vec<String>>,
249    ) -> PyResult<Bound<'py, PyAny>> {
250        let rate_limiter = slf.rate_limiter.clone();
251        let writer_tx = slf.writer_tx.clone();
252        let mode = slf.connection_mode.clone();
253
254        pyo3_async_runtimes::tokio::future_into_py(py, async move {
255            if !ConnectionMode::from_atomic(&mode).is_active() {
256                let msg = "Cannot send data: connection not active".to_string();
257                log::error!("{msg}");
258                return Err(to_pyruntime_err(std::io::Error::new(
259                    std::io::ErrorKind::NotConnected,
260                    msg,
261                )));
262            }
263            rate_limiter.await_keys_ready(keys).await;
264            log::trace!("Sending binary: {data:?}");
265
266            let msg = Message::Binary(data.into());
267            writer_tx
268                .send(WriterCommand::Send(msg))
269                .map_err(to_pyruntime_err)
270        })
271    }
272
273    /// Send UTF-8 encoded bytes as text data to the server, respecting rate limits.
274    ///
275    /// `data`: The byte data to be sent, which will be converted to a UTF-8 string.
276    /// `keys`: Optional list of rate limit keys. If provided, the function will wait for rate limits to be met for each key before sending the data.
277    ///
278    /// # Errors
279    /// - Raises `PyRuntimeError` if unable to send the data.
280    ///
281    /// # Example
282    ///
283    /// When a request is made the URL should be split into all relevant keys within it.
284    ///
285    /// For request /foo/bar, should pass keys ["foo/bar", "foo"] for rate limiting.
286    #[pyo3(name = "send_text")]
287    #[pyo3(signature = (data, keys=None))]
288    fn py_send_text<'py>(
289        slf: PyRef<'_, Self>,
290        data: Vec<u8>,
291        py: Python<'py>,
292        keys: Option<Vec<String>>,
293    ) -> PyResult<Bound<'py, PyAny>> {
294        let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
295        let data = Utf8Bytes::from(data_str);
296        let rate_limiter = slf.rate_limiter.clone();
297        let writer_tx = slf.writer_tx.clone();
298        let mode = slf.connection_mode.clone();
299
300        pyo3_async_runtimes::tokio::future_into_py(py, async move {
301            if !ConnectionMode::from_atomic(&mode).is_active() {
302                let e = std::io::Error::new(
303                    std::io::ErrorKind::NotConnected,
304                    "Cannot send text: connection not active",
305                );
306                return Err(to_pyruntime_err(e));
307            }
308            rate_limiter.await_keys_ready(keys).await;
309            log::trace!("Sending text: {data}");
310
311            let msg = Message::Text(data);
312            writer_tx
313                .send(WriterCommand::Send(msg))
314                .map_err(to_pyruntime_err)
315        })
316    }
317
318    /// Send pong bytes data to the server.
319    ///
320    /// # Errors
321    ///
322    /// - Raises `PyRuntimeError` if not able to send data.
323    #[pyo3(name = "send_pong")]
324    fn py_send_pong<'py>(
325        slf: PyRef<'_, Self>,
326        data: Vec<u8>,
327        py: Python<'py>,
328    ) -> PyResult<Bound<'py, PyAny>> {
329        let writer_tx = slf.writer_tx.clone();
330        let mode = slf.connection_mode.clone();
331        let data_len = data.len();
332
333        pyo3_async_runtimes::tokio::future_into_py(py, async move {
334            if !ConnectionMode::from_atomic(&mode).is_active() {
335                let e = std::io::Error::new(
336                    std::io::ErrorKind::NotConnected,
337                    "Cannot send pong: connection not active",
338                );
339                return Err(to_pyruntime_err(e));
340            }
341            log::trace!("Sending pong frame ({data_len} bytes)");
342
343            let msg = Message::Pong(data.into());
344            writer_tx
345                .send(WriterCommand::Send(msg))
346                .map_err(to_pyruntime_err)
347        })
348    }
349}
350
351#[cfg(test)]
352#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
353mod tests {
354    use std::ffi::CString;
355
356    use futures_util::{SinkExt, StreamExt};
357    use nautilus_core::python::IntoPyObjectNautilusExt;
358    use pyo3::{prelude::*, types::PyBytes};
359    use tokio::{
360        net::TcpListener,
361        task::{self, JoinHandle},
362        time::{Duration, sleep},
363    };
364    use tokio_tungstenite::{
365        accept_hdr_async,
366        tungstenite::{
367            Message,
368            handshake::server::{self, Callback},
369            http::HeaderValue,
370        },
371    };
372
373    use crate::websocket::{MessageHandler, WebSocketClient, WebSocketConfig};
374
375    struct TestServer {
376        task: JoinHandle<()>,
377        port: u16,
378    }
379
380    #[derive(Debug, Clone)]
381    struct TestCallback {
382        key: String,
383        value: HeaderValue,
384    }
385
386    impl Callback for TestCallback {
387        #[allow(clippy::panic_in_result_fn)]
388        fn on_request(
389            self,
390            request: &server::Request,
391            response: server::Response,
392        ) -> Result<server::Response, server::ErrorResponse> {
393            let _ = response;
394            let value = request.headers().get(&self.key);
395            assert!(value.is_some());
396
397            if let Some(value) = request.headers().get(&self.key) {
398                assert_eq!(value, self.value);
399            }
400
401            Ok(response)
402        }
403    }
404
405    impl TestServer {
406        async fn setup(key: String, value: String) -> Self {
407            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
408            let port = TcpListener::local_addr(&server).unwrap().port();
409
410            let test_call_back = TestCallback {
411                key,
412                value: HeaderValue::from_str(&value).unwrap(),
413            };
414
415            // Set up test server
416            let task = task::spawn(async move {
417                // Keep accepting connections
418                loop {
419                    let (conn, _) = server.accept().await.unwrap();
420                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
421                        .await
422                        .unwrap();
423
424                    task::spawn(async move {
425                        while let Some(Ok(msg)) = websocket.next().await {
426                            match msg {
427                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
428                                    if txt == "close-now" =>
429                                {
430                                    log::debug!("Forcibly closing from server side");
431                                    // This sends a close frame, then stops reading
432                                    let _ = websocket.close(None).await;
433                                    break;
434                                }
435                                // Echo text/binary frames
436                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
437                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
438                                    if websocket.send(msg).await.is_err() {
439                                        break;
440                                    }
441                                }
442                                // If the client closes, we also break
443                                tokio_tungstenite::tungstenite::protocol::Message::Close(
444                                    _frame,
445                                ) => {
446                                    let _ = websocket.close(None).await;
447                                    break;
448                                }
449                                // Ignore pings/pongs
450                                _ => {}
451                            }
452                        }
453                    });
454                }
455            });
456
457            Self { task, port }
458        }
459    }
460
461    impl Drop for TestServer {
462        fn drop(&mut self) {
463            self.task.abort();
464        }
465    }
466
467    fn create_test_handler() -> (Py<PyAny>, Py<PyAny>) {
468        let code_raw = r"
469class Counter:
470    def __init__(self):
471        self.count = 0
472        self.check = False
473
474    def handler(self, bytes):
475        msg = bytes.decode()
476        if msg == 'ping':
477            self.count += 1
478        elif msg == 'heartbeat message':
479            self.check = True
480
481    def get_check(self):
482        return self.check
483
484    def get_count(self):
485        return self.count
486
487counter = Counter()
488";
489
490        let code = CString::new(code_raw).unwrap();
491        let filename = CString::new("test".to_string()).unwrap();
492        let module = CString::new("test".to_string()).unwrap();
493        Python::attach(|py| {
494            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
495
496            let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
497            let handler = counter
498                .getattr(py, "handler")
499                .unwrap()
500                .into_py_any_unwrap(py);
501
502            (counter, handler)
503        })
504    }
505
506    #[tokio::test]
507    async fn basic_client_test() {
508        const N: usize = 10;
509
510        Python::initialize();
511
512        let mut success_count = 0;
513        let header_key = "hello-custom-key".to_string();
514        let header_value = "hello-custom-value".to_string();
515
516        let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
517        let (counter, handler) = create_test_handler();
518
519        let config = WebSocketConfig::py_new(
520            format!("ws://127.0.0.1:{}", server.port),
521            vec![(header_key, header_value)],
522            None,
523            None,
524            None,
525            None,
526            None,
527            None,
528            None,
529            None,
530        );
531
532        let handler_clone = Python::attach(|py| handler.clone_ref(py));
533
534        let message_handler: MessageHandler = std::sync::Arc::new(move |msg: Message| {
535            Python::attach(|py| {
536                let data = match msg {
537                    Message::Binary(data) => data.to_vec(),
538                    Message::Text(text) => text.as_bytes().to_vec(),
539                    _ => return,
540                };
541                let py_bytes = PyBytes::new(py, &data);
542                if let Err(e) = handler_clone.call1(py, (py_bytes,)) {
543                    log::error!("Error calling handler: {e}");
544                }
545            });
546        });
547
548        let client =
549            WebSocketClient::connect(config, Some(message_handler), None, None, vec![], None)
550                .await
551                .unwrap();
552
553        for _ in 0..N {
554            client.send_bytes(b"ping".to_vec(), None).await.unwrap();
555            success_count += 1;
556        }
557
558        sleep(Duration::from_secs(1)).await;
559        let count_value: usize = Python::attach(|py| {
560            counter
561                .getattr(py, "get_count")
562                .unwrap()
563                .call0(py)
564                .unwrap()
565                .extract(py)
566                .unwrap()
567        });
568        assert_eq!(count_value, success_count);
569
570        // Close the connection => client should reconnect automatically
571        client.send_close_message().await.unwrap();
572
573        // Send messages that increment the count
574        sleep(Duration::from_secs(2)).await;
575        for _ in 0..N {
576            client.send_bytes(b"ping".to_vec(), None).await.unwrap();
577            success_count += 1;
578        }
579
580        sleep(Duration::from_secs(1)).await;
581        let count_value: usize = Python::attach(|py| {
582            counter
583                .getattr(py, "get_count")
584                .unwrap()
585                .call0(py)
586                .unwrap()
587                .extract(py)
588                .unwrap()
589        });
590        assert_eq!(count_value, success_count);
591        assert_eq!(success_count, N + N);
592
593        client.disconnect().await;
594        assert!(client.is_disconnected());
595    }
596
597    #[tokio::test]
598    async fn message_ping_test() {
599        Python::initialize();
600
601        let header_key = "hello-custom-key".to_string();
602        let header_value = "hello-custom-value".to_string();
603
604        let (checker, handler) = create_test_handler();
605
606        let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
607        let config = WebSocketConfig::py_new(
608            format!("ws://127.0.0.1:{}", server.port),
609            vec![(header_key, header_value)],
610            Some(1),
611            Some("heartbeat message".to_string()),
612            None,
613            None,
614            None,
615            None,
616            None,
617            None,
618        );
619
620        let handler_clone = Python::attach(|py| handler.clone_ref(py));
621
622        let message_handler: MessageHandler = std::sync::Arc::new(move |msg: Message| {
623            Python::attach(|py| {
624                let data = match msg {
625                    Message::Binary(data) => data.to_vec(),
626                    Message::Text(text) => text.as_bytes().to_vec(),
627                    _ => return,
628                };
629                let py_bytes = PyBytes::new(py, &data);
630                if let Err(e) = handler_clone.call1(py, (py_bytes,)) {
631                    log::error!("Error calling handler: {e}");
632                }
633            });
634        });
635
636        let client =
637            WebSocketClient::connect(config, Some(message_handler), None, None, vec![], None)
638                .await
639                .unwrap();
640
641        sleep(Duration::from_secs(2)).await;
642        let check_value: bool = Python::attach(|py| {
643            checker
644                .getattr(py, "get_check")
645                .unwrap()
646                .call0(py)
647                .unwrap()
648                .extract(py)
649                .unwrap()
650        });
651        assert!(check_value);
652
653        client.disconnect().await;
654        assert!(client.is_disconnected());
655    }
656}