nautilus_dydx/python/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Python bindings for the dYdX WebSocket client.
17
18use std::{
19    sync::atomic::Ordering,
20    time::{Duration, Instant},
21};
22
23use nautilus_common::live::get_runtime;
24use nautilus_core::python::to_pyvalue_err;
25use nautilus_model::{
26    data::BarType,
27    identifiers::{AccountId, InstrumentId},
28    python::instruments::pyobject_to_instrument_any,
29};
30use nautilus_network::mode::ConnectionMode;
31use pyo3::prelude::*;
32
33use crate::{
34    common::{credential::DydxCredential, parse::extract_raw_symbol},
35    websocket::{client::DydxWebSocketClient, error::DydxWsError, handler::HandlerCommand},
36};
37
38fn to_pyvalue_err_dydx(e: DydxWsError) -> PyErr {
39    pyo3::exceptions::PyValueError::new_err(e.to_string())
40}
41
42#[pymethods]
43impl DydxWebSocketClient {
44    /// Creates a new public WebSocket client for market data.
45    #[staticmethod]
46    #[pyo3(name = "new_public")]
47    fn py_new_public(url: String, heartbeat: Option<u64>) -> Self {
48        Self::new_public(url, heartbeat)
49    }
50
51    /// Creates a new private WebSocket client for account updates.
52    #[staticmethod]
53    #[pyo3(name = "new_private")]
54    fn py_new_private(
55        url: String,
56        mnemonic: String,
57        account_index: u32,
58        authenticator_ids: Vec<u64>,
59        account_id: AccountId,
60        heartbeat: Option<u64>,
61    ) -> PyResult<Self> {
62        let credential = DydxCredential::from_mnemonic(&mnemonic, account_index, authenticator_ids)
63            .map_err(to_pyvalue_err)?;
64        Ok(Self::new_private(url, credential, account_id, heartbeat))
65    }
66
67    /// Returns whether the client is currently connected.
68    #[pyo3(name = "is_connected")]
69    fn py_is_connected(&self) -> bool {
70        self.is_connected()
71    }
72
73    /// Sets the account ID for account message parsing.
74    #[pyo3(name = "set_account_id")]
75    fn py_set_account_id(&mut self, account_id: AccountId) {
76        self.set_account_id(account_id);
77    }
78
79    /// Returns the current account ID if set.
80    #[pyo3(name = "account_id")]
81    fn py_account_id(&self) -> Option<AccountId> {
82        self.account_id()
83    }
84
85    /// Returns the WebSocket URL.
86    #[getter]
87    fn py_url(&self) -> String {
88        self.url().to_string()
89    }
90
91    /// Connects the WebSocket client.
92    #[pyo3(name = "connect")]
93    fn py_connect<'py>(
94        &mut self,
95        py: Python<'py>,
96        instruments: Vec<Py<PyAny>>,
97        callback: Py<PyAny>,
98    ) -> PyResult<Bound<'py, PyAny>> {
99        // Convert Python instruments to Rust InstrumentAny
100        let mut instruments_any = Vec::new();
101        for inst in instruments {
102            let inst_any = pyobject_to_instrument_any(py, inst)?;
103            instruments_any.push(inst_any);
104        }
105
106        // Cache instruments first
107        self.cache_instruments(instruments_any);
108
109        let mut client = self.clone();
110
111        pyo3_async_runtimes::tokio::future_into_py(py, async move {
112            // Connect the WebSocket client
113            client.connect().await.map_err(to_pyvalue_err_dydx)?;
114
115            // Take the receiver for messages
116            if let Some(mut rx) = client.take_receiver() {
117                // Spawn task to process messages and call Python callback
118                get_runtime().spawn(async move {
119                    let _client = client; // Keep client alive in spawned task
120
121                    while let Some(msg) = rx.recv().await {
122                        match msg {
123                            crate::websocket::enums::NautilusWsMessage::Data(items) => {
124                                Python::attach(|py| {
125                                    for data in items {
126                                        use nautilus_model::python::data::data_to_pycapsule;
127                                        let py_obj = data_to_pycapsule(py, data);
128                                        if let Err(e) = callback.call1(py, (py_obj,)) {
129                                            tracing::error!("Error calling Python callback: {e}");
130                                        }
131                                    }
132                                });
133                            }
134                            crate::websocket::enums::NautilusWsMessage::Deltas(deltas) => {
135                                Python::attach(|py| {
136                                    use nautilus_model::{
137                                        data::{Data, OrderBookDeltas_API},
138                                        python::data::data_to_pycapsule,
139                                    };
140                                    let data = Data::Deltas(OrderBookDeltas_API::new(*deltas));
141                                    let py_obj = data_to_pycapsule(py, data);
142                                    if let Err(e) = callback.call1(py, (py_obj,)) {
143                                        tracing::error!("Error calling Python callback: {e}");
144                                    }
145                                });
146                            }
147                            crate::websocket::enums::NautilusWsMessage::Error(err) => {
148                                tracing::error!("dYdX WebSocket error: {err}");
149                            }
150                            crate::websocket::enums::NautilusWsMessage::Reconnected => {
151                                tracing::info!("dYdX WebSocket reconnected");
152                            }
153                            _ => {
154                                // Handle other message types if needed
155                            }
156                        }
157                    }
158                });
159            }
160
161            Ok(())
162        })
163    }
164
165    /// Disconnects the WebSocket client.
166    #[pyo3(name = "disconnect")]
167    fn py_disconnect<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
168        let mut client = self.clone();
169        pyo3_async_runtimes::tokio::future_into_py(py, async move {
170            client.disconnect().await.map_err(to_pyvalue_err_dydx)?;
171            Ok(())
172        })
173    }
174
175    /// Waits until the client is in an active state.
176    #[pyo3(name = "wait_until_active")]
177    fn py_wait_until_active<'py>(
178        &self,
179        py: Python<'py>,
180        timeout_secs: f64,
181    ) -> PyResult<Bound<'py, PyAny>> {
182        let connection_mode = self.connection_mode_atomic();
183
184        pyo3_async_runtimes::tokio::future_into_py(py, async move {
185            let timeout = Duration::from_secs_f64(timeout_secs);
186            let start = Instant::now();
187
188            loop {
189                let mode = connection_mode.load();
190                let mode_u8 = mode.load(Ordering::Relaxed);
191                let is_connected = matches!(
192                    mode_u8,
193                    x if x == ConnectionMode::Active as u8 || x == ConnectionMode::Reconnect as u8
194                );
195
196                if is_connected {
197                    break;
198                }
199
200                if start.elapsed() > timeout {
201                    return Err(to_pyvalue_err(std::io::Error::new(
202                        std::io::ErrorKind::TimedOut,
203                        format!("Client did not become active within {timeout_secs}s"),
204                    )));
205                }
206                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
207            }
208
209            Ok(())
210        })
211    }
212
213    /// Caches a single instrument.
214    #[pyo3(name = "cache_instrument")]
215    fn py_cache_instrument(&self, instrument: Py<PyAny>, py: Python<'_>) -> PyResult<()> {
216        let inst_any = pyobject_to_instrument_any(py, instrument)?;
217        self.cache_instrument(inst_any);
218        Ok(())
219    }
220
221    /// Caches multiple instruments.
222    #[pyo3(name = "cache_instruments")]
223    fn py_cache_instruments(&self, instruments: Vec<Py<PyAny>>, py: Python<'_>) -> PyResult<()> {
224        let mut instruments_any = Vec::new();
225        for inst in instruments {
226            let inst_any = pyobject_to_instrument_any(py, inst)?;
227            instruments_any.push(inst_any);
228        }
229        self.cache_instruments(instruments_any);
230        Ok(())
231    }
232
233    /// Returns whether the client is closed.
234    #[pyo3(name = "is_closed")]
235    fn py_is_closed(&self) -> bool {
236        !self.is_connected()
237    }
238
239    /// Subscribes to public trade updates for a specific instrument.
240    #[pyo3(name = "subscribe_trades")]
241    fn py_subscribe_trades<'py>(
242        &self,
243        py: Python<'py>,
244        instrument_id: InstrumentId,
245    ) -> PyResult<Bound<'py, PyAny>> {
246        let client = self.clone();
247        pyo3_async_runtimes::tokio::future_into_py(py, async move {
248            client
249                .subscribe_trades(instrument_id)
250                .await
251                .map_err(to_pyvalue_err_dydx)?;
252            Ok(())
253        })
254    }
255
256    /// Unsubscribes from public trade updates for a specific instrument.
257    #[pyo3(name = "unsubscribe_trades")]
258    fn py_unsubscribe_trades<'py>(
259        &self,
260        py: Python<'py>,
261        instrument_id: InstrumentId,
262    ) -> PyResult<Bound<'py, PyAny>> {
263        let client = self.clone();
264        pyo3_async_runtimes::tokio::future_into_py(py, async move {
265            client
266                .unsubscribe_trades(instrument_id)
267                .await
268                .map_err(to_pyvalue_err_dydx)?;
269            Ok(())
270        })
271    }
272
273    /// Subscribes to orderbook updates for a specific instrument.
274    #[pyo3(name = "subscribe_orderbook")]
275    fn py_subscribe_orderbook<'py>(
276        &self,
277        py: Python<'py>,
278        instrument_id: InstrumentId,
279    ) -> PyResult<Bound<'py, PyAny>> {
280        let client = self.clone();
281        pyo3_async_runtimes::tokio::future_into_py(py, async move {
282            client
283                .subscribe_orderbook(instrument_id)
284                .await
285                .map_err(to_pyvalue_err_dydx)?;
286            Ok(())
287        })
288    }
289
290    /// Unsubscribes from orderbook updates for a specific instrument.
291    #[pyo3(name = "unsubscribe_orderbook")]
292    fn py_unsubscribe_orderbook<'py>(
293        &self,
294        py: Python<'py>,
295        instrument_id: InstrumentId,
296    ) -> PyResult<Bound<'py, PyAny>> {
297        let client = self.clone();
298        pyo3_async_runtimes::tokio::future_into_py(py, async move {
299            client
300                .unsubscribe_orderbook(instrument_id)
301                .await
302                .map_err(to_pyvalue_err_dydx)?;
303            Ok(())
304        })
305    }
306
307    /// Subscribes to bar updates for a specific instrument.
308    #[pyo3(name = "subscribe_bars")]
309    fn py_subscribe_bars<'py>(
310        &self,
311        py: Python<'py>,
312        bar_type: BarType,
313        resolution: String,
314    ) -> PyResult<Bound<'py, PyAny>> {
315        let client = self.clone();
316        let instrument_id = bar_type.instrument_id();
317
318        // Build topic for bar type registration (e.g., "ETH-USD/1MIN")
319        let ticker = extract_raw_symbol(instrument_id.symbol.as_str());
320        let topic = format!("{ticker}/{resolution}");
321
322        pyo3_async_runtimes::tokio::future_into_py(py, async move {
323            // Register bar type in handler before subscribing
324            client
325                .send_command(HandlerCommand::RegisterBarType { topic, bar_type })
326                .map_err(to_pyvalue_err_dydx)?;
327
328            // Brief delay to ensure handler processes registration
329            tokio::time::sleep(Duration::from_millis(50)).await;
330
331            client
332                .subscribe_candles(instrument_id, &resolution)
333                .await
334                .map_err(to_pyvalue_err_dydx)?;
335            Ok(())
336        })
337    }
338
339    /// Unsubscribes from bar updates for a specific instrument.
340    #[pyo3(name = "unsubscribe_bars")]
341    fn py_unsubscribe_bars<'py>(
342        &self,
343        py: Python<'py>,
344        bar_type: BarType,
345        resolution: String,
346    ) -> PyResult<Bound<'py, PyAny>> {
347        let client = self.clone();
348        let instrument_id = bar_type.instrument_id();
349
350        // Build topic for unregistration
351        let ticker = extract_raw_symbol(instrument_id.symbol.as_str());
352        let topic = format!("{ticker}/{resolution}");
353
354        pyo3_async_runtimes::tokio::future_into_py(py, async move {
355            client
356                .unsubscribe_candles(instrument_id, &resolution)
357                .await
358                .map_err(to_pyvalue_err_dydx)?;
359
360            // Unregister bar type after unsubscribing
361            client
362                .send_command(HandlerCommand::UnregisterBarType { topic })
363                .map_err(to_pyvalue_err_dydx)?;
364
365            Ok(())
366        })
367    }
368
369    /// Subscribes to all markets updates.
370    #[pyo3(name = "subscribe_markets")]
371    fn py_subscribe_markets<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
372        let client = self.clone();
373        pyo3_async_runtimes::tokio::future_into_py(py, async move {
374            client
375                .subscribe_markets()
376                .await
377                .map_err(to_pyvalue_err_dydx)?;
378            Ok(())
379        })
380    }
381
382    /// Unsubscribes from all markets updates.
383    #[pyo3(name = "unsubscribe_markets")]
384    fn py_unsubscribe_markets<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
385        let client = self.clone();
386        pyo3_async_runtimes::tokio::future_into_py(py, async move {
387            client
388                .unsubscribe_markets()
389                .await
390                .map_err(to_pyvalue_err_dydx)?;
391            Ok(())
392        })
393    }
394
395    /// Subscribes to subaccount updates.
396    #[pyo3(name = "subscribe_subaccount")]
397    fn py_subscribe_subaccount<'py>(
398        &self,
399        py: Python<'py>,
400        address: String,
401        subaccount_number: u32,
402    ) -> PyResult<Bound<'py, PyAny>> {
403        let client = self.clone();
404        pyo3_async_runtimes::tokio::future_into_py(py, async move {
405            client
406                .subscribe_subaccount(&address, subaccount_number)
407                .await
408                .map_err(to_pyvalue_err_dydx)?;
409            Ok(())
410        })
411    }
412
413    /// Unsubscribes from subaccount updates.
414    #[pyo3(name = "unsubscribe_subaccount")]
415    fn py_unsubscribe_subaccount<'py>(
416        &self,
417        py: Python<'py>,
418        address: String,
419        subaccount_number: u32,
420    ) -> PyResult<Bound<'py, PyAny>> {
421        let client = self.clone();
422        pyo3_async_runtimes::tokio::future_into_py(py, async move {
423            client
424                .unsubscribe_subaccount(&address, subaccount_number)
425                .await
426                .map_err(to_pyvalue_err_dydx)?;
427            Ok(())
428        })
429    }
430
431    /// Subscribes to block height updates.
432    #[pyo3(name = "subscribe_block_height")]
433    fn py_subscribe_block_height<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
434        let client = self.clone();
435        pyo3_async_runtimes::tokio::future_into_py(py, async move {
436            client
437                .subscribe_block_height()
438                .await
439                .map_err(to_pyvalue_err_dydx)?;
440            Ok(())
441        })
442    }
443
444    /// Unsubscribes from block height updates.
445    #[pyo3(name = "unsubscribe_block_height")]
446    fn py_unsubscribe_block_height<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
447        let client = self.clone();
448        pyo3_async_runtimes::tokio::future_into_py(py, async move {
449            client
450                .unsubscribe_block_height()
451                .await
452                .map_err(to_pyvalue_err_dydx)?;
453            Ok(())
454        })
455    }
456}