1use std::{collections::HashMap, path::Path, sync::Arc};
17
18use ahash::AHashMap;
19use futures_util::{Stream, StreamExt, pin_mut};
20use nautilus_core::python::{IntoPyObjectNautilusExt, to_pyruntime_err};
21use nautilus_model::{
22 data::{Bar, Data, funding::FundingRateUpdate},
23 identifiers::InstrumentId,
24 python::data::data_to_pycapsule,
25};
26use pyo3::{prelude::*, types::PyList};
27
28use crate::{
29 machine::{
30 Error,
31 client::{TardisMachineClient, determine_instrument_info},
32 message::WsMessage,
33 parse::{parse_tardis_ws_message, parse_tardis_ws_message_funding_rate},
34 replay_normalized, stream_normalized,
35 types::{
36 ReplayNormalizedRequestOptions, StreamNormalizedRequestOptions, TardisInstrumentKey,
37 TardisInstrumentMiniInfo,
38 },
39 },
40 replay::run_tardis_machine_replay_from_config,
41};
42
43#[pymethods]
44impl ReplayNormalizedRequestOptions {
45 #[staticmethod]
46 #[pyo3(name = "from_json")]
47 fn py_from_json(data: Vec<u8>) -> Self {
48 serde_json::from_slice(&data).expect("Failed to parse JSON")
49 }
50
51 #[pyo3(name = "from_json_array")]
52 #[staticmethod]
53 fn py_from_json_array(data: Vec<u8>) -> Vec<Self> {
54 serde_json::from_slice(&data).expect("Failed to parse JSON array")
55 }
56}
57
58#[pymethods]
59impl StreamNormalizedRequestOptions {
60 #[staticmethod]
61 #[pyo3(name = "from_json")]
62 fn py_from_json(data: Vec<u8>) -> Self {
63 serde_json::from_slice(&data).expect("Failed to parse JSON")
64 }
65
66 #[pyo3(name = "from_json_array")]
67 #[staticmethod]
68 fn py_from_json_array(data: Vec<u8>) -> Vec<Self> {
69 serde_json::from_slice(&data).expect("Failed to parse JSON array")
70 }
71}
72
73#[pymethods]
74impl TardisMachineClient {
75 #[new]
76 #[pyo3(signature = (base_url=None, normalize_symbols=true))]
77 fn py_new(base_url: Option<&str>, normalize_symbols: bool) -> PyResult<Self> {
78 Self::new(base_url, normalize_symbols).map_err(to_pyruntime_err)
79 }
80
81 #[pyo3(name = "is_closed")]
82 #[must_use]
83 pub fn py_is_closed(&self) -> bool {
84 self.is_closed()
85 }
86
87 #[pyo3(name = "close")]
88 fn py_close(&mut self) {
89 self.close();
90 }
91
92 #[pyo3(name = "replay")]
93 fn py_replay<'py>(
94 &self,
95 instruments: Vec<TardisInstrumentMiniInfo>,
96 options: Vec<ReplayNormalizedRequestOptions>,
97 callback: PyObject,
98 py: Python<'py>,
99 ) -> PyResult<Bound<'py, PyAny>> {
100 let map = if instruments.is_empty() {
101 self.instruments.clone()
102 } else {
103 let mut instrument_map: HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>> =
104 HashMap::new();
105 for inst in instruments {
106 let key = inst.as_tardis_instrument_key();
107 instrument_map.insert(key, Arc::new(inst.clone()));
108 }
109 instrument_map
110 };
111
112 let base_url = self.base_url.clone();
113 let replay_signal = self.replay_signal.clone();
114
115 pyo3_async_runtimes::tokio::future_into_py(py, async move {
116 let stream = replay_normalized(&base_url, options, replay_signal)
117 .await
118 .map_err(to_pyruntime_err)?;
119
120 handle_python_stream(Box::pin(stream), callback, None, Some(map)).await;
123 Ok(())
124 })
125 }
126
127 #[pyo3(name = "replay_bars")]
128 fn py_replay_bars<'py>(
129 &self,
130 instruments: Vec<TardisInstrumentMiniInfo>,
131 options: Vec<ReplayNormalizedRequestOptions>,
132 py: Python<'py>,
133 ) -> PyResult<Bound<'py, PyAny>> {
134 let map = if instruments.is_empty() {
135 self.instruments.clone()
136 } else {
137 instruments
138 .into_iter()
139 .map(|inst| (inst.as_tardis_instrument_key(), Arc::new(inst)))
140 .collect()
141 };
142
143 let base_url = self.base_url.clone();
144 let replay_signal = self.replay_signal.clone();
145
146 pyo3_async_runtimes::tokio::future_into_py(py, async move {
147 let stream = replay_normalized(&base_url, options, replay_signal)
148 .await
149 .map_err(to_pyruntime_err)?;
150
151 pin_mut!(stream);
154
155 let mut bars: Vec<Bar> = Vec::new();
156
157 while let Some(result) = stream.next().await {
158 match result {
159 Ok(msg) => {
160 if let Some(Data::Bar(bar)) = determine_instrument_info(&msg, &map)
161 .and_then(|info| parse_tardis_ws_message(msg, info))
162 {
163 bars.push(bar);
164 }
165 }
166 Err(e) => {
167 tracing::error!("Error in WebSocket stream: {e:?}");
168 break;
169 }
170 }
171 }
172
173 Python::with_gil(|py| {
174 let pylist =
175 PyList::new(py, bars.into_iter().map(|bar| bar.into_py_any_unwrap(py)))
176 .expect("Invalid `ExactSizeIterator`");
177 Ok(pylist.into_py_any_unwrap(py))
178 })
179 })
180 }
181
182 #[pyo3(name = "stream")]
183 fn py_stream<'py>(
184 &self,
185 instruments: Vec<TardisInstrumentMiniInfo>,
186 options: Vec<StreamNormalizedRequestOptions>,
187 callback: PyObject,
188 py: Python<'py>,
189 ) -> PyResult<Bound<'py, PyAny>> {
190 let mut instrument_map: HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>> =
191 HashMap::new();
192 for inst in instruments {
193 let key = inst.as_tardis_instrument_key();
194 instrument_map.insert(key, Arc::new(inst.clone()));
195 }
196
197 let base_url = self.base_url.clone();
198 let replay_signal = self.replay_signal.clone();
199
200 pyo3_async_runtimes::tokio::future_into_py(py, async move {
201 let stream = stream_normalized(&base_url, options, replay_signal)
202 .await
203 .map_err(to_pyruntime_err)?;
204
205 handle_python_stream(Box::pin(stream), callback, None, Some(instrument_map)).await;
208 Ok(())
209 })
210 }
211}
212
213#[pyfunction]
219#[pyo3(name = "run_tardis_machine_replay")]
220#[pyo3(signature = (config_filepath))]
221pub fn py_run_tardis_machine_replay(
222 py: Python<'_>,
223 config_filepath: String,
224) -> PyResult<Bound<'_, PyAny>> {
225 tracing_subscriber::fmt()
226 .with_max_level(tracing::Level::DEBUG)
227 .init();
228
229 pyo3_async_runtimes::tokio::future_into_py(py, async move {
230 let config_filepath = Path::new(&config_filepath);
231 run_tardis_machine_replay_from_config(config_filepath)
232 .await
233 .map_err(to_pyruntime_err)?;
234 Ok(())
235 })
236}
237
238async fn handle_python_stream<S>(
239 stream: S,
240 callback: PyObject,
241 instrument: Option<Arc<TardisInstrumentMiniInfo>>,
242 instrument_map: Option<HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>>>,
243) where
244 S: Stream<Item = Result<WsMessage, Error>> + Unpin,
245{
246 pin_mut!(stream);
247
248 let mut funding_rate_cache: AHashMap<InstrumentId, FundingRateUpdate> = AHashMap::new();
250
251 while let Some(result) = stream.next().await {
252 match result {
253 Ok(msg) => {
254 let info = instrument.clone().or_else(|| {
255 instrument_map
256 .as_ref()
257 .and_then(|map| determine_instrument_info(&msg, map))
258 });
259
260 if let Some(info) = info.clone() {
261 if let Some(data) = parse_tardis_ws_message(msg.clone(), info.clone()) {
262 Python::with_gil(|py| {
263 let py_obj = data_to_pycapsule(py, data);
264 call_python(py, &callback, py_obj);
265 });
266 } else if let Some(funding_rate) =
267 parse_tardis_ws_message_funding_rate(msg, info)
268 {
269 let should_emit = if let Some(cached_rate) =
271 funding_rate_cache.get(&funding_rate.instrument_id)
272 {
273 if cached_rate == &funding_rate {
275 false } else {
277 funding_rate_cache.insert(funding_rate.instrument_id, funding_rate);
278 true
279 }
280 } else {
281 funding_rate_cache.insert(funding_rate.instrument_id, funding_rate);
283 true
284 };
285
286 if should_emit {
287 Python::with_gil(|py| {
288 let py_obj = funding_rate.into_py_any_unwrap(py);
289 call_python(py, &callback, py_obj);
290 });
291 }
292 }
293 }
294 }
295 Err(e) => {
296 tracing::error!("Error in WebSocket stream: {e:?}");
297 break;
298 }
299 }
300 }
301}
302
303fn call_python(py: Python, callback: &PyObject, py_obj: PyObject) {
304 if let Err(e) = callback.call1(py, (py_obj,)) {
305 tracing::error!("Error calling Python: {e}");
306 }
307}