nautilus_dydx/python/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Python bindings for the dYdX WebSocket client.
17
18use std::{
19    sync::atomic::Ordering,
20    time::{Duration, Instant},
21};
22
23use nautilus_core::python::to_pyvalue_err;
24use nautilus_model::{
25    identifiers::{AccountId, InstrumentId},
26    python::instruments::pyobject_to_instrument_any,
27};
28use nautilus_network::mode::ConnectionMode;
29use pyo3::prelude::*;
30
31use crate::{
32    common::credential::DydxCredential,
33    websocket::{client::DydxWebSocketClient, error::DydxWsError},
34};
35
36fn to_pyvalue_err_dydx(e: DydxWsError) -> PyErr {
37    pyo3::exceptions::PyValueError::new_err(e.to_string())
38}
39
40#[pymethods]
41impl DydxWebSocketClient {
42    /// Creates a new public WebSocket client for market data.
43    #[staticmethod]
44    #[pyo3(name = "new_public")]
45    fn py_new_public(url: String, heartbeat: Option<u64>) -> Self {
46        Self::new_public(url, heartbeat)
47    }
48
49    /// Creates a new private WebSocket client for account updates.
50    #[staticmethod]
51    #[pyo3(name = "new_private")]
52    fn py_new_private(
53        url: String,
54        mnemonic: String,
55        account_index: u32,
56        authenticator_ids: Vec<u64>,
57        account_id: AccountId,
58        heartbeat: Option<u64>,
59    ) -> PyResult<Self> {
60        let credential = DydxCredential::from_mnemonic(&mnemonic, account_index, authenticator_ids)
61            .map_err(to_pyvalue_err)?;
62        Ok(Self::new_private(url, credential, account_id, heartbeat))
63    }
64
65    /// Returns whether the client is currently connected.
66    #[pyo3(name = "is_connected")]
67    fn py_is_connected(&self) -> bool {
68        self.is_connected()
69    }
70
71    /// Sets the account ID for account message parsing.
72    #[pyo3(name = "set_account_id")]
73    fn py_set_account_id(&mut self, account_id: AccountId) {
74        self.set_account_id(account_id);
75    }
76
77    /// Returns the current account ID if set.
78    #[pyo3(name = "account_id")]
79    fn py_account_id(&self) -> Option<AccountId> {
80        self.account_id()
81    }
82
83    /// Returns the WebSocket URL.
84    #[getter]
85    fn py_url(&self) -> String {
86        self.url().to_string()
87    }
88
89    /// Connects the WebSocket client.
90    #[pyo3(name = "connect")]
91    fn py_connect<'py>(
92        &mut self,
93        py: Python<'py>,
94        instruments: Vec<Py<PyAny>>,
95        callback: Py<PyAny>,
96    ) -> PyResult<Bound<'py, PyAny>> {
97        // Convert Python instruments to Rust InstrumentAny
98        let mut instruments_any = Vec::new();
99        for inst in instruments {
100            let inst_any = pyobject_to_instrument_any(py, inst)?;
101            instruments_any.push(inst_any);
102        }
103
104        // Cache instruments first
105        self.cache_instruments(instruments_any);
106
107        let mut client = self.clone();
108
109        pyo3_async_runtimes::tokio::future_into_py(py, async move {
110            // Connect the WebSocket client
111            client.connect().await.map_err(to_pyvalue_err_dydx)?;
112
113            // Take the receiver for messages
114            if let Some(mut rx) = client.take_receiver() {
115                // Spawn task to process messages and call Python callback
116                tokio::spawn(async move {
117                    let _client = client; // Keep client alive in spawned task
118
119                    while let Some(msg) = rx.recv().await {
120                        match msg {
121                            crate::websocket::messages::NautilusWsMessage::Data(items) => {
122                                Python::attach(|py| {
123                                    for data in items {
124                                        use nautilus_model::python::data::data_to_pycapsule;
125                                        let py_obj = data_to_pycapsule(py, data);
126                                        if let Err(e) = callback.call1(py, (py_obj,)) {
127                                            tracing::error!("Error calling Python callback: {e}");
128                                        }
129                                    }
130                                });
131                            }
132                            crate::websocket::messages::NautilusWsMessage::Deltas(deltas) => {
133                                Python::attach(|py| {
134                                    use nautilus_model::{
135                                        data::{Data, OrderBookDeltas_API},
136                                        python::data::data_to_pycapsule,
137                                    };
138                                    let data = Data::Deltas(OrderBookDeltas_API::new(*deltas));
139                                    let py_obj = data_to_pycapsule(py, data);
140                                    if let Err(e) = callback.call1(py, (py_obj,)) {
141                                        tracing::error!("Error calling Python callback: {e}");
142                                    }
143                                });
144                            }
145                            crate::websocket::messages::NautilusWsMessage::Error(err) => {
146                                tracing::error!("dYdX WebSocket error: {err}");
147                            }
148                            crate::websocket::messages::NautilusWsMessage::Reconnected => {
149                                tracing::info!("dYdX WebSocket reconnected");
150                            }
151                            _ => {
152                                // Handle other message types if needed
153                            }
154                        }
155                    }
156                });
157            }
158
159            Ok(())
160        })
161    }
162
163    /// Disconnects the WebSocket client.
164    #[pyo3(name = "disconnect")]
165    fn py_disconnect<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
166        let mut client = self.clone();
167        pyo3_async_runtimes::tokio::future_into_py(py, async move {
168            client.disconnect().await.map_err(to_pyvalue_err_dydx)?;
169            Ok(())
170        })
171    }
172
173    /// Waits until the client is in an active state.
174    #[pyo3(name = "wait_until_active")]
175    fn py_wait_until_active<'py>(
176        &self,
177        py: Python<'py>,
178        timeout_secs: f64,
179    ) -> PyResult<Bound<'py, PyAny>> {
180        let connection_mode = self.connection_mode_atomic();
181
182        pyo3_async_runtimes::tokio::future_into_py(py, async move {
183            let timeout = Duration::from_secs_f64(timeout_secs);
184            let start = Instant::now();
185
186            loop {
187                let mode = connection_mode.load();
188                let mode_u8 = mode.load(Ordering::Relaxed);
189                let is_connected = matches!(
190                    mode_u8,
191                    x if x == ConnectionMode::Active as u8 || x == ConnectionMode::Reconnect as u8
192                );
193
194                if is_connected {
195                    break;
196                }
197
198                if start.elapsed() > timeout {
199                    return Err(to_pyvalue_err(std::io::Error::new(
200                        std::io::ErrorKind::TimedOut,
201                        format!("Client did not become active within {timeout_secs}s"),
202                    )));
203                }
204                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
205            }
206
207            Ok(())
208        })
209    }
210
211    /// Caches a single instrument.
212    #[pyo3(name = "cache_instrument")]
213    fn py_cache_instrument(&self, instrument: Py<PyAny>, py: Python<'_>) -> PyResult<()> {
214        let inst_any = pyobject_to_instrument_any(py, instrument)?;
215        self.cache_instrument(inst_any);
216        Ok(())
217    }
218
219    /// Caches multiple instruments.
220    #[pyo3(name = "cache_instruments")]
221    fn py_cache_instruments(&self, instruments: Vec<Py<PyAny>>, py: Python<'_>) -> PyResult<()> {
222        let mut instruments_any = Vec::new();
223        for inst in instruments {
224            let inst_any = pyobject_to_instrument_any(py, inst)?;
225            instruments_any.push(inst_any);
226        }
227        self.cache_instruments(instruments_any);
228        Ok(())
229    }
230
231    /// Returns whether the client is closed.
232    #[pyo3(name = "is_closed")]
233    fn py_is_closed(&self) -> bool {
234        !self.is_connected()
235    }
236
237    /// Subscribes to public trade updates for a specific instrument.
238    #[pyo3(name = "subscribe_trades")]
239    fn py_subscribe_trades<'py>(
240        &self,
241        py: Python<'py>,
242        instrument_id: InstrumentId,
243    ) -> PyResult<Bound<'py, PyAny>> {
244        let client = self.clone();
245        pyo3_async_runtimes::tokio::future_into_py(py, async move {
246            client
247                .subscribe_trades(instrument_id)
248                .await
249                .map_err(to_pyvalue_err_dydx)?;
250            Ok(())
251        })
252    }
253
254    /// Unsubscribes from public trade updates for a specific instrument.
255    #[pyo3(name = "unsubscribe_trades")]
256    fn py_unsubscribe_trades<'py>(
257        &self,
258        py: Python<'py>,
259        instrument_id: InstrumentId,
260    ) -> PyResult<Bound<'py, PyAny>> {
261        let client = self.clone();
262        pyo3_async_runtimes::tokio::future_into_py(py, async move {
263            client
264                .unsubscribe_trades(instrument_id)
265                .await
266                .map_err(to_pyvalue_err_dydx)?;
267            Ok(())
268        })
269    }
270
271    /// Subscribes to orderbook updates for a specific instrument.
272    #[pyo3(name = "subscribe_orderbook")]
273    fn py_subscribe_orderbook<'py>(
274        &self,
275        py: Python<'py>,
276        instrument_id: InstrumentId,
277    ) -> PyResult<Bound<'py, PyAny>> {
278        let client = self.clone();
279        pyo3_async_runtimes::tokio::future_into_py(py, async move {
280            client
281                .subscribe_orderbook(instrument_id)
282                .await
283                .map_err(to_pyvalue_err_dydx)?;
284            Ok(())
285        })
286    }
287
288    /// Unsubscribes from orderbook updates for a specific instrument.
289    #[pyo3(name = "unsubscribe_orderbook")]
290    fn py_unsubscribe_orderbook<'py>(
291        &self,
292        py: Python<'py>,
293        instrument_id: InstrumentId,
294    ) -> PyResult<Bound<'py, PyAny>> {
295        let client = self.clone();
296        pyo3_async_runtimes::tokio::future_into_py(py, async move {
297            client
298                .unsubscribe_orderbook(instrument_id)
299                .await
300                .map_err(to_pyvalue_err_dydx)?;
301            Ok(())
302        })
303    }
304
305    /// Subscribes to bar updates for a specific instrument.
306    #[pyo3(name = "subscribe_bars")]
307    fn py_subscribe_bars<'py>(
308        &self,
309        py: Python<'py>,
310        instrument_id: InstrumentId,
311        resolution: String,
312    ) -> PyResult<Bound<'py, PyAny>> {
313        let client = self.clone();
314        pyo3_async_runtimes::tokio::future_into_py(py, async move {
315            client
316                .subscribe_candles(instrument_id, &resolution)
317                .await
318                .map_err(to_pyvalue_err_dydx)?;
319            Ok(())
320        })
321    }
322
323    /// Unsubscribes from bar updates for a specific instrument.
324    #[pyo3(name = "unsubscribe_bars")]
325    fn py_unsubscribe_bars<'py>(
326        &self,
327        py: Python<'py>,
328        instrument_id: InstrumentId,
329        resolution: String,
330    ) -> PyResult<Bound<'py, PyAny>> {
331        let client = self.clone();
332        pyo3_async_runtimes::tokio::future_into_py(py, async move {
333            client
334                .unsubscribe_candles(instrument_id, &resolution)
335                .await
336                .map_err(to_pyvalue_err_dydx)?;
337            Ok(())
338        })
339    }
340
341    /// Subscribes to all markets updates.
342    #[pyo3(name = "subscribe_markets")]
343    fn py_subscribe_markets<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
344        let client = self.clone();
345        pyo3_async_runtimes::tokio::future_into_py(py, async move {
346            client
347                .subscribe_markets()
348                .await
349                .map_err(to_pyvalue_err_dydx)?;
350            Ok(())
351        })
352    }
353
354    /// Unsubscribes from all markets updates.
355    #[pyo3(name = "unsubscribe_markets")]
356    fn py_unsubscribe_markets<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
357        let client = self.clone();
358        pyo3_async_runtimes::tokio::future_into_py(py, async move {
359            client
360                .unsubscribe_markets()
361                .await
362                .map_err(to_pyvalue_err_dydx)?;
363            Ok(())
364        })
365    }
366
367    /// Subscribes to subaccount updates.
368    #[pyo3(name = "subscribe_subaccount")]
369    fn py_subscribe_subaccount<'py>(
370        &self,
371        py: Python<'py>,
372        address: String,
373        subaccount_number: u32,
374    ) -> PyResult<Bound<'py, PyAny>> {
375        let client = self.clone();
376        pyo3_async_runtimes::tokio::future_into_py(py, async move {
377            client
378                .subscribe_subaccount(&address, subaccount_number)
379                .await
380                .map_err(to_pyvalue_err_dydx)?;
381            Ok(())
382        })
383    }
384
385    /// Unsubscribes from subaccount updates.
386    #[pyo3(name = "unsubscribe_subaccount")]
387    fn py_unsubscribe_subaccount<'py>(
388        &self,
389        py: Python<'py>,
390        address: String,
391        subaccount_number: u32,
392    ) -> PyResult<Bound<'py, PyAny>> {
393        let client = self.clone();
394        pyo3_async_runtimes::tokio::future_into_py(py, async move {
395            client
396                .unsubscribe_subaccount(&address, subaccount_number)
397                .await
398                .map_err(to_pyvalue_err_dydx)?;
399            Ok(())
400        })
401    }
402
403    /// Subscribes to block height updates.
404    #[pyo3(name = "subscribe_block_height")]
405    fn py_subscribe_block_height<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
406        let client = self.clone();
407        pyo3_async_runtimes::tokio::future_into_py(py, async move {
408            client
409                .subscribe_block_height()
410                .await
411                .map_err(to_pyvalue_err_dydx)?;
412            Ok(())
413        })
414    }
415
416    /// Unsubscribes from block height updates.
417    #[pyo3(name = "unsubscribe_block_height")]
418    fn py_unsubscribe_block_height<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
419        let client = self.clone();
420        pyo3_async_runtimes::tokio::future_into_py(py, async move {
421            client
422                .unsubscribe_block_height()
423                .await
424                .map_err(to_pyvalue_err_dydx)?;
425            Ok(())
426        })
427    }
428}