nautilus_tardis/
replay.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    collections::HashMap,
18    fs,
19    path::{Path, PathBuf},
20};
21
22use anyhow::Context;
23use arrow::record_batch::RecordBatch;
24use chrono::{DateTime, Duration, NaiveDate};
25use futures_util::{StreamExt, future::join_all, pin_mut};
26use heck::ToSnakeCase;
27use nautilus_core::{UnixNanos, parsing::precision_from_str};
28use nautilus_model::{
29    data::{
30        Bar, BarType, Data, OrderBookDelta, OrderBookDeltas_API, OrderBookDepth10, QuoteTick,
31        TradeTick,
32    },
33    identifiers::InstrumentId,
34};
35use nautilus_serialization::arrow::{
36    bars_to_arrow_record_batch_bytes, book_deltas_to_arrow_record_batch_bytes,
37    book_depth10_to_arrow_record_batch_bytes, quotes_to_arrow_record_batch_bytes,
38    trades_to_arrow_record_batch_bytes,
39};
40use parquet::{arrow::ArrowWriter, basic::Compression, file::properties::WriterProperties};
41use thousands::Separable;
42use ustr::Ustr;
43
44use super::{enums::TardisExchange, http::models::TardisInstrumentInfo};
45use crate::{
46    config::TardisReplayConfig,
47    http::TardisHttpClient,
48    machine::{TardisMachineClient, types::TardisInstrumentMiniInfo},
49    parse::{normalize_instrument_id, parse_instrument_id},
50};
51
52struct DateCursor {
53    /// Cursor date UTC.
54    date_utc: NaiveDate,
55    /// Cursor end timestamp UNIX nanoseconds.
56    end_ns: UnixNanos,
57}
58
59impl DateCursor {
60    /// Creates a new [`DateCursor`] instance.
61    fn new(current_ns: UnixNanos) -> Self {
62        let current_utc = DateTime::from_timestamp_nanos(current_ns.as_i64());
63        let date_utc = current_utc.date_naive();
64
65        // Calculate end of the current UTC day
66        // SAFETY: Known safe input values
67        let end_utc =
68            date_utc.and_hms_opt(23, 59, 59).unwrap() + Duration::nanoseconds(999_999_999);
69        let end_ns = UnixNanos::from(end_utc.and_utc().timestamp_nanos_opt().unwrap() as u64);
70
71        Self { date_utc, end_ns }
72    }
73}
74
75async fn gather_instruments_info(
76    config: &TardisReplayConfig,
77    http_client: &TardisHttpClient,
78) -> HashMap<TardisExchange, Vec<TardisInstrumentInfo>> {
79    let futures = config.options.iter().map(|options| {
80        let exchange = options.exchange;
81        let client = &http_client;
82
83        tracing::info!("Requesting instruments for {exchange}");
84
85        async move {
86            match client.instruments_info(exchange, None, None).await {
87                Ok(instruments) => Some((exchange, instruments)),
88                Err(e) => {
89                    tracing::error!("Error fetching instruments for {exchange}: {e}");
90                    None
91                }
92            }
93        }
94    });
95
96    let results: Vec<(TardisExchange, Vec<TardisInstrumentInfo>)> =
97        join_all(futures).await.into_iter().flatten().collect();
98
99    tracing::info!("Received all instruments");
100
101    results.into_iter().collect()
102}
103
104/// Run the Tardis Machine replay from a JSON configuration file.
105///
106/// # Errors
107///
108/// Returns an error if reading or parsing the config file fails,
109/// or if any downstream replay operation fails.
110/// Run the Tardis Machine replay from a JSON configuration file.
111///
112/// # Panics
113///
114/// Panics if unable to determine the output path (current directory fallback fails).
115pub async fn run_tardis_machine_replay_from_config(config_filepath: &Path) -> anyhow::Result<()> {
116    tracing::info!("Starting replay");
117    tracing::info!("Config filepath: {config_filepath:?}");
118
119    // Load and parse the replay configuration
120    let config_data = fs::read_to_string(config_filepath)
121        .with_context(|| format!("Failed to read config file: {config_filepath:?}"))?;
122    let config: TardisReplayConfig = serde_json::from_str(&config_data)
123        .context("Failed to parse config JSON into TardisReplayConfig")?;
124
125    let path = config
126        .output_path
127        .as_deref()
128        .map(Path::new)
129        .map(Path::to_path_buf)
130        .or_else(|| {
131            std::env::var("NAUTILUS_PATH")
132                .ok()
133                .map(|env_path| PathBuf::from(env_path).join("catalog").join("data"))
134        })
135        .unwrap_or_else(|| std::env::current_dir().expect("Failed to get current directory"));
136
137    tracing::info!("Output path: {path:?}");
138
139    let normalize_symbols = config.normalize_symbols.unwrap_or(true);
140    tracing::info!("normalize_symbols={normalize_symbols}");
141
142    let http_client = TardisHttpClient::new(None, None, None, normalize_symbols)?;
143    let mut machine_client =
144        TardisMachineClient::new(config.tardis_ws_url.as_deref(), normalize_symbols)?;
145
146    let info_map = gather_instruments_info(&config, &http_client).await;
147
148    for (exchange, instruments) in &info_map {
149        for inst in instruments {
150            let instrument_type = inst.instrument_type;
151            let price_precision = precision_from_str(&inst.price_increment.to_string());
152            let size_precision = precision_from_str(&inst.amount_increment.to_string());
153
154            let instrument_id = if normalize_symbols {
155                normalize_instrument_id(exchange, inst.id, &instrument_type, inst.inverse)
156            } else {
157                parse_instrument_id(exchange, inst.id)
158            };
159
160            let info = TardisInstrumentMiniInfo::new(
161                instrument_id,
162                Some(Ustr::from(&inst.id)),
163                *exchange,
164                price_precision,
165                size_precision,
166            );
167            machine_client.add_instrument_info(info);
168        }
169    }
170
171    tracing::info!("Starting tardis-machine stream");
172    let stream = machine_client.replay(config.options).await;
173    pin_mut!(stream);
174
175    // Initialize date cursors
176    let mut deltas_cursors: HashMap<InstrumentId, DateCursor> = HashMap::new();
177    let mut depths_cursors: HashMap<InstrumentId, DateCursor> = HashMap::new();
178    let mut quotes_cursors: HashMap<InstrumentId, DateCursor> = HashMap::new();
179    let mut trades_cursors: HashMap<InstrumentId, DateCursor> = HashMap::new();
180    let mut bars_cursors: HashMap<BarType, DateCursor> = HashMap::new();
181
182    // Initialize date collection maps
183    let mut deltas_map: HashMap<InstrumentId, Vec<OrderBookDelta>> = HashMap::new();
184    let mut depths_map: HashMap<InstrumentId, Vec<OrderBookDepth10>> = HashMap::new();
185    let mut quotes_map: HashMap<InstrumentId, Vec<QuoteTick>> = HashMap::new();
186    let mut trades_map: HashMap<InstrumentId, Vec<TradeTick>> = HashMap::new();
187    let mut bars_map: HashMap<BarType, Vec<Bar>> = HashMap::new();
188
189    let mut msg_count = 0;
190
191    while let Some(msg) = stream.next().await {
192        match msg {
193            Data::Deltas(msg) => {
194                handle_deltas_msg(msg, &mut deltas_map, &mut deltas_cursors, &path);
195            }
196            Data::Depth10(msg) => {
197                handle_depth10_msg(*msg, &mut depths_map, &mut depths_cursors, &path);
198            }
199            Data::Quote(msg) => handle_quote_msg(msg, &mut quotes_map, &mut quotes_cursors, &path),
200            Data::Trade(msg) => handle_trade_msg(msg, &mut trades_map, &mut trades_cursors, &path),
201            Data::Bar(msg) => handle_bar_msg(msg, &mut bars_map, &mut bars_cursors, &path),
202            Data::Delta(_) => panic!("Individual delta message not implemented (or required)"),
203            _ => panic!("Not implemented"),
204        }
205
206        msg_count += 1;
207        if msg_count % 100_000 == 0 {
208            tracing::debug!("Processed {} messages", msg_count.separate_with_commas());
209        }
210    }
211
212    // Iterate through every remaining type and instrument sequentially
213
214    for (instrument_id, deltas) in deltas_map {
215        let cursor = deltas_cursors.get(&instrument_id).expect("Expected cursor");
216        batch_and_write_deltas(deltas, &instrument_id, cursor.date_utc, &path);
217    }
218
219    for (instrument_id, depths) in depths_map {
220        let cursor = depths_cursors.get(&instrument_id).expect("Expected cursor");
221        batch_and_write_depths(depths, &instrument_id, cursor.date_utc, &path);
222    }
223
224    for (instrument_id, quotes) in quotes_map {
225        let cursor = quotes_cursors.get(&instrument_id).expect("Expected cursor");
226        batch_and_write_quotes(quotes, &instrument_id, cursor.date_utc, &path);
227    }
228
229    for (instrument_id, trades) in trades_map {
230        let cursor = trades_cursors.get(&instrument_id).expect("Expected cursor");
231        batch_and_write_trades(trades, &instrument_id, cursor.date_utc, &path);
232    }
233
234    for (bar_type, bars) in bars_map {
235        let cursor = bars_cursors.get(&bar_type).expect("Expected cursor");
236        batch_and_write_bars(bars, &bar_type, cursor.date_utc, &path);
237    }
238
239    tracing::info!(
240        "Replay completed after {} messages",
241        msg_count.separate_with_commas()
242    );
243    Ok(())
244}
245
246fn handle_deltas_msg(
247    deltas: OrderBookDeltas_API,
248    map: &mut HashMap<InstrumentId, Vec<OrderBookDelta>>,
249    cursors: &mut HashMap<InstrumentId, DateCursor>,
250    path: &Path,
251) {
252    let cursor = cursors
253        .entry(deltas.instrument_id)
254        .or_insert_with(|| DateCursor::new(deltas.ts_init));
255
256    if deltas.ts_init > cursor.end_ns {
257        if let Some(deltas_vec) = map.remove(&deltas.instrument_id) {
258            batch_and_write_deltas(deltas_vec, &deltas.instrument_id, cursor.date_utc, path);
259        }
260        // Update cursor
261        *cursor = DateCursor::new(deltas.ts_init);
262    }
263
264    map.entry(deltas.instrument_id)
265        .or_insert_with(|| Vec::with_capacity(1_000_000))
266        .extend(&*deltas.deltas);
267}
268
269fn handle_depth10_msg(
270    depth10: OrderBookDepth10,
271    map: &mut HashMap<InstrumentId, Vec<OrderBookDepth10>>,
272    cursors: &mut HashMap<InstrumentId, DateCursor>,
273    path: &Path,
274) {
275    let cursor = cursors
276        .entry(depth10.instrument_id)
277        .or_insert_with(|| DateCursor::new(depth10.ts_init));
278
279    if depth10.ts_init > cursor.end_ns {
280        if let Some(depths_vec) = map.remove(&depth10.instrument_id) {
281            batch_and_write_depths(depths_vec, &depth10.instrument_id, cursor.date_utc, path);
282        }
283        // Update cursor
284        *cursor = DateCursor::new(depth10.ts_init);
285    }
286
287    map.entry(depth10.instrument_id)
288        .or_insert_with(|| Vec::with_capacity(1_000_000))
289        .push(depth10);
290}
291
292fn handle_quote_msg(
293    quote: QuoteTick,
294    map: &mut HashMap<InstrumentId, Vec<QuoteTick>>,
295    cursors: &mut HashMap<InstrumentId, DateCursor>,
296    path: &Path,
297) {
298    let cursor = cursors
299        .entry(quote.instrument_id)
300        .or_insert_with(|| DateCursor::new(quote.ts_init));
301
302    if quote.ts_init > cursor.end_ns {
303        if let Some(quotes_vec) = map.remove(&quote.instrument_id) {
304            batch_and_write_quotes(quotes_vec, &quote.instrument_id, cursor.date_utc, path);
305        }
306        // Update cursor
307        *cursor = DateCursor::new(quote.ts_init);
308    }
309
310    map.entry(quote.instrument_id)
311        .or_insert_with(|| Vec::with_capacity(1_000_000))
312        .push(quote);
313}
314
315fn handle_trade_msg(
316    trade: TradeTick,
317    map: &mut HashMap<InstrumentId, Vec<TradeTick>>,
318    cursors: &mut HashMap<InstrumentId, DateCursor>,
319    path: &Path,
320) {
321    let cursor = cursors
322        .entry(trade.instrument_id)
323        .or_insert_with(|| DateCursor::new(trade.ts_init));
324
325    if trade.ts_init > cursor.end_ns {
326        if let Some(trades_vec) = map.remove(&trade.instrument_id) {
327            batch_and_write_trades(trades_vec, &trade.instrument_id, cursor.date_utc, path);
328        }
329        // Update cursor
330        *cursor = DateCursor::new(trade.ts_init);
331    }
332
333    map.entry(trade.instrument_id)
334        .or_insert_with(|| Vec::with_capacity(1_000_000))
335        .push(trade);
336}
337
338fn handle_bar_msg(
339    bar: Bar,
340    map: &mut HashMap<BarType, Vec<Bar>>,
341    cursors: &mut HashMap<BarType, DateCursor>,
342    path: &Path,
343) {
344    let cursor = cursors
345        .entry(bar.bar_type)
346        .or_insert_with(|| DateCursor::new(bar.ts_init));
347
348    if bar.ts_init > cursor.end_ns {
349        if let Some(bars_vec) = map.remove(&bar.bar_type) {
350            batch_and_write_bars(bars_vec, &bar.bar_type, cursor.date_utc, path);
351        }
352        // Update cursor
353        *cursor = DateCursor::new(bar.ts_init);
354    }
355
356    map.entry(bar.bar_type)
357        .or_insert_with(|| Vec::with_capacity(1_000_000))
358        .push(bar);
359}
360
361fn batch_and_write_deltas(
362    deltas: Vec<OrderBookDelta>,
363    instrument_id: &InstrumentId,
364    date: NaiveDate,
365    path: &Path,
366) {
367    let typename = stringify!(OrderBookDeltas);
368    match book_deltas_to_arrow_record_batch_bytes(deltas) {
369        Ok(batch) => write_batch(batch, typename, instrument_id, date, path),
370        Err(e) => {
371            tracing::error!("Error converting `{typename}` to Arrow: {e:?}");
372        }
373    }
374}
375
376fn batch_and_write_depths(
377    depths: Vec<OrderBookDepth10>,
378    instrument_id: &InstrumentId,
379    date: NaiveDate,
380    path: &Path,
381) {
382    let typename = stringify!(OrderBookDepth10);
383    match book_depth10_to_arrow_record_batch_bytes(depths) {
384        Ok(batch) => write_batch(batch, typename, instrument_id, date, path),
385        Err(e) => {
386            tracing::error!("Error converting `{typename}` to Arrow: {e:?}");
387        }
388    }
389}
390
391fn batch_and_write_quotes(
392    quotes: Vec<QuoteTick>,
393    instrument_id: &InstrumentId,
394    date: NaiveDate,
395    path: &Path,
396) {
397    let typename = stringify!(QuoteTick);
398    match quotes_to_arrow_record_batch_bytes(quotes) {
399        Ok(batch) => write_batch(batch, typename, instrument_id, date, path),
400        Err(e) => {
401            tracing::error!("Error converting `{typename}` to Arrow: {e:?}");
402        }
403    }
404}
405
406fn batch_and_write_trades(
407    trades: Vec<TradeTick>,
408    instrument_id: &InstrumentId,
409    date: NaiveDate,
410    path: &Path,
411) {
412    let typename = stringify!(TradeTick);
413    match trades_to_arrow_record_batch_bytes(trades) {
414        Ok(batch) => write_batch(batch, typename, instrument_id, date, path),
415        Err(e) => {
416            tracing::error!("Error converting `{typename}` to Arrow: {e:?}");
417        }
418    }
419}
420
421fn batch_and_write_bars(bars: Vec<Bar>, bar_type: &BarType, date: NaiveDate, path: &Path) {
422    let typename = stringify!(Bar);
423    let batch = match bars_to_arrow_record_batch_bytes(bars) {
424        Ok(batch) => batch,
425        Err(e) => {
426            tracing::error!("Error converting `{typename}` to Arrow: {e:?}");
427            return;
428        }
429    };
430
431    let filepath = path.join(parquet_filepath_bars(bar_type, date));
432    if let Err(e) = write_parquet_local(batch, &filepath) {
433        tracing::error!("Error writing {filepath:?}: {e:?}");
434    } else {
435        tracing::info!("File written: {filepath:?}");
436    }
437}
438
439fn parquet_filepath(typename: &str, instrument_id: &InstrumentId, date: NaiveDate) -> PathBuf {
440    let typename = typename.to_snake_case();
441    let instrument_id_str = instrument_id.to_string().replace('/', "");
442    let date_str = date.to_string().replace('-', "");
443    PathBuf::new()
444        .join(typename)
445        .join(instrument_id_str)
446        .join(format!("{date_str}.parquet"))
447}
448
449fn parquet_filepath_bars(bar_type: &BarType, date: NaiveDate) -> PathBuf {
450    let bar_type_str = bar_type.to_string().replace('/', "");
451    let date_str = date.to_string().replace('-', "");
452    PathBuf::new()
453        .join("bar")
454        .join(bar_type_str)
455        .join(format!("{date_str}.parquet"))
456}
457
458fn write_batch(
459    batch: RecordBatch,
460    typename: &str,
461    instrument_id: &InstrumentId,
462    date: NaiveDate,
463    path: &Path,
464) {
465    let filepath = path.join(parquet_filepath(typename, instrument_id, date));
466    if let Err(e) = write_parquet_local(batch, &filepath) {
467        tracing::error!("Error writing {filepath:?}: {e:?}");
468    } else {
469        tracing::info!("File written: {filepath:?}");
470    }
471}
472
473fn write_parquet_local(batch: RecordBatch, file_path: &Path) -> anyhow::Result<()> {
474    if let Some(parent) = file_path.parent() {
475        std::fs::create_dir_all(parent)?;
476    }
477
478    let file = std::fs::File::create(file_path)?;
479    let props = WriterProperties::builder()
480        .set_compression(Compression::SNAPPY)
481        .build();
482
483    let mut writer = ArrowWriter::try_new(file, batch.schema(), Some(props))?;
484    writer.write(&batch)?;
485    writer.close()?;
486    Ok(())
487}
488
489///////////////////////////////////////////////////////////////////////////////////////////////////
490// Tests
491///////////////////////////////////////////////////////////////////////////////////////////////////
492#[cfg(test)]
493mod tests {
494    use chrono::{TimeZone, Utc};
495    use rstest::rstest;
496
497    use super::*;
498
499    #[rstest]
500    #[case(
501    // Start of day: 2024-01-01 00:00:00 UTC
502    Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap().timestamp_nanos_opt().unwrap() as u64,
503    NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
504    Utc.with_ymd_and_hms(2024, 1, 1, 23, 59, 59).unwrap().timestamp_nanos_opt().unwrap() as u64 + 999_999_999
505)]
506    #[case(
507    // Midday: 2024-01-01 12:00:00 UTC
508    Utc.with_ymd_and_hms(2024, 1, 1, 12, 0, 0).unwrap().timestamp_nanos_opt().unwrap() as u64,
509    NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
510    Utc.with_ymd_and_hms(2024, 1, 1, 23, 59, 59).unwrap().timestamp_nanos_opt().unwrap() as u64 + 999_999_999
511)]
512    #[case(
513    // End of day: 2024-01-01 23:59:59.999999999 UTC
514    Utc.with_ymd_and_hms(2024, 1, 1, 23, 59, 59).unwrap().timestamp_nanos_opt().unwrap() as u64 + 999_999_999,
515    NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
516    Utc.with_ymd_and_hms(2024, 1, 1, 23, 59, 59).unwrap().timestamp_nanos_opt().unwrap() as u64 + 999_999_999
517)]
518    #[case(
519    // Start of new day: 2024-01-02 00:00:00 UTC
520    Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap().timestamp_nanos_opt().unwrap() as u64,
521    NaiveDate::from_ymd_opt(2024, 1, 2).unwrap(),
522    Utc.with_ymd_and_hms(2024, 1, 2, 23, 59, 59).unwrap().timestamp_nanos_opt().unwrap() as u64 + 999_999_999
523)]
524    fn test_date_cursor(
525        #[case] timestamp: u64,
526        #[case] expected_date: NaiveDate,
527        #[case] expected_end_ns: u64,
528    ) {
529        let unix_nanos = UnixNanos::from(timestamp);
530        let cursor = DateCursor::new(unix_nanos);
531
532        assert_eq!(cursor.date_utc, expected_date);
533        assert_eq!(cursor.end_ns, UnixNanos::from(expected_end_ns));
534    }
535}