1use std::{
17 sync::{Arc, RwLock},
18 time::Duration as StdDuration,
19};
20
21use ahash::{AHashMap, HashSet, HashSetExt};
22use databento::{
23 dbn::{self, PitSymbolMap, Record, SymbolIndex},
24 live::Subscription,
25};
26use indexmap::IndexMap;
27use nautilus_core::{UnixNanos, consts::NAUTILUS_USER_AGENT, time::get_atomic_clock_realtime};
28use nautilus_model::{
29 data::{Data, InstrumentStatus, OrderBookDelta, OrderBookDeltas, OrderBookDeltas_API},
30 enums::RecordFlag,
31 identifiers::{InstrumentId, Symbol, Venue},
32 instruments::InstrumentAny,
33};
34use nautilus_network::backoff::ExponentialBackoff;
35use tokio::{
36 sync::mpsc::error::TryRecvError,
37 time::{Duration, Instant},
38};
39
40use super::{
41 decode::{decode_imbalance_msg, decode_statistics_msg, decode_status_msg},
42 types::{DatabentoImbalance, DatabentoStatistics},
43};
44use crate::{
45 decode::{decode_instrument_def_msg, decode_record},
46 types::PublisherId,
47};
48
49#[derive(Debug)]
50pub enum LiveCommand {
51 Subscribe(Subscription),
52 Start,
53 Close,
54}
55
56#[derive(Debug)]
57#[allow(
58 clippy::large_enum_variant,
59 reason = "TODO: Optimize this (largest variant 1096 vs 80 bytes)"
60)]
61pub enum LiveMessage {
62 Data(Data),
63 Instrument(InstrumentAny),
64 Status(InstrumentStatus),
65 Imbalance(DatabentoImbalance),
66 Statistics(DatabentoStatistics),
67 Error(anyhow::Error),
68 Close,
69}
70
71#[derive(Debug)]
85pub struct DatabentoFeedHandler {
86 key: String,
87 dataset: String,
88 cmd_rx: tokio::sync::mpsc::UnboundedReceiver<LiveCommand>,
89 msg_tx: tokio::sync::mpsc::Sender<LiveMessage>,
90 publisher_venue_map: IndexMap<PublisherId, Venue>,
91 symbol_venue_map: Arc<RwLock<AHashMap<Symbol, Venue>>>,
92 replay: bool,
93 use_exchange_as_venue: bool,
94 bars_timestamp_on_close: bool,
95 reconnect_timeout_mins: Option<u64>,
96 backoff: ExponentialBackoff,
97 subscriptions: Vec<Subscription>,
98 buffered_commands: Vec<LiveCommand>,
99}
100
101impl DatabentoFeedHandler {
102 #[must_use]
108 #[allow(clippy::too_many_arguments)]
109 pub fn new(
110 key: String,
111 dataset: String,
112 rx: tokio::sync::mpsc::UnboundedReceiver<LiveCommand>,
113 tx: tokio::sync::mpsc::Sender<LiveMessage>,
114 publisher_venue_map: IndexMap<PublisherId, Venue>,
115 symbol_venue_map: Arc<RwLock<AHashMap<Symbol, Venue>>>,
116 use_exchange_as_venue: bool,
117 bars_timestamp_on_close: bool,
118 reconnect_timeout_mins: Option<u64>,
119 ) -> Self {
120 let delay_max = if reconnect_timeout_mins.is_some() {
124 Duration::from_secs(60)
125 } else {
126 Duration::from_secs(600)
127 };
128
129 let backoff =
131 ExponentialBackoff::new(Duration::from_secs(1), delay_max, 2.0, 1000, false).unwrap();
132
133 Self {
134 key,
135 dataset,
136 cmd_rx: rx,
137 msg_tx: tx,
138 publisher_venue_map,
139 symbol_venue_map,
140 replay: false,
141 use_exchange_as_venue,
142 bars_timestamp_on_close,
143 reconnect_timeout_mins,
144 backoff,
145 subscriptions: Vec::new(),
146 buffered_commands: Vec::new(),
147 }
148 }
149
150 #[allow(clippy::blocks_in_conditions)]
163 pub async fn run(&mut self) -> anyhow::Result<()> {
164 tracing::debug!("Running feed handler");
165
166 let mut reconnect_start: Option<Instant> = None;
167 let mut attempt = 0;
168
169 loop {
170 attempt += 1;
171
172 match self.run_session(attempt).await {
173 Ok(ran_successfully) => {
174 if ran_successfully {
175 tracing::info!("Resetting reconnection cycle after successful session");
176 reconnect_start = None;
177 attempt = 0;
178 self.backoff.reset();
179 continue;
180 } else {
181 tracing::info!("Session ended normally");
182 break Ok(());
183 }
184 }
185 Err(e) => {
186 let cycle_start = reconnect_start.get_or_insert_with(Instant::now);
187
188 if let Some(timeout_mins) = self.reconnect_timeout_mins {
189 let elapsed = cycle_start.elapsed();
190 let timeout = Duration::from_secs(timeout_mins * 60);
191
192 if elapsed >= timeout {
193 tracing::error!(
194 "Giving up reconnection after {} minutes",
195 timeout_mins
196 );
197 self.send_msg(LiveMessage::Error(anyhow::anyhow!(
198 "Reconnection timeout after {} minutes: {}",
199 timeout_mins,
200 e
201 )))
202 .await;
203 break Err(e);
204 }
205 }
206
207 let delay = self.backoff.next_duration();
208
209 tracing::warn!(
210 "Connection lost (attempt {}): {}. Reconnecting in {}s...",
211 attempt,
212 e,
213 delay.as_secs()
214 );
215
216 tokio::select! {
217 _ = tokio::time::sleep(delay) => {}
218 cmd = self.cmd_rx.recv() => {
219 match cmd {
220 Some(LiveCommand::Close) => {
221 tracing::info!("Close received during backoff");
222 return Ok(());
223 }
224 None => {
225 tracing::debug!("Command channel closed during backoff");
226 return Ok(());
227 }
228 Some(cmd) => {
229 tracing::debug!("Buffering command received during backoff: {:?}", cmd);
230 self.buffered_commands.push(cmd);
231 }
232 }
233 }
234 }
235 }
236 }
237 }
238 }
239
240 async fn run_session(&mut self, attempt: usize) -> anyhow::Result<bool> {
249 if attempt > 1 {
250 tracing::info!("Reconnecting (attempt {})...", attempt);
251 }
252
253 let session_start = Instant::now();
254 let clock = get_atomic_clock_realtime();
255 let mut symbol_map = PitSymbolMap::new();
256 let mut instrument_id_map: AHashMap<u32, InstrumentId> = AHashMap::new();
257
258 let mut buffering_start = None;
259 let mut buffered_deltas: AHashMap<InstrumentId, Vec<OrderBookDelta>> = AHashMap::new();
260 let mut initialized_books = HashSet::new();
261 let timeout = Duration::from_secs(5); let result = tokio::time::timeout(
264 timeout,
265 databento::LiveClient::builder()
266 .user_agent_extension(NAUTILUS_USER_AGENT.into())
267 .key(self.key.clone())?
268 .dataset(self.dataset.clone())
269 .build(),
270 )
271 .await?;
272
273 let mut client = match result {
274 Ok(client) => {
275 if attempt > 1 {
276 tracing::info!("Reconnected successfully");
277 } else {
278 tracing::info!("Connected");
279 }
280 client
281 }
282 Err(e) => {
283 anyhow::bail!("Failed to connect to Databento LSG: {e}");
284 }
285 };
286
287 let mut start_buffered = false;
289 if !self.buffered_commands.is_empty() {
290 tracing::info!(
291 "Processing {} buffered commands",
292 self.buffered_commands.len()
293 );
294 for cmd in self.buffered_commands.drain(..) {
295 match cmd {
296 LiveCommand::Subscribe(sub) => {
297 if !self.replay && sub.start.is_some() {
298 self.replay = true;
299 }
300 self.subscriptions.push(sub);
301 }
302 LiveCommand::Start => {
303 start_buffered = true;
304 }
305 LiveCommand::Close => {
306 tracing::warn!("Close command was buffered, shutting down");
307 return Ok(false);
308 }
309 }
310 }
311 }
312
313 let timeout = Duration::from_millis(10);
314 let mut running = false;
315
316 if !self.subscriptions.is_empty() {
317 tracing::info!(
318 "Resubscribing to {} subscriptions",
319 self.subscriptions.len()
320 );
321 for sub in self.subscriptions.clone() {
322 client.subscribe(sub).await?;
323 }
324 for sub in &mut self.subscriptions {
326 sub.start = None;
327 }
328 client.start().await?;
329 running = true;
330 tracing::info!("Resubscription complete");
331 } else if start_buffered {
332 tracing::info!("Starting session from buffered Start command");
333 buffering_start = if self.replay {
334 Some(clock.get_time_ns())
335 } else {
336 None
337 };
338 client.start().await?;
339 running = true;
340 }
341
342 loop {
343 if self.msg_tx.is_closed() {
344 tracing::debug!("Message channel was closed: stopping");
345 return Ok(false);
346 }
347
348 match self.cmd_rx.try_recv() {
349 Ok(cmd) => {
350 tracing::debug!("Received command: {cmd:?}");
351 match cmd {
352 LiveCommand::Subscribe(sub) => {
353 if !self.replay && sub.start.is_some() {
354 self.replay = true;
355 }
356 client.subscribe(sub.clone()).await?;
357 let mut sub_for_reconnect = sub;
359 sub_for_reconnect.start = None;
360 self.subscriptions.push(sub_for_reconnect);
361 }
362 LiveCommand::Start => {
363 buffering_start = if self.replay {
364 Some(clock.get_time_ns())
365 } else {
366 None
367 };
368 client.start().await?;
369 running = true;
370 tracing::debug!("Started");
371 }
372 LiveCommand::Close => {
373 self.msg_tx.send(LiveMessage::Close).await?;
374 if running {
375 client.close().await?;
376 tracing::debug!("Closed inner client");
377 }
378 return Ok(false);
379 }
380 }
381 }
382 Err(TryRecvError::Empty) => {}
383 Err(TryRecvError::Disconnected) => {
384 tracing::debug!("Command channel disconnected");
385 return Ok(false);
386 }
387 }
388
389 if !running {
390 continue;
391 }
392
393 let result = tokio::time::timeout(timeout, client.next_record()).await;
394 let record_opt = match result {
395 Ok(record_opt) => record_opt,
396 Err(_) => continue,
397 };
398
399 let record = match record_opt {
400 Ok(Some(record)) => record,
401 Ok(None) => {
402 const SUCCESS_THRESHOLD: Duration = Duration::from_secs(60);
403 if session_start.elapsed() >= SUCCESS_THRESHOLD {
404 tracing::info!("Session ended after successful run");
405 return Ok(true);
406 }
407 anyhow::bail!("Session ended by gateway");
408 }
409 Err(e) => {
410 const SUCCESS_THRESHOLD: Duration = Duration::from_secs(60);
411 if session_start.elapsed() >= SUCCESS_THRESHOLD {
412 tracing::info!("Connection error after successful run: {e}");
413 return Ok(true);
414 }
415 anyhow::bail!("Connection error: {e}");
416 }
417 };
418
419 let ts_init = clock.get_time_ns();
420
421 if let Some(msg) = record.get::<dbn::ErrorMsg>() {
423 handle_error_msg(msg);
424 } else if let Some(msg) = record.get::<dbn::SystemMsg>() {
425 handle_system_msg(msg);
426 } else if let Some(msg) = record.get::<dbn::SymbolMappingMsg>() {
427 instrument_id_map.remove(&msg.hd.instrument_id);
429 handle_symbol_mapping_msg(msg, &mut symbol_map, &mut instrument_id_map)?;
430 } else if let Some(msg) = record.get::<dbn::InstrumentDefMsg>() {
431 if self.use_exchange_as_venue {
432 let exchange = msg.exchange()?;
433 if !exchange.is_empty() {
434 update_instrument_id_map_with_exchange(
435 &symbol_map,
436 &self.symbol_venue_map,
437 &mut instrument_id_map,
438 msg.hd.instrument_id,
439 exchange,
440 )?;
441 }
442 }
443 let data = {
444 let sym_map = self.read_symbol_venue_map()?;
445 handle_instrument_def_msg(
446 msg,
447 &record,
448 &symbol_map,
449 &self.publisher_venue_map,
450 &sym_map,
451 &mut instrument_id_map,
452 ts_init,
453 )?
454 };
455 self.send_msg(LiveMessage::Instrument(data)).await;
456 } else if let Some(msg) = record.get::<dbn::StatusMsg>() {
457 let data = {
458 let sym_map = self.read_symbol_venue_map()?;
459 handle_status_msg(
460 msg,
461 &record,
462 &symbol_map,
463 &self.publisher_venue_map,
464 &sym_map,
465 &mut instrument_id_map,
466 ts_init,
467 )?
468 };
469 self.send_msg(LiveMessage::Status(data)).await;
470 } else if let Some(msg) = record.get::<dbn::ImbalanceMsg>() {
471 let data = {
472 let sym_map = self.read_symbol_venue_map()?;
473 handle_imbalance_msg(
474 msg,
475 &record,
476 &symbol_map,
477 &self.publisher_venue_map,
478 &sym_map,
479 &mut instrument_id_map,
480 ts_init,
481 )?
482 };
483 self.send_msg(LiveMessage::Imbalance(data)).await;
484 } else if let Some(msg) = record.get::<dbn::StatMsg>() {
485 let data = {
486 let sym_map = self.read_symbol_venue_map()?;
487 handle_statistics_msg(
488 msg,
489 &record,
490 &symbol_map,
491 &self.publisher_venue_map,
492 &sym_map,
493 &mut instrument_id_map,
494 ts_init,
495 )?
496 };
497 self.send_msg(LiveMessage::Statistics(data)).await;
498 } else {
499 let res = {
501 let sym_map = self.read_symbol_venue_map()?;
502 handle_record(
503 record,
504 &symbol_map,
505 &self.publisher_venue_map,
506 &sym_map,
507 &mut instrument_id_map,
508 ts_init,
509 &initialized_books,
510 self.bars_timestamp_on_close,
511 )
512 };
513 let (mut data1, data2) = match res {
514 Ok(decoded) => decoded,
515 Err(e) => {
516 tracing::error!("Error decoding record: {e}");
517 continue;
518 }
519 };
520
521 if let Some(msg) = record.get::<dbn::MboMsg>() {
522 if let Some(Data::Delta(delta)) = &data1 {
524 initialized_books.insert(delta.instrument_id);
525 } else {
526 continue; }
528
529 if let Some(Data::Delta(delta)) = &data1 {
530 let buffer = buffered_deltas.entry(delta.instrument_id).or_default();
531 buffer.push(*delta);
532
533 tracing::trace!(
534 "Buffering delta: {} {buffering_start:?} flags={}",
535 delta.ts_event,
536 msg.flags.raw(),
537 );
538
539 if !RecordFlag::F_LAST.matches(msg.flags.raw()) {
541 continue; }
543
544 if RecordFlag::F_SNAPSHOT.matches(msg.flags.raw()) {
546 continue; }
548
549 if let Some(start_ns) = buffering_start {
551 if delta.ts_event <= start_ns {
552 continue; }
554 buffering_start = None;
555 }
556
557 let buffer =
559 buffered_deltas
560 .remove(&delta.instrument_id)
561 .ok_or_else(|| {
562 anyhow::anyhow!(
563 "Internal error: no buffered deltas for instrument {id}",
564 id = delta.instrument_id
565 )
566 })?;
567 let deltas = OrderBookDeltas::new(delta.instrument_id, buffer);
568 let deltas = OrderBookDeltas_API::new(deltas);
569 data1 = Some(Data::Deltas(deltas));
570 }
571 }
572
573 if let Some(data) = data1 {
574 self.send_msg(LiveMessage::Data(data)).await;
575 }
576
577 if let Some(data) = data2 {
578 self.send_msg(LiveMessage::Data(data)).await;
579 }
580 }
581 }
582 }
583
584 async fn send_msg(&mut self, msg: LiveMessage) {
586 tracing::trace!("Sending {msg:?}");
587 match self.msg_tx.send(msg).await {
588 Ok(()) => {}
589 Err(e) => tracing::error!("Error sending message: {e}"),
590 }
591 }
592
593 fn read_symbol_venue_map(
599 &self,
600 ) -> anyhow::Result<std::sync::RwLockReadGuard<'_, AHashMap<Symbol, Venue>>> {
601 const MAX_WAIT_MS: u64 = 500; const INITIAL_DELAY_MICROS: u64 = 10;
604 const MAX_DELAY_MICROS: u64 = 1000;
605
606 let deadline = std::time::Instant::now() + StdDuration::from_millis(MAX_WAIT_MS);
607 let mut delay = INITIAL_DELAY_MICROS;
608
609 loop {
610 match self.symbol_venue_map.try_read() {
611 Ok(guard) => return Ok(guard),
612 Err(std::sync::TryLockError::WouldBlock) => {
613 if std::time::Instant::now() >= deadline {
614 break;
615 }
616
617 std::thread::yield_now();
619
620 if std::time::Instant::now() < deadline {
622 let remaining = deadline - std::time::Instant::now();
623 let sleep_duration = StdDuration::from_micros(delay).min(remaining);
624 std::thread::sleep(sleep_duration);
625 delay = ((delay * 2) + delay / 4).min(MAX_DELAY_MICROS);
627 }
628 }
629 Err(std::sync::TryLockError::Poisoned(e)) => {
630 anyhow::bail!("symbol_venue_map lock poisoned: {e}");
631 }
632 }
633 }
634
635 anyhow::bail!(
636 "Failed to acquire read lock on symbol_venue_map after {MAX_WAIT_MS}ms deadline"
637 )
638 }
639}
640
641fn handle_error_msg(msg: &dbn::ErrorMsg) {
643 tracing::error!("{msg:?}");
644}
645
646fn handle_system_msg(msg: &dbn::SystemMsg) {
648 tracing::info!("{msg:?}");
649}
650
651fn handle_symbol_mapping_msg(
657 msg: &dbn::SymbolMappingMsg,
658 symbol_map: &mut PitSymbolMap,
659 instrument_id_map: &mut AHashMap<u32, InstrumentId>,
660) -> anyhow::Result<()> {
661 symbol_map
662 .on_symbol_mapping(msg)
663 .map_err(|e| anyhow::anyhow!("on_symbol_mapping failed for {msg:?}: {e}"))?;
664 instrument_id_map.remove(&msg.header().instrument_id);
665 Ok(())
666}
667
668fn update_instrument_id_map_with_exchange(
670 symbol_map: &PitSymbolMap,
671 symbol_venue_map: &RwLock<AHashMap<Symbol, Venue>>,
672 instrument_id_map: &mut AHashMap<u32, InstrumentId>,
673 raw_instrument_id: u32,
674 exchange: &str,
675) -> anyhow::Result<InstrumentId> {
676 let raw_symbol = symbol_map.get(raw_instrument_id).ok_or_else(|| {
677 anyhow::anyhow!("Cannot resolve raw_symbol for instrument_id {raw_instrument_id}")
678 })?;
679 let symbol = Symbol::from(raw_symbol.as_str());
680 let venue = Venue::from_code(exchange)
681 .map_err(|e| anyhow::anyhow!("Invalid venue code '{exchange}': {e}"))?;
682 let instrument_id = InstrumentId::new(symbol, venue);
683 let mut map = symbol_venue_map
684 .write()
685 .map_err(|e| anyhow::anyhow!("symbol_venue_map lock poisoned: {e}"))?;
686 map.entry(symbol).or_insert(venue);
687 instrument_id_map.insert(raw_instrument_id, instrument_id);
688 Ok(instrument_id)
689}
690
691fn update_instrument_id_map(
692 record: &dbn::RecordRef,
693 symbol_map: &PitSymbolMap,
694 publisher_venue_map: &IndexMap<PublisherId, Venue>,
695 symbol_venue_map: &AHashMap<Symbol, Venue>,
696 instrument_id_map: &mut AHashMap<u32, InstrumentId>,
697) -> anyhow::Result<InstrumentId> {
698 let header = record.header();
699
700 if let Some(&instrument_id) = instrument_id_map.get(&header.instrument_id) {
702 return Ok(instrument_id);
703 }
704
705 let raw_symbol = symbol_map.get_for_rec(record).ok_or_else(|| {
706 anyhow::anyhow!(
707 "Cannot resolve `raw_symbol` from `symbol_map` for instrument_id {}",
708 header.instrument_id
709 )
710 })?;
711
712 let symbol = Symbol::from_str_unchecked(raw_symbol);
713
714 let publisher_id = header.publisher_id;
715 let venue = if let Some(venue) = symbol_venue_map.get(&symbol) {
716 *venue
717 } else {
718 let venue = publisher_venue_map
719 .get(&publisher_id)
720 .ok_or_else(|| anyhow::anyhow!("No venue found for `publisher_id` {publisher_id}"))?;
721 *venue
722 };
723 let instrument_id = InstrumentId::new(symbol, venue);
724
725 instrument_id_map.insert(header.instrument_id, instrument_id);
726 Ok(instrument_id)
727}
728
729fn handle_instrument_def_msg(
735 msg: &dbn::InstrumentDefMsg,
736 record: &dbn::RecordRef,
737 symbol_map: &PitSymbolMap,
738 publisher_venue_map: &IndexMap<PublisherId, Venue>,
739 symbol_venue_map: &AHashMap<Symbol, Venue>,
740 instrument_id_map: &mut AHashMap<u32, InstrumentId>,
741 ts_init: UnixNanos,
742) -> anyhow::Result<InstrumentAny> {
743 let instrument_id = update_instrument_id_map(
744 record,
745 symbol_map,
746 publisher_venue_map,
747 symbol_venue_map,
748 instrument_id_map,
749 )?;
750
751 decode_instrument_def_msg(msg, instrument_id, Some(ts_init))
752}
753
754fn handle_status_msg(
755 msg: &dbn::StatusMsg,
756 record: &dbn::RecordRef,
757 symbol_map: &PitSymbolMap,
758 publisher_venue_map: &IndexMap<PublisherId, Venue>,
759 symbol_venue_map: &AHashMap<Symbol, Venue>,
760 instrument_id_map: &mut AHashMap<u32, InstrumentId>,
761 ts_init: UnixNanos,
762) -> anyhow::Result<InstrumentStatus> {
763 let instrument_id = update_instrument_id_map(
764 record,
765 symbol_map,
766 publisher_venue_map,
767 symbol_venue_map,
768 instrument_id_map,
769 )?;
770
771 decode_status_msg(msg, instrument_id, Some(ts_init))
772}
773
774fn handle_imbalance_msg(
775 msg: &dbn::ImbalanceMsg,
776 record: &dbn::RecordRef,
777 symbol_map: &PitSymbolMap,
778 publisher_venue_map: &IndexMap<PublisherId, Venue>,
779 symbol_venue_map: &AHashMap<Symbol, Venue>,
780 instrument_id_map: &mut AHashMap<u32, InstrumentId>,
781 ts_init: UnixNanos,
782) -> anyhow::Result<DatabentoImbalance> {
783 let instrument_id = update_instrument_id_map(
784 record,
785 symbol_map,
786 publisher_venue_map,
787 symbol_venue_map,
788 instrument_id_map,
789 )?;
790
791 let price_precision = 2; decode_imbalance_msg(msg, instrument_id, price_precision, Some(ts_init))
794}
795
796fn handle_statistics_msg(
797 msg: &dbn::StatMsg,
798 record: &dbn::RecordRef,
799 symbol_map: &PitSymbolMap,
800 publisher_venue_map: &IndexMap<PublisherId, Venue>,
801 symbol_venue_map: &AHashMap<Symbol, Venue>,
802 instrument_id_map: &mut AHashMap<u32, InstrumentId>,
803 ts_init: UnixNanos,
804) -> anyhow::Result<DatabentoStatistics> {
805 let instrument_id = update_instrument_id_map(
806 record,
807 symbol_map,
808 publisher_venue_map,
809 symbol_venue_map,
810 instrument_id_map,
811 )?;
812
813 let price_precision = 2; decode_statistics_msg(msg, instrument_id, price_precision, Some(ts_init))
816}
817
818#[allow(clippy::too_many_arguments)]
819fn handle_record(
820 record: dbn::RecordRef,
821 symbol_map: &PitSymbolMap,
822 publisher_venue_map: &IndexMap<PublisherId, Venue>,
823 symbol_venue_map: &AHashMap<Symbol, Venue>,
824 instrument_id_map: &mut AHashMap<u32, InstrumentId>,
825 ts_init: UnixNanos,
826 initialized_books: &HashSet<InstrumentId>,
827 bars_timestamp_on_close: bool,
828) -> anyhow::Result<(Option<Data>, Option<Data>)> {
829 let instrument_id = update_instrument_id_map(
830 &record,
831 symbol_map,
832 publisher_venue_map,
833 symbol_venue_map,
834 instrument_id_map,
835 )?;
836
837 let price_precision = 2; let include_trades = if record.get::<dbn::Mbp1Msg>().is_some()
842 || record.get::<dbn::TbboMsg>().is_some()
843 || record.get::<dbn::Cmbp1Msg>().is_some()
844 {
845 true } else {
847 initialized_books.contains(&instrument_id) };
849
850 decode_record(
851 &record,
852 instrument_id,
853 price_precision,
854 Some(ts_init),
855 include_trades,
856 bars_timestamp_on_close,
857 )
858}
859
860#[cfg(test)]
865mod tests {
866 use databento::live::Subscription;
867 use indexmap::IndexMap;
868 use rstest::*;
869 use time::macros::datetime;
870
871 use super::*;
872
873 fn create_test_handler(reconnect_timeout_mins: Option<u64>) -> DatabentoFeedHandler {
874 let (_cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel();
875 let (msg_tx, _msg_rx) = tokio::sync::mpsc::channel(100);
876
877 DatabentoFeedHandler::new(
878 "test_key".to_string(),
879 "GLBX.MDP3".to_string(),
880 cmd_rx,
881 msg_tx,
882 IndexMap::new(),
883 Arc::new(RwLock::new(AHashMap::new())),
884 false,
885 false,
886 reconnect_timeout_mins,
887 )
888 }
889
890 #[rstest]
891 #[case(Some(10))]
892 #[case(None)]
893 fn test_backoff_initialization(#[case] reconnect_timeout_mins: Option<u64>) {
894 let handler = create_test_handler(reconnect_timeout_mins);
895
896 assert_eq!(handler.reconnect_timeout_mins, reconnect_timeout_mins);
897 assert!(handler.subscriptions.is_empty());
898 assert!(handler.buffered_commands.is_empty());
899 }
900
901 #[rstest]
902 fn test_subscription_with_and_without_start() {
903 let start_time = datetime!(2024-01-01 00:00:00 UTC);
904 let sub_with_start = Subscription::builder()
905 .symbols("ES.FUT")
906 .schema(databento::dbn::Schema::Mbp1)
907 .start(start_time)
908 .build();
909
910 let mut sub_without_start = sub_with_start.clone();
911 sub_without_start.start = None;
912
913 assert!(sub_with_start.start.is_some());
914 assert!(sub_without_start.start.is_none());
915 assert_eq!(sub_with_start.schema, sub_without_start.schema);
916 assert_eq!(sub_with_start.symbols, sub_without_start.symbols);
917 }
918
919 #[rstest]
920 fn test_handler_initialization_state() {
921 let handler = create_test_handler(Some(10));
922
923 assert!(!handler.replay);
924 assert_eq!(handler.dataset, "GLBX.MDP3");
925 assert_eq!(handler.key, "test_key");
926 assert!(handler.subscriptions.is_empty());
927 assert!(handler.buffered_commands.is_empty());
928 }
929
930 #[rstest]
931 fn test_handler_with_no_timeout() {
932 let handler = create_test_handler(None);
933
934 assert_eq!(handler.reconnect_timeout_mins, None);
935 assert!(!handler.replay);
936 }
937
938 #[rstest]
939 fn test_handler_with_zero_timeout() {
940 let handler = create_test_handler(Some(0));
941
942 assert_eq!(handler.reconnect_timeout_mins, Some(0));
943 assert!(!handler.replay);
944 }
945}