1use std::{sync::atomic::Ordering, time::Duration};
17
18use nautilus_core::python::{clone_py_object, to_pyruntime_err};
19use pyo3::prelude::*;
20use tokio_tungstenite::tungstenite::stream::Mode;
21
22use crate::{
23 mode::ConnectionMode,
24 socket::{SocketClient, SocketConfig, TcpMessageHandler, WriterCommand},
25};
26
27#[pymethods]
28impl SocketConfig {
29 #[new]
30 #[allow(clippy::too_many_arguments)]
31 #[pyo3(signature = (url, ssl, suffix, handler, heartbeat=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100, certs_dir=None))]
32 fn py_new(
33 url: String,
34 ssl: bool,
35 suffix: Vec<u8>,
36 handler: PyObject,
37 heartbeat: Option<(u64, Vec<u8>)>,
38 reconnect_timeout_ms: Option<u64>,
39 reconnect_delay_initial_ms: Option<u64>,
40 reconnect_delay_max_ms: Option<u64>,
41 reconnect_backoff_factor: Option<f64>,
42 reconnect_jitter_ms: Option<u64>,
43 certs_dir: Option<String>,
44 ) -> Self {
45 let mode = if ssl { Mode::Tls } else { Mode::Plain };
46
47 let handler_clone = clone_py_object(&handler);
49 let message_handler: TcpMessageHandler = std::sync::Arc::new(move |data: &[u8]| {
50 Python::with_gil(|py| {
51 if let Err(e) = handler_clone.call1(py, (data,)) {
52 tracing::error!("Error calling Python message handler: {e}");
53 }
54 });
55 });
56
57 Self {
58 url,
59 mode,
60 suffix,
61 message_handler: Some(message_handler),
62 heartbeat,
63 reconnect_timeout_ms,
64 reconnect_delay_initial_ms,
65 reconnect_delay_max_ms,
66 reconnect_backoff_factor,
67 reconnect_jitter_ms,
68 certs_dir,
69 }
70 }
71}
72
73#[pymethods]
74impl SocketClient {
75 #[staticmethod]
81 #[pyo3(name = "connect")]
82 #[pyo3(signature = (config, post_connection=None, post_reconnection=None, post_disconnection=None))]
83 fn py_connect(
84 config: SocketConfig,
85 post_connection: Option<PyObject>,
86 post_reconnection: Option<PyObject>,
87 post_disconnection: Option<PyObject>,
88 py: Python<'_>,
89 ) -> PyResult<Bound<'_, PyAny>> {
90 let post_connection_fn = post_connection.map(|callback| {
92 let callback_clone = clone_py_object(&callback);
93 std::sync::Arc::new(move || {
94 Python::with_gil(|py| {
95 if let Err(e) = callback_clone.call0(py) {
96 tracing::error!("Error calling post_connection handler: {e}");
97 }
98 });
99 }) as std::sync::Arc<dyn Fn() + Send + Sync>
100 });
101
102 let post_reconnection_fn = post_reconnection.map(|callback| {
103 let callback_clone = clone_py_object(&callback);
104 std::sync::Arc::new(move || {
105 Python::with_gil(|py| {
106 if let Err(e) = callback_clone.call0(py) {
107 tracing::error!("Error calling post_reconnection handler: {e}");
108 }
109 });
110 }) as std::sync::Arc<dyn Fn() + Send + Sync>
111 });
112
113 let post_disconnection_fn = post_disconnection.map(|callback| {
114 let callback_clone = clone_py_object(&callback);
115 std::sync::Arc::new(move || {
116 Python::with_gil(|py| {
117 if let Err(e) = callback_clone.call0(py) {
118 tracing::error!("Error calling post_disconnection handler: {e}");
119 }
120 });
121 }) as std::sync::Arc<dyn Fn() + Send + Sync>
122 });
123
124 pyo3_async_runtimes::tokio::future_into_py(py, async move {
125 Self::connect(
126 config,
127 post_connection_fn,
128 post_reconnection_fn,
129 post_disconnection_fn,
130 )
131 .await
132 .map_err(to_pyruntime_err)
133 })
134 }
135
136 #[pyo3(name = "is_active")]
146 fn py_is_active(slf: PyRef<'_, Self>) -> bool {
147 slf.is_active()
148 }
149
150 #[pyo3(name = "is_reconnecting")]
151 fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
152 slf.is_reconnecting()
153 }
154
155 #[pyo3(name = "is_disconnecting")]
156 fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
157 slf.is_disconnecting()
158 }
159
160 #[pyo3(name = "is_closed")]
161 fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
162 slf.is_closed()
163 }
164
165 #[pyo3(name = "mode")]
166 fn py_mode(slf: PyRef<'_, Self>) -> String {
167 slf.connection_mode().to_string()
168 }
169
170 #[pyo3(name = "reconnect")]
172 fn py_reconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
173 let mode = slf.connection_mode.clone();
174 let mode_str = ConnectionMode::from_atomic(&mode).to_string();
175 tracing::debug!("Reconnect from mode {mode_str}");
176
177 pyo3_async_runtimes::tokio::future_into_py(py, async move {
178 match ConnectionMode::from_atomic(&mode) {
179 ConnectionMode::Reconnect => {
180 tracing::warn!("Cannot reconnect - socket already reconnecting");
181 }
182 ConnectionMode::Disconnect => {
183 tracing::warn!("Cannot reconnect - socket disconnecting");
184 }
185 ConnectionMode::Closed => {
186 tracing::warn!("Cannot reconnect - socket closed");
187 }
188 _ => {
189 mode.store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
190 while !ConnectionMode::from_atomic(&mode).is_active() {
191 tokio::time::sleep(Duration::from_millis(10)).await;
192 }
193 }
194 }
195
196 Ok(())
197 })
198 }
199
200 #[pyo3(name = "close")]
210 fn py_close<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
211 let mode = slf.connection_mode.clone();
212 let mode_str = ConnectionMode::from_atomic(&mode).to_string();
213 tracing::debug!("Close from mode {mode_str}");
214
215 pyo3_async_runtimes::tokio::future_into_py(py, async move {
216 match ConnectionMode::from_atomic(&mode) {
217 ConnectionMode::Closed => {
218 tracing::debug!("Socket already closed");
219 }
220 ConnectionMode::Disconnect => {
221 tracing::debug!("Socket already disconnecting");
222 }
223 _ => {
224 mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
225 while !ConnectionMode::from_atomic(&mode).is_closed() {
226 tokio::time::sleep(Duration::from_millis(10)).await;
227 }
228 }
229 }
230
231 Ok(())
232 })
233 }
234
235 #[pyo3(name = "send")]
241 fn py_send<'py>(
242 slf: PyRef<'_, Self>,
243 data: Vec<u8>,
244 py: Python<'py>,
245 ) -> PyResult<Bound<'py, PyAny>> {
246 tracing::trace!("Sending {}", String::from_utf8_lossy(&data));
247
248 let mode = slf.connection_mode.clone();
249 let writer_tx = slf.writer_tx.clone();
250
251 pyo3_async_runtimes::tokio::future_into_py(py, async move {
252 if ConnectionMode::from_atomic(&mode).is_closed() {
253 let msg = format!(
254 "Cannot send data ({}): socket closed",
255 String::from_utf8_lossy(&data)
256 );
257
258 let io_err = std::io::Error::new(std::io::ErrorKind::NotConnected, msg);
259 return Err(to_pyruntime_err(io_err));
260 }
261
262 let timeout = Duration::from_secs(2);
263 let check_interval = Duration::from_millis(1);
264
265 if !ConnectionMode::from_atomic(&mode).is_active() {
266 tracing::debug!("Waiting for client to become ACTIVE before sending (2s)...");
267 match tokio::time::timeout(timeout, async {
268 while !ConnectionMode::from_atomic(&mode).is_active() {
269 if matches!(
270 ConnectionMode::from_atomic(&mode),
271 ConnectionMode::Disconnect | ConnectionMode::Closed
272 ) {
273 return Err("Client disconnected waiting to send");
274 }
275
276 tokio::time::sleep(check_interval).await;
277 }
278
279 Ok(())
280 })
281 .await
282 {
283 Ok(Ok(())) => tracing::debug!("Client now active"),
284 Ok(Err(e)) => {
285 let err_msg = format!(
286 "Failed sending data ({}): {e}",
287 String::from_utf8_lossy(&data)
288 );
289
290 let io_err = std::io::Error::new(std::io::ErrorKind::NotConnected, err_msg);
291 return Err(to_pyruntime_err(io_err));
292 }
293 Err(_) => {
294 let err_msg = format!(
295 "Failed sending data ({}): timeout waiting to become ACTIVE",
296 String::from_utf8_lossy(&data)
297 );
298
299 let io_err = std::io::Error::new(std::io::ErrorKind::TimedOut, err_msg);
300 return Err(to_pyruntime_err(io_err));
301 }
302 }
303 }
304
305 let msg = WriterCommand::Send(data.into());
306 writer_tx.send(msg).map_err(to_pyruntime_err)
307 })
308 }
309}