nautilus_tardis/python/
machine.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, path::Path, sync::Arc};
17
18use ahash::AHashMap;
19use futures_util::{Stream, StreamExt, pin_mut};
20use nautilus_core::python::{IntoPyObjectNautilusExt, to_pyruntime_err};
21use nautilus_model::{
22    data::{Bar, Data, funding::FundingRateUpdate},
23    identifiers::InstrumentId,
24    python::data::data_to_pycapsule,
25};
26use pyo3::{prelude::*, types::PyList};
27
28use crate::{
29    machine::{
30        Error,
31        client::{TardisMachineClient, determine_instrument_info},
32        message::WsMessage,
33        parse::{parse_tardis_ws_message, parse_tardis_ws_message_funding_rate},
34        replay_normalized, stream_normalized,
35        types::{
36            ReplayNormalizedRequestOptions, StreamNormalizedRequestOptions, TardisInstrumentKey,
37            TardisInstrumentMiniInfo,
38        },
39    },
40    replay::run_tardis_machine_replay_from_config,
41};
42
43#[pymethods]
44impl ReplayNormalizedRequestOptions {
45    #[staticmethod]
46    #[pyo3(name = "from_json")]
47    fn py_from_json(data: Vec<u8>) -> Self {
48        serde_json::from_slice(&data).expect("Failed to parse JSON")
49    }
50
51    #[pyo3(name = "from_json_array")]
52    #[staticmethod]
53    fn py_from_json_array(data: Vec<u8>) -> Vec<Self> {
54        serde_json::from_slice(&data).expect("Failed to parse JSON array")
55    }
56}
57
58#[pymethods]
59impl StreamNormalizedRequestOptions {
60    #[staticmethod]
61    #[pyo3(name = "from_json")]
62    fn py_from_json(data: Vec<u8>) -> Self {
63        serde_json::from_slice(&data).expect("Failed to parse JSON")
64    }
65
66    #[pyo3(name = "from_json_array")]
67    #[staticmethod]
68    fn py_from_json_array(data: Vec<u8>) -> Vec<Self> {
69        serde_json::from_slice(&data).expect("Failed to parse JSON array")
70    }
71}
72
73#[pymethods]
74impl TardisMachineClient {
75    #[new]
76    #[pyo3(signature = (base_url=None, normalize_symbols=true))]
77    fn py_new(base_url: Option<&str>, normalize_symbols: bool) -> PyResult<Self> {
78        Self::new(base_url, normalize_symbols).map_err(to_pyruntime_err)
79    }
80
81    #[pyo3(name = "is_closed")]
82    #[must_use]
83    pub fn py_is_closed(&self) -> bool {
84        self.is_closed()
85    }
86
87    #[pyo3(name = "close")]
88    fn py_close(&mut self) {
89        self.close();
90    }
91
92    #[pyo3(name = "replay")]
93    fn py_replay<'py>(
94        &self,
95        instruments: Vec<TardisInstrumentMiniInfo>,
96        options: Vec<ReplayNormalizedRequestOptions>,
97        callback: PyObject,
98        py: Python<'py>,
99    ) -> PyResult<Bound<'py, PyAny>> {
100        let map = if instruments.is_empty() {
101            self.instruments.clone()
102        } else {
103            let mut instrument_map: HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>> =
104                HashMap::new();
105            for inst in instruments {
106                let key = inst.as_tardis_instrument_key();
107                instrument_map.insert(key, Arc::new(inst.clone()));
108            }
109            instrument_map
110        };
111
112        let base_url = self.base_url.clone();
113        let replay_signal = self.replay_signal.clone();
114
115        pyo3_async_runtimes::tokio::future_into_py(py, async move {
116            let stream = replay_normalized(&base_url, options, replay_signal)
117                .await
118                .map_err(to_pyruntime_err)?;
119
120            // We use Box::pin to heap-allocate the stream and ensure it implements
121            // Unpin for safe async handling across lifetimes.
122            handle_python_stream(Box::pin(stream), callback, None, Some(map)).await;
123            Ok(())
124        })
125    }
126
127    #[pyo3(name = "replay_bars")]
128    fn py_replay_bars<'py>(
129        &self,
130        instruments: Vec<TardisInstrumentMiniInfo>,
131        options: Vec<ReplayNormalizedRequestOptions>,
132        py: Python<'py>,
133    ) -> PyResult<Bound<'py, PyAny>> {
134        let map = if instruments.is_empty() {
135            self.instruments.clone()
136        } else {
137            instruments
138                .into_iter()
139                .map(|inst| (inst.as_tardis_instrument_key(), Arc::new(inst)))
140                .collect()
141        };
142
143        let base_url = self.base_url.clone();
144        let replay_signal = self.replay_signal.clone();
145
146        pyo3_async_runtimes::tokio::future_into_py(py, async move {
147            let stream = replay_normalized(&base_url, options, replay_signal)
148                .await
149                .map_err(to_pyruntime_err)?;
150
151            // We use Box::pin to heap-allocate the stream and ensure it implements
152            // Unpin for safe async handling across lifetimes.
153            pin_mut!(stream);
154
155            let mut bars: Vec<Bar> = Vec::new();
156
157            while let Some(result) = stream.next().await {
158                match result {
159                    Ok(msg) => {
160                        if let Some(Data::Bar(bar)) = determine_instrument_info(&msg, &map)
161                            .and_then(|info| parse_tardis_ws_message(msg, info))
162                        {
163                            bars.push(bar);
164                        }
165                    }
166                    Err(e) => {
167                        tracing::error!("Error in WebSocket stream: {e:?}");
168                        break;
169                    }
170                }
171            }
172
173            Python::with_gil(|py| {
174                let pylist =
175                    PyList::new(py, bars.into_iter().map(|bar| bar.into_py_any_unwrap(py)))
176                        .expect("Invalid `ExactSizeIterator`");
177                Ok(pylist.into_py_any_unwrap(py))
178            })
179        })
180    }
181
182    #[pyo3(name = "stream")]
183    fn py_stream<'py>(
184        &self,
185        instruments: Vec<TardisInstrumentMiniInfo>,
186        options: Vec<StreamNormalizedRequestOptions>,
187        callback: PyObject,
188        py: Python<'py>,
189    ) -> PyResult<Bound<'py, PyAny>> {
190        let mut instrument_map: HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>> =
191            HashMap::new();
192        for inst in instruments {
193            let key = inst.as_tardis_instrument_key();
194            instrument_map.insert(key, Arc::new(inst.clone()));
195        }
196
197        let base_url = self.base_url.clone();
198        let replay_signal = self.replay_signal.clone();
199
200        pyo3_async_runtimes::tokio::future_into_py(py, async move {
201            let stream = stream_normalized(&base_url, options, replay_signal)
202                .await
203                .map_err(to_pyruntime_err)?;
204
205            // We use Box::pin to heap-allocate the stream and ensure it implements
206            // Unpin for safe async handling across lifetimes.
207            handle_python_stream(Box::pin(stream), callback, None, Some(instrument_map)).await;
208            Ok(())
209        })
210    }
211}
212
213/// Run the Tardis Machine replay as an async Python future.
214///
215/// # Errors
216///
217/// Returns a `PyErr` if reading the config file or replay execution fails.
218#[pyfunction]
219#[pyo3(name = "run_tardis_machine_replay")]
220#[pyo3(signature = (config_filepath))]
221pub fn py_run_tardis_machine_replay(
222    py: Python<'_>,
223    config_filepath: String,
224) -> PyResult<Bound<'_, PyAny>> {
225    tracing_subscriber::fmt()
226        .with_max_level(tracing::Level::DEBUG)
227        .init();
228
229    pyo3_async_runtimes::tokio::future_into_py(py, async move {
230        let config_filepath = Path::new(&config_filepath);
231        run_tardis_machine_replay_from_config(config_filepath)
232            .await
233            .map_err(to_pyruntime_err)?;
234        Ok(())
235    })
236}
237
238async fn handle_python_stream<S>(
239    stream: S,
240    callback: PyObject,
241    instrument: Option<Arc<TardisInstrumentMiniInfo>>,
242    instrument_map: Option<HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>>>,
243) where
244    S: Stream<Item = Result<WsMessage, Error>> + Unpin,
245{
246    pin_mut!(stream);
247
248    // Cache for funding rates to avoid duplicate emissions
249    let mut funding_rate_cache: AHashMap<InstrumentId, FundingRateUpdate> = AHashMap::new();
250
251    while let Some(result) = stream.next().await {
252        match result {
253            Ok(msg) => {
254                let info = instrument.clone().or_else(|| {
255                    instrument_map
256                        .as_ref()
257                        .and_then(|map| determine_instrument_info(&msg, map))
258                });
259
260                if let Some(info) = info.clone() {
261                    if let Some(data) = parse_tardis_ws_message(msg.clone(), info.clone()) {
262                        Python::with_gil(|py| {
263                            let py_obj = data_to_pycapsule(py, data);
264                            call_python(py, &callback, py_obj);
265                        });
266                    } else if let Some(funding_rate) =
267                        parse_tardis_ws_message_funding_rate(msg, info)
268                    {
269                        // Check if we should emit this funding rate
270                        let should_emit = if let Some(cached_rate) =
271                            funding_rate_cache.get(&funding_rate.instrument_id)
272                        {
273                            // Only emit if changed (uses custom PartialEq comparing rate and next_funding_ns)
274                            if cached_rate == &funding_rate {
275                                false // Skip unchanged rate
276                            } else {
277                                funding_rate_cache.insert(funding_rate.instrument_id, funding_rate);
278                                true
279                            }
280                        } else {
281                            // First time seeing this instrument, cache and emit
282                            funding_rate_cache.insert(funding_rate.instrument_id, funding_rate);
283                            true
284                        };
285
286                        if should_emit {
287                            Python::with_gil(|py| {
288                                let py_obj = funding_rate.into_py_any_unwrap(py);
289                                call_python(py, &callback, py_obj);
290                            });
291                        }
292                    }
293                }
294            }
295            Err(e) => {
296                tracing::error!("Error in WebSocket stream: {e:?}");
297                break;
298            }
299        }
300    }
301}
302
303fn call_python(py: Python, callback: &PyObject, py_obj: PyObject) {
304    if let Err(e) = callback.call1(py, (py_obj,)) {
305        tracing::error!("Error calling Python: {e}");
306    }
307}