nautilus_network/python/
socket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{sync::atomic::Ordering, time::Duration};
17
18use nautilus_core::python::{clone_py_object, to_pyruntime_err};
19use pyo3::prelude::*;
20use tokio_tungstenite::tungstenite::stream::Mode;
21
22use crate::{
23    mode::ConnectionMode,
24    socket::{SocketClient, SocketConfig, TcpMessageHandler, WriterCommand},
25};
26
27#[pymethods]
28impl SocketConfig {
29    #[new]
30    #[allow(clippy::too_many_arguments)]
31    #[pyo3(signature = (url, ssl, suffix, handler, heartbeat=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100, certs_dir=None))]
32    fn py_new(
33        url: String,
34        ssl: bool,
35        suffix: Vec<u8>,
36        handler: PyObject,
37        heartbeat: Option<(u64, Vec<u8>)>,
38        reconnect_timeout_ms: Option<u64>,
39        reconnect_delay_initial_ms: Option<u64>,
40        reconnect_delay_max_ms: Option<u64>,
41        reconnect_backoff_factor: Option<f64>,
42        reconnect_jitter_ms: Option<u64>,
43        certs_dir: Option<String>,
44    ) -> Self {
45        let mode = if ssl { Mode::Tls } else { Mode::Plain };
46
47        // Create function pointer that calls Python handler
48        let handler_clone = clone_py_object(&handler);
49        let message_handler: TcpMessageHandler = std::sync::Arc::new(move |data: &[u8]| {
50            Python::with_gil(|py| {
51                if let Err(e) = handler_clone.call1(py, (data,)) {
52                    tracing::error!("Error calling Python message handler: {e}");
53                }
54            });
55        });
56
57        Self {
58            url,
59            mode,
60            suffix,
61            message_handler: Some(message_handler),
62            heartbeat,
63            reconnect_timeout_ms,
64            reconnect_delay_initial_ms,
65            reconnect_delay_max_ms,
66            reconnect_backoff_factor,
67            reconnect_jitter_ms,
68            certs_dir,
69        }
70    }
71}
72
73#[pymethods]
74impl SocketClient {
75    /// Create a socket client.
76    ///
77    /// # Errors
78    ///
79    /// - Throws an Exception if it is unable to make socket connection.
80    #[staticmethod]
81    #[pyo3(name = "connect")]
82    #[pyo3(signature = (config, post_connection=None, post_reconnection=None, post_disconnection=None))]
83    fn py_connect(
84        config: SocketConfig,
85        post_connection: Option<PyObject>,
86        post_reconnection: Option<PyObject>,
87        post_disconnection: Option<PyObject>,
88        py: Python<'_>,
89    ) -> PyResult<Bound<'_, PyAny>> {
90        // Convert Python callbacks to function pointers
91        let post_connection_fn = post_connection.map(|callback| {
92            let callback_clone = clone_py_object(&callback);
93            std::sync::Arc::new(move || {
94                Python::with_gil(|py| {
95                    if let Err(e) = callback_clone.call0(py) {
96                        tracing::error!("Error calling post_connection handler: {e}");
97                    }
98                });
99            }) as std::sync::Arc<dyn Fn() + Send + Sync>
100        });
101
102        let post_reconnection_fn = post_reconnection.map(|callback| {
103            let callback_clone = clone_py_object(&callback);
104            std::sync::Arc::new(move || {
105                Python::with_gil(|py| {
106                    if let Err(e) = callback_clone.call0(py) {
107                        tracing::error!("Error calling post_reconnection handler: {e}");
108                    }
109                });
110            }) as std::sync::Arc<dyn Fn() + Send + Sync>
111        });
112
113        let post_disconnection_fn = post_disconnection.map(|callback| {
114            let callback_clone = clone_py_object(&callback);
115            std::sync::Arc::new(move || {
116                Python::with_gil(|py| {
117                    if let Err(e) = callback_clone.call0(py) {
118                        tracing::error!("Error calling post_disconnection handler: {e}");
119                    }
120                });
121            }) as std::sync::Arc<dyn Fn() + Send + Sync>
122        });
123
124        pyo3_async_runtimes::tokio::future_into_py(py, async move {
125            Self::connect(
126                config,
127                post_connection_fn,
128                post_reconnection_fn,
129                post_disconnection_fn,
130            )
131            .await
132            .map_err(to_pyruntime_err)
133        })
134    }
135
136    /// Check if the client is still alive.
137    ///
138    /// Even if the connection is disconnected the client will still be alive
139    /// and trying to reconnect.
140    ///
141    /// This is particularly useful for check why a `send` failed. It could
142    /// be because the connection disconnected and the client is still alive
143    /// and reconnecting. In such cases the send can be retried after some
144    /// delay
145    #[pyo3(name = "is_active")]
146    fn py_is_active(slf: PyRef<'_, Self>) -> bool {
147        slf.is_active()
148    }
149
150    #[pyo3(name = "is_reconnecting")]
151    fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
152        slf.is_reconnecting()
153    }
154
155    #[pyo3(name = "is_disconnecting")]
156    fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
157        slf.is_disconnecting()
158    }
159
160    #[pyo3(name = "is_closed")]
161    fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
162        slf.is_closed()
163    }
164
165    #[pyo3(name = "mode")]
166    fn py_mode(slf: PyRef<'_, Self>) -> String {
167        slf.connection_mode().to_string()
168    }
169
170    /// Reconnect the client.
171    #[pyo3(name = "reconnect")]
172    fn py_reconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
173        let mode = slf.connection_mode.clone();
174        let mode_str = ConnectionMode::from_atomic(&mode).to_string();
175        tracing::debug!("Reconnect from mode {mode_str}");
176
177        pyo3_async_runtimes::tokio::future_into_py(py, async move {
178            match ConnectionMode::from_atomic(&mode) {
179                ConnectionMode::Reconnect => {
180                    tracing::warn!("Cannot reconnect - socket already reconnecting");
181                }
182                ConnectionMode::Disconnect => {
183                    tracing::warn!("Cannot reconnect - socket disconnecting");
184                }
185                ConnectionMode::Closed => {
186                    tracing::warn!("Cannot reconnect - socket closed");
187                }
188                _ => {
189                    mode.store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
190                    while !ConnectionMode::from_atomic(&mode).is_active() {
191                        tokio::time::sleep(Duration::from_millis(10)).await;
192                    }
193                }
194            }
195
196            Ok(())
197        })
198    }
199
200    /// Close the client.
201    ///
202    /// The connection is not completely closed until all references
203    /// to the client are gone and the client is dropped.
204    ///
205    /// # Safety
206    ///
207    /// - The client should not be used after closing it
208    /// - Any auto-reconnect job should be aborted before closing the client
209    #[pyo3(name = "close")]
210    fn py_close<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
211        let mode = slf.connection_mode.clone();
212        let mode_str = ConnectionMode::from_atomic(&mode).to_string();
213        tracing::debug!("Close from mode {mode_str}");
214
215        pyo3_async_runtimes::tokio::future_into_py(py, async move {
216            match ConnectionMode::from_atomic(&mode) {
217                ConnectionMode::Closed => {
218                    tracing::debug!("Socket already closed");
219                }
220                ConnectionMode::Disconnect => {
221                    tracing::debug!("Socket already disconnecting");
222                }
223                _ => {
224                    mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
225                    while !ConnectionMode::from_atomic(&mode).is_closed() {
226                        tokio::time::sleep(Duration::from_millis(10)).await;
227                    }
228                }
229            }
230
231            Ok(())
232        })
233    }
234
235    /// Send bytes data to the connection.
236    ///
237    /// # Errors
238    ///
239    /// - Throws an Exception if it is not able to send data.
240    #[pyo3(name = "send")]
241    fn py_send<'py>(
242        slf: PyRef<'_, Self>,
243        data: Vec<u8>,
244        py: Python<'py>,
245    ) -> PyResult<Bound<'py, PyAny>> {
246        tracing::trace!("Sending {}", String::from_utf8_lossy(&data));
247
248        let mode = slf.connection_mode.clone();
249        let writer_tx = slf.writer_tx.clone();
250
251        pyo3_async_runtimes::tokio::future_into_py(py, async move {
252            if ConnectionMode::from_atomic(&mode).is_closed() {
253                let msg = format!(
254                    "Cannot send data ({}): socket closed",
255                    String::from_utf8_lossy(&data)
256                );
257
258                let io_err = std::io::Error::new(std::io::ErrorKind::NotConnected, msg);
259                return Err(to_pyruntime_err(io_err));
260            }
261
262            let timeout = Duration::from_secs(2);
263            let check_interval = Duration::from_millis(1);
264
265            if !ConnectionMode::from_atomic(&mode).is_active() {
266                tracing::debug!("Waiting for client to become ACTIVE before sending (2s)...");
267                match tokio::time::timeout(timeout, async {
268                    while !ConnectionMode::from_atomic(&mode).is_active() {
269                        if matches!(
270                            ConnectionMode::from_atomic(&mode),
271                            ConnectionMode::Disconnect | ConnectionMode::Closed
272                        ) {
273                            return Err("Client disconnected waiting to send");
274                        }
275
276                        tokio::time::sleep(check_interval).await;
277                    }
278
279                    Ok(())
280                })
281                .await
282                {
283                    Ok(Ok(())) => tracing::debug!("Client now active"),
284                    Ok(Err(e)) => {
285                        let err_msg = format!(
286                            "Failed sending data ({}): {e}",
287                            String::from_utf8_lossy(&data)
288                        );
289
290                        let io_err = std::io::Error::new(std::io::ErrorKind::NotConnected, err_msg);
291                        return Err(to_pyruntime_err(io_err));
292                    }
293                    Err(_) => {
294                        let err_msg = format!(
295                            "Failed sending data ({}): timeout waiting to become ACTIVE",
296                            String::from_utf8_lossy(&data)
297                        );
298
299                        let io_err = std::io::Error::new(std::io::ErrorKind::TimedOut, err_msg);
300                        return Err(to_pyruntime_err(io_err));
301                    }
302                }
303            }
304
305            let msg = WriterCommand::Send(data.into());
306            writer_tx.send(msg).map_err(to_pyruntime_err)
307        })
308    }
309}