nautilus_databento/
loader.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    env, fs,
18    path::{Path, PathBuf},
19};
20
21use ahash::AHashMap;
22use anyhow::Context;
23use databento::dbn::{self, InstrumentDefMsg};
24use dbn::{
25    Publisher,
26    decode::{DbnMetadata, DecodeStream, dbn::Decoder},
27};
28use fallible_streaming_iterator::FallibleStreamingIterator;
29use indexmap::IndexMap;
30use nautilus_model::{
31    data::{Bar, Data, InstrumentStatus, OrderBookDelta, OrderBookDepth10, QuoteTick, TradeTick},
32    identifiers::{InstrumentId, Symbol, Venue},
33    instruments::InstrumentAny,
34    types::Currency,
35};
36
37use super::{
38    decode::{decode_imbalance_msg, decode_record, decode_statistics_msg, decode_status_msg},
39    symbology::decode_nautilus_instrument_id,
40    types::{DatabentoImbalance, DatabentoPublisher, DatabentoStatistics, Dataset, PublisherId},
41};
42use crate::{decode::decode_instrument_def_msg, symbology::MetadataCache};
43
44/// A Nautilus data loader for Databento Binary Encoding (DBN) format data.
45///
46/// # Supported schemas:
47///  - `MBO` -> `OrderBookDelta`
48///  - `MBP_1` -> `(QuoteTick, Option<TradeTick>)`
49///  - `MBP_10` -> `OrderBookDepth10`
50///  - `BBO_1S` -> `QuoteTick`
51///  - `BBO_1M` -> `QuoteTick`
52///  - `TBBO` -> `(QuoteTick, TradeTick)`
53///  - `TRADES` -> `TradeTick`
54///  - `OHLCV_1S` -> `Bar`
55///  - `OHLCV_1M` -> `Bar`
56///  - `OHLCV_1H` -> `Bar`
57///  - `OHLCV_1D` -> `Bar`
58///  - `DEFINITION` -> `Instrument`
59///  - `IMBALANCE` -> `DatabentoImbalance`
60///  - `STATISTICS` -> `DatabentoStatistics`
61///  - `STATUS` -> `InstrumentStatus`
62///
63/// # References
64///
65/// <https://databento.com/docs/schemas-and-data-formats>
66#[cfg_attr(
67    feature = "python",
68    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.databento")
69)]
70#[derive(Debug)]
71pub struct DatabentoDataLoader {
72    publishers_map: IndexMap<PublisherId, DatabentoPublisher>,
73    venue_dataset_map: IndexMap<Venue, Dataset>,
74    publisher_venue_map: IndexMap<PublisherId, Venue>,
75    symbol_venue_map: AHashMap<Symbol, Venue>,
76}
77
78impl DatabentoDataLoader {
79    /// Creates a new [`DatabentoDataLoader`] instance.
80    ///
81    /// # Errors
82    ///
83    /// Returns an error if locating or loading publishers data fails.
84    pub fn new(publishers_filepath: Option<PathBuf>) -> anyhow::Result<Self> {
85        let mut loader = Self {
86            publishers_map: IndexMap::new(),
87            venue_dataset_map: IndexMap::new(),
88            publisher_venue_map: IndexMap::new(),
89            symbol_venue_map: AHashMap::new(),
90        };
91
92        // Load publishers
93        let publishers_filepath = if let Some(p) = publishers_filepath {
94            p
95        } else {
96            // Use built-in publishers path
97            let mut exe_path = env::current_exe()?;
98            exe_path.pop();
99            exe_path.push("publishers.json");
100            exe_path
101        };
102
103        loader
104            .load_publishers(publishers_filepath)
105            .context("Error loading publishers.json")?;
106
107        Ok(loader)
108    }
109
110    /// Load the publishers data from the file at the given `filepath`.
111    ///
112    /// # Errors
113    ///
114    /// Returns an error if the file cannot be read or parsed as JSON.
115    pub fn load_publishers(&mut self, filepath: PathBuf) -> anyhow::Result<()> {
116        let file_content = fs::read_to_string(filepath)?;
117        let publishers: Vec<DatabentoPublisher> = serde_json::from_str(&file_content)?;
118
119        self.publishers_map = publishers
120            .clone()
121            .into_iter()
122            .map(|p| (p.publisher_id, p))
123            .collect();
124
125        let mut venue_dataset_map = IndexMap::new();
126
127        // Only insert a dataset if the venue key is not already in the map
128        for publisher in &publishers {
129            let venue = Venue::from(publisher.venue.as_str());
130            let dataset = Dataset::from(publisher.dataset.as_str());
131            venue_dataset_map.entry(venue).or_insert(dataset);
132        }
133
134        self.venue_dataset_map = venue_dataset_map;
135
136        // Insert CME Globex exchanges
137        let glbx = Dataset::from("GLBX.MDP3");
138        self.venue_dataset_map.insert(Venue::CBCM(), glbx);
139        self.venue_dataset_map.insert(Venue::GLBX(), glbx);
140        self.venue_dataset_map.insert(Venue::NYUM(), glbx);
141        self.venue_dataset_map.insert(Venue::XCBT(), glbx);
142        self.venue_dataset_map.insert(Venue::XCEC(), glbx);
143        self.venue_dataset_map.insert(Venue::XCME(), glbx);
144        self.venue_dataset_map.insert(Venue::XFXS(), glbx);
145        self.venue_dataset_map.insert(Venue::XNYM(), glbx);
146
147        self.publisher_venue_map = publishers
148            .into_iter()
149            .map(|p| (p.publisher_id, Venue::from(p.venue.as_str())))
150            .collect();
151
152        Ok(())
153    }
154
155    /// Returns the internal Databento publishers currently held by the loader.
156    #[must_use]
157    pub const fn get_publishers(&self) -> &IndexMap<u16, DatabentoPublisher> {
158        &self.publishers_map
159    }
160
161    /// Sets the `venue` to map to the given `dataset`.
162    pub fn set_dataset_for_venue(&mut self, dataset: Dataset, venue: Venue) {
163        _ = self.venue_dataset_map.insert(venue, dataset);
164    }
165
166    /// Returns the dataset which matches the given `venue` (if found).
167    #[must_use]
168    pub fn get_dataset_for_venue(&self, venue: &Venue) -> Option<&Dataset> {
169        self.venue_dataset_map.get(venue)
170    }
171
172    /// Returns the venue which matches the given `publisher_id` (if found).
173    #[must_use]
174    pub fn get_venue_for_publisher(&self, publisher_id: PublisherId) -> Option<&Venue> {
175        self.publisher_venue_map.get(&publisher_id)
176    }
177
178    /// Returns the schema for the given `filepath`.
179    ///
180    /// # Errors
181    ///
182    /// Returns an error if the file cannot be decoded or metadata retrieval fails.
183    pub fn schema_from_file(&self, filepath: &Path) -> anyhow::Result<Option<String>> {
184        let decoder = Decoder::from_zstd_file(filepath)?;
185        let metadata = decoder.metadata();
186        Ok(metadata.schema.map(|schema| schema.to_string()))
187    }
188
189    /// # Errors
190    ///
191    /// Returns an error if decoding the definition records fails.
192    pub fn read_definition_records(
193        &mut self,
194        filepath: &Path,
195        use_exchange_as_venue: bool,
196    ) -> anyhow::Result<impl Iterator<Item = anyhow::Result<InstrumentAny>> + '_> {
197        let decoder = Decoder::from_zstd_file(filepath)?;
198        let mut dbn_stream = decoder.decode_stream::<InstrumentDefMsg>();
199
200        Ok(std::iter::from_fn(move || {
201            let result: anyhow::Result<Option<InstrumentAny>> = (|| {
202                dbn_stream
203                    .advance()
204                    .map_err(|e| anyhow::anyhow!("Stream advance error: {e}"))?;
205
206                if let Some(rec) = dbn_stream.get() {
207                    let record = dbn::RecordRef::from(rec);
208                    let msg = record
209                        .get::<InstrumentDefMsg>()
210                        .ok_or_else(|| anyhow::anyhow!("Failed to decode InstrumentDefMsg"))?;
211
212                    // Symbol and venue resolution
213                    let raw_symbol = rec
214                        .raw_symbol()
215                        .map_err(|e| anyhow::anyhow!("Error decoding `raw_symbol`: {e}"))?;
216                    let symbol = Symbol::from(raw_symbol);
217
218                    let publisher = rec
219                        .hd
220                        .publisher()
221                        .map_err(|e| anyhow::anyhow!("Invalid `publisher` for record: {e}"))?;
222                    let venue = match publisher {
223                        Publisher::GlbxMdp3Glbx if use_exchange_as_venue => {
224                            let exchange = rec.exchange().map_err(|e| {
225                                anyhow::anyhow!("Missing `exchange` for record: {e}")
226                            })?;
227                            let venue = Venue::from_code(exchange).map_err(|e| {
228                                anyhow::anyhow!("Venue not found for exchange {exchange}: {e}")
229                            })?;
230                            self.symbol_venue_map.insert(symbol, venue);
231                            venue
232                        }
233                        _ => *self
234                            .publisher_venue_map
235                            .get(&msg.hd.publisher_id)
236                            .ok_or_else(|| {
237                                anyhow::anyhow!(
238                                    "Venue not found for publisher_id {}",
239                                    msg.hd.publisher_id
240                                )
241                            })?,
242                    };
243                    let instrument_id = InstrumentId::new(symbol, venue);
244                    let ts_init = msg.ts_recv.into();
245
246                    let data = decode_instrument_def_msg(rec, instrument_id, Some(ts_init))?;
247                    Ok(Some(data))
248                } else {
249                    // No more records
250                    Ok(None)
251                }
252            })();
253
254            match result {
255                Ok(Some(item)) => Some(Ok(item)),
256                Ok(None) => None,
257                Err(e) => Some(Err(e)),
258            }
259        }))
260    }
261
262    /// # Errors
263    ///
264    /// Returns an error if reading records fails.
265    pub fn read_records<T>(
266        &self,
267        filepath: &Path,
268        instrument_id: Option<InstrumentId>,
269        price_precision: Option<u8>,
270        include_trades: bool,
271        bars_timestamp_on_close: Option<bool>,
272    ) -> anyhow::Result<impl Iterator<Item = anyhow::Result<(Option<Data>, Option<Data>)>> + '_>
273    where
274        T: dbn::Record + dbn::HasRType + 'static,
275    {
276        let decoder = Decoder::from_zstd_file(filepath)?;
277        let metadata = decoder.metadata().clone();
278        let mut metadata_cache = MetadataCache::new(metadata);
279        let mut dbn_stream = decoder.decode_stream::<T>();
280
281        let price_precision = price_precision.unwrap_or(Currency::USD().precision);
282
283        Ok(std::iter::from_fn(move || {
284            let result: anyhow::Result<Option<(Option<Data>, Option<Data>)>> = (|| {
285                dbn_stream
286                    .advance()
287                    .map_err(|e| anyhow::anyhow!("Stream advance error: {e}"))?;
288                if let Some(rec) = dbn_stream.get() {
289                    let record = dbn::RecordRef::from(rec);
290                    let instrument_id = if let Some(id) = &instrument_id {
291                        *id
292                    } else {
293                        decode_nautilus_instrument_id(
294                            &record,
295                            &mut metadata_cache,
296                            &self.publisher_venue_map,
297                            &self.symbol_venue_map,
298                        )
299                        .context("Failed to decode instrument id")?
300                    };
301                    let (item1, item2) = decode_record(
302                        &record,
303                        instrument_id,
304                        price_precision,
305                        None,
306                        include_trades,
307                        bars_timestamp_on_close.unwrap_or(true),
308                    )?;
309                    Ok(Some((item1, item2)))
310                } else {
311                    Ok(None)
312                }
313            })();
314            match result {
315                Ok(Some(v)) => Some(Ok(v)),
316                Ok(None) => None,
317                Err(e) => Some(Err(e)),
318            }
319        }))
320    }
321
322    /// # Errors
323    ///
324    /// Returns an error if loading instruments fails.
325    pub fn load_instruments(
326        &mut self,
327        filepath: &Path,
328        use_exchange_as_venue: bool,
329    ) -> anyhow::Result<Vec<InstrumentAny>> {
330        self.read_definition_records(filepath, use_exchange_as_venue)?
331            .collect::<Result<Vec<_>, _>>()
332    }
333
334    // Cannot include trades
335    /// # Errors
336    ///
337    /// Returns an error if loading order book deltas fails.
338    pub fn load_order_book_deltas(
339        &self,
340        filepath: &Path,
341        instrument_id: Option<InstrumentId>,
342        price_precision: Option<u8>,
343    ) -> anyhow::Result<Vec<OrderBookDelta>> {
344        self.read_records::<dbn::MboMsg>(filepath, instrument_id, price_precision, false, None)?
345            .filter_map(|result| match result {
346                Ok((Some(item1), _)) => {
347                    if let Data::Delta(delta) = item1 {
348                        Some(Ok(delta))
349                    } else {
350                        None
351                    }
352                }
353                Ok((None, _)) => None,
354                Err(e) => Some(Err(e)),
355            })
356            .collect()
357    }
358
359    /// # Errors
360    ///
361    /// Returns an error if loading order book depth10 fails.
362    pub fn load_order_book_depth10(
363        &self,
364        filepath: &Path,
365        instrument_id: Option<InstrumentId>,
366        price_precision: Option<u8>,
367    ) -> anyhow::Result<Vec<OrderBookDepth10>> {
368        self.read_records::<dbn::Mbp10Msg>(filepath, instrument_id, price_precision, false, None)?
369            .filter_map(|result| match result {
370                Ok((Some(item1), _)) => {
371                    if let Data::Depth10(depth) = item1 {
372                        Some(Ok(*depth))
373                    } else {
374                        None
375                    }
376                }
377                Ok((None, _)) => None,
378                Err(e) => Some(Err(e)),
379            })
380            .collect()
381    }
382
383    /// # Errors
384    ///
385    /// Returns an error if loading quotes fails.
386    pub fn load_quotes(
387        &self,
388        filepath: &Path,
389        instrument_id: Option<InstrumentId>,
390        price_precision: Option<u8>,
391    ) -> anyhow::Result<Vec<QuoteTick>> {
392        self.read_records::<dbn::Mbp1Msg>(filepath, instrument_id, price_precision, false, None)?
393            .filter_map(|result| match result {
394                Ok((Some(item1), _)) => {
395                    if let Data::Quote(quote) = item1 {
396                        Some(Ok(quote))
397                    } else {
398                        None
399                    }
400                }
401                Ok((None, _)) => None,
402                Err(e) => Some(Err(e)),
403            })
404            .collect()
405    }
406
407    /// # Errors
408    ///
409    /// Returns an error if loading BBO quotes fails.
410    pub fn load_bbo_quotes(
411        &self,
412        filepath: &Path,
413        instrument_id: Option<InstrumentId>,
414        price_precision: Option<u8>,
415    ) -> anyhow::Result<Vec<QuoteTick>> {
416        self.read_records::<dbn::BboMsg>(filepath, instrument_id, price_precision, false, None)?
417            .filter_map(|result| match result {
418                Ok((Some(item1), _)) => {
419                    if let Data::Quote(quote) = item1 {
420                        Some(Ok(quote))
421                    } else {
422                        None
423                    }
424                }
425                Ok((None, _)) => None,
426                Err(e) => Some(Err(e)),
427            })
428            .collect()
429    }
430
431    /// # Errors
432    ///
433    /// Returns an error if loading TBBO trades fails.
434    pub fn load_tbbo_trades(
435        &self,
436        filepath: &Path,
437        instrument_id: Option<InstrumentId>,
438        price_precision: Option<u8>,
439    ) -> anyhow::Result<Vec<TradeTick>> {
440        self.read_records::<dbn::TbboMsg>(filepath, instrument_id, price_precision, false, None)?
441            .filter_map(|result| match result {
442                Ok((_, maybe_item2)) => {
443                    if let Some(Data::Trade(trade)) = maybe_item2 {
444                        Some(Ok(trade))
445                    } else {
446                        None
447                    }
448                }
449                Err(e) => Some(Err(e)),
450            })
451            .collect()
452    }
453
454    /// # Errors
455    ///
456    /// Returns an error if loading trades fails.
457    pub fn load_trades(
458        &self,
459        filepath: &Path,
460        instrument_id: Option<InstrumentId>,
461        price_precision: Option<u8>,
462    ) -> anyhow::Result<Vec<TradeTick>> {
463        self.read_records::<dbn::TradeMsg>(filepath, instrument_id, price_precision, false, None)?
464            .filter_map(|result| match result {
465                Ok((Some(item1), _)) => {
466                    if let Data::Trade(trade) = item1 {
467                        Some(Ok(trade))
468                    } else {
469                        None
470                    }
471                }
472                Ok((None, _)) => None,
473                Err(e) => Some(Err(e)),
474            })
475            .collect()
476    }
477
478    /// # Errors
479    ///
480    /// Returns an error if loading bars fails.
481    pub fn load_bars(
482        &self,
483        filepath: &Path,
484        instrument_id: Option<InstrumentId>,
485        price_precision: Option<u8>,
486        timestamp_on_close: Option<bool>,
487    ) -> anyhow::Result<Vec<Bar>> {
488        self.read_records::<dbn::OhlcvMsg>(
489            filepath,
490            instrument_id,
491            price_precision,
492            false,
493            timestamp_on_close,
494        )?
495        .filter_map(|result| match result {
496            Ok((Some(item1), _)) => {
497                if let Data::Bar(bar) = item1 {
498                    Some(Ok(bar))
499                } else {
500                    None
501                }
502            }
503            Ok((None, _)) => None,
504            Err(e) => Some(Err(e)),
505        })
506        .collect()
507    }
508
509    /// # Errors
510    ///
511    /// Returns an error if loading status records fails.
512    pub fn load_status_records<T>(
513        &self,
514        filepath: &Path,
515        instrument_id: Option<InstrumentId>,
516    ) -> anyhow::Result<impl Iterator<Item = anyhow::Result<InstrumentStatus>> + '_>
517    where
518        T: dbn::Record + dbn::HasRType + 'static,
519    {
520        let decoder = Decoder::from_zstd_file(filepath)?;
521        let metadata = decoder.metadata().clone();
522        let mut metadata_cache = MetadataCache::new(metadata);
523        let mut dbn_stream = decoder.decode_stream::<T>();
524
525        Ok(std::iter::from_fn(move || {
526            if let Err(e) = dbn_stream.advance() {
527                return Some(Err(e.into()));
528            }
529            match dbn_stream.get() {
530                Some(rec) => {
531                    let record = dbn::RecordRef::from(rec);
532                    let instrument_id = match &instrument_id {
533                        Some(id) => *id, // Copy
534                        None => match decode_nautilus_instrument_id(
535                            &record,
536                            &mut metadata_cache,
537                            &self.publisher_venue_map,
538                            &self.symbol_venue_map,
539                        ) {
540                            Ok(id) => id,
541                            Err(e) => return Some(Err(e)),
542                        },
543                    };
544
545                    let msg = match record.get::<dbn::StatusMsg>() {
546                        Some(m) => m,
547                        None => return Some(Err(anyhow::anyhow!("Invalid `StatusMsg`"))),
548                    };
549                    let ts_init = msg.ts_recv.into();
550
551                    match decode_status_msg(msg, instrument_id, Some(ts_init)) {
552                        Ok(data) => Some(Ok(data)),
553                        Err(e) => Some(Err(e)),
554                    }
555                }
556                None => None,
557            }
558        }))
559    }
560
561    /// # Errors
562    ///
563    /// Returns an error if reading imbalance records fails.
564    pub fn read_imbalance_records<T>(
565        &self,
566        filepath: &Path,
567        instrument_id: Option<InstrumentId>,
568        price_precision: Option<u8>,
569    ) -> anyhow::Result<impl Iterator<Item = anyhow::Result<DatabentoImbalance>> + '_>
570    where
571        T: dbn::Record + dbn::HasRType + 'static,
572    {
573        let decoder = Decoder::from_zstd_file(filepath)?;
574        let metadata = decoder.metadata().clone();
575        let mut metadata_cache = MetadataCache::new(metadata);
576        let mut dbn_stream = decoder.decode_stream::<T>();
577
578        let price_precision = price_precision.unwrap_or(Currency::USD().precision);
579
580        Ok(std::iter::from_fn(move || {
581            if let Err(e) = dbn_stream.advance() {
582                return Some(Err(e.into()));
583            }
584            match dbn_stream.get() {
585                Some(rec) => {
586                    let record = dbn::RecordRef::from(rec);
587                    let instrument_id = match &instrument_id {
588                        Some(id) => *id, // Copy
589                        None => match decode_nautilus_instrument_id(
590                            &record,
591                            &mut metadata_cache,
592                            &self.publisher_venue_map,
593                            &self.symbol_venue_map,
594                        ) {
595                            Ok(id) => id,
596                            Err(e) => return Some(Err(e)),
597                        },
598                    };
599
600                    let msg = match record.get::<dbn::ImbalanceMsg>() {
601                        Some(m) => m,
602                        None => return Some(Err(anyhow::anyhow!("Invalid `ImbalanceMsg`"))),
603                    };
604                    let ts_init = msg.ts_recv.into();
605
606                    match decode_imbalance_msg(msg, instrument_id, price_precision, Some(ts_init)) {
607                        Ok(data) => Some(Ok(data)),
608                        Err(e) => Some(Err(e)),
609                    }
610                }
611                None => None,
612            }
613        }))
614    }
615
616    /// # Errors
617    ///
618    /// Returns an error if reading statistics records fails.
619    pub fn read_statistics_records<T>(
620        &self,
621        filepath: &Path,
622        instrument_id: Option<InstrumentId>,
623        price_precision: Option<u8>,
624    ) -> anyhow::Result<impl Iterator<Item = anyhow::Result<DatabentoStatistics>> + '_>
625    where
626        T: dbn::Record + dbn::HasRType + 'static,
627    {
628        let decoder = Decoder::from_zstd_file(filepath)?;
629        let metadata = decoder.metadata().clone();
630        let mut metadata_cache = MetadataCache::new(metadata);
631        let mut dbn_stream = decoder.decode_stream::<T>();
632
633        let price_precision = price_precision.unwrap_or(Currency::USD().precision);
634
635        Ok(std::iter::from_fn(move || {
636            if let Err(e) = dbn_stream.advance() {
637                return Some(Err(e.into()));
638            }
639            match dbn_stream.get() {
640                Some(rec) => {
641                    let record = dbn::RecordRef::from(rec);
642                    let instrument_id = match &instrument_id {
643                        Some(id) => *id, // Copy
644                        None => match decode_nautilus_instrument_id(
645                            &record,
646                            &mut metadata_cache,
647                            &self.publisher_venue_map,
648                            &self.symbol_venue_map,
649                        ) {
650                            Ok(id) => id,
651                            Err(e) => return Some(Err(e)),
652                        },
653                    };
654                    let msg = match record.get::<dbn::StatMsg>() {
655                        Some(m) => m,
656                        None => return Some(Err(anyhow::anyhow!("Invalid `StatMsg`"))),
657                    };
658                    let ts_init = msg.ts_recv.into();
659
660                    match decode_statistics_msg(msg, instrument_id, price_precision, Some(ts_init))
661                    {
662                        Ok(data) => Some(Ok(data)),
663                        Err(e) => Some(Err(e)),
664                    }
665                }
666                None => None,
667            }
668        }))
669    }
670}
671
672////////////////////////////////////////////////////////////////////////////////
673// Tests
674////////////////////////////////////////////////////////////////////////////////
675#[cfg(test)]
676mod tests {
677    use std::path::{Path, PathBuf};
678
679    use rstest::{fixture, rstest};
680    use ustr::Ustr;
681
682    use super::*;
683
684    fn test_data_path() -> PathBuf {
685        Path::new(env!("CARGO_MANIFEST_DIR")).join("test_data")
686    }
687
688    #[fixture]
689    fn loader() -> DatabentoDataLoader {
690        let publishers_filepath = Path::new(env!("CARGO_MANIFEST_DIR")).join("publishers.json");
691        DatabentoDataLoader::new(Some(publishers_filepath)).unwrap()
692    }
693
694    // TODO: Improve the below assertions that we've actually read the records we expected
695
696    #[rstest]
697    fn test_set_dataset_venue_mapping(mut loader: DatabentoDataLoader) {
698        let dataset = Ustr::from("EQUS.PLUS");
699        let venue = Venue::from("XNAS");
700        loader.set_dataset_for_venue(dataset, venue);
701
702        let result = loader.get_dataset_for_venue(&venue).unwrap();
703        assert_eq!(*result, dataset);
704    }
705
706    #[rstest]
707    #[case(test_data_path().join("test_data.definition.dbn.zst"))]
708    fn test_load_instruments(mut loader: DatabentoDataLoader, #[case] path: PathBuf) {
709        let instruments = loader.load_instruments(&path, false).unwrap();
710
711        assert_eq!(instruments.len(), 2);
712    }
713
714    #[rstest]
715    fn test_load_order_book_deltas(loader: DatabentoDataLoader) {
716        let path = test_data_path().join("test_data.mbo.dbn.zst");
717        let instrument_id = InstrumentId::from("ESM4.GLBX");
718
719        let deltas = loader
720            .load_order_book_deltas(&path, Some(instrument_id), None)
721            .unwrap();
722
723        assert_eq!(deltas.len(), 2);
724    }
725
726    #[rstest]
727    fn test_load_order_book_depth10(loader: DatabentoDataLoader) {
728        let path = test_data_path().join("test_data.mbp-10.dbn.zst");
729        let instrument_id = InstrumentId::from("ESM4.GLBX");
730
731        let depths = loader
732            .load_order_book_depth10(&path, Some(instrument_id), None)
733            .unwrap();
734
735        assert_eq!(depths.len(), 2);
736    }
737
738    #[rstest]
739    fn test_load_quotes(loader: DatabentoDataLoader) {
740        let path = test_data_path().join("test_data.mbp-1.dbn.zst");
741        let instrument_id = InstrumentId::from("ESM4.GLBX");
742
743        let quotes = loader
744            .load_quotes(&path, Some(instrument_id), None)
745            .unwrap();
746
747        assert_eq!(quotes.len(), 2);
748    }
749
750    #[rstest]
751    #[case(test_data_path().join("test_data.bbo-1s.dbn.zst"))]
752    #[case(test_data_path().join("test_data.bbo-1m.dbn.zst"))]
753    fn test_load_bbo_quotes(loader: DatabentoDataLoader, #[case] path: PathBuf) {
754        let instrument_id = InstrumentId::from("ESM4.GLBX");
755
756        let quotes = loader
757            .load_bbo_quotes(&path, Some(instrument_id), None)
758            .unwrap();
759
760        assert_eq!(quotes.len(), 2);
761    }
762
763    #[rstest]
764    fn test_load_tbbo_trades(loader: DatabentoDataLoader) {
765        let path = test_data_path().join("test_data.tbbo.dbn.zst");
766        let instrument_id = InstrumentId::from("ESM4.GLBX");
767
768        let _trades = loader
769            .load_tbbo_trades(&path, Some(instrument_id), None)
770            .unwrap();
771
772        // assert_eq!(trades.len(), 2);  TODO: No records?
773    }
774
775    #[rstest]
776    fn test_load_trades(loader: DatabentoDataLoader) {
777        let path = test_data_path().join("test_data.trades.dbn.zst");
778        let instrument_id = InstrumentId::from("ESM4.GLBX");
779        let trades = loader
780            .load_trades(&path, Some(instrument_id), None)
781            .unwrap();
782
783        assert_eq!(trades.len(), 2);
784    }
785
786    #[rstest]
787    // #[case(test_data_path().join("test_data.ohlcv-1d.dbn.zst"))]  // TODO: Needs new data
788    #[case(test_data_path().join("test_data.ohlcv-1h.dbn.zst"))]
789    #[case(test_data_path().join("test_data.ohlcv-1m.dbn.zst"))]
790    #[case(test_data_path().join("test_data.ohlcv-1s.dbn.zst"))]
791    fn test_load_bars(loader: DatabentoDataLoader, #[case] path: PathBuf) {
792        let instrument_id = InstrumentId::from("ESM4.GLBX");
793        let bars = loader
794            .load_bars(&path, Some(instrument_id), None, None)
795            .unwrap();
796
797        assert_eq!(bars.len(), 2);
798    }
799
800    #[rstest]
801    #[case(test_data_path().join("test_data.ohlcv-1s.dbn.zst"))]
802    fn test_load_bars_timestamp_on_close_true(loader: DatabentoDataLoader, #[case] path: PathBuf) {
803        let instrument_id = InstrumentId::from("ESM4.GLBX");
804        let bars = loader
805            .load_bars(&path, Some(instrument_id), None, Some(true))
806            .unwrap();
807
808        assert_eq!(bars.len(), 2);
809
810        // When bars_timestamp_on_close is true, both ts_event and ts_init should be equal (close time)
811        for bar in &bars {
812            assert_eq!(
813                bar.ts_event, bar.ts_init,
814                "ts_event and ts_init should be equal when bars_timestamp_on_close=true"
815            );
816            // For 1-second bars, ts_event should be 1 second after the open time
817            // This confirms the bar is timestamped at close
818        }
819    }
820
821    #[rstest]
822    #[case(test_data_path().join("test_data.ohlcv-1s.dbn.zst"))]
823    fn test_load_bars_timestamp_on_close_false(loader: DatabentoDataLoader, #[case] path: PathBuf) {
824        let instrument_id = InstrumentId::from("ESM4.GLBX");
825        let bars = loader
826            .load_bars(&path, Some(instrument_id), None, Some(false))
827            .unwrap();
828
829        assert_eq!(bars.len(), 2);
830
831        // When bars_timestamp_on_close is false, both ts_event and ts_init should be equal (open time)
832        for bar in &bars {
833            assert_eq!(
834                bar.ts_event, bar.ts_init,
835                "ts_event and ts_init should be equal when bars_timestamp_on_close=false"
836            );
837        }
838    }
839
840    #[rstest]
841    #[case(test_data_path().join("test_data.ohlcv-1s.dbn.zst"), 0)]
842    #[case(test_data_path().join("test_data.ohlcv-1s.dbn.zst"), 1)]
843    fn test_load_bars_timestamp_comparison(
844        loader: DatabentoDataLoader,
845        #[case] path: PathBuf,
846        #[case] bar_index: usize,
847    ) {
848        let instrument_id = InstrumentId::from("ESM4.GLBX");
849
850        let bars_close = loader
851            .load_bars(&path, Some(instrument_id), None, Some(true))
852            .unwrap();
853
854        let bars_open = loader
855            .load_bars(&path, Some(instrument_id), None, Some(false))
856            .unwrap();
857
858        assert_eq!(bars_close.len(), bars_open.len());
859        assert_eq!(bars_close.len(), 2);
860
861        let bar_close = &bars_close[bar_index];
862        let bar_open = &bars_open[bar_index];
863
864        // Bars should have the same OHLCV data
865        assert_eq!(bar_close.open, bar_open.open);
866        assert_eq!(bar_close.high, bar_open.high);
867        assert_eq!(bar_close.low, bar_open.low);
868        assert_eq!(bar_close.close, bar_open.close);
869        assert_eq!(bar_close.volume, bar_open.volume);
870
871        // The close-timestamped bar should have later timestamp than open-timestamped bar
872        // For 1-second bars, this should be exactly 1 second difference
873        assert!(
874            bar_close.ts_event > bar_open.ts_event,
875            "Close-timestamped bar should have later timestamp than open-timestamped bar"
876        );
877
878        // The difference should be exactly 1 second (1_000_000_000 nanoseconds) for 1s bars
879        const ONE_SECOND_NS: u64 = 1_000_000_000;
880        assert_eq!(
881            bar_close.ts_event.as_u64() - bar_open.ts_event.as_u64(),
882            ONE_SECOND_NS,
883            "Timestamp difference should be exactly 1 second for 1s bars"
884        );
885    }
886}