nautilus_databento/
loader.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    env, fs,
18    path::{Path, PathBuf},
19};
20
21use ahash::AHashMap;
22use anyhow::Context;
23use databento::dbn::{self, InstrumentDefMsg};
24use dbn::{
25    Publisher,
26    decode::{DbnMetadata, DecodeStream, dbn::Decoder},
27};
28use fallible_streaming_iterator::FallibleStreamingIterator;
29use indexmap::IndexMap;
30use nautilus_model::{
31    data::{Bar, Data, InstrumentStatus, OrderBookDelta, OrderBookDepth10, QuoteTick, TradeTick},
32    identifiers::{InstrumentId, Symbol, Venue},
33    instruments::InstrumentAny,
34    types::Currency,
35};
36
37use super::{
38    decode::{decode_imbalance_msg, decode_record, decode_statistics_msg, decode_status_msg},
39    symbology::decode_nautilus_instrument_id,
40    types::{DatabentoImbalance, DatabentoPublisher, DatabentoStatistics, Dataset, PublisherId},
41};
42use crate::{decode::decode_instrument_def_msg, symbology::MetadataCache};
43
44/// A Nautilus data loader for Databento Binary Encoding (DBN) format data.
45///
46/// # Supported schemas:
47///  - `MBO` -> `OrderBookDelta`
48///  - `MBP_1` -> `(QuoteTick, Option<TradeTick>)`
49///  - `MBP_10` -> `OrderBookDepth10`
50///  - `BBO_1S` -> `QuoteTick`
51///  - `BBO_1M` -> `QuoteTick`
52///  - `CMBP_1` -> `(QuoteTick, Option<TradeTick>)`
53///  - `CBBO_1S` -> `QuoteTick`
54///  - `CBBO_1M` -> `QuoteTick`
55///  - `TCBBO` -> `(QuoteTick, TradeTick)`
56///  - `TBBO` -> `(QuoteTick, TradeTick)`
57///  - `TRADES` -> `TradeTick`
58///  - `OHLCV_1S` -> `Bar`
59///  - `OHLCV_1M` -> `Bar`
60///  - `OHLCV_1H` -> `Bar`
61///  - `OHLCV_1D` -> `Bar`
62///  - `OHLCV_EOD` -> `Bar`
63///  - `DEFINITION` -> `Instrument`
64///  - `IMBALANCE` -> `DatabentoImbalance`
65///  - `STATISTICS` -> `DatabentoStatistics`
66///  - `STATUS` -> `InstrumentStatus`
67///
68/// # References
69///
70/// <https://databento.com/docs/schemas-and-data-formats>
71#[cfg_attr(
72    feature = "python",
73    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.databento")
74)]
75#[derive(Debug)]
76pub struct DatabentoDataLoader {
77    publishers_map: IndexMap<PublisherId, DatabentoPublisher>,
78    venue_dataset_map: IndexMap<Venue, Dataset>,
79    publisher_venue_map: IndexMap<PublisherId, Venue>,
80    symbol_venue_map: AHashMap<Symbol, Venue>,
81}
82
83impl DatabentoDataLoader {
84    /// Creates a new [`DatabentoDataLoader`] instance.
85    ///
86    /// # Errors
87    ///
88    /// Returns an error if locating or loading publishers data fails.
89    pub fn new(publishers_filepath: Option<PathBuf>) -> anyhow::Result<Self> {
90        let mut loader = Self {
91            publishers_map: IndexMap::new(),
92            venue_dataset_map: IndexMap::new(),
93            publisher_venue_map: IndexMap::new(),
94            symbol_venue_map: AHashMap::new(),
95        };
96
97        // Load publishers
98        let publishers_filepath = if let Some(p) = publishers_filepath {
99            p
100        } else {
101            // Use built-in publishers path
102            let mut exe_path = env::current_exe()?;
103            exe_path.pop();
104            exe_path.push("publishers.json");
105            exe_path
106        };
107
108        loader
109            .load_publishers(publishers_filepath)
110            .context("Error loading publishers.json")?;
111
112        Ok(loader)
113    }
114
115    /// Load the publishers data from the file at the given `filepath`.
116    ///
117    /// # Errors
118    ///
119    /// Returns an error if the file cannot be read or parsed as JSON.
120    pub fn load_publishers(&mut self, filepath: PathBuf) -> anyhow::Result<()> {
121        let file_content = fs::read_to_string(filepath)?;
122        let publishers: Vec<DatabentoPublisher> = serde_json::from_str(&file_content)?;
123
124        self.publishers_map = publishers
125            .clone()
126            .into_iter()
127            .map(|p| (p.publisher_id, p))
128            .collect();
129
130        let mut venue_dataset_map = IndexMap::new();
131
132        // Only insert a dataset if the venue key is not already in the map
133        for publisher in &publishers {
134            let venue = Venue::from(publisher.venue.as_str());
135            let dataset = Dataset::from(publisher.dataset.as_str());
136            venue_dataset_map.entry(venue).or_insert(dataset);
137        }
138
139        self.venue_dataset_map = venue_dataset_map;
140
141        // Insert CME Globex exchanges
142        let glbx = Dataset::from("GLBX.MDP3");
143        self.venue_dataset_map.insert(Venue::CBCM(), glbx);
144        self.venue_dataset_map.insert(Venue::GLBX(), glbx);
145        self.venue_dataset_map.insert(Venue::NYUM(), glbx);
146        self.venue_dataset_map.insert(Venue::XCBT(), glbx);
147        self.venue_dataset_map.insert(Venue::XCEC(), glbx);
148        self.venue_dataset_map.insert(Venue::XCME(), glbx);
149        self.venue_dataset_map.insert(Venue::XFXS(), glbx);
150        self.venue_dataset_map.insert(Venue::XNYM(), glbx);
151
152        self.publisher_venue_map = publishers
153            .into_iter()
154            .map(|p| (p.publisher_id, Venue::from(p.venue.as_str())))
155            .collect();
156
157        Ok(())
158    }
159
160    /// Returns the internal Databento publishers currently held by the loader.
161    #[must_use]
162    pub const fn get_publishers(&self) -> &IndexMap<u16, DatabentoPublisher> {
163        &self.publishers_map
164    }
165
166    /// Sets the `venue` to map to the given `dataset`.
167    pub fn set_dataset_for_venue(&mut self, dataset: Dataset, venue: Venue) {
168        _ = self.venue_dataset_map.insert(venue, dataset);
169    }
170
171    /// Returns the dataset which matches the given `venue` (if found).
172    #[must_use]
173    pub fn get_dataset_for_venue(&self, venue: &Venue) -> Option<&Dataset> {
174        self.venue_dataset_map.get(venue)
175    }
176
177    /// Returns the venue which matches the given `publisher_id` (if found).
178    #[must_use]
179    pub fn get_venue_for_publisher(&self, publisher_id: PublisherId) -> Option<&Venue> {
180        self.publisher_venue_map.get(&publisher_id)
181    }
182
183    /// Returns the schema for the given `filepath`.
184    ///
185    /// # Errors
186    ///
187    /// Returns an error if the file cannot be decoded or metadata retrieval fails.
188    pub fn schema_from_file(&self, filepath: &Path) -> anyhow::Result<Option<String>> {
189        let decoder = Decoder::from_zstd_file(filepath)?;
190        let metadata = decoder.metadata();
191        Ok(metadata.schema.map(|schema| schema.to_string()))
192    }
193
194    /// # Errors
195    ///
196    /// Returns an error if decoding the definition records fails.
197    pub fn read_definition_records(
198        &mut self,
199        filepath: &Path,
200        use_exchange_as_venue: bool,
201    ) -> anyhow::Result<impl Iterator<Item = anyhow::Result<InstrumentAny>> + '_> {
202        let decoder = Decoder::from_zstd_file(filepath)?;
203        let mut dbn_stream = decoder.decode_stream::<InstrumentDefMsg>();
204
205        Ok(std::iter::from_fn(move || {
206            let result: anyhow::Result<Option<InstrumentAny>> = (|| {
207                dbn_stream
208                    .advance()
209                    .map_err(|e| anyhow::anyhow!("Stream advance error: {e}"))?;
210
211                if let Some(rec) = dbn_stream.get() {
212                    let record = dbn::RecordRef::from(rec);
213                    let msg = record
214                        .get::<InstrumentDefMsg>()
215                        .ok_or_else(|| anyhow::anyhow!("Failed to decode InstrumentDefMsg"))?;
216
217                    // Symbol and venue resolution
218                    let raw_symbol = rec
219                        .raw_symbol()
220                        .map_err(|e| anyhow::anyhow!("Error decoding `raw_symbol`: {e}"))?;
221                    let symbol = Symbol::from(raw_symbol);
222
223                    let publisher = rec
224                        .hd
225                        .publisher()
226                        .map_err(|e| anyhow::anyhow!("Invalid `publisher` for record: {e}"))?;
227                    let venue = match publisher {
228                        Publisher::GlbxMdp3Glbx if use_exchange_as_venue => {
229                            let exchange = rec.exchange().map_err(|e| {
230                                anyhow::anyhow!("Missing `exchange` for record: {e}")
231                            })?;
232                            let venue = Venue::from_code(exchange).map_err(|e| {
233                                anyhow::anyhow!("Venue not found for exchange {exchange}: {e}")
234                            })?;
235                            self.symbol_venue_map.insert(symbol, venue);
236                            venue
237                        }
238                        _ => *self
239                            .publisher_venue_map
240                            .get(&msg.hd.publisher_id)
241                            .ok_or_else(|| {
242                                anyhow::anyhow!(
243                                    "Venue not found for publisher_id {}",
244                                    msg.hd.publisher_id
245                                )
246                            })?,
247                    };
248                    let instrument_id = InstrumentId::new(symbol, venue);
249                    let ts_init = msg.ts_recv.into();
250
251                    let data = decode_instrument_def_msg(rec, instrument_id, Some(ts_init))?;
252                    Ok(Some(data))
253                } else {
254                    // No more records
255                    Ok(None)
256                }
257            })();
258
259            match result {
260                Ok(Some(item)) => Some(Ok(item)),
261                Ok(None) => None,
262                Err(e) => Some(Err(e)),
263            }
264        }))
265    }
266
267    /// # Errors
268    ///
269    /// Returns an error if reading records fails.
270    pub fn read_records<T>(
271        &self,
272        filepath: &Path,
273        instrument_id: Option<InstrumentId>,
274        price_precision: Option<u8>,
275        include_trades: bool,
276        bars_timestamp_on_close: Option<bool>,
277    ) -> anyhow::Result<impl Iterator<Item = anyhow::Result<(Option<Data>, Option<Data>)>> + '_>
278    where
279        T: dbn::Record + dbn::HasRType + 'static,
280    {
281        let decoder = Decoder::from_zstd_file(filepath)?;
282        let metadata = decoder.metadata().clone();
283        let mut metadata_cache = MetadataCache::new(metadata);
284        let mut dbn_stream = decoder.decode_stream::<T>();
285
286        let price_precision = price_precision.unwrap_or(Currency::USD().precision);
287
288        Ok(std::iter::from_fn(move || {
289            let result: anyhow::Result<Option<(Option<Data>, Option<Data>)>> = (|| {
290                dbn_stream
291                    .advance()
292                    .map_err(|e| anyhow::anyhow!("Stream advance error: {e}"))?;
293                if let Some(rec) = dbn_stream.get() {
294                    let record = dbn::RecordRef::from(rec);
295                    let instrument_id = if let Some(id) = &instrument_id {
296                        *id
297                    } else {
298                        decode_nautilus_instrument_id(
299                            &record,
300                            &mut metadata_cache,
301                            &self.publisher_venue_map,
302                            &self.symbol_venue_map,
303                        )
304                        .context("Failed to decode instrument id")?
305                    };
306                    let (item1, item2) = decode_record(
307                        &record,
308                        instrument_id,
309                        price_precision,
310                        None,
311                        include_trades,
312                        bars_timestamp_on_close.unwrap_or(true),
313                    )?;
314                    Ok(Some((item1, item2)))
315                } else {
316                    Ok(None)
317                }
318            })();
319            match result {
320                Ok(Some(v)) => Some(Ok(v)),
321                Ok(None) => None,
322                Err(e) => Some(Err(e)),
323            }
324        }))
325    }
326
327    /// # Errors
328    ///
329    /// Returns an error if loading instruments fails.
330    pub fn load_instruments(
331        &mut self,
332        filepath: &Path,
333        use_exchange_as_venue: bool,
334    ) -> anyhow::Result<Vec<InstrumentAny>> {
335        self.read_definition_records(filepath, use_exchange_as_venue)?
336            .collect::<Result<Vec<_>, _>>()
337    }
338
339    // Cannot include trades
340    /// # Errors
341    ///
342    /// Returns an error if loading order book deltas fails.
343    pub fn load_order_book_deltas(
344        &self,
345        filepath: &Path,
346        instrument_id: Option<InstrumentId>,
347        price_precision: Option<u8>,
348    ) -> anyhow::Result<Vec<OrderBookDelta>> {
349        self.read_records::<dbn::MboMsg>(filepath, instrument_id, price_precision, false, None)?
350            .filter_map(|result| match result {
351                Ok((Some(item1), _)) => {
352                    if let Data::Delta(delta) = item1 {
353                        Some(Ok(delta))
354                    } else {
355                        None
356                    }
357                }
358                Ok((None, _)) => None,
359                Err(e) => Some(Err(e)),
360            })
361            .collect()
362    }
363
364    /// # Errors
365    ///
366    /// Returns an error if loading order book depth10 fails.
367    pub fn load_order_book_depth10(
368        &self,
369        filepath: &Path,
370        instrument_id: Option<InstrumentId>,
371        price_precision: Option<u8>,
372    ) -> anyhow::Result<Vec<OrderBookDepth10>> {
373        self.read_records::<dbn::Mbp10Msg>(filepath, instrument_id, price_precision, false, None)?
374            .filter_map(|result| match result {
375                Ok((Some(item1), _)) => {
376                    if let Data::Depth10(depth) = item1 {
377                        Some(Ok(*depth))
378                    } else {
379                        None
380                    }
381                }
382                Ok((None, _)) => None,
383                Err(e) => Some(Err(e)),
384            })
385            .collect()
386    }
387
388    /// # Errors
389    ///
390    /// Returns an error if loading quotes fails.
391    pub fn load_quotes(
392        &self,
393        filepath: &Path,
394        instrument_id: Option<InstrumentId>,
395        price_precision: Option<u8>,
396    ) -> anyhow::Result<Vec<QuoteTick>> {
397        self.read_records::<dbn::Mbp1Msg>(filepath, instrument_id, price_precision, false, None)?
398            .filter_map(|result| match result {
399                Ok((Some(item1), _)) => {
400                    if let Data::Quote(quote) = item1 {
401                        Some(Ok(quote))
402                    } else {
403                        None
404                    }
405                }
406                Ok((None, _)) => None,
407                Err(e) => Some(Err(e)),
408            })
409            .collect()
410    }
411
412    /// # Errors
413    ///
414    /// Returns an error if loading BBO quotes fails.
415    pub fn load_bbo_quotes(
416        &self,
417        filepath: &Path,
418        instrument_id: Option<InstrumentId>,
419        price_precision: Option<u8>,
420    ) -> anyhow::Result<Vec<QuoteTick>> {
421        self.read_records::<dbn::BboMsg>(filepath, instrument_id, price_precision, false, None)?
422            .filter_map(|result| match result {
423                Ok((Some(item1), _)) => {
424                    if let Data::Quote(quote) = item1 {
425                        Some(Ok(quote))
426                    } else {
427                        None
428                    }
429                }
430                Ok((None, _)) => None,
431                Err(e) => Some(Err(e)),
432            })
433            .collect()
434    }
435
436    /// # Errors
437    ///
438    /// Returns an error if loading TBBO trades fails.
439    pub fn load_tbbo_trades(
440        &self,
441        filepath: &Path,
442        instrument_id: Option<InstrumentId>,
443        price_precision: Option<u8>,
444    ) -> anyhow::Result<Vec<TradeTick>> {
445        self.read_records::<dbn::TbboMsg>(filepath, instrument_id, price_precision, false, None)?
446            .filter_map(|result| match result {
447                Ok((_, maybe_item2)) => {
448                    if let Some(Data::Trade(trade)) = maybe_item2 {
449                        Some(Ok(trade))
450                    } else {
451                        None
452                    }
453                }
454                Err(e) => Some(Err(e)),
455            })
456            .collect()
457    }
458
459    /// # Errors
460    ///
461    /// Returns an error if loading TCBBO trades fails.
462    pub fn load_tcbbo_trades(
463        &self,
464        filepath: &Path,
465        instrument_id: Option<InstrumentId>,
466        price_precision: Option<u8>,
467    ) -> anyhow::Result<Vec<TradeTick>> {
468        self.read_records::<dbn::CbboMsg>(filepath, instrument_id, price_precision, false, None)?
469            .filter_map(|result| match result {
470                Ok((_, maybe_item2)) => {
471                    if let Some(Data::Trade(trade)) = maybe_item2 {
472                        Some(Ok(trade))
473                    } else {
474                        None
475                    }
476                }
477                Err(e) => Some(Err(e)),
478            })
479            .collect()
480    }
481
482    /// # Errors
483    ///
484    /// Returns an error if loading trades fails.
485    pub fn load_trades(
486        &self,
487        filepath: &Path,
488        instrument_id: Option<InstrumentId>,
489        price_precision: Option<u8>,
490    ) -> anyhow::Result<Vec<TradeTick>> {
491        self.read_records::<dbn::TradeMsg>(filepath, instrument_id, price_precision, false, None)?
492            .filter_map(|result| match result {
493                Ok((Some(item1), _)) => {
494                    if let Data::Trade(trade) = item1 {
495                        Some(Ok(trade))
496                    } else {
497                        None
498                    }
499                }
500                Ok((None, _)) => None,
501                Err(e) => Some(Err(e)),
502            })
503            .collect()
504    }
505
506    /// # Errors
507    ///
508    /// Returns an error if loading bars fails.
509    pub fn load_bars(
510        &self,
511        filepath: &Path,
512        instrument_id: Option<InstrumentId>,
513        price_precision: Option<u8>,
514        timestamp_on_close: Option<bool>,
515    ) -> anyhow::Result<Vec<Bar>> {
516        self.read_records::<dbn::OhlcvMsg>(
517            filepath,
518            instrument_id,
519            price_precision,
520            false,
521            timestamp_on_close,
522        )?
523        .filter_map(|result| match result {
524            Ok((Some(item1), _)) => {
525                if let Data::Bar(bar) = item1 {
526                    Some(Ok(bar))
527                } else {
528                    None
529                }
530            }
531            Ok((None, _)) => None,
532            Err(e) => Some(Err(e)),
533        })
534        .collect()
535    }
536
537    /// # Errors
538    ///
539    /// Returns an error if loading status records fails.
540    pub fn load_status_records<T>(
541        &self,
542        filepath: &Path,
543        instrument_id: Option<InstrumentId>,
544    ) -> anyhow::Result<impl Iterator<Item = anyhow::Result<InstrumentStatus>> + '_>
545    where
546        T: dbn::Record + dbn::HasRType + 'static,
547    {
548        let decoder = Decoder::from_zstd_file(filepath)?;
549        let metadata = decoder.metadata().clone();
550        let mut metadata_cache = MetadataCache::new(metadata);
551        let mut dbn_stream = decoder.decode_stream::<T>();
552
553        Ok(std::iter::from_fn(move || {
554            if let Err(e) = dbn_stream.advance() {
555                return Some(Err(e.into()));
556            }
557            match dbn_stream.get() {
558                Some(rec) => {
559                    let record = dbn::RecordRef::from(rec);
560                    let instrument_id = match &instrument_id {
561                        Some(id) => *id, // Copy
562                        None => match decode_nautilus_instrument_id(
563                            &record,
564                            &mut metadata_cache,
565                            &self.publisher_venue_map,
566                            &self.symbol_venue_map,
567                        ) {
568                            Ok(id) => id,
569                            Err(e) => return Some(Err(e)),
570                        },
571                    };
572
573                    let msg = match record.get::<dbn::StatusMsg>() {
574                        Some(m) => m,
575                        None => return Some(Err(anyhow::anyhow!("Invalid `StatusMsg`"))),
576                    };
577                    let ts_init = msg.ts_recv.into();
578
579                    match decode_status_msg(msg, instrument_id, Some(ts_init)) {
580                        Ok(data) => Some(Ok(data)),
581                        Err(e) => Some(Err(e)),
582                    }
583                }
584                None => None,
585            }
586        }))
587    }
588
589    /// # Errors
590    ///
591    /// Returns an error if reading imbalance records fails.
592    pub fn read_imbalance_records<T>(
593        &self,
594        filepath: &Path,
595        instrument_id: Option<InstrumentId>,
596        price_precision: Option<u8>,
597    ) -> anyhow::Result<impl Iterator<Item = anyhow::Result<DatabentoImbalance>> + '_>
598    where
599        T: dbn::Record + dbn::HasRType + 'static,
600    {
601        let decoder = Decoder::from_zstd_file(filepath)?;
602        let metadata = decoder.metadata().clone();
603        let mut metadata_cache = MetadataCache::new(metadata);
604        let mut dbn_stream = decoder.decode_stream::<T>();
605
606        let price_precision = price_precision.unwrap_or(Currency::USD().precision);
607
608        Ok(std::iter::from_fn(move || {
609            if let Err(e) = dbn_stream.advance() {
610                return Some(Err(e.into()));
611            }
612            match dbn_stream.get() {
613                Some(rec) => {
614                    let record = dbn::RecordRef::from(rec);
615                    let instrument_id = match &instrument_id {
616                        Some(id) => *id, // Copy
617                        None => match decode_nautilus_instrument_id(
618                            &record,
619                            &mut metadata_cache,
620                            &self.publisher_venue_map,
621                            &self.symbol_venue_map,
622                        ) {
623                            Ok(id) => id,
624                            Err(e) => return Some(Err(e)),
625                        },
626                    };
627
628                    let msg = match record.get::<dbn::ImbalanceMsg>() {
629                        Some(m) => m,
630                        None => return Some(Err(anyhow::anyhow!("Invalid `ImbalanceMsg`"))),
631                    };
632                    let ts_init = msg.ts_recv.into();
633
634                    match decode_imbalance_msg(msg, instrument_id, price_precision, Some(ts_init)) {
635                        Ok(data) => Some(Ok(data)),
636                        Err(e) => Some(Err(e)),
637                    }
638                }
639                None => None,
640            }
641        }))
642    }
643
644    /// # Errors
645    ///
646    /// Returns an error if reading statistics records fails.
647    pub fn read_statistics_records<T>(
648        &self,
649        filepath: &Path,
650        instrument_id: Option<InstrumentId>,
651        price_precision: Option<u8>,
652    ) -> anyhow::Result<impl Iterator<Item = anyhow::Result<DatabentoStatistics>> + '_>
653    where
654        T: dbn::Record + dbn::HasRType + 'static,
655    {
656        let decoder = Decoder::from_zstd_file(filepath)?;
657        let metadata = decoder.metadata().clone();
658        let mut metadata_cache = MetadataCache::new(metadata);
659        let mut dbn_stream = decoder.decode_stream::<T>();
660
661        let price_precision = price_precision.unwrap_or(Currency::USD().precision);
662
663        Ok(std::iter::from_fn(move || {
664            if let Err(e) = dbn_stream.advance() {
665                return Some(Err(e.into()));
666            }
667            match dbn_stream.get() {
668                Some(rec) => {
669                    let record = dbn::RecordRef::from(rec);
670                    let instrument_id = match &instrument_id {
671                        Some(id) => *id, // Copy
672                        None => match decode_nautilus_instrument_id(
673                            &record,
674                            &mut metadata_cache,
675                            &self.publisher_venue_map,
676                            &self.symbol_venue_map,
677                        ) {
678                            Ok(id) => id,
679                            Err(e) => return Some(Err(e)),
680                        },
681                    };
682                    let msg = match record.get::<dbn::StatMsg>() {
683                        Some(m) => m,
684                        None => return Some(Err(anyhow::anyhow!("Invalid `StatMsg`"))),
685                    };
686                    let ts_init = msg.ts_recv.into();
687
688                    match decode_statistics_msg(msg, instrument_id, price_precision, Some(ts_init))
689                    {
690                        Ok(data) => Some(Ok(data)),
691                        Err(e) => Some(Err(e)),
692                    }
693                }
694                None => None,
695            }
696        }))
697    }
698}
699
700////////////////////////////////////////////////////////////////////////////////
701// Tests
702////////////////////////////////////////////////////////////////////////////////
703#[cfg(test)]
704mod tests {
705    use std::path::{Path, PathBuf};
706
707    use rstest::{fixture, rstest};
708    use ustr::Ustr;
709
710    use super::*;
711
712    fn test_data_path() -> PathBuf {
713        Path::new(env!("CARGO_MANIFEST_DIR")).join("test_data")
714    }
715
716    #[fixture]
717    fn loader() -> DatabentoDataLoader {
718        let publishers_filepath = Path::new(env!("CARGO_MANIFEST_DIR")).join("publishers.json");
719        DatabentoDataLoader::new(Some(publishers_filepath)).unwrap()
720    }
721
722    // TODO: Improve the below assertions that we've actually read the records we expected
723
724    #[rstest]
725    fn test_set_dataset_venue_mapping(mut loader: DatabentoDataLoader) {
726        let dataset = Ustr::from("EQUS.PLUS");
727        let venue = Venue::from("XNAS");
728        loader.set_dataset_for_venue(dataset, venue);
729
730        let result = loader.get_dataset_for_venue(&venue).unwrap();
731        assert_eq!(*result, dataset);
732    }
733
734    #[rstest]
735    #[case(test_data_path().join("test_data.definition.dbn.zst"))]
736    fn test_load_instruments(mut loader: DatabentoDataLoader, #[case] path: PathBuf) {
737        let instruments = loader.load_instruments(&path, false).unwrap();
738
739        assert_eq!(instruments.len(), 2);
740    }
741
742    #[rstest]
743    fn test_load_order_book_deltas(loader: DatabentoDataLoader) {
744        let path = test_data_path().join("test_data.mbo.dbn.zst");
745        let instrument_id = InstrumentId::from("ESM4.GLBX");
746
747        let deltas = loader
748            .load_order_book_deltas(&path, Some(instrument_id), None)
749            .unwrap();
750
751        assert_eq!(deltas.len(), 2);
752    }
753
754    #[rstest]
755    fn test_load_order_book_depth10(loader: DatabentoDataLoader) {
756        let path = test_data_path().join("test_data.mbp-10.dbn.zst");
757        let instrument_id = InstrumentId::from("ESM4.GLBX");
758
759        let depths = loader
760            .load_order_book_depth10(&path, Some(instrument_id), None)
761            .unwrap();
762
763        assert_eq!(depths.len(), 2);
764    }
765
766    #[rstest]
767    fn test_load_quotes(loader: DatabentoDataLoader) {
768        let path = test_data_path().join("test_data.mbp-1.dbn.zst");
769        let instrument_id = InstrumentId::from("ESM4.GLBX");
770
771        let quotes = loader
772            .load_quotes(&path, Some(instrument_id), None)
773            .unwrap();
774
775        assert_eq!(quotes.len(), 2);
776    }
777
778    #[rstest]
779    #[case(test_data_path().join("test_data.bbo-1s.dbn.zst"))]
780    #[case(test_data_path().join("test_data.bbo-1m.dbn.zst"))]
781    fn test_load_bbo_quotes(loader: DatabentoDataLoader, #[case] path: PathBuf) {
782        let instrument_id = InstrumentId::from("ESM4.GLBX");
783
784        let quotes = loader
785            .load_bbo_quotes(&path, Some(instrument_id), None)
786            .unwrap();
787
788        assert_eq!(quotes.len(), 2);
789    }
790
791    #[rstest]
792    fn test_load_tbbo_trades(loader: DatabentoDataLoader) {
793        let path = test_data_path().join("test_data.tbbo.dbn.zst");
794        let instrument_id = InstrumentId::from("ESM4.GLBX");
795
796        let _trades = loader
797            .load_tbbo_trades(&path, Some(instrument_id), None)
798            .unwrap();
799
800        // assert_eq!(trades.len(), 2);  TODO: No records?
801    }
802
803    // TODO: Re-enable this test once proper TCBBO/CBBO test data is available
804    #[rstest]
805    #[ignore]
806    fn test_load_tcbbo_trades(loader: DatabentoDataLoader) {
807        // Since we don't have dedicated TCBBO test data, we'll use CBBO data
808        // In practice, TCBBO would be CBBO messages with trade data
809        let path = test_data_path().join("test_data.cbbo.dbn.zst");
810        let instrument_id = InstrumentId::from("ESM4.GLBX");
811
812        let result = loader.load_tcbbo_trades(&path, Some(instrument_id), None);
813
814        // This should work without error even if no trades are found
815        assert!(result.is_ok());
816        let trades = result.unwrap();
817        // The test data might not have trade messages, so we just verify it loads without error
818        // The actual count doesn't matter for this test
819        let _ = trades.len();
820    }
821
822    #[rstest]
823    fn test_load_trades(loader: DatabentoDataLoader) {
824        let path = test_data_path().join("test_data.trades.dbn.zst");
825        let instrument_id = InstrumentId::from("ESM4.GLBX");
826        let trades = loader
827            .load_trades(&path, Some(instrument_id), None)
828            .unwrap();
829
830        assert_eq!(trades.len(), 2);
831    }
832
833    #[rstest]
834    // #[case(test_data_path().join("test_data.ohlcv-1d.dbn.zst"))]  // TODO: Needs new data
835    #[case(test_data_path().join("test_data.ohlcv-1h.dbn.zst"))]
836    #[case(test_data_path().join("test_data.ohlcv-1m.dbn.zst"))]
837    #[case(test_data_path().join("test_data.ohlcv-1s.dbn.zst"))]
838    fn test_load_bars(loader: DatabentoDataLoader, #[case] path: PathBuf) {
839        let instrument_id = InstrumentId::from("ESM4.GLBX");
840        let bars = loader
841            .load_bars(&path, Some(instrument_id), None, None)
842            .unwrap();
843
844        assert_eq!(bars.len(), 2);
845    }
846
847    #[rstest]
848    #[case(test_data_path().join("test_data.ohlcv-1s.dbn.zst"))]
849    fn test_load_bars_timestamp_on_close_true(loader: DatabentoDataLoader, #[case] path: PathBuf) {
850        let instrument_id = InstrumentId::from("ESM4.GLBX");
851        let bars = loader
852            .load_bars(&path, Some(instrument_id), None, Some(true))
853            .unwrap();
854
855        assert_eq!(bars.len(), 2);
856
857        // When bars_timestamp_on_close is true, both ts_event and ts_init should be close time
858        for bar in &bars {
859            assert_eq!(
860                bar.ts_event, bar.ts_init,
861                "ts_event and ts_init should both be close time when bars_timestamp_on_close=true"
862            );
863        }
864    }
865
866    #[rstest]
867    #[case(test_data_path().join("test_data.ohlcv-1s.dbn.zst"))]
868    fn test_load_bars_timestamp_on_close_false(loader: DatabentoDataLoader, #[case] path: PathBuf) {
869        let instrument_id = InstrumentId::from("ESM4.GLBX");
870        let bars = loader
871            .load_bars(&path, Some(instrument_id), None, Some(false))
872            .unwrap();
873
874        assert_eq!(bars.len(), 2);
875
876        // When bars_timestamp_on_close is false, ts_event is open time and ts_init is close time
877        for bar in &bars {
878            assert_ne!(
879                bar.ts_event, bar.ts_init,
880                "ts_event should be open time and ts_init should be close time when bars_timestamp_on_close=false"
881            );
882            // For 1-second bars, ts_init (close) should be 1 second after ts_event (open)
883            assert_eq!(bar.ts_init.as_u64(), bar.ts_event.as_u64() + 1_000_000_000);
884        }
885    }
886
887    #[rstest]
888    #[case(test_data_path().join("test_data.ohlcv-1s.dbn.zst"), 0)]
889    #[case(test_data_path().join("test_data.ohlcv-1s.dbn.zst"), 1)]
890    fn test_load_bars_timestamp_comparison(
891        loader: DatabentoDataLoader,
892        #[case] path: PathBuf,
893        #[case] bar_index: usize,
894    ) {
895        let instrument_id = InstrumentId::from("ESM4.GLBX");
896
897        let bars_close = loader
898            .load_bars(&path, Some(instrument_id), None, Some(true))
899            .unwrap();
900
901        let bars_open = loader
902            .load_bars(&path, Some(instrument_id), None, Some(false))
903            .unwrap();
904
905        assert_eq!(bars_close.len(), bars_open.len());
906        assert_eq!(bars_close.len(), 2);
907
908        let bar_close = &bars_close[bar_index];
909        let bar_open = &bars_open[bar_index];
910
911        // Bars should have the same OHLCV data
912        assert_eq!(bar_close.open, bar_open.open);
913        assert_eq!(bar_close.high, bar_open.high);
914        assert_eq!(bar_close.low, bar_open.low);
915        assert_eq!(bar_close.close, bar_open.close);
916        assert_eq!(bar_close.volume, bar_open.volume);
917
918        // The close-timestamped bar should have later timestamp than open-timestamped bar
919        // For 1-second bars, this should be exactly 1 second difference
920        assert!(
921            bar_close.ts_event > bar_open.ts_event,
922            "Close-timestamped bar should have later timestamp than open-timestamped bar"
923        );
924
925        // The difference should be exactly 1 second (1_000_000_000 nanoseconds) for 1s bars
926        const ONE_SECOND_NS: u64 = 1_000_000_000;
927        assert_eq!(
928            bar_close.ts_event.as_u64() - bar_open.ts_event.as_u64(),
929            ONE_SECOND_NS,
930            "Timestamp difference should be exactly 1 second for 1s bars"
931        );
932    }
933}