nautilus_tardis/python/
machine.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, path::Path, sync::Arc};
17
18use futures_util::{Stream, StreamExt, pin_mut};
19use nautilus_core::python::{IntoPyObjectNautilusExt, to_pyruntime_err};
20use nautilus_model::{
21    data::{Bar, Data},
22    python::data::data_to_pycapsule,
23};
24use pyo3::{prelude::*, types::PyList};
25
26use crate::{
27    machine::{
28        Error,
29        client::{TardisMachineClient, determine_instrument_info},
30        message::WsMessage,
31        parse::parse_tardis_ws_message,
32        replay_normalized, stream_normalized,
33        types::{
34            InstrumentMiniInfo, ReplayNormalizedRequestOptions, StreamNormalizedRequestOptions,
35            TardisInstrumentKey,
36        },
37    },
38    replay::run_tardis_machine_replay_from_config,
39};
40
41#[pymethods]
42impl ReplayNormalizedRequestOptions {
43    #[staticmethod]
44    #[pyo3(name = "from_json")]
45    fn py_from_json(data: Vec<u8>) -> Self {
46        serde_json::from_slice(&data).expect("Failed to parse JSON")
47    }
48
49    #[pyo3(name = "from_json_array")]
50    #[staticmethod]
51    fn py_from_json_array(data: Vec<u8>) -> Vec<Self> {
52        serde_json::from_slice(&data).expect("Failed to parse JSON array")
53    }
54}
55
56#[pymethods]
57impl StreamNormalizedRequestOptions {
58    #[staticmethod]
59    #[pyo3(name = "from_json")]
60    fn py_from_json(data: Vec<u8>) -> Self {
61        serde_json::from_slice(&data).expect("Failed to parse JSON")
62    }
63
64    #[pyo3(name = "from_json_array")]
65    #[staticmethod]
66    fn py_from_json_array(data: Vec<u8>) -> Vec<Self> {
67        serde_json::from_slice(&data).expect("Failed to parse JSON array")
68    }
69}
70
71#[pymethods]
72impl TardisMachineClient {
73    #[new]
74    #[pyo3(signature = (base_url=None, normalize_symbols=true))]
75    fn py_new(base_url: Option<&str>, normalize_symbols: bool) -> PyResult<Self> {
76        Self::new(base_url, normalize_symbols).map_err(to_pyruntime_err)
77    }
78
79    #[pyo3(name = "is_closed")]
80    #[must_use]
81    pub fn py_is_closed(&self) -> bool {
82        self.is_closed()
83    }
84
85    #[pyo3(name = "close")]
86    fn py_close(&mut self) {
87        self.close();
88    }
89
90    #[pyo3(name = "replay")]
91    fn py_replay<'py>(
92        &self,
93        instruments: Vec<InstrumentMiniInfo>,
94        options: Vec<ReplayNormalizedRequestOptions>,
95        callback: PyObject,
96        py: Python<'py>,
97    ) -> PyResult<Bound<'py, PyAny>> {
98        let map = if instruments.is_empty() {
99            self.instruments.clone()
100        } else {
101            let mut instrument_map: HashMap<TardisInstrumentKey, Arc<InstrumentMiniInfo>> =
102                HashMap::new();
103            for inst in instruments {
104                let key = inst.as_tardis_instrument_key();
105                instrument_map.insert(key, Arc::new(inst.clone()));
106            }
107            instrument_map
108        };
109
110        let base_url = self.base_url.clone();
111        let replay_signal = self.replay_signal.clone();
112
113        pyo3_async_runtimes::tokio::future_into_py(py, async move {
114            let stream = replay_normalized(&base_url, options, replay_signal)
115                .await
116                .map_err(to_pyruntime_err)?;
117
118            // We use Box::pin to heap-allocate the stream and ensure it implements
119            // Unpin for safe async handling across lifetimes.
120            handle_python_stream(Box::pin(stream), callback, None, Some(map)).await;
121            Ok(())
122        })
123    }
124
125    #[pyo3(name = "replay_bars")]
126    fn py_replay_bars<'py>(
127        &self,
128        instruments: Vec<InstrumentMiniInfo>,
129        options: Vec<ReplayNormalizedRequestOptions>,
130        py: Python<'py>,
131    ) -> PyResult<Bound<'py, PyAny>> {
132        let map = if instruments.is_empty() {
133            self.instruments.clone()
134        } else {
135            instruments
136                .into_iter()
137                .map(|inst| (inst.as_tardis_instrument_key(), Arc::new(inst)))
138                .collect()
139        };
140
141        let base_url = self.base_url.clone();
142        let replay_signal = self.replay_signal.clone();
143
144        pyo3_async_runtimes::tokio::future_into_py(py, async move {
145            let stream = replay_normalized(&base_url, options, replay_signal)
146                .await
147                .map_err(to_pyruntime_err)?;
148
149            // We use Box::pin to heap-allocate the stream and ensure it implements
150            // Unpin for safe async handling across lifetimes.
151            pin_mut!(stream);
152
153            let mut bars: Vec<Bar> = Vec::new();
154
155            while let Some(result) = stream.next().await {
156                match result {
157                    Ok(msg) => {
158                        if let Some(Data::Bar(bar)) = determine_instrument_info(&msg, &map)
159                            .and_then(|info| parse_tardis_ws_message(msg, info))
160                        {
161                            bars.push(bar);
162                        }
163                    }
164                    Err(e) => {
165                        tracing::error!("Error in WebSocket stream: {e:?}");
166                        break;
167                    }
168                }
169            }
170
171            Python::with_gil(|py| {
172                let pylist =
173                    PyList::new(py, bars.into_iter().map(|bar| bar.into_py_any_unwrap(py)))
174                        .expect("Invalid `ExactSizeIterator`");
175                Ok(pylist.into_py_any_unwrap(py))
176            })
177        })
178    }
179
180    #[pyo3(name = "stream")]
181    fn py_stream<'py>(
182        &self,
183        instruments: Vec<InstrumentMiniInfo>,
184        options: Vec<StreamNormalizedRequestOptions>,
185        callback: PyObject,
186        py: Python<'py>,
187    ) -> PyResult<Bound<'py, PyAny>> {
188        let mut instrument_map: HashMap<TardisInstrumentKey, Arc<InstrumentMiniInfo>> =
189            HashMap::new();
190        for inst in instruments {
191            let key = inst.as_tardis_instrument_key();
192            instrument_map.insert(key, Arc::new(inst.clone()));
193        }
194
195        let base_url = self.base_url.clone();
196        let replay_signal = self.replay_signal.clone();
197
198        pyo3_async_runtimes::tokio::future_into_py(py, async move {
199            let stream = stream_normalized(&base_url, options, replay_signal)
200                .await
201                .expect("Failed to connect to WebSocket");
202
203            // We use Box::pin to heap-allocate the stream and ensure it implements
204            // Unpin for safe async handling across lifetimes.
205            handle_python_stream(Box::pin(stream), callback, None, Some(instrument_map)).await;
206            Ok(())
207        })
208    }
209}
210
211#[pyfunction]
212#[pyo3(name = "run_tardis_machine_replay")]
213#[pyo3(signature = (config_filepath))]
214pub fn py_run_tardis_machine_replay(
215    py: Python<'_>,
216    config_filepath: String,
217) -> PyResult<Bound<'_, PyAny>> {
218    tracing_subscriber::fmt()
219        .with_max_level(tracing::Level::DEBUG)
220        .init();
221
222    pyo3_async_runtimes::tokio::future_into_py(py, async move {
223        let config_filepath = Path::new(&config_filepath);
224        run_tardis_machine_replay_from_config(config_filepath)
225            .await
226            .map_err(to_pyruntime_err)?;
227        Ok(())
228    })
229}
230
231async fn handle_python_stream<S>(
232    stream: S,
233    callback: PyObject,
234    instrument: Option<Arc<InstrumentMiniInfo>>,
235    instrument_map: Option<HashMap<TardisInstrumentKey, Arc<InstrumentMiniInfo>>>,
236) where
237    S: Stream<Item = Result<WsMessage, Error>> + Unpin,
238{
239    pin_mut!(stream);
240
241    while let Some(result) = stream.next().await {
242        match result {
243            Ok(msg) => {
244                let info = instrument.clone().or_else(|| {
245                    instrument_map
246                        .as_ref()
247                        .and_then(|map| determine_instrument_info(&msg, map))
248                });
249
250                if let Some(info) = info {
251                    if let Some(data) = parse_tardis_ws_message(msg, info) {
252                        Python::with_gil(|py| {
253                            let py_obj = data_to_pycapsule(py, data);
254                            call_python(py, &callback, py_obj);
255                        });
256                    }
257                }
258            }
259            Err(e) => {
260                tracing::error!("Error in WebSocket stream: {e:?}");
261                break;
262            }
263        }
264    }
265}
266
267fn call_python(py: Python, callback: &PyObject, py_obj: PyObject) {
268    if let Err(e) = callback.call1(py, (py_obj,)) {
269        tracing::error!("Error calling Python: {e}");
270    }
271}