1use std::{
17 collections::HashMap,
18 fs,
19 path::{Path, PathBuf},
20};
21
22use arrow::array::RecordBatch;
23use chrono::{DateTime, Duration, NaiveDate};
24use futures_util::{StreamExt, future::join_all, pin_mut};
25use heck::ToSnakeCase;
26use nautilus_core::{UnixNanos, parsing::precision_from_str};
27use nautilus_model::{
28 data::{
29 Bar, BarType, Data, OrderBookDelta, OrderBookDeltas_API, OrderBookDepth10, QuoteTick,
30 TradeTick,
31 },
32 identifiers::InstrumentId,
33};
34use nautilus_serialization::{
35 arrow::{
36 bars_to_arrow_record_batch_bytes, book_deltas_to_arrow_record_batch_bytes,
37 book_depth10_to_arrow_record_batch_bytes, quotes_to_arrow_record_batch_bytes,
38 trades_to_arrow_record_batch_bytes,
39 },
40 parquet::write_batch_to_parquet,
41};
42use thousands::Separable;
43use ustr::Ustr;
44
45use super::{enums::Exchange, http::models::InstrumentInfo};
46use crate::{
47 config::TardisReplayConfig,
48 http::TardisHttpClient,
49 machine::{TardisMachineClient, types::InstrumentMiniInfo},
50 parse::{normalize_instrument_id, parse_instrument_id},
51};
52
53struct DateCursor {
54 date_utc: NaiveDate,
56 end_ns: UnixNanos,
58}
59
60impl DateCursor {
61 fn new(current_ns: UnixNanos) -> Self {
63 let current_utc = DateTime::from_timestamp_nanos(current_ns.as_i64());
64 let date_utc = current_utc.date_naive();
65
66 let end_utc =
69 date_utc.and_hms_opt(23, 59, 59).unwrap() + Duration::nanoseconds(999_999_999);
70 let end_ns = UnixNanos::from(end_utc.and_utc().timestamp_nanos_opt().unwrap() as u64);
71
72 Self { date_utc, end_ns }
73 }
74}
75
76async fn gather_instruments_info(
77 config: &TardisReplayConfig,
78 http_client: &TardisHttpClient,
79) -> HashMap<Exchange, Vec<InstrumentInfo>> {
80 let futures = config.options.iter().map(|options| {
81 let exchange = options.exchange.clone();
82 let client = &http_client;
83
84 tracing::info!("Requesting instruments for {exchange}");
85
86 async move {
87 match client.instruments_info(exchange.clone(), None, None).await {
88 Ok(instruments) => Some((exchange, instruments)),
89 Err(e) => {
90 tracing::error!("Error fetching instruments for {exchange}: {e}");
91 None
92 }
93 }
94 }
95 });
96
97 let results: Vec<(Exchange, Vec<InstrumentInfo>)> =
98 join_all(futures).await.into_iter().flatten().collect();
99
100 tracing::info!("Received all instruments");
101
102 results.into_iter().collect()
103}
104
105pub async fn run_tardis_machine_replay_from_config(config_filepath: &Path) -> anyhow::Result<()> {
106 tracing::info!("Starting replay");
107 tracing::info!("Config filepath: {config_filepath:?}");
108
109 let config_data = fs::read_to_string(config_filepath).expect("Failed to read config file");
110 let config: TardisReplayConfig =
111 serde_json::from_str(&config_data).expect("Failed to parse config JSON");
112
113 let path = config
114 .output_path
115 .as_deref()
116 .map(Path::new)
117 .map(Path::to_path_buf)
118 .or_else(|| {
119 std::env::var("NAUTILUS_CATALOG_PATH")
120 .ok()
121 .map(|env_path| PathBuf::from(env_path).join("data"))
122 })
123 .unwrap_or_else(|| std::env::current_dir().expect("Failed to get current directory"));
124
125 tracing::info!("Output path: {path:?}");
126
127 let normalize_symbols = config.normalize_symbols.unwrap_or(true);
128 tracing::info!("normalize_symbols={normalize_symbols}");
129
130 let http_client = TardisHttpClient::new(None, None, None, normalize_symbols)?;
131 let mut machine_client =
132 TardisMachineClient::new(config.tardis_ws_url.as_deref(), normalize_symbols)?;
133
134 let info_map = gather_instruments_info(&config, &http_client).await;
135
136 for (exchange, instruments) in &info_map {
137 for inst in instruments {
138 let instrument_type = inst.instrument_type.clone();
139 let price_precision = precision_from_str(&inst.price_increment.to_string());
140 let size_precision = precision_from_str(&inst.amount_increment.to_string());
141
142 let instrument_id = if normalize_symbols {
143 normalize_instrument_id(exchange, inst.id, &instrument_type, inst.inverse)
144 } else {
145 parse_instrument_id(exchange, inst.id)
146 };
147
148 let info = InstrumentMiniInfo::new(
149 instrument_id,
150 Some(Ustr::from(&inst.id)),
151 exchange.clone(),
152 price_precision,
153 size_precision,
154 );
155 machine_client.add_instrument_info(info);
156 }
157 }
158
159 tracing::info!("Starting tardis-machine stream");
160 let stream = machine_client.replay(config.options).await;
161 pin_mut!(stream);
162
163 let mut deltas_cursors: HashMap<InstrumentId, DateCursor> = HashMap::new();
165 let mut depths_cursors: HashMap<InstrumentId, DateCursor> = HashMap::new();
166 let mut quotes_cursors: HashMap<InstrumentId, DateCursor> = HashMap::new();
167 let mut trades_cursors: HashMap<InstrumentId, DateCursor> = HashMap::new();
168 let mut bars_cursors: HashMap<BarType, DateCursor> = HashMap::new();
169
170 let mut deltas_map: HashMap<InstrumentId, Vec<OrderBookDelta>> = HashMap::new();
172 let mut depths_map: HashMap<InstrumentId, Vec<OrderBookDepth10>> = HashMap::new();
173 let mut quotes_map: HashMap<InstrumentId, Vec<QuoteTick>> = HashMap::new();
174 let mut trades_map: HashMap<InstrumentId, Vec<TradeTick>> = HashMap::new();
175 let mut bars_map: HashMap<BarType, Vec<Bar>> = HashMap::new();
176
177 let mut msg_count = 0;
178
179 while let Some(msg) = stream.next().await {
180 match msg {
181 Data::Deltas(msg) => {
182 handle_deltas_msg(msg, &mut deltas_map, &mut deltas_cursors, &path);
183 }
184 Data::Depth10(msg) => {
185 handle_depth10_msg(*msg, &mut depths_map, &mut depths_cursors, &path);
186 }
187 Data::Quote(msg) => handle_quote_msg(msg, &mut quotes_map, &mut quotes_cursors, &path),
188 Data::Trade(msg) => handle_trade_msg(msg, &mut trades_map, &mut trades_cursors, &path),
189 Data::Bar(msg) => handle_bar_msg(msg, &mut bars_map, &mut bars_cursors, &path),
190 Data::Delta(_) => panic!("Individual delta message not implemented (or required)"),
191 _ => panic!("Not implemented"),
192 }
193
194 msg_count += 1;
195 if msg_count % 100_000 == 0 {
196 tracing::debug!("Processed {} messages", msg_count.separate_with_commas());
197 }
198 }
199
200 for (instrument_id, deltas) in deltas_map {
203 let cursor = deltas_cursors.get(&instrument_id).expect("Expected cursor");
204 batch_and_write_deltas(deltas, &instrument_id, cursor.date_utc, &path);
205 }
206
207 for (instrument_id, depths) in depths_map {
208 let cursor = depths_cursors.get(&instrument_id).expect("Expected cursor");
209 batch_and_write_depths(depths, &instrument_id, cursor.date_utc, &path);
210 }
211
212 for (instrument_id, quotes) in quotes_map {
213 let cursor = quotes_cursors.get(&instrument_id).expect("Expected cursor");
214 batch_and_write_quotes(quotes, &instrument_id, cursor.date_utc, &path);
215 }
216
217 for (instrument_id, trades) in trades_map {
218 let cursor = trades_cursors.get(&instrument_id).expect("Expected cursor");
219 batch_and_write_trades(trades, &instrument_id, cursor.date_utc, &path);
220 }
221
222 for (bar_type, bars) in bars_map {
223 let cursor = bars_cursors.get(&bar_type).expect("Expected cursor");
224 batch_and_write_bars(bars, &bar_type, cursor.date_utc, &path);
225 }
226
227 tracing::info!(
228 "Replay completed after {} messages",
229 msg_count.separate_with_commas()
230 );
231 Ok(())
232}
233
234fn handle_deltas_msg(
235 deltas: OrderBookDeltas_API,
236 map: &mut HashMap<InstrumentId, Vec<OrderBookDelta>>,
237 cursors: &mut HashMap<InstrumentId, DateCursor>,
238 path: &Path,
239) {
240 let cursor = cursors
241 .entry(deltas.instrument_id)
242 .or_insert_with(|| DateCursor::new(deltas.ts_init));
243
244 if deltas.ts_init > cursor.end_ns {
245 if let Some(deltas_vec) = map.remove(&deltas.instrument_id) {
246 batch_and_write_deltas(deltas_vec, &deltas.instrument_id, cursor.date_utc, path);
247 }
248 *cursor = DateCursor::new(deltas.ts_init);
250 }
251
252 map.entry(deltas.instrument_id)
253 .or_insert_with(|| Vec::with_capacity(1_000_000))
254 .extend(&*deltas.deltas);
255}
256
257fn handle_depth10_msg(
258 depth10: OrderBookDepth10,
259 map: &mut HashMap<InstrumentId, Vec<OrderBookDepth10>>,
260 cursors: &mut HashMap<InstrumentId, DateCursor>,
261 path: &Path,
262) {
263 let cursor = cursors
264 .entry(depth10.instrument_id)
265 .or_insert_with(|| DateCursor::new(depth10.ts_init));
266
267 if depth10.ts_init > cursor.end_ns {
268 if let Some(depths_vec) = map.remove(&depth10.instrument_id) {
269 batch_and_write_depths(depths_vec, &depth10.instrument_id, cursor.date_utc, path);
270 }
271 *cursor = DateCursor::new(depth10.ts_init);
273 }
274
275 map.entry(depth10.instrument_id)
276 .or_insert_with(|| Vec::with_capacity(1_000_000))
277 .push(depth10);
278}
279
280fn handle_quote_msg(
281 quote: QuoteTick,
282 map: &mut HashMap<InstrumentId, Vec<QuoteTick>>,
283 cursors: &mut HashMap<InstrumentId, DateCursor>,
284 path: &Path,
285) {
286 let cursor = cursors
287 .entry(quote.instrument_id)
288 .or_insert_with(|| DateCursor::new(quote.ts_init));
289
290 if quote.ts_init > cursor.end_ns {
291 if let Some(quotes_vec) = map.remove("e.instrument_id) {
292 batch_and_write_quotes(quotes_vec, "e.instrument_id, cursor.date_utc, path);
293 }
294 *cursor = DateCursor::new(quote.ts_init);
296 }
297
298 map.entry(quote.instrument_id)
299 .or_insert_with(|| Vec::with_capacity(1_000_000))
300 .push(quote);
301}
302
303fn handle_trade_msg(
304 trade: TradeTick,
305 map: &mut HashMap<InstrumentId, Vec<TradeTick>>,
306 cursors: &mut HashMap<InstrumentId, DateCursor>,
307 path: &Path,
308) {
309 let cursor = cursors
310 .entry(trade.instrument_id)
311 .or_insert_with(|| DateCursor::new(trade.ts_init));
312
313 if trade.ts_init > cursor.end_ns {
314 if let Some(trades_vec) = map.remove(&trade.instrument_id) {
315 batch_and_write_trades(trades_vec, &trade.instrument_id, cursor.date_utc, path);
316 }
317 *cursor = DateCursor::new(trade.ts_init);
319 }
320
321 map.entry(trade.instrument_id)
322 .or_insert_with(|| Vec::with_capacity(1_000_000))
323 .push(trade);
324}
325
326fn handle_bar_msg(
327 bar: Bar,
328 map: &mut HashMap<BarType, Vec<Bar>>,
329 cursors: &mut HashMap<BarType, DateCursor>,
330 path: &Path,
331) {
332 let cursor = cursors
333 .entry(bar.bar_type)
334 .or_insert_with(|| DateCursor::new(bar.ts_init));
335
336 if bar.ts_init > cursor.end_ns {
337 if let Some(bars_vec) = map.remove(&bar.bar_type) {
338 batch_and_write_bars(bars_vec, &bar.bar_type, cursor.date_utc, path);
339 }
340 *cursor = DateCursor::new(bar.ts_init);
342 }
343
344 map.entry(bar.bar_type)
345 .or_insert_with(|| Vec::with_capacity(1_000_000))
346 .push(bar);
347}
348
349fn batch_and_write_deltas(
350 deltas: Vec<OrderBookDelta>,
351 instrument_id: &InstrumentId,
352 date: NaiveDate,
353 path: &Path,
354) {
355 let typename = stringify!(OrderBookDeltas);
356 match book_deltas_to_arrow_record_batch_bytes(deltas) {
357 Ok(batch) => write_batch(batch, typename, instrument_id, date, path),
358 Err(e) => {
359 tracing::error!("Error converting `{typename}` to Arrow: {e:?}");
360 }
361 }
362}
363
364fn batch_and_write_depths(
365 depths: Vec<OrderBookDepth10>,
366 instrument_id: &InstrumentId,
367 date: NaiveDate,
368 path: &Path,
369) {
370 let typename = stringify!(OrderBookDepth10);
371 match book_depth10_to_arrow_record_batch_bytes(depths) {
372 Ok(batch) => write_batch(batch, typename, instrument_id, date, path),
373 Err(e) => {
374 tracing::error!("Error converting `{typename}` to Arrow: {e:?}");
375 }
376 }
377}
378
379fn batch_and_write_quotes(
380 quotes: Vec<QuoteTick>,
381 instrument_id: &InstrumentId,
382 date: NaiveDate,
383 path: &Path,
384) {
385 let typename = stringify!(QuoteTick);
386 match quotes_to_arrow_record_batch_bytes(quotes) {
387 Ok(batch) => write_batch(batch, typename, instrument_id, date, path),
388 Err(e) => {
389 tracing::error!("Error converting `{typename}` to Arrow: {e:?}");
390 }
391 }
392}
393
394fn batch_and_write_trades(
395 trades: Vec<TradeTick>,
396 instrument_id: &InstrumentId,
397 date: NaiveDate,
398 path: &Path,
399) {
400 let typename = stringify!(TradeTick);
401 match trades_to_arrow_record_batch_bytes(trades) {
402 Ok(batch) => write_batch(batch, typename, instrument_id, date, path),
403 Err(e) => {
404 tracing::error!("Error converting `{typename}` to Arrow: {e:?}");
405 }
406 }
407}
408
409fn batch_and_write_bars(bars: Vec<Bar>, bar_type: &BarType, date: NaiveDate, path: &Path) {
410 let typename = stringify!(Bar);
411 let batch = match bars_to_arrow_record_batch_bytes(bars) {
412 Ok(batch) => batch,
413 Err(e) => {
414 tracing::error!("Error converting `{typename}` to Arrow: {e:?}");
415 return;
416 }
417 };
418
419 let filepath = path.join(parquet_filepath_bars(bar_type, date));
420 match write_batch_to_parquet(batch, &filepath, None, None, None) {
421 Ok(()) => tracing::info!("File written: {filepath:?}"),
422 Err(e) => tracing::error!("Error writing {filepath:?}: {e:?}"),
423 }
424}
425
426fn parquet_filepath(typename: &str, instrument_id: &InstrumentId, date: NaiveDate) -> PathBuf {
427 let typename = typename.to_snake_case();
428 let instrument_id_str = instrument_id.to_string().replace('/', "");
429 let date_str = date.to_string().replace('-', "");
430 PathBuf::new()
431 .join(typename)
432 .join(instrument_id_str)
433 .join(format!("{date_str}.parquet"))
434}
435
436fn parquet_filepath_bars(bar_type: &BarType, date: NaiveDate) -> PathBuf {
437 let bar_type_str = bar_type.to_string().replace('/', "");
438 let date_str = date.to_string().replace('-', "");
439 PathBuf::new()
440 .join("bar")
441 .join(bar_type_str)
442 .join(format!("{date_str}.parquet"))
443}
444
445fn write_batch(
446 batch: RecordBatch,
447 typename: &str,
448 instrument_id: &InstrumentId,
449 date: NaiveDate,
450 path: &Path,
451) {
452 let filepath = path.join(parquet_filepath(typename, instrument_id, date));
453 match write_batch_to_parquet(batch, &filepath, None, None, None) {
454 Ok(()) => tracing::info!("File written: {filepath:?}"),
455 Err(e) => tracing::error!("Error writing {filepath:?}: {e:?}"),
456 }
457}
458
459#[cfg(test)]
463mod tests {
464 use chrono::{TimeZone, Utc};
465 use rstest::rstest;
466
467 use super::*;
468
469 #[rstest]
470 #[case(
471 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap().timestamp_nanos_opt().unwrap() as u64,
473 NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
474 Utc.with_ymd_and_hms(2024, 1, 1, 23, 59, 59).unwrap().timestamp_nanos_opt().unwrap() as u64 + 999_999_999
475)]
476 #[case(
477 Utc.with_ymd_and_hms(2024, 1, 1, 12, 0, 0).unwrap().timestamp_nanos_opt().unwrap() as u64,
479 NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
480 Utc.with_ymd_and_hms(2024, 1, 1, 23, 59, 59).unwrap().timestamp_nanos_opt().unwrap() as u64 + 999_999_999
481)]
482 #[case(
483 Utc.with_ymd_and_hms(2024, 1, 1, 23, 59, 59).unwrap().timestamp_nanos_opt().unwrap() as u64 + 999_999_999,
485 NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
486 Utc.with_ymd_and_hms(2024, 1, 1, 23, 59, 59).unwrap().timestamp_nanos_opt().unwrap() as u64 + 999_999_999
487)]
488 #[case(
489 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap().timestamp_nanos_opt().unwrap() as u64,
491 NaiveDate::from_ymd_opt(2024, 1, 2).unwrap(),
492 Utc.with_ymd_and_hms(2024, 1, 2, 23, 59, 59).unwrap().timestamp_nanos_opt().unwrap() as u64 + 999_999_999
493)]
494 fn test_date_cursor(
495 #[case] timestamp: u64,
496 #[case] expected_date: NaiveDate,
497 #[case] expected_end_ns: u64,
498 ) {
499 let unix_nanos = UnixNanos::from(timestamp);
500 let cursor = DateCursor::new(unix_nanos);
501
502 assert_eq!(cursor.date_utc, expected_date);
503 assert_eq!(cursor.end_ns, UnixNanos::from(expected_end_ns));
504 }
505}