1use std::{sync::atomic::Ordering, time::Duration};
17
18use nautilus_core::python::{clone_py_object, to_pyruntime_err};
19use pyo3::{Py, prelude::*};
20use tokio_tungstenite::tungstenite::stream::Mode;
21
22use crate::{
23 mode::ConnectionMode,
24 socket::{SocketClient, SocketConfig, TcpMessageHandler, WriterCommand},
25};
26
27#[pymethods]
28impl SocketConfig {
29 #[new]
30 #[allow(clippy::too_many_arguments)]
31 #[pyo3(signature = (url, ssl, suffix, handler, heartbeat=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100, connection_max_retries=5, certs_dir=None, reconnect_max_attempts=None))]
32 fn py_new(
33 url: String,
34 ssl: bool,
35 suffix: Vec<u8>,
36 handler: Py<PyAny>,
37 heartbeat: Option<(u64, Vec<u8>)>,
38 reconnect_timeout_ms: Option<u64>,
39 reconnect_delay_initial_ms: Option<u64>,
40 reconnect_delay_max_ms: Option<u64>,
41 reconnect_backoff_factor: Option<f64>,
42 reconnect_jitter_ms: Option<u64>,
43 connection_max_retries: Option<u32>,
44 certs_dir: Option<String>,
45 reconnect_max_attempts: Option<u32>,
46 ) -> Self {
47 let mode = if ssl { Mode::Tls } else { Mode::Plain };
48
49 let handler_clone = clone_py_object(&handler);
51 let message_handler: TcpMessageHandler = std::sync::Arc::new(move |data: &[u8]| {
52 Python::attach(|py| {
53 if let Err(e) = handler_clone.call1(py, (data,)) {
54 tracing::error!("Error calling Python message handler: {e}");
55 }
56 });
57 });
58
59 Self {
60 url,
61 mode,
62 suffix,
63 message_handler: Some(message_handler),
64 heartbeat,
65 reconnect_timeout_ms,
66 reconnect_delay_initial_ms,
67 reconnect_delay_max_ms,
68 reconnect_backoff_factor,
69 reconnect_jitter_ms,
70 connection_max_retries,
71 certs_dir,
72 reconnect_max_attempts,
73 }
74 }
75}
76
77#[pymethods]
78impl SocketClient {
79 #[staticmethod]
85 #[pyo3(name = "connect")]
86 #[pyo3(signature = (config, post_connection=None, post_reconnection=None, post_disconnection=None))]
87 fn py_connect(
88 config: SocketConfig,
89 post_connection: Option<Py<PyAny>>,
90 post_reconnection: Option<Py<PyAny>>,
91 post_disconnection: Option<Py<PyAny>>,
92 py: Python<'_>,
93 ) -> PyResult<Bound<'_, PyAny>> {
94 let post_connection_fn = post_connection.map(|callback| {
96 let callback_clone = clone_py_object(&callback);
97 std::sync::Arc::new(move || {
98 Python::attach(|py| {
99 if let Err(e) = callback_clone.call0(py) {
100 tracing::error!("Error calling post_connection handler: {e}");
101 }
102 });
103 }) as std::sync::Arc<dyn Fn() + Send + Sync>
104 });
105
106 let post_reconnection_fn = post_reconnection.map(|callback| {
107 let callback_clone = clone_py_object(&callback);
108 std::sync::Arc::new(move || {
109 Python::attach(|py| {
110 if let Err(e) = callback_clone.call0(py) {
111 tracing::error!("Error calling post_reconnection handler: {e}");
112 }
113 });
114 }) as std::sync::Arc<dyn Fn() + Send + Sync>
115 });
116
117 let post_disconnection_fn = post_disconnection.map(|callback| {
118 let callback_clone = clone_py_object(&callback);
119 std::sync::Arc::new(move || {
120 Python::attach(|py| {
121 if let Err(e) = callback_clone.call0(py) {
122 tracing::error!("Error calling post_disconnection handler: {e}");
123 }
124 });
125 }) as std::sync::Arc<dyn Fn() + Send + Sync>
126 });
127
128 pyo3_async_runtimes::tokio::future_into_py(py, async move {
129 Self::connect(
130 config,
131 post_connection_fn,
132 post_reconnection_fn,
133 post_disconnection_fn,
134 )
135 .await
136 .map_err(to_pyruntime_err)
137 })
138 }
139
140 #[pyo3(name = "is_active")]
150 fn py_is_active(slf: PyRef<'_, Self>) -> bool {
151 slf.is_active()
152 }
153
154 #[pyo3(name = "is_reconnecting")]
155 fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
156 slf.is_reconnecting()
157 }
158
159 #[pyo3(name = "is_disconnecting")]
160 fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
161 slf.is_disconnecting()
162 }
163
164 #[pyo3(name = "is_closed")]
165 fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
166 slf.is_closed()
167 }
168
169 #[pyo3(name = "mode")]
170 fn py_mode(slf: PyRef<'_, Self>) -> String {
171 slf.connection_mode().to_string()
172 }
173
174 #[pyo3(name = "reconnect")]
176 fn py_reconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
177 let mode = slf.connection_mode.clone();
178 let mode_str = ConnectionMode::from_atomic(&mode).to_string();
179 tracing::debug!("Reconnect from mode {mode_str}");
180
181 pyo3_async_runtimes::tokio::future_into_py(py, async move {
182 match ConnectionMode::from_atomic(&mode) {
183 ConnectionMode::Reconnect => {
184 tracing::warn!("Cannot reconnect - socket already reconnecting");
185 }
186 ConnectionMode::Disconnect => {
187 tracing::warn!("Cannot reconnect - socket disconnecting");
188 }
189 ConnectionMode::Closed => {
190 tracing::warn!("Cannot reconnect - socket closed");
191 }
192 _ => {
193 mode.store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
194 while !ConnectionMode::from_atomic(&mode).is_active() {
195 tokio::time::sleep(Duration::from_millis(10)).await;
196 }
197 }
198 }
199
200 Ok(())
201 })
202 }
203
204 #[pyo3(name = "close")]
214 fn py_close<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
215 let mode = slf.connection_mode.clone();
216 let mode_str = ConnectionMode::from_atomic(&mode).to_string();
217 tracing::debug!("Close from mode {mode_str}");
218
219 pyo3_async_runtimes::tokio::future_into_py(py, async move {
220 match ConnectionMode::from_atomic(&mode) {
221 ConnectionMode::Closed => {
222 tracing::debug!("Socket already closed");
223 }
224 ConnectionMode::Disconnect => {
225 tracing::debug!("Socket already disconnecting");
226 }
227 _ => {
228 mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
229 while !ConnectionMode::from_atomic(&mode).is_closed() {
230 tokio::time::sleep(Duration::from_millis(10)).await;
231 }
232 }
233 }
234
235 Ok(())
236 })
237 }
238
239 #[pyo3(name = "send")]
245 fn py_send<'py>(
246 slf: PyRef<'_, Self>,
247 data: Vec<u8>,
248 py: Python<'py>,
249 ) -> PyResult<Bound<'py, PyAny>> {
250 tracing::trace!("Sending {}", String::from_utf8_lossy(&data));
251
252 let mode = slf.connection_mode.clone();
253 let writer_tx = slf.writer_tx.clone();
254
255 pyo3_async_runtimes::tokio::future_into_py(py, async move {
256 if ConnectionMode::from_atomic(&mode).is_closed() {
257 let msg = format!(
258 "Cannot send data ({}): socket closed",
259 String::from_utf8_lossy(&data)
260 );
261
262 let io_err = std::io::Error::new(std::io::ErrorKind::NotConnected, msg);
263 return Err(to_pyruntime_err(io_err));
264 }
265
266 let timeout = Duration::from_secs(2);
267 let check_interval = Duration::from_millis(1);
268
269 if !ConnectionMode::from_atomic(&mode).is_active() {
270 tracing::debug!("Waiting for client to become ACTIVE before sending (2s)...");
271 match tokio::time::timeout(timeout, async {
272 while !ConnectionMode::from_atomic(&mode).is_active() {
273 if matches!(
274 ConnectionMode::from_atomic(&mode),
275 ConnectionMode::Disconnect | ConnectionMode::Closed
276 ) {
277 return Err("Client disconnected waiting to send");
278 }
279
280 tokio::time::sleep(check_interval).await;
281 }
282
283 Ok(())
284 })
285 .await
286 {
287 Ok(Ok(())) => tracing::debug!("Client now active"),
288 Ok(Err(e)) => {
289 let err_msg = format!(
290 "Failed sending data ({}): {e}",
291 String::from_utf8_lossy(&data)
292 );
293
294 let io_err = std::io::Error::new(std::io::ErrorKind::NotConnected, err_msg);
295 return Err(to_pyruntime_err(io_err));
296 }
297 Err(_) => {
298 let err_msg = format!(
299 "Failed sending data ({}): timeout waiting to become ACTIVE",
300 String::from_utf8_lossy(&data)
301 );
302
303 let io_err = std::io::Error::new(std::io::ErrorKind::TimedOut, err_msg);
304 return Err(to_pyruntime_err(io_err));
305 }
306 }
307 }
308
309 let msg = WriterCommand::Send(data.into());
310 writer_tx.send(msg).map_err(to_pyruntime_err)
311 })
312 }
313}