nautilus_tardis/python/
machine.rs1use std::{collections::HashMap, path::Path, sync::Arc};
17
18use futures_util::{pin_mut, Stream, StreamExt};
19use nautilus_core::python::to_pyruntime_err;
20use nautilus_model::{
21 data::{Bar, Data},
22 python::data::data_to_pycapsule,
23};
24use pyo3::{prelude::*, types::PyList};
25
26use crate::{
27 machine::{
28 client::{determine_instrument_info, TardisMachineClient},
29 message::WsMessage,
30 parse::parse_tardis_ws_message,
31 replay_normalized, stream_normalized,
32 types::{
33 InstrumentMiniInfo, ReplayNormalizedRequestOptions, StreamNormalizedRequestOptions,
34 TardisInstrumentKey,
35 },
36 Error,
37 },
38 replay::run_tardis_machine_replay_from_config,
39};
40
41#[pymethods]
42impl ReplayNormalizedRequestOptions {
43 #[staticmethod]
44 #[pyo3(name = "from_json")]
45 fn py_from_json(data: Vec<u8>) -> Self {
46 serde_json::from_slice(&data).expect("Failed to parse JSON")
47 }
48
49 #[pyo3(name = "from_json_array")]
50 #[staticmethod]
51 fn py_from_json_array(data: Vec<u8>) -> Vec<Self> {
52 serde_json::from_slice(&data).expect("Failed to parse JSON array")
53 }
54}
55
56#[pymethods]
57impl StreamNormalizedRequestOptions {
58 #[staticmethod]
59 #[pyo3(name = "from_json")]
60 fn py_from_json(data: Vec<u8>) -> Self {
61 serde_json::from_slice(&data).expect("Failed to parse JSON")
62 }
63
64 #[pyo3(name = "from_json_array")]
65 #[staticmethod]
66 fn py_from_json_array(data: Vec<u8>) -> Vec<Self> {
67 serde_json::from_slice(&data).expect("Failed to parse JSON array")
68 }
69}
70
71#[pymethods]
72impl TardisMachineClient {
73 #[new]
74 #[pyo3(signature = (base_url=None, normalize_symbols=true))]
75 fn py_new(base_url: Option<&str>, normalize_symbols: bool) -> PyResult<Self> {
76 Self::new(base_url, normalize_symbols).map_err(to_pyruntime_err)
77 }
78
79 #[pyo3(name = "is_closed")]
80 #[must_use]
81 pub fn py_is_closed(&self) -> bool {
82 self.is_closed()
83 }
84
85 #[pyo3(name = "close")]
86 fn py_close(&mut self) {
87 self.close();
88 }
89
90 #[pyo3(name = "replay")]
91 fn py_replay<'py>(
92 &self,
93 instruments: Vec<InstrumentMiniInfo>,
94 options: Vec<ReplayNormalizedRequestOptions>,
95 callback: PyObject,
96 py: Python<'py>,
97 ) -> PyResult<Bound<'py, PyAny>> {
98 let map = if instruments.is_empty() {
99 self.instruments.clone()
100 } else {
101 let mut instrument_map: HashMap<TardisInstrumentKey, Arc<InstrumentMiniInfo>> =
102 HashMap::new();
103 for inst in instruments {
104 let key = inst.as_tardis_instrument_key();
105 instrument_map.insert(key, Arc::new(inst.clone()));
106 }
107 instrument_map
108 };
109
110 let base_url = self.base_url.clone();
111 let replay_signal = self.replay_signal.clone();
112
113 pyo3_async_runtimes::tokio::future_into_py(py, async move {
114 let stream = replay_normalized(&base_url, options, replay_signal)
115 .await
116 .map_err(to_pyruntime_err)?;
117
118 handle_python_stream(Box::pin(stream), callback, None, Some(map)).await;
121 Ok(())
122 })
123 }
124
125 #[pyo3(name = "replay_bars")]
126 fn py_replay_bars<'py>(
127 &self,
128 instruments: Vec<InstrumentMiniInfo>,
129 options: Vec<ReplayNormalizedRequestOptions>,
130 py: Python<'py>,
131 ) -> PyResult<Bound<'py, PyAny>> {
132 let map = if instruments.is_empty() {
133 self.instruments.clone()
134 } else {
135 instruments
136 .into_iter()
137 .map(|inst| (inst.as_tardis_instrument_key(), Arc::new(inst)))
138 .collect()
139 };
140
141 let base_url = self.base_url.clone();
142 let replay_signal = self.replay_signal.clone();
143
144 pyo3_async_runtimes::tokio::future_into_py(py, async move {
145 let stream = replay_normalized(&base_url, options, replay_signal)
146 .await
147 .map_err(to_pyruntime_err)?;
148
149 pin_mut!(stream);
152
153 let mut bars: Vec<Bar> = Vec::new();
154
155 while let Some(result) = stream.next().await {
156 match result {
157 Ok(msg) => {
158 if let Some(Data::Bar(bar)) = determine_instrument_info(&msg, &map)
159 .and_then(|info| parse_tardis_ws_message(msg, info))
160 {
161 bars.push(bar);
162 }
163 }
164 Err(e) => {
165 tracing::error!("Error in WebSocket stream: {e:?}");
166 break;
167 }
168 }
169 }
170
171 Python::with_gil(|py| {
172 let pylist = PyList::new(py, bars.into_iter().map(|bar| bar.into_py(py)))
173 .expect("Invalid `ExactSizeIterator`");
174 Ok(pylist.into_py(py))
175 })
176 })
177 }
178
179 #[pyo3(name = "stream")]
180 fn py_stream<'py>(
181 &self,
182 instruments: Vec<InstrumentMiniInfo>,
183 options: Vec<StreamNormalizedRequestOptions>,
184 callback: PyObject,
185 py: Python<'py>,
186 ) -> PyResult<Bound<'py, PyAny>> {
187 let mut instrument_map: HashMap<TardisInstrumentKey, Arc<InstrumentMiniInfo>> =
188 HashMap::new();
189 for inst in instruments {
190 let key = inst.as_tardis_instrument_key();
191 instrument_map.insert(key, Arc::new(inst.clone()));
192 }
193
194 let base_url = self.base_url.clone();
195 let replay_signal = self.replay_signal.clone();
196
197 pyo3_async_runtimes::tokio::future_into_py(py, async move {
198 let stream = stream_normalized(&base_url, options, replay_signal)
199 .await
200 .expect("Failed to connect to WebSocket");
201
202 handle_python_stream(Box::pin(stream), callback, None, Some(instrument_map)).await;
205 Ok(())
206 })
207 }
208}
209
210#[pyfunction]
211#[pyo3(name = "run_tardis_machine_replay")]
212#[pyo3(signature = (config_filepath))]
213pub fn py_run_tardis_machine_replay(
214 py: Python<'_>,
215 config_filepath: String,
216) -> PyResult<Bound<'_, PyAny>> {
217 tracing_subscriber::fmt()
218 .with_max_level(tracing::Level::DEBUG)
219 .init();
220
221 pyo3_async_runtimes::tokio::future_into_py(py, async move {
222 let config_filepath = Path::new(&config_filepath);
223 run_tardis_machine_replay_from_config(config_filepath)
224 .await
225 .map_err(to_pyruntime_err)?;
226 Ok(())
227 })
228}
229
230async fn handle_python_stream<S>(
231 stream: S,
232 callback: PyObject,
233 instrument: Option<Arc<InstrumentMiniInfo>>,
234 instrument_map: Option<HashMap<TardisInstrumentKey, Arc<InstrumentMiniInfo>>>,
235) where
236 S: Stream<Item = Result<WsMessage, Error>> + Unpin,
237{
238 pin_mut!(stream);
239
240 while let Some(result) = stream.next().await {
241 match result {
242 Ok(msg) => {
243 let info = instrument.clone().or_else(|| {
244 instrument_map
245 .as_ref()
246 .and_then(|map| determine_instrument_info(&msg, map))
247 });
248
249 if let Some(info) = info {
250 if let Some(data) = parse_tardis_ws_message(msg, info) {
251 Python::with_gil(|py| {
252 let py_obj = data_to_pycapsule(py, data);
253 call_python(py, &callback, py_obj);
254 });
255 }
256 }
257 }
258 Err(e) => {
259 tracing::error!("Error in WebSocket stream: {e:?}");
260 break;
261 }
262 }
263 }
264}
265
266fn call_python(py: Python, callback: &PyObject, py_obj: PyObject) {
267 if let Err(e) = callback.call1(py, (py_obj,)) {
268 tracing::error!("Error calling Python: {e}");
269 }
270}