nautilus_network/python/
socket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{sync::atomic::Ordering, time::Duration};
17
18use nautilus_core::python::{clone_py_object, to_pyruntime_err};
19use pyo3::{Py, prelude::*};
20use tokio_tungstenite::tungstenite::stream::Mode;
21
22use crate::{
23    mode::ConnectionMode,
24    socket::{SocketClient, SocketConfig, TcpMessageHandler, WriterCommand},
25};
26
27#[pymethods]
28impl SocketConfig {
29    #[new]
30    #[allow(clippy::too_many_arguments)]
31    #[pyo3(signature = (url, ssl, suffix, handler, heartbeat=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100, connection_max_retries=5, certs_dir=None, reconnect_max_attempts=None))]
32    fn py_new(
33        url: String,
34        ssl: bool,
35        suffix: Vec<u8>,
36        handler: Py<PyAny>,
37        heartbeat: Option<(u64, Vec<u8>)>,
38        reconnect_timeout_ms: Option<u64>,
39        reconnect_delay_initial_ms: Option<u64>,
40        reconnect_delay_max_ms: Option<u64>,
41        reconnect_backoff_factor: Option<f64>,
42        reconnect_jitter_ms: Option<u64>,
43        connection_max_retries: Option<u32>,
44        certs_dir: Option<String>,
45        reconnect_max_attempts: Option<u32>,
46    ) -> Self {
47        let mode = if ssl { Mode::Tls } else { Mode::Plain };
48
49        // Create function pointer that calls Python handler
50        let handler_clone = clone_py_object(&handler);
51        let message_handler: TcpMessageHandler = std::sync::Arc::new(move |data: &[u8]| {
52            Python::attach(|py| {
53                if let Err(e) = handler_clone.call1(py, (data,)) {
54                    tracing::error!("Error calling Python message handler: {e}");
55                }
56            });
57        });
58
59        Self {
60            url,
61            mode,
62            suffix,
63            message_handler: Some(message_handler),
64            heartbeat,
65            reconnect_timeout_ms,
66            reconnect_delay_initial_ms,
67            reconnect_delay_max_ms,
68            reconnect_backoff_factor,
69            reconnect_jitter_ms,
70            connection_max_retries,
71            certs_dir,
72            reconnect_max_attempts,
73        }
74    }
75}
76
77#[pymethods]
78impl SocketClient {
79    /// Create a socket client.
80    ///
81    /// # Errors
82    ///
83    /// - Throws an Exception if it is unable to make socket connection.
84    #[staticmethod]
85    #[pyo3(name = "connect")]
86    #[pyo3(signature = (config, post_connection=None, post_reconnection=None, post_disconnection=None))]
87    fn py_connect(
88        config: SocketConfig,
89        post_connection: Option<Py<PyAny>>,
90        post_reconnection: Option<Py<PyAny>>,
91        post_disconnection: Option<Py<PyAny>>,
92        py: Python<'_>,
93    ) -> PyResult<Bound<'_, PyAny>> {
94        // Convert Python callbacks to function pointers
95        let post_connection_fn = post_connection.map(|callback| {
96            let callback_clone = clone_py_object(&callback);
97            std::sync::Arc::new(move || {
98                Python::attach(|py| {
99                    if let Err(e) = callback_clone.call0(py) {
100                        tracing::error!("Error calling post_connection handler: {e}");
101                    }
102                });
103            }) as std::sync::Arc<dyn Fn() + Send + Sync>
104        });
105
106        let post_reconnection_fn = post_reconnection.map(|callback| {
107            let callback_clone = clone_py_object(&callback);
108            std::sync::Arc::new(move || {
109                Python::attach(|py| {
110                    if let Err(e) = callback_clone.call0(py) {
111                        tracing::error!("Error calling post_reconnection handler: {e}");
112                    }
113                });
114            }) as std::sync::Arc<dyn Fn() + Send + Sync>
115        });
116
117        let post_disconnection_fn = post_disconnection.map(|callback| {
118            let callback_clone = clone_py_object(&callback);
119            std::sync::Arc::new(move || {
120                Python::attach(|py| {
121                    if let Err(e) = callback_clone.call0(py) {
122                        tracing::error!("Error calling post_disconnection handler: {e}");
123                    }
124                });
125            }) as std::sync::Arc<dyn Fn() + Send + Sync>
126        });
127
128        pyo3_async_runtimes::tokio::future_into_py(py, async move {
129            Self::connect(
130                config,
131                post_connection_fn,
132                post_reconnection_fn,
133                post_disconnection_fn,
134            )
135            .await
136            .map_err(to_pyruntime_err)
137        })
138    }
139
140    /// Check if the client is still alive.
141    ///
142    /// Even if the connection is disconnected the client will still be alive
143    /// and trying to reconnect.
144    ///
145    /// This is particularly useful for check why a `send` failed. It could
146    /// be because the connection disconnected and the client is still alive
147    /// and reconnecting. In such cases the send can be retried after some
148    /// delay
149    #[pyo3(name = "is_active")]
150    fn py_is_active(slf: PyRef<'_, Self>) -> bool {
151        slf.is_active()
152    }
153
154    #[pyo3(name = "is_reconnecting")]
155    fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
156        slf.is_reconnecting()
157    }
158
159    #[pyo3(name = "is_disconnecting")]
160    fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
161        slf.is_disconnecting()
162    }
163
164    #[pyo3(name = "is_closed")]
165    fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
166        slf.is_closed()
167    }
168
169    #[pyo3(name = "mode")]
170    fn py_mode(slf: PyRef<'_, Self>) -> String {
171        slf.connection_mode().to_string()
172    }
173
174    /// Reconnect the client.
175    #[pyo3(name = "reconnect")]
176    fn py_reconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
177        let mode = slf.connection_mode.clone();
178        let mode_str = ConnectionMode::from_atomic(&mode).to_string();
179        tracing::debug!("Reconnect from mode {mode_str}");
180
181        pyo3_async_runtimes::tokio::future_into_py(py, async move {
182            match ConnectionMode::from_atomic(&mode) {
183                ConnectionMode::Reconnect => {
184                    tracing::warn!("Cannot reconnect - socket already reconnecting");
185                }
186                ConnectionMode::Disconnect => {
187                    tracing::warn!("Cannot reconnect - socket disconnecting");
188                }
189                ConnectionMode::Closed => {
190                    tracing::warn!("Cannot reconnect - socket closed");
191                }
192                _ => {
193                    mode.store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
194                    while !ConnectionMode::from_atomic(&mode).is_active() {
195                        tokio::time::sleep(Duration::from_millis(10)).await;
196                    }
197                }
198            }
199
200            Ok(())
201        })
202    }
203
204    /// Close the client.
205    ///
206    /// The connection is not completely closed until all references
207    /// to the client are gone and the client is dropped.
208    ///
209    /// # Safety
210    ///
211    /// - The client should not be used after closing it
212    /// - Any auto-reconnect job should be aborted before closing the client
213    #[pyo3(name = "close")]
214    fn py_close<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
215        let mode = slf.connection_mode.clone();
216        let mode_str = ConnectionMode::from_atomic(&mode).to_string();
217        tracing::debug!("Close from mode {mode_str}");
218
219        pyo3_async_runtimes::tokio::future_into_py(py, async move {
220            match ConnectionMode::from_atomic(&mode) {
221                ConnectionMode::Closed => {
222                    tracing::debug!("Socket already closed");
223                }
224                ConnectionMode::Disconnect => {
225                    tracing::debug!("Socket already disconnecting");
226                }
227                _ => {
228                    mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
229                    while !ConnectionMode::from_atomic(&mode).is_closed() {
230                        tokio::time::sleep(Duration::from_millis(10)).await;
231                    }
232                }
233            }
234
235            Ok(())
236        })
237    }
238
239    /// Send bytes data to the connection.
240    ///
241    /// # Errors
242    ///
243    /// - Throws an Exception if it is not able to send data.
244    #[pyo3(name = "send")]
245    fn py_send<'py>(
246        slf: PyRef<'_, Self>,
247        data: Vec<u8>,
248        py: Python<'py>,
249    ) -> PyResult<Bound<'py, PyAny>> {
250        tracing::trace!("Sending {}", String::from_utf8_lossy(&data));
251
252        let mode = slf.connection_mode.clone();
253        let writer_tx = slf.writer_tx.clone();
254
255        pyo3_async_runtimes::tokio::future_into_py(py, async move {
256            if ConnectionMode::from_atomic(&mode).is_closed() {
257                let msg = format!(
258                    "Cannot send data ({}): socket closed",
259                    String::from_utf8_lossy(&data)
260                );
261
262                let io_err = std::io::Error::new(std::io::ErrorKind::NotConnected, msg);
263                return Err(to_pyruntime_err(io_err));
264            }
265
266            let timeout = Duration::from_secs(2);
267            let check_interval = Duration::from_millis(1);
268
269            if !ConnectionMode::from_atomic(&mode).is_active() {
270                tracing::debug!("Waiting for client to become ACTIVE before sending (2s)...");
271                match tokio::time::timeout(timeout, async {
272                    while !ConnectionMode::from_atomic(&mode).is_active() {
273                        if matches!(
274                            ConnectionMode::from_atomic(&mode),
275                            ConnectionMode::Disconnect | ConnectionMode::Closed
276                        ) {
277                            return Err("Client disconnected waiting to send");
278                        }
279
280                        tokio::time::sleep(check_interval).await;
281                    }
282
283                    Ok(())
284                })
285                .await
286                {
287                    Ok(Ok(())) => tracing::debug!("Client now active"),
288                    Ok(Err(e)) => {
289                        let err_msg = format!(
290                            "Failed sending data ({}): {e}",
291                            String::from_utf8_lossy(&data)
292                        );
293
294                        let io_err = std::io::Error::new(std::io::ErrorKind::NotConnected, err_msg);
295                        return Err(to_pyruntime_err(io_err));
296                    }
297                    Err(_) => {
298                        let err_msg = format!(
299                            "Failed sending data ({}): timeout waiting to become ACTIVE",
300                            String::from_utf8_lossy(&data)
301                        );
302
303                        let io_err = std::io::Error::new(std::io::ErrorKind::TimedOut, err_msg);
304                        return Err(to_pyruntime_err(io_err));
305                    }
306                }
307            }
308
309            let msg = WriterCommand::Send(data.into());
310            writer_tx.send(msg).map_err(to_pyruntime_err)
311        })
312    }
313}