1use std::{collections::HashMap, path::Path, sync::Arc};
17
18use ahash::AHashMap;
19use futures_util::{Stream, StreamExt, pin_mut};
20use nautilus_core::python::{IntoPyObjectNautilusExt, to_pyruntime_err};
21use nautilus_model::{
22 data::{Bar, Data, funding::FundingRateUpdate},
23 identifiers::InstrumentId,
24 python::data::data_to_pycapsule,
25};
26use pyo3::{prelude::*, types::PyList};
27
28use crate::{
29 config::BookSnapshotOutput,
30 machine::{
31 Error,
32 client::{TardisMachineClient, determine_instrument_info},
33 message::WsMessage,
34 parse::{parse_tardis_ws_message, parse_tardis_ws_message_funding_rate},
35 replay_normalized, stream_normalized,
36 types::{
37 ReplayNormalizedRequestOptions, StreamNormalizedRequestOptions, TardisInstrumentKey,
38 TardisInstrumentMiniInfo,
39 },
40 },
41 replay::run_tardis_machine_replay_from_config,
42};
43
44#[pymethods]
45impl ReplayNormalizedRequestOptions {
46 #[staticmethod]
47 #[pyo3(name = "from_json")]
48 fn py_from_json(data: Vec<u8>) -> Self {
49 serde_json::from_slice(&data).expect("Failed to parse JSON")
50 }
51
52 #[pyo3(name = "from_json_array")]
53 #[staticmethod]
54 fn py_from_json_array(data: Vec<u8>) -> Vec<Self> {
55 serde_json::from_slice(&data).expect("Failed to parse JSON array")
56 }
57}
58
59#[pymethods]
60impl StreamNormalizedRequestOptions {
61 #[staticmethod]
62 #[pyo3(name = "from_json")]
63 fn py_from_json(data: Vec<u8>) -> Self {
64 serde_json::from_slice(&data).expect("Failed to parse JSON")
65 }
66
67 #[pyo3(name = "from_json_array")]
68 #[staticmethod]
69 fn py_from_json_array(data: Vec<u8>) -> Vec<Self> {
70 serde_json::from_slice(&data).expect("Failed to parse JSON array")
71 }
72}
73
74#[pymethods]
75impl TardisMachineClient {
76 #[new]
77 #[pyo3(signature = (base_url=None, normalize_symbols=true, book_snapshot_output="deltas"))]
78 fn py_new(
79 base_url: Option<&str>,
80 normalize_symbols: bool,
81 book_snapshot_output: &str,
82 ) -> PyResult<Self> {
83 let output = match book_snapshot_output {
84 "depth10" => BookSnapshotOutput::Depth10,
85 "deltas" => BookSnapshotOutput::Deltas,
86 _ => {
87 return Err(to_pyruntime_err(anyhow::anyhow!(
88 "Invalid book_snapshot_output: '{book_snapshot_output}'. Expected 'depth10' or 'deltas'"
89 )));
90 }
91 };
92 Self::new(base_url, normalize_symbols, output).map_err(to_pyruntime_err)
93 }
94
95 #[pyo3(name = "is_closed")]
96 #[must_use]
97 pub fn py_is_closed(&self) -> bool {
98 self.is_closed()
99 }
100
101 #[pyo3(name = "close")]
102 fn py_close(&mut self) {
103 self.close();
104 }
105
106 #[pyo3(name = "replay")]
107 fn py_replay<'py>(
108 &self,
109 instruments: Vec<TardisInstrumentMiniInfo>,
110 options: Vec<ReplayNormalizedRequestOptions>,
111 callback: Py<PyAny>,
112 py: Python<'py>,
113 ) -> PyResult<Bound<'py, PyAny>> {
114 let map = if instruments.is_empty() {
115 self.instruments.clone()
116 } else {
117 let mut instrument_map: HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>> =
118 HashMap::new();
119 for inst in instruments {
120 let key = inst.as_tardis_instrument_key();
121 instrument_map.insert(key, Arc::new(inst.clone()));
122 }
123 instrument_map
124 };
125
126 let base_url = self.base_url.clone();
127 let replay_signal = self.replay_signal.clone();
128 let book_snapshot_output = self.book_snapshot_output.clone();
129
130 pyo3_async_runtimes::tokio::future_into_py(py, async move {
131 let stream = replay_normalized(&base_url, options, replay_signal)
132 .await
133 .map_err(to_pyruntime_err)?;
134
135 handle_python_stream(
138 Box::pin(stream),
139 callback,
140 None,
141 Some(map),
142 book_snapshot_output,
143 )
144 .await;
145 Ok(())
146 })
147 }
148
149 #[pyo3(name = "replay_bars")]
150 fn py_replay_bars<'py>(
151 &self,
152 instruments: Vec<TardisInstrumentMiniInfo>,
153 options: Vec<ReplayNormalizedRequestOptions>,
154 py: Python<'py>,
155 ) -> PyResult<Bound<'py, PyAny>> {
156 let map = if instruments.is_empty() {
157 self.instruments.clone()
158 } else {
159 instruments
160 .into_iter()
161 .map(|inst| (inst.as_tardis_instrument_key(), Arc::new(inst)))
162 .collect()
163 };
164
165 let base_url = self.base_url.clone();
166 let replay_signal = self.replay_signal.clone();
167 let book_snapshot_output = self.book_snapshot_output.clone();
168
169 pyo3_async_runtimes::tokio::future_into_py(py, async move {
170 let stream = replay_normalized(&base_url, options, replay_signal)
171 .await
172 .map_err(to_pyruntime_err)?;
173
174 pin_mut!(stream);
177
178 let mut bars: Vec<Bar> = Vec::new();
179
180 while let Some(result) = stream.next().await {
181 match result {
182 Ok(msg) => {
183 if let Some(Data::Bar(bar)) = determine_instrument_info(&msg, &map)
184 .and_then(|info| {
185 parse_tardis_ws_message(msg, info, &book_snapshot_output)
186 })
187 {
188 bars.push(bar);
189 }
190 }
191 Err(e) => {
192 tracing::error!("Error in WebSocket stream: {e:?}");
193 break;
194 }
195 }
196 }
197
198 Python::attach(|py| {
199 let pylist =
200 PyList::new(py, bars.into_iter().map(|bar| bar.into_py_any_unwrap(py)))
201 .expect("Invalid `ExactSizeIterator`");
202 Ok(pylist.into_py_any_unwrap(py))
203 })
204 })
205 }
206
207 #[pyo3(name = "stream")]
208 fn py_stream<'py>(
209 &self,
210 instruments: Vec<TardisInstrumentMiniInfo>,
211 options: Vec<StreamNormalizedRequestOptions>,
212 callback: Py<PyAny>,
213 py: Python<'py>,
214 ) -> PyResult<Bound<'py, PyAny>> {
215 let mut instrument_map: HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>> =
216 HashMap::new();
217 for inst in instruments {
218 let key = inst.as_tardis_instrument_key();
219 instrument_map.insert(key, Arc::new(inst.clone()));
220 }
221
222 let base_url = self.base_url.clone();
223 let replay_signal = self.replay_signal.clone();
224 let book_snapshot_output = self.book_snapshot_output.clone();
225
226 pyo3_async_runtimes::tokio::future_into_py(py, async move {
227 let stream = stream_normalized(&base_url, options, replay_signal)
228 .await
229 .map_err(to_pyruntime_err)?;
230
231 handle_python_stream(
234 Box::pin(stream),
235 callback,
236 None,
237 Some(instrument_map),
238 book_snapshot_output,
239 )
240 .await;
241 Ok(())
242 })
243 }
244}
245
246#[pyfunction]
252#[pyo3(name = "run_tardis_machine_replay")]
253#[pyo3(signature = (config_filepath))]
254pub fn py_run_tardis_machine_replay(
255 py: Python<'_>,
256 config_filepath: String,
257) -> PyResult<Bound<'_, PyAny>> {
258 tracing_subscriber::fmt()
259 .with_max_level(tracing::Level::DEBUG)
260 .init();
261
262 pyo3_async_runtimes::tokio::future_into_py(py, async move {
263 let config_filepath = Path::new(&config_filepath);
264 run_tardis_machine_replay_from_config(config_filepath)
265 .await
266 .map_err(to_pyruntime_err)?;
267 Ok(())
268 })
269}
270
271async fn handle_python_stream<S>(
272 stream: S,
273 callback: Py<PyAny>,
274 instrument: Option<Arc<TardisInstrumentMiniInfo>>,
275 instrument_map: Option<HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>>>,
276 book_snapshot_output: BookSnapshotOutput,
277) where
278 S: Stream<Item = Result<WsMessage, Error>> + Unpin,
279{
280 pin_mut!(stream);
281
282 let mut funding_rate_cache: AHashMap<InstrumentId, FundingRateUpdate> = AHashMap::new();
284
285 while let Some(result) = stream.next().await {
286 match result {
287 Ok(msg) => {
288 let info = instrument.clone().or_else(|| {
289 instrument_map
290 .as_ref()
291 .and_then(|map| determine_instrument_info(&msg, map))
292 });
293
294 if let Some(info) = info.clone() {
295 if let Some(data) =
296 parse_tardis_ws_message(msg.clone(), info.clone(), &book_snapshot_output)
297 {
298 Python::attach(|py| {
299 let py_obj = data_to_pycapsule(py, data);
300 call_python(py, &callback, py_obj);
301 });
302 } else if let Some(funding_rate) =
303 parse_tardis_ws_message_funding_rate(msg, info)
304 {
305 let should_emit = if let Some(cached_rate) =
307 funding_rate_cache.get(&funding_rate.instrument_id)
308 {
309 if cached_rate == &funding_rate {
311 false } else {
313 funding_rate_cache.insert(funding_rate.instrument_id, funding_rate);
314 true
315 }
316 } else {
317 funding_rate_cache.insert(funding_rate.instrument_id, funding_rate);
319 true
320 };
321
322 if should_emit {
323 Python::attach(|py| {
324 let py_obj = funding_rate.into_py_any_unwrap(py);
325 call_python(py, &callback, py_obj);
326 });
327 }
328 }
329 }
330 }
331 Err(e) => {
332 tracing::error!("Error in WebSocket stream: {e:?}");
333 break;
334 }
335 }
336 }
337}
338
339fn call_python(py: Python, callback: &Py<PyAny>, py_obj: Py<PyAny>) {
340 if let Err(e) = callback.call1(py, (py_obj,)) {
341 tracing::error!("Error calling Python: {e}");
342 }
343}