nautilus_tardis/python/
machine.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, path::Path, sync::Arc};
17
18use ahash::AHashMap;
19use futures_util::{Stream, StreamExt, pin_mut};
20use nautilus_core::python::{IntoPyObjectNautilusExt, to_pyruntime_err};
21use nautilus_model::{
22    data::{Bar, Data, funding::FundingRateUpdate},
23    identifiers::InstrumentId,
24    python::data::data_to_pycapsule,
25};
26use pyo3::{prelude::*, types::PyList};
27
28use crate::{
29    config::BookSnapshotOutput,
30    machine::{
31        Error,
32        client::{TardisMachineClient, determine_instrument_info},
33        message::WsMessage,
34        parse::{parse_tardis_ws_message, parse_tardis_ws_message_funding_rate},
35        replay_normalized, stream_normalized,
36        types::{
37            ReplayNormalizedRequestOptions, StreamNormalizedRequestOptions, TardisInstrumentKey,
38            TardisInstrumentMiniInfo,
39        },
40    },
41    replay::run_tardis_machine_replay_from_config,
42};
43
44#[pymethods]
45impl ReplayNormalizedRequestOptions {
46    #[staticmethod]
47    #[pyo3(name = "from_json")]
48    fn py_from_json(data: Vec<u8>) -> Self {
49        serde_json::from_slice(&data).expect("Failed to parse JSON")
50    }
51
52    #[pyo3(name = "from_json_array")]
53    #[staticmethod]
54    fn py_from_json_array(data: Vec<u8>) -> Vec<Self> {
55        serde_json::from_slice(&data).expect("Failed to parse JSON array")
56    }
57}
58
59#[pymethods]
60impl StreamNormalizedRequestOptions {
61    #[staticmethod]
62    #[pyo3(name = "from_json")]
63    fn py_from_json(data: Vec<u8>) -> Self {
64        serde_json::from_slice(&data).expect("Failed to parse JSON")
65    }
66
67    #[pyo3(name = "from_json_array")]
68    #[staticmethod]
69    fn py_from_json_array(data: Vec<u8>) -> Vec<Self> {
70        serde_json::from_slice(&data).expect("Failed to parse JSON array")
71    }
72}
73
74#[pymethods]
75impl TardisMachineClient {
76    #[new]
77    #[pyo3(signature = (base_url=None, normalize_symbols=true, book_snapshot_output="deltas"))]
78    fn py_new(
79        base_url: Option<&str>,
80        normalize_symbols: bool,
81        book_snapshot_output: &str,
82    ) -> PyResult<Self> {
83        let output = match book_snapshot_output {
84            "depth10" => BookSnapshotOutput::Depth10,
85            "deltas" => BookSnapshotOutput::Deltas,
86            _ => {
87                return Err(to_pyruntime_err(anyhow::anyhow!(
88                    "Invalid book_snapshot_output: '{book_snapshot_output}'. Expected 'depth10' or 'deltas'"
89                )));
90            }
91        };
92        Self::new(base_url, normalize_symbols, output).map_err(to_pyruntime_err)
93    }
94
95    #[pyo3(name = "is_closed")]
96    #[must_use]
97    pub fn py_is_closed(&self) -> bool {
98        self.is_closed()
99    }
100
101    #[pyo3(name = "close")]
102    fn py_close(&mut self) {
103        self.close();
104    }
105
106    #[pyo3(name = "replay")]
107    fn py_replay<'py>(
108        &self,
109        instruments: Vec<TardisInstrumentMiniInfo>,
110        options: Vec<ReplayNormalizedRequestOptions>,
111        callback: Py<PyAny>,
112        py: Python<'py>,
113    ) -> PyResult<Bound<'py, PyAny>> {
114        let map = if instruments.is_empty() {
115            self.instruments.clone()
116        } else {
117            let mut instrument_map: HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>> =
118                HashMap::new();
119            for inst in instruments {
120                let key = inst.as_tardis_instrument_key();
121                instrument_map.insert(key, Arc::new(inst.clone()));
122            }
123            instrument_map
124        };
125
126        let base_url = self.base_url.clone();
127        let replay_signal = self.replay_signal.clone();
128        let book_snapshot_output = self.book_snapshot_output.clone();
129
130        pyo3_async_runtimes::tokio::future_into_py(py, async move {
131            let stream = replay_normalized(&base_url, options, replay_signal)
132                .await
133                .map_err(to_pyruntime_err)?;
134
135            // We use Box::pin to heap-allocate the stream and ensure it implements
136            // Unpin for safe async handling across lifetimes.
137            handle_python_stream(
138                Box::pin(stream),
139                callback,
140                None,
141                Some(map),
142                book_snapshot_output,
143            )
144            .await;
145            Ok(())
146        })
147    }
148
149    #[pyo3(name = "replay_bars")]
150    fn py_replay_bars<'py>(
151        &self,
152        instruments: Vec<TardisInstrumentMiniInfo>,
153        options: Vec<ReplayNormalizedRequestOptions>,
154        py: Python<'py>,
155    ) -> PyResult<Bound<'py, PyAny>> {
156        let map = if instruments.is_empty() {
157            self.instruments.clone()
158        } else {
159            instruments
160                .into_iter()
161                .map(|inst| (inst.as_tardis_instrument_key(), Arc::new(inst)))
162                .collect()
163        };
164
165        let base_url = self.base_url.clone();
166        let replay_signal = self.replay_signal.clone();
167        let book_snapshot_output = self.book_snapshot_output.clone();
168
169        pyo3_async_runtimes::tokio::future_into_py(py, async move {
170            let stream = replay_normalized(&base_url, options, replay_signal)
171                .await
172                .map_err(to_pyruntime_err)?;
173
174            // We use Box::pin to heap-allocate the stream and ensure it implements
175            // Unpin for safe async handling across lifetimes.
176            pin_mut!(stream);
177
178            let mut bars: Vec<Bar> = Vec::new();
179
180            while let Some(result) = stream.next().await {
181                match result {
182                    Ok(msg) => {
183                        if let Some(Data::Bar(bar)) = determine_instrument_info(&msg, &map)
184                            .and_then(|info| {
185                                parse_tardis_ws_message(msg, info, &book_snapshot_output)
186                            })
187                        {
188                            bars.push(bar);
189                        }
190                    }
191                    Err(e) => {
192                        tracing::error!("Error in WebSocket stream: {e:?}");
193                        break;
194                    }
195                }
196            }
197
198            Python::attach(|py| {
199                let pylist =
200                    PyList::new(py, bars.into_iter().map(|bar| bar.into_py_any_unwrap(py)))
201                        .expect("Invalid `ExactSizeIterator`");
202                Ok(pylist.into_py_any_unwrap(py))
203            })
204        })
205    }
206
207    #[pyo3(name = "stream")]
208    fn py_stream<'py>(
209        &self,
210        instruments: Vec<TardisInstrumentMiniInfo>,
211        options: Vec<StreamNormalizedRequestOptions>,
212        callback: Py<PyAny>,
213        py: Python<'py>,
214    ) -> PyResult<Bound<'py, PyAny>> {
215        let mut instrument_map: HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>> =
216            HashMap::new();
217        for inst in instruments {
218            let key = inst.as_tardis_instrument_key();
219            instrument_map.insert(key, Arc::new(inst.clone()));
220        }
221
222        let base_url = self.base_url.clone();
223        let replay_signal = self.replay_signal.clone();
224        let book_snapshot_output = self.book_snapshot_output.clone();
225
226        pyo3_async_runtimes::tokio::future_into_py(py, async move {
227            let stream = stream_normalized(&base_url, options, replay_signal)
228                .await
229                .map_err(to_pyruntime_err)?;
230
231            // We use Box::pin to heap-allocate the stream and ensure it implements
232            // Unpin for safe async handling across lifetimes.
233            handle_python_stream(
234                Box::pin(stream),
235                callback,
236                None,
237                Some(instrument_map),
238                book_snapshot_output,
239            )
240            .await;
241            Ok(())
242        })
243    }
244}
245
246/// Run the Tardis Machine replay as an async Python future.
247///
248/// # Errors
249///
250/// Returns a `PyErr` if reading the config file or replay execution fails.
251#[pyfunction]
252#[pyo3(name = "run_tardis_machine_replay")]
253#[pyo3(signature = (config_filepath))]
254pub fn py_run_tardis_machine_replay(
255    py: Python<'_>,
256    config_filepath: String,
257) -> PyResult<Bound<'_, PyAny>> {
258    tracing_subscriber::fmt()
259        .with_max_level(tracing::Level::DEBUG)
260        .init();
261
262    pyo3_async_runtimes::tokio::future_into_py(py, async move {
263        let config_filepath = Path::new(&config_filepath);
264        run_tardis_machine_replay_from_config(config_filepath)
265            .await
266            .map_err(to_pyruntime_err)?;
267        Ok(())
268    })
269}
270
271async fn handle_python_stream<S>(
272    stream: S,
273    callback: Py<PyAny>,
274    instrument: Option<Arc<TardisInstrumentMiniInfo>>,
275    instrument_map: Option<HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>>>,
276    book_snapshot_output: BookSnapshotOutput,
277) where
278    S: Stream<Item = Result<WsMessage, Error>> + Unpin,
279{
280    pin_mut!(stream);
281
282    // Cache for funding rates to avoid duplicate emissions
283    let mut funding_rate_cache: AHashMap<InstrumentId, FundingRateUpdate> = AHashMap::new();
284
285    while let Some(result) = stream.next().await {
286        match result {
287            Ok(msg) => {
288                let info = instrument.clone().or_else(|| {
289                    instrument_map
290                        .as_ref()
291                        .and_then(|map| determine_instrument_info(&msg, map))
292                });
293
294                if let Some(info) = info.clone() {
295                    if let Some(data) =
296                        parse_tardis_ws_message(msg.clone(), info.clone(), &book_snapshot_output)
297                    {
298                        Python::attach(|py| {
299                            let py_obj = data_to_pycapsule(py, data);
300                            call_python(py, &callback, py_obj);
301                        });
302                    } else if let Some(funding_rate) =
303                        parse_tardis_ws_message_funding_rate(msg, info)
304                    {
305                        // Check if we should emit this funding rate
306                        let should_emit = if let Some(cached_rate) =
307                            funding_rate_cache.get(&funding_rate.instrument_id)
308                        {
309                            // Only emit if changed (uses custom PartialEq comparing rate and next_funding_ns)
310                            if cached_rate == &funding_rate {
311                                false // Skip unchanged rate
312                            } else {
313                                funding_rate_cache.insert(funding_rate.instrument_id, funding_rate);
314                                true
315                            }
316                        } else {
317                            // First time seeing this instrument, cache and emit
318                            funding_rate_cache.insert(funding_rate.instrument_id, funding_rate);
319                            true
320                        };
321
322                        if should_emit {
323                            Python::attach(|py| {
324                                let py_obj = funding_rate.into_py_any_unwrap(py);
325                                call_python(py, &callback, py_obj);
326                            });
327                        }
328                    }
329                }
330            }
331            Err(e) => {
332                tracing::error!("Error in WebSocket stream: {e:?}");
333                break;
334            }
335        }
336    }
337}
338
339fn call_python(py: Python, callback: &Py<PyAny>, py_obj: Py<PyAny>) {
340    if let Err(e) = callback.call1(py, (py_obj,)) {
341        tracing::error!("Error calling Python: {e}");
342    }
343}