nautilus_dydx/python/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Python bindings for the dYdX WebSocket client.
17
18use std::{
19    sync::atomic::Ordering,
20    time::{Duration, Instant},
21};
22
23use nautilus_core::python::to_pyvalue_err;
24use nautilus_model::{
25    data::BarType,
26    identifiers::{AccountId, InstrumentId},
27    python::instruments::pyobject_to_instrument_any,
28};
29use nautilus_network::mode::ConnectionMode;
30use pyo3::prelude::*;
31
32use crate::{
33    common::{credential::DydxCredential, parse::extract_raw_symbol},
34    websocket::{client::DydxWebSocketClient, error::DydxWsError, handler::HandlerCommand},
35};
36
37fn to_pyvalue_err_dydx(e: DydxWsError) -> PyErr {
38    pyo3::exceptions::PyValueError::new_err(e.to_string())
39}
40
41#[pymethods]
42impl DydxWebSocketClient {
43    /// Creates a new public WebSocket client for market data.
44    #[staticmethod]
45    #[pyo3(name = "new_public")]
46    fn py_new_public(url: String, heartbeat: Option<u64>) -> Self {
47        Self::new_public(url, heartbeat)
48    }
49
50    /// Creates a new private WebSocket client for account updates.
51    #[staticmethod]
52    #[pyo3(name = "new_private")]
53    fn py_new_private(
54        url: String,
55        mnemonic: String,
56        account_index: u32,
57        authenticator_ids: Vec<u64>,
58        account_id: AccountId,
59        heartbeat: Option<u64>,
60    ) -> PyResult<Self> {
61        let credential = DydxCredential::from_mnemonic(&mnemonic, account_index, authenticator_ids)
62            .map_err(to_pyvalue_err)?;
63        Ok(Self::new_private(url, credential, account_id, heartbeat))
64    }
65
66    /// Returns whether the client is currently connected.
67    #[pyo3(name = "is_connected")]
68    fn py_is_connected(&self) -> bool {
69        self.is_connected()
70    }
71
72    /// Sets the account ID for account message parsing.
73    #[pyo3(name = "set_account_id")]
74    fn py_set_account_id(&mut self, account_id: AccountId) {
75        self.set_account_id(account_id);
76    }
77
78    /// Returns the current account ID if set.
79    #[pyo3(name = "account_id")]
80    fn py_account_id(&self) -> Option<AccountId> {
81        self.account_id()
82    }
83
84    /// Returns the WebSocket URL.
85    #[getter]
86    fn py_url(&self) -> String {
87        self.url().to_string()
88    }
89
90    /// Connects the WebSocket client.
91    #[pyo3(name = "connect")]
92    fn py_connect<'py>(
93        &mut self,
94        py: Python<'py>,
95        instruments: Vec<Py<PyAny>>,
96        callback: Py<PyAny>,
97    ) -> PyResult<Bound<'py, PyAny>> {
98        // Convert Python instruments to Rust InstrumentAny
99        let mut instruments_any = Vec::new();
100        for inst in instruments {
101            let inst_any = pyobject_to_instrument_any(py, inst)?;
102            instruments_any.push(inst_any);
103        }
104
105        // Cache instruments first
106        self.cache_instruments(instruments_any);
107
108        let mut client = self.clone();
109
110        pyo3_async_runtimes::tokio::future_into_py(py, async move {
111            // Connect the WebSocket client
112            client.connect().await.map_err(to_pyvalue_err_dydx)?;
113
114            // Take the receiver for messages
115            if let Some(mut rx) = client.take_receiver() {
116                // Spawn task to process messages and call Python callback
117                tokio::spawn(async move {
118                    let _client = client; // Keep client alive in spawned task
119
120                    while let Some(msg) = rx.recv().await {
121                        match msg {
122                            crate::websocket::enums::NautilusWsMessage::Data(items) => {
123                                Python::attach(|py| {
124                                    for data in items {
125                                        use nautilus_model::python::data::data_to_pycapsule;
126                                        let py_obj = data_to_pycapsule(py, data);
127                                        if let Err(e) = callback.call1(py, (py_obj,)) {
128                                            tracing::error!("Error calling Python callback: {e}");
129                                        }
130                                    }
131                                });
132                            }
133                            crate::websocket::enums::NautilusWsMessage::Deltas(deltas) => {
134                                Python::attach(|py| {
135                                    use nautilus_model::{
136                                        data::{Data, OrderBookDeltas_API},
137                                        python::data::data_to_pycapsule,
138                                    };
139                                    let data = Data::Deltas(OrderBookDeltas_API::new(*deltas));
140                                    let py_obj = data_to_pycapsule(py, data);
141                                    if let Err(e) = callback.call1(py, (py_obj,)) {
142                                        tracing::error!("Error calling Python callback: {e}");
143                                    }
144                                });
145                            }
146                            crate::websocket::enums::NautilusWsMessage::Error(err) => {
147                                tracing::error!("dYdX WebSocket error: {err}");
148                            }
149                            crate::websocket::enums::NautilusWsMessage::Reconnected => {
150                                tracing::info!("dYdX WebSocket reconnected");
151                            }
152                            _ => {
153                                // Handle other message types if needed
154                            }
155                        }
156                    }
157                });
158            }
159
160            Ok(())
161        })
162    }
163
164    /// Disconnects the WebSocket client.
165    #[pyo3(name = "disconnect")]
166    fn py_disconnect<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
167        let mut client = self.clone();
168        pyo3_async_runtimes::tokio::future_into_py(py, async move {
169            client.disconnect().await.map_err(to_pyvalue_err_dydx)?;
170            Ok(())
171        })
172    }
173
174    /// Waits until the client is in an active state.
175    #[pyo3(name = "wait_until_active")]
176    fn py_wait_until_active<'py>(
177        &self,
178        py: Python<'py>,
179        timeout_secs: f64,
180    ) -> PyResult<Bound<'py, PyAny>> {
181        let connection_mode = self.connection_mode_atomic();
182
183        pyo3_async_runtimes::tokio::future_into_py(py, async move {
184            let timeout = Duration::from_secs_f64(timeout_secs);
185            let start = Instant::now();
186
187            loop {
188                let mode = connection_mode.load();
189                let mode_u8 = mode.load(Ordering::Relaxed);
190                let is_connected = matches!(
191                    mode_u8,
192                    x if x == ConnectionMode::Active as u8 || x == ConnectionMode::Reconnect as u8
193                );
194
195                if is_connected {
196                    break;
197                }
198
199                if start.elapsed() > timeout {
200                    return Err(to_pyvalue_err(std::io::Error::new(
201                        std::io::ErrorKind::TimedOut,
202                        format!("Client did not become active within {timeout_secs}s"),
203                    )));
204                }
205                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
206            }
207
208            Ok(())
209        })
210    }
211
212    /// Caches a single instrument.
213    #[pyo3(name = "cache_instrument")]
214    fn py_cache_instrument(&self, instrument: Py<PyAny>, py: Python<'_>) -> PyResult<()> {
215        let inst_any = pyobject_to_instrument_any(py, instrument)?;
216        self.cache_instrument(inst_any);
217        Ok(())
218    }
219
220    /// Caches multiple instruments.
221    #[pyo3(name = "cache_instruments")]
222    fn py_cache_instruments(&self, instruments: Vec<Py<PyAny>>, py: Python<'_>) -> PyResult<()> {
223        let mut instruments_any = Vec::new();
224        for inst in instruments {
225            let inst_any = pyobject_to_instrument_any(py, inst)?;
226            instruments_any.push(inst_any);
227        }
228        self.cache_instruments(instruments_any);
229        Ok(())
230    }
231
232    /// Returns whether the client is closed.
233    #[pyo3(name = "is_closed")]
234    fn py_is_closed(&self) -> bool {
235        !self.is_connected()
236    }
237
238    /// Subscribes to public trade updates for a specific instrument.
239    #[pyo3(name = "subscribe_trades")]
240    fn py_subscribe_trades<'py>(
241        &self,
242        py: Python<'py>,
243        instrument_id: InstrumentId,
244    ) -> PyResult<Bound<'py, PyAny>> {
245        let client = self.clone();
246        pyo3_async_runtimes::tokio::future_into_py(py, async move {
247            client
248                .subscribe_trades(instrument_id)
249                .await
250                .map_err(to_pyvalue_err_dydx)?;
251            Ok(())
252        })
253    }
254
255    /// Unsubscribes from public trade updates for a specific instrument.
256    #[pyo3(name = "unsubscribe_trades")]
257    fn py_unsubscribe_trades<'py>(
258        &self,
259        py: Python<'py>,
260        instrument_id: InstrumentId,
261    ) -> PyResult<Bound<'py, PyAny>> {
262        let client = self.clone();
263        pyo3_async_runtimes::tokio::future_into_py(py, async move {
264            client
265                .unsubscribe_trades(instrument_id)
266                .await
267                .map_err(to_pyvalue_err_dydx)?;
268            Ok(())
269        })
270    }
271
272    /// Subscribes to orderbook updates for a specific instrument.
273    #[pyo3(name = "subscribe_orderbook")]
274    fn py_subscribe_orderbook<'py>(
275        &self,
276        py: Python<'py>,
277        instrument_id: InstrumentId,
278    ) -> PyResult<Bound<'py, PyAny>> {
279        let client = self.clone();
280        pyo3_async_runtimes::tokio::future_into_py(py, async move {
281            client
282                .subscribe_orderbook(instrument_id)
283                .await
284                .map_err(to_pyvalue_err_dydx)?;
285            Ok(())
286        })
287    }
288
289    /// Unsubscribes from orderbook updates for a specific instrument.
290    #[pyo3(name = "unsubscribe_orderbook")]
291    fn py_unsubscribe_orderbook<'py>(
292        &self,
293        py: Python<'py>,
294        instrument_id: InstrumentId,
295    ) -> PyResult<Bound<'py, PyAny>> {
296        let client = self.clone();
297        pyo3_async_runtimes::tokio::future_into_py(py, async move {
298            client
299                .unsubscribe_orderbook(instrument_id)
300                .await
301                .map_err(to_pyvalue_err_dydx)?;
302            Ok(())
303        })
304    }
305
306    /// Subscribes to bar updates for a specific instrument.
307    #[pyo3(name = "subscribe_bars")]
308    fn py_subscribe_bars<'py>(
309        &self,
310        py: Python<'py>,
311        bar_type: BarType,
312        resolution: String,
313    ) -> PyResult<Bound<'py, PyAny>> {
314        let client = self.clone();
315        let instrument_id = bar_type.instrument_id();
316
317        // Build topic for bar type registration (e.g., "ETH-USD/1MIN")
318        let ticker = extract_raw_symbol(instrument_id.symbol.as_str());
319        let topic = format!("{ticker}/{resolution}");
320
321        pyo3_async_runtimes::tokio::future_into_py(py, async move {
322            // Register bar type in handler before subscribing
323            client
324                .send_command(HandlerCommand::RegisterBarType { topic, bar_type })
325                .map_err(to_pyvalue_err_dydx)?;
326
327            // Brief delay to ensure handler processes registration
328            tokio::time::sleep(Duration::from_millis(50)).await;
329
330            client
331                .subscribe_candles(instrument_id, &resolution)
332                .await
333                .map_err(to_pyvalue_err_dydx)?;
334            Ok(())
335        })
336    }
337
338    /// Unsubscribes from bar updates for a specific instrument.
339    #[pyo3(name = "unsubscribe_bars")]
340    fn py_unsubscribe_bars<'py>(
341        &self,
342        py: Python<'py>,
343        bar_type: BarType,
344        resolution: String,
345    ) -> PyResult<Bound<'py, PyAny>> {
346        let client = self.clone();
347        let instrument_id = bar_type.instrument_id();
348
349        // Build topic for unregistration
350        let ticker = extract_raw_symbol(instrument_id.symbol.as_str());
351        let topic = format!("{ticker}/{resolution}");
352
353        pyo3_async_runtimes::tokio::future_into_py(py, async move {
354            client
355                .unsubscribe_candles(instrument_id, &resolution)
356                .await
357                .map_err(to_pyvalue_err_dydx)?;
358
359            // Unregister bar type after unsubscribing
360            client
361                .send_command(HandlerCommand::UnregisterBarType { topic })
362                .map_err(to_pyvalue_err_dydx)?;
363
364            Ok(())
365        })
366    }
367
368    /// Subscribes to all markets updates.
369    #[pyo3(name = "subscribe_markets")]
370    fn py_subscribe_markets<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
371        let client = self.clone();
372        pyo3_async_runtimes::tokio::future_into_py(py, async move {
373            client
374                .subscribe_markets()
375                .await
376                .map_err(to_pyvalue_err_dydx)?;
377            Ok(())
378        })
379    }
380
381    /// Unsubscribes from all markets updates.
382    #[pyo3(name = "unsubscribe_markets")]
383    fn py_unsubscribe_markets<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
384        let client = self.clone();
385        pyo3_async_runtimes::tokio::future_into_py(py, async move {
386            client
387                .unsubscribe_markets()
388                .await
389                .map_err(to_pyvalue_err_dydx)?;
390            Ok(())
391        })
392    }
393
394    /// Subscribes to subaccount updates.
395    #[pyo3(name = "subscribe_subaccount")]
396    fn py_subscribe_subaccount<'py>(
397        &self,
398        py: Python<'py>,
399        address: String,
400        subaccount_number: u32,
401    ) -> PyResult<Bound<'py, PyAny>> {
402        let client = self.clone();
403        pyo3_async_runtimes::tokio::future_into_py(py, async move {
404            client
405                .subscribe_subaccount(&address, subaccount_number)
406                .await
407                .map_err(to_pyvalue_err_dydx)?;
408            Ok(())
409        })
410    }
411
412    /// Unsubscribes from subaccount updates.
413    #[pyo3(name = "unsubscribe_subaccount")]
414    fn py_unsubscribe_subaccount<'py>(
415        &self,
416        py: Python<'py>,
417        address: String,
418        subaccount_number: u32,
419    ) -> PyResult<Bound<'py, PyAny>> {
420        let client = self.clone();
421        pyo3_async_runtimes::tokio::future_into_py(py, async move {
422            client
423                .unsubscribe_subaccount(&address, subaccount_number)
424                .await
425                .map_err(to_pyvalue_err_dydx)?;
426            Ok(())
427        })
428    }
429
430    /// Subscribes to block height updates.
431    #[pyo3(name = "subscribe_block_height")]
432    fn py_subscribe_block_height<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
433        let client = self.clone();
434        pyo3_async_runtimes::tokio::future_into_py(py, async move {
435            client
436                .subscribe_block_height()
437                .await
438                .map_err(to_pyvalue_err_dydx)?;
439            Ok(())
440        })
441    }
442
443    /// Unsubscribes from block height updates.
444    #[pyo3(name = "unsubscribe_block_height")]
445    fn py_unsubscribe_block_height<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
446        let client = self.clone();
447        pyo3_async_runtimes::tokio::future_into_py(py, async move {
448            client
449                .unsubscribe_block_height()
450                .await
451                .map_err(to_pyvalue_err_dydx)?;
452            Ok(())
453        })
454    }
455}