1use std::{collections::HashMap, path::Path, sync::Arc};
17
18use futures_util::{Stream, StreamExt, pin_mut};
19use nautilus_core::python::{IntoPyObjectNautilusExt, to_pyruntime_err};
20use nautilus_model::{
21 data::{Bar, Data},
22 python::data::data_to_pycapsule,
23};
24use pyo3::{prelude::*, types::PyList};
25
26use crate::{
27 machine::{
28 Error,
29 client::{TardisMachineClient, determine_instrument_info},
30 message::WsMessage,
31 parse::parse_tardis_ws_message,
32 replay_normalized, stream_normalized,
33 types::{
34 InstrumentMiniInfo, ReplayNormalizedRequestOptions, StreamNormalizedRequestOptions,
35 TardisInstrumentKey,
36 },
37 },
38 replay::run_tardis_machine_replay_from_config,
39};
40
41#[pymethods]
42impl ReplayNormalizedRequestOptions {
43 #[staticmethod]
44 #[pyo3(name = "from_json")]
45 fn py_from_json(data: Vec<u8>) -> Self {
46 serde_json::from_slice(&data).expect("Failed to parse JSON")
47 }
48
49 #[pyo3(name = "from_json_array")]
50 #[staticmethod]
51 fn py_from_json_array(data: Vec<u8>) -> Vec<Self> {
52 serde_json::from_slice(&data).expect("Failed to parse JSON array")
53 }
54}
55
56#[pymethods]
57impl StreamNormalizedRequestOptions {
58 #[staticmethod]
59 #[pyo3(name = "from_json")]
60 fn py_from_json(data: Vec<u8>) -> Self {
61 serde_json::from_slice(&data).expect("Failed to parse JSON")
62 }
63
64 #[pyo3(name = "from_json_array")]
65 #[staticmethod]
66 fn py_from_json_array(data: Vec<u8>) -> Vec<Self> {
67 serde_json::from_slice(&data).expect("Failed to parse JSON array")
68 }
69}
70
71#[pymethods]
72impl TardisMachineClient {
73 #[new]
74 #[pyo3(signature = (base_url=None, normalize_symbols=true))]
75 fn py_new(base_url: Option<&str>, normalize_symbols: bool) -> PyResult<Self> {
76 Self::new(base_url, normalize_symbols).map_err(to_pyruntime_err)
77 }
78
79 #[pyo3(name = "is_closed")]
80 #[must_use]
81 pub fn py_is_closed(&self) -> bool {
82 self.is_closed()
83 }
84
85 #[pyo3(name = "close")]
86 fn py_close(&mut self) {
87 self.close();
88 }
89
90 #[pyo3(name = "replay")]
91 fn py_replay<'py>(
92 &self,
93 instruments: Vec<InstrumentMiniInfo>,
94 options: Vec<ReplayNormalizedRequestOptions>,
95 callback: PyObject,
96 py: Python<'py>,
97 ) -> PyResult<Bound<'py, PyAny>> {
98 let map = if instruments.is_empty() {
99 self.instruments.clone()
100 } else {
101 let mut instrument_map: HashMap<TardisInstrumentKey, Arc<InstrumentMiniInfo>> =
102 HashMap::new();
103 for inst in instruments {
104 let key = inst.as_tardis_instrument_key();
105 instrument_map.insert(key, Arc::new(inst.clone()));
106 }
107 instrument_map
108 };
109
110 let base_url = self.base_url.clone();
111 let replay_signal = self.replay_signal.clone();
112
113 pyo3_async_runtimes::tokio::future_into_py(py, async move {
114 let stream = replay_normalized(&base_url, options, replay_signal)
115 .await
116 .map_err(to_pyruntime_err)?;
117
118 handle_python_stream(Box::pin(stream), callback, None, Some(map)).await;
121 Ok(())
122 })
123 }
124
125 #[pyo3(name = "replay_bars")]
126 fn py_replay_bars<'py>(
127 &self,
128 instruments: Vec<InstrumentMiniInfo>,
129 options: Vec<ReplayNormalizedRequestOptions>,
130 py: Python<'py>,
131 ) -> PyResult<Bound<'py, PyAny>> {
132 let map = if instruments.is_empty() {
133 self.instruments.clone()
134 } else {
135 instruments
136 .into_iter()
137 .map(|inst| (inst.as_tardis_instrument_key(), Arc::new(inst)))
138 .collect()
139 };
140
141 let base_url = self.base_url.clone();
142 let replay_signal = self.replay_signal.clone();
143
144 pyo3_async_runtimes::tokio::future_into_py(py, async move {
145 let stream = replay_normalized(&base_url, options, replay_signal)
146 .await
147 .map_err(to_pyruntime_err)?;
148
149 pin_mut!(stream);
152
153 let mut bars: Vec<Bar> = Vec::new();
154
155 while let Some(result) = stream.next().await {
156 match result {
157 Ok(msg) => {
158 if let Some(Data::Bar(bar)) = determine_instrument_info(&msg, &map)
159 .and_then(|info| parse_tardis_ws_message(msg, info))
160 {
161 bars.push(bar);
162 }
163 }
164 Err(e) => {
165 tracing::error!("Error in WebSocket stream: {e:?}");
166 break;
167 }
168 }
169 }
170
171 Python::with_gil(|py| {
172 let pylist =
173 PyList::new(py, bars.into_iter().map(|bar| bar.into_py_any_unwrap(py)))
174 .expect("Invalid `ExactSizeIterator`");
175 Ok(pylist.into_py_any_unwrap(py))
176 })
177 })
178 }
179
180 #[pyo3(name = "stream")]
181 fn py_stream<'py>(
182 &self,
183 instruments: Vec<InstrumentMiniInfo>,
184 options: Vec<StreamNormalizedRequestOptions>,
185 callback: PyObject,
186 py: Python<'py>,
187 ) -> PyResult<Bound<'py, PyAny>> {
188 let mut instrument_map: HashMap<TardisInstrumentKey, Arc<InstrumentMiniInfo>> =
189 HashMap::new();
190 for inst in instruments {
191 let key = inst.as_tardis_instrument_key();
192 instrument_map.insert(key, Arc::new(inst.clone()));
193 }
194
195 let base_url = self.base_url.clone();
196 let replay_signal = self.replay_signal.clone();
197
198 pyo3_async_runtimes::tokio::future_into_py(py, async move {
199 let stream = stream_normalized(&base_url, options, replay_signal)
200 .await
201 .expect("Failed to connect to WebSocket");
202
203 handle_python_stream(Box::pin(stream), callback, None, Some(instrument_map)).await;
206 Ok(())
207 })
208 }
209}
210
211#[pyfunction]
212#[pyo3(name = "run_tardis_machine_replay")]
213#[pyo3(signature = (config_filepath))]
214pub fn py_run_tardis_machine_replay(
215 py: Python<'_>,
216 config_filepath: String,
217) -> PyResult<Bound<'_, PyAny>> {
218 tracing_subscriber::fmt()
219 .with_max_level(tracing::Level::DEBUG)
220 .init();
221
222 pyo3_async_runtimes::tokio::future_into_py(py, async move {
223 let config_filepath = Path::new(&config_filepath);
224 run_tardis_machine_replay_from_config(config_filepath)
225 .await
226 .map_err(to_pyruntime_err)?;
227 Ok(())
228 })
229}
230
231async fn handle_python_stream<S>(
232 stream: S,
233 callback: PyObject,
234 instrument: Option<Arc<InstrumentMiniInfo>>,
235 instrument_map: Option<HashMap<TardisInstrumentKey, Arc<InstrumentMiniInfo>>>,
236) where
237 S: Stream<Item = Result<WsMessage, Error>> + Unpin,
238{
239 pin_mut!(stream);
240
241 while let Some(result) = stream.next().await {
242 match result {
243 Ok(msg) => {
244 let info = instrument.clone().or_else(|| {
245 instrument_map
246 .as_ref()
247 .and_then(|map| determine_instrument_info(&msg, map))
248 });
249
250 if let Some(info) = info {
251 if let Some(data) = parse_tardis_ws_message(msg, info) {
252 Python::with_gil(|py| {
253 let py_obj = data_to_pycapsule(py, data);
254 call_python(py, &callback, py_obj);
255 });
256 }
257 }
258 }
259 Err(e) => {
260 tracing::error!("Error in WebSocket stream: {e:?}");
261 break;
262 }
263 }
264 }
265}
266
267fn call_python(py: Python, callback: &PyObject, py_obj: PyObject) {
268 if let Err(e) = callback.call1(py, (py_obj,)) {
269 tracing::error!("Error calling Python: {e}");
270 }
271}