nautilus_databento/
live.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    sync::{Arc, RwLock},
18    time::Duration as StdDuration,
19};
20
21use ahash::{AHashMap, HashSet, HashSetExt};
22use databento::{
23    dbn::{self, PitSymbolMap, Record, SymbolIndex},
24    live::Subscription,
25};
26use indexmap::IndexMap;
27use nautilus_core::{UnixNanos, consts::NAUTILUS_USER_AGENT, time::get_atomic_clock_realtime};
28use nautilus_model::{
29    data::{Data, InstrumentStatus, OrderBookDelta, OrderBookDeltas, OrderBookDeltas_API},
30    enums::RecordFlag,
31    identifiers::{InstrumentId, Symbol, Venue},
32    instruments::InstrumentAny,
33};
34use nautilus_network::backoff::ExponentialBackoff;
35use tokio::{
36    sync::mpsc::error::TryRecvError,
37    time::{Duration, Instant},
38};
39
40use super::{
41    decode::{decode_imbalance_msg, decode_statistics_msg, decode_status_msg},
42    types::{DatabentoImbalance, DatabentoStatistics, SubscriptionAckEvent},
43};
44use crate::{
45    decode::{decode_instrument_def_msg, decode_record},
46    types::PublisherId,
47};
48
49#[derive(Debug)]
50pub enum LiveCommand {
51    Subscribe(Subscription),
52    Start,
53    Close,
54}
55
56#[derive(Debug)]
57#[allow(
58    clippy::large_enum_variant,
59    reason = "TODO: Optimize this (largest variant 1096 vs 80 bytes)"
60)]
61pub enum LiveMessage {
62    Data(Data),
63    Instrument(InstrumentAny),
64    Status(InstrumentStatus),
65    Imbalance(DatabentoImbalance),
66    Statistics(DatabentoStatistics),
67    SubscriptionAck(SubscriptionAckEvent),
68    Error(anyhow::Error),
69    Close,
70}
71
72/// Handles a raw TCP data feed from the Databento LSG for a single dataset.
73///
74/// [`LiveCommand`] messages are received synchronously across a channel,
75/// decoded records are sent asynchronously on a tokio channel as [`LiveMessage`]s
76/// back to a message processing task.
77///
78/// # Crash Policy
79///
80/// This handler intentionally crashes on catastrophic feed issues rather than
81/// attempting recovery. If excessive buffering occurs (indicating severe feed
82/// misbehavior), the process will run out of memory and terminate. This is by
83/// design - such scenarios indicate fundamental problems that require external
84/// intervention.
85#[derive(Debug)]
86pub struct DatabentoFeedHandler {
87    key: String,
88    dataset: String,
89    cmd_rx: tokio::sync::mpsc::UnboundedReceiver<LiveCommand>,
90    msg_tx: tokio::sync::mpsc::Sender<LiveMessage>,
91    publisher_venue_map: IndexMap<PublisherId, Venue>,
92    symbol_venue_map: Arc<RwLock<AHashMap<Symbol, Venue>>>,
93    replay: bool,
94    use_exchange_as_venue: bool,
95    bars_timestamp_on_close: bool,
96    reconnect_timeout_mins: Option<u64>,
97    backoff: ExponentialBackoff,
98    subscriptions: Vec<Subscription>,
99    buffered_commands: Vec<LiveCommand>,
100}
101
102impl DatabentoFeedHandler {
103    /// Creates a new [`DatabentoFeedHandler`] instance.
104    ///
105    /// # Panics
106    ///
107    /// Panics if exponential backoff creation fails (should never happen with valid hardcoded parameters).
108    #[must_use]
109    #[allow(clippy::too_many_arguments)]
110    pub fn new(
111        key: String,
112        dataset: String,
113        rx: tokio::sync::mpsc::UnboundedReceiver<LiveCommand>,
114        tx: tokio::sync::mpsc::Sender<LiveMessage>,
115        publisher_venue_map: IndexMap<PublisherId, Venue>,
116        symbol_venue_map: Arc<RwLock<AHashMap<Symbol, Venue>>>,
117        use_exchange_as_venue: bool,
118        bars_timestamp_on_close: bool,
119        reconnect_timeout_mins: Option<u64>,
120    ) -> Self {
121        // Choose max delay based on timeout configuration:
122        // - With timeout: 60s max (quick recovery to reconnect within window)
123        // - Without timeout (None): 600s max (patient recovery, respectful of infrastructure)
124        let delay_max = if reconnect_timeout_mins.is_some() {
125            Duration::from_secs(60)
126        } else {
127            Duration::from_secs(600)
128        };
129
130        // SAFETY: Hardcoded parameters are all valid
131        let backoff =
132            ExponentialBackoff::new(Duration::from_secs(1), delay_max, 2.0, 1000, false).unwrap();
133
134        Self {
135            key,
136            dataset,
137            cmd_rx: rx,
138            msg_tx: tx,
139            publisher_venue_map,
140            symbol_venue_map,
141            replay: false,
142            use_exchange_as_venue,
143            bars_timestamp_on_close,
144            reconnect_timeout_mins,
145            backoff,
146            subscriptions: Vec::new(),
147            buffered_commands: Vec::new(),
148        }
149    }
150
151    /// Runs the feed handler main loop, processing commands and streaming market data.
152    ///
153    /// Establishes a connection to the Databento LSG, subscribes to requested data feeds,
154    /// and continuously processes incoming market data messages until shutdown.
155    ///
156    /// Implements automatic reconnection with exponential backoff (1s to 60s with jitter).
157    /// Each successful session resets the reconnection cycle, giving the next disconnect
158    /// a fresh timeout window. Gives up after `reconnect_timeout_mins` if configured.
159    ///
160    /// # Errors
161    ///
162    /// Returns an error if any client operation or message handling fails.
163    #[allow(clippy::blocks_in_conditions)]
164    pub async fn run(&mut self) -> anyhow::Result<()> {
165        tracing::debug!("Running feed handler");
166
167        let mut reconnect_start: Option<Instant> = None;
168        let mut attempt = 0;
169
170        loop {
171            attempt += 1;
172
173            match self.run_session(attempt).await {
174                Ok(ran_successfully) => {
175                    if ran_successfully {
176                        tracing::info!("Resetting reconnection cycle after successful session");
177                        reconnect_start = None;
178                        attempt = 0;
179                        self.backoff.reset();
180                        continue;
181                    } else {
182                        tracing::info!("Session ended normally");
183                        break Ok(());
184                    }
185                }
186                Err(e) => {
187                    let cycle_start = reconnect_start.get_or_insert_with(Instant::now);
188
189                    if let Some(timeout_mins) = self.reconnect_timeout_mins {
190                        let elapsed = cycle_start.elapsed();
191                        let timeout = Duration::from_secs(timeout_mins * 60);
192
193                        if elapsed >= timeout {
194                            tracing::error!(
195                                "Giving up reconnection after {} minutes",
196                                timeout_mins
197                            );
198                            self.send_msg(LiveMessage::Error(anyhow::anyhow!(
199                                "Reconnection timeout after {timeout_mins} minutes: {e}"
200                            )))
201                            .await;
202                            break Err(e);
203                        }
204                    }
205
206                    let delay = self.backoff.next_duration();
207
208                    tracing::warn!(
209                        "Connection lost (attempt {}): {}. Reconnecting in {}s...",
210                        attempt,
211                        e,
212                        delay.as_secs()
213                    );
214
215                    tokio::select! {
216                        _ = tokio::time::sleep(delay) => {}
217                        cmd = self.cmd_rx.recv() => {
218                            match cmd {
219                                Some(LiveCommand::Close) => {
220                                    tracing::info!("Close received during backoff");
221                                    return Ok(());
222                                }
223                                None => {
224                                    tracing::debug!("Command channel closed during backoff");
225                                    return Ok(());
226                                }
227                                Some(cmd) => {
228                                    tracing::debug!("Buffering command received during backoff: {:?}", cmd);
229                                    self.buffered_commands.push(cmd);
230                                }
231                            }
232                        }
233                    }
234                }
235            }
236        }
237    }
238
239    /// Runs a single session, handling connection, subscriptions, and data streaming.
240    ///
241    /// Returns `Ok(bool)` where the bool indicates if the session ran successfully
242    /// for a meaningful duration (true) or was intentionally closed (false).
243    ///
244    /// # Errors
245    ///
246    /// Returns an error if connection fails, subscription fails, or data streaming encounters an error.
247    async fn run_session(&mut self, attempt: usize) -> anyhow::Result<bool> {
248        if attempt > 1 {
249            tracing::info!("Reconnecting (attempt {})...", attempt);
250        }
251
252        let session_start = Instant::now();
253        let clock = get_atomic_clock_realtime();
254        let mut symbol_map = PitSymbolMap::new();
255        let mut instrument_id_map: AHashMap<u32, InstrumentId> = AHashMap::new();
256
257        let mut buffering_start = None;
258        let mut buffered_deltas: AHashMap<InstrumentId, Vec<OrderBookDelta>> = AHashMap::new();
259        let mut initialized_books = HashSet::new();
260        let timeout = Duration::from_secs(5); // Hardcoded timeout for now
261
262        let result = tokio::time::timeout(
263            timeout,
264            databento::LiveClient::builder()
265                .user_agent_extension(NAUTILUS_USER_AGENT.into())
266                .key(self.key.clone())?
267                .dataset(self.dataset.clone())
268                .build(),
269        )
270        .await?;
271
272        let mut client = match result {
273            Ok(client) => {
274                if attempt > 1 {
275                    tracing::info!("Reconnected successfully");
276                } else {
277                    tracing::info!("Connected");
278                }
279                client
280            }
281            Err(e) => {
282                anyhow::bail!("Failed to connect to Databento LSG: {e}");
283            }
284        };
285
286        // Process any commands buffered during reconnection backoff
287        let mut start_buffered = false;
288        if !self.buffered_commands.is_empty() {
289            tracing::info!(
290                "Processing {} buffered commands",
291                self.buffered_commands.len()
292            );
293            for cmd in self.buffered_commands.drain(..) {
294                match cmd {
295                    LiveCommand::Subscribe(sub) => {
296                        if !self.replay && sub.start.is_some() {
297                            self.replay = true;
298                        }
299                        self.subscriptions.push(sub);
300                    }
301                    LiveCommand::Start => {
302                        start_buffered = true;
303                    }
304                    LiveCommand::Close => {
305                        tracing::warn!("Close command was buffered, shutting down");
306                        return Ok(false);
307                    }
308                }
309            }
310        }
311
312        let timeout = Duration::from_millis(10);
313        let mut running = false;
314
315        if !self.subscriptions.is_empty() {
316            tracing::info!(
317                "Resubscribing to {} subscriptions",
318                self.subscriptions.len()
319            );
320            for sub in self.subscriptions.clone() {
321                client.subscribe(sub).await?;
322            }
323            // Strip start timestamps after successful subscription to avoid replaying history on future reconnects
324            for sub in &mut self.subscriptions {
325                sub.start = None;
326            }
327            client.start().await?;
328            running = true;
329            tracing::info!("Resubscription complete");
330        } else if start_buffered {
331            tracing::info!("Starting session from buffered Start command");
332            buffering_start = if self.replay {
333                Some(clock.get_time_ns())
334            } else {
335                None
336            };
337            client.start().await?;
338            running = true;
339        }
340
341        loop {
342            if self.msg_tx.is_closed() {
343                tracing::debug!("Message channel was closed: stopping");
344                return Ok(false);
345            }
346
347            match self.cmd_rx.try_recv() {
348                Ok(cmd) => {
349                    tracing::debug!("Received command: {cmd:?}");
350                    match cmd {
351                        LiveCommand::Subscribe(sub) => {
352                            if !self.replay && sub.start.is_some() {
353                                self.replay = true;
354                            }
355                            client.subscribe(sub.clone()).await?;
356                            // Store without start to avoid replaying history on reconnect
357                            let mut sub_for_reconnect = sub;
358                            sub_for_reconnect.start = None;
359                            self.subscriptions.push(sub_for_reconnect);
360                        }
361                        LiveCommand::Start => {
362                            buffering_start = if self.replay {
363                                Some(clock.get_time_ns())
364                            } else {
365                                None
366                            };
367                            client.start().await?;
368                            running = true;
369                            tracing::debug!("Started");
370                        }
371                        LiveCommand::Close => {
372                            self.msg_tx.send(LiveMessage::Close).await?;
373                            if running {
374                                client.close().await?;
375                                tracing::debug!("Closed inner client");
376                            }
377                            return Ok(false);
378                        }
379                    }
380                }
381                Err(TryRecvError::Empty) => {}
382                Err(TryRecvError::Disconnected) => {
383                    tracing::debug!("Command channel disconnected");
384                    return Ok(false);
385                }
386            }
387
388            if !running {
389                continue;
390            }
391
392            let result = tokio::time::timeout(timeout, client.next_record()).await;
393            let record_opt = match result {
394                Ok(record_opt) => record_opt,
395                Err(_) => continue,
396            };
397
398            let record = match record_opt {
399                Ok(Some(record)) => record,
400                Ok(None) => {
401                    const SUCCESS_THRESHOLD: Duration = Duration::from_secs(60);
402                    if session_start.elapsed() >= SUCCESS_THRESHOLD {
403                        tracing::info!("Session ended after successful run");
404                        return Ok(true);
405                    }
406                    anyhow::bail!("Session ended by gateway");
407                }
408                Err(e) => {
409                    const SUCCESS_THRESHOLD: Duration = Duration::from_secs(60);
410                    if session_start.elapsed() >= SUCCESS_THRESHOLD {
411                        tracing::info!("Connection error after successful run: {e}");
412                        return Ok(true);
413                    }
414                    anyhow::bail!("Connection error: {e}");
415                }
416            };
417
418            let ts_init = clock.get_time_ns();
419
420            // Decode record
421            if let Some(msg) = record.get::<dbn::ErrorMsg>() {
422                handle_error_msg(msg);
423            } else if let Some(msg) = record.get::<dbn::SystemMsg>() {
424                if let Some(ack) = handle_system_msg(msg, ts_init) {
425                    self.send_msg(LiveMessage::SubscriptionAck(ack)).await;
426                }
427            } else if let Some(msg) = record.get::<dbn::SymbolMappingMsg>() {
428                // Remove instrument ID index as the raw symbol may have changed
429                instrument_id_map.remove(&msg.hd.instrument_id);
430                handle_symbol_mapping_msg(msg, &mut symbol_map, &mut instrument_id_map)?;
431            } else if let Some(msg) = record.get::<dbn::InstrumentDefMsg>() {
432                if self.use_exchange_as_venue {
433                    let exchange = msg.exchange()?;
434                    if !exchange.is_empty() {
435                        update_instrument_id_map_with_exchange(
436                            &symbol_map,
437                            &self.symbol_venue_map,
438                            &mut instrument_id_map,
439                            msg.hd.instrument_id,
440                            exchange,
441                        )?;
442                    }
443                }
444                let data = {
445                    let sym_map = self.read_symbol_venue_map()?;
446                    handle_instrument_def_msg(
447                        msg,
448                        &record,
449                        &symbol_map,
450                        &self.publisher_venue_map,
451                        &sym_map,
452                        &mut instrument_id_map,
453                        ts_init,
454                    )?
455                };
456                self.send_msg(LiveMessage::Instrument(data)).await;
457            } else if let Some(msg) = record.get::<dbn::StatusMsg>() {
458                let data = {
459                    let sym_map = self.read_symbol_venue_map()?;
460                    handle_status_msg(
461                        msg,
462                        &record,
463                        &symbol_map,
464                        &self.publisher_venue_map,
465                        &sym_map,
466                        &mut instrument_id_map,
467                        ts_init,
468                    )?
469                };
470                self.send_msg(LiveMessage::Status(data)).await;
471            } else if let Some(msg) = record.get::<dbn::ImbalanceMsg>() {
472                let data = {
473                    let sym_map = self.read_symbol_venue_map()?;
474                    handle_imbalance_msg(
475                        msg,
476                        &record,
477                        &symbol_map,
478                        &self.publisher_venue_map,
479                        &sym_map,
480                        &mut instrument_id_map,
481                        ts_init,
482                    )?
483                };
484                self.send_msg(LiveMessage::Imbalance(data)).await;
485            } else if let Some(msg) = record.get::<dbn::StatMsg>() {
486                let data = {
487                    let sym_map = self.read_symbol_venue_map()?;
488                    handle_statistics_msg(
489                        msg,
490                        &record,
491                        &symbol_map,
492                        &self.publisher_venue_map,
493                        &sym_map,
494                        &mut instrument_id_map,
495                        ts_init,
496                    )?
497                };
498                self.send_msg(LiveMessage::Statistics(data)).await;
499            } else {
500                // Decode a generic record with possible errors
501                let res = {
502                    let sym_map = self.read_symbol_venue_map()?;
503                    handle_record(
504                        record,
505                        &symbol_map,
506                        &self.publisher_venue_map,
507                        &sym_map,
508                        &mut instrument_id_map,
509                        ts_init,
510                        &initialized_books,
511                        self.bars_timestamp_on_close,
512                    )
513                };
514                let (mut data1, data2) = match res {
515                    Ok(decoded) => decoded,
516                    Err(e) => {
517                        tracing::error!("Error decoding record: {e}");
518                        continue;
519                    }
520                };
521
522                if let Some(msg) = record.get::<dbn::MboMsg>() {
523                    // Check if should mark book initialized
524                    if let Some(Data::Delta(delta)) = &data1 {
525                        initialized_books.insert(delta.instrument_id);
526                    } else {
527                        continue; // No delta yet
528                    }
529
530                    if let Some(Data::Delta(delta)) = &data1 {
531                        let buffer = buffered_deltas.entry(delta.instrument_id).or_default();
532                        buffer.push(*delta);
533
534                        tracing::trace!(
535                            "Buffering delta: {} {buffering_start:?} flags={}",
536                            delta.ts_event,
537                            msg.flags.raw(),
538                        );
539
540                        // Check if last message in the book event
541                        if !RecordFlag::F_LAST.matches(msg.flags.raw()) {
542                            continue; // NOT last message
543                        }
544
545                        // Check if snapshot
546                        if RecordFlag::F_SNAPSHOT.matches(msg.flags.raw()) {
547                            continue; // Buffer snapshot
548                        }
549
550                        // Check if buffering a replay
551                        if let Some(start_ns) = buffering_start {
552                            if delta.ts_event <= start_ns {
553                                continue; // Continue buffering replay
554                            }
555                            buffering_start = None;
556                        }
557
558                        // SAFETY: We can guarantee a deltas vec exists
559                        let buffer =
560                            buffered_deltas
561                                .remove(&delta.instrument_id)
562                                .ok_or_else(|| {
563                                    anyhow::anyhow!(
564                                        "Internal error: no buffered deltas for instrument {id}",
565                                        id = delta.instrument_id
566                                    )
567                                })?;
568                        let deltas = OrderBookDeltas::new(delta.instrument_id, buffer);
569                        let deltas = OrderBookDeltas_API::new(deltas);
570                        data1 = Some(Data::Deltas(deltas));
571                    }
572                }
573
574                if let Some(data) = data1 {
575                    self.send_msg(LiveMessage::Data(data)).await;
576                }
577
578                if let Some(data) = data2 {
579                    self.send_msg(LiveMessage::Data(data)).await;
580                }
581            }
582        }
583    }
584
585    /// Sends a message to the message processing task.
586    async fn send_msg(&mut self, msg: LiveMessage) {
587        tracing::trace!("Sending {msg:?}");
588        match self.msg_tx.send(msg).await {
589            Ok(()) => {}
590            Err(e) => tracing::error!("Error sending message: {e}"),
591        }
592    }
593
594    /// Acquires a read lock on the symbol-venue map with exponential backoff and timeout.
595    ///
596    /// # Errors
597    ///
598    /// Returns an error if the read lock cannot be acquired within the deadline.
599    fn read_symbol_venue_map(
600        &self,
601    ) -> anyhow::Result<std::sync::RwLockReadGuard<'_, AHashMap<Symbol, Venue>>> {
602        // Try to acquire the lock with exponential backoff and deadline
603        const MAX_WAIT_MS: u64 = 500; // Total maximum wait time
604        const INITIAL_DELAY_MICROS: u64 = 10;
605        const MAX_DELAY_MICROS: u64 = 1000;
606
607        let deadline = std::time::Instant::now() + StdDuration::from_millis(MAX_WAIT_MS);
608        let mut delay = INITIAL_DELAY_MICROS;
609
610        loop {
611            match self.symbol_venue_map.try_read() {
612                Ok(guard) => return Ok(guard),
613                Err(std::sync::TryLockError::WouldBlock) => {
614                    if std::time::Instant::now() >= deadline {
615                        break;
616                    }
617
618                    // Yield to other threads first
619                    std::thread::yield_now();
620
621                    // Then sleep with exponential backoff if still blocked
622                    if std::time::Instant::now() < deadline {
623                        let remaining = deadline - std::time::Instant::now();
624                        let sleep_duration = StdDuration::from_micros(delay).min(remaining);
625                        std::thread::sleep(sleep_duration);
626                        // Exponential backoff with cap and jitter
627                        delay = ((delay * 2) + delay / 4).min(MAX_DELAY_MICROS);
628                    }
629                }
630                Err(std::sync::TryLockError::Poisoned(e)) => {
631                    anyhow::bail!("symbol_venue_map lock poisoned: {e}");
632                }
633            }
634        }
635
636        anyhow::bail!(
637            "Failed to acquire read lock on symbol_venue_map after {MAX_WAIT_MS}ms deadline"
638        )
639    }
640}
641
642/// Handles Databento error messages by logging them.
643fn handle_error_msg(msg: &dbn::ErrorMsg) {
644    tracing::error!("{msg:?}");
645}
646
647/// Handles Databento system messages, returning a subscription ack event if applicable.
648fn handle_system_msg(msg: &dbn::SystemMsg, ts_received: UnixNanos) -> Option<SubscriptionAckEvent> {
649    match msg.code() {
650        Ok(dbn::SystemCode::SubscriptionAck) => {
651            let message = msg.msg().unwrap_or("<invalid utf-8>");
652            tracing::info!("Subscription acknowledged: {message}");
653
654            let schema = parse_ack_message(message);
655
656            Some(SubscriptionAckEvent {
657                schema,
658                message: message.to_string(),
659                ts_received,
660            })
661        }
662        Ok(dbn::SystemCode::Heartbeat) => {
663            tracing::trace!("Heartbeat received");
664            None
665        }
666        Ok(dbn::SystemCode::SlowReaderWarning) => {
667            let message = msg.msg().unwrap_or("<invalid utf-8>");
668            tracing::warn!("Slow reader warning: {message}");
669            None
670        }
671        Ok(dbn::SystemCode::ReplayCompleted) => {
672            let message = msg.msg().unwrap_or("<invalid utf-8>");
673            tracing::info!("Replay completed: {message}");
674            None
675        }
676        _ => {
677            tracing::debug!("{msg:?}");
678            None
679        }
680    }
681}
682
683/// Parses a subscription ack message to extract the schema.
684fn parse_ack_message(message: &str) -> String {
685    // Format: "Subscription request N for <schema> data succeeded"
686    message
687        .strip_prefix("Subscription request ")
688        .and_then(|rest| rest.split_once(" for "))
689        .and_then(|(_, after_num)| after_num.strip_suffix(" data succeeded"))
690        .map(|schema| schema.trim().to_string())
691        .unwrap_or_default()
692}
693
694/// Handles symbol mapping messages and updates the instrument ID map.
695///
696/// # Errors
697///
698/// Returns an error if symbol mapping fails.
699fn handle_symbol_mapping_msg(
700    msg: &dbn::SymbolMappingMsg,
701    symbol_map: &mut PitSymbolMap,
702    instrument_id_map: &mut AHashMap<u32, InstrumentId>,
703) -> anyhow::Result<()> {
704    symbol_map
705        .on_symbol_mapping(msg)
706        .map_err(|e| anyhow::anyhow!("on_symbol_mapping failed for {msg:?}: {e}"))?;
707    instrument_id_map.remove(&msg.header().instrument_id);
708    Ok(())
709}
710
711/// Updates the instrument ID map using exchange information from the symbol map.
712fn update_instrument_id_map_with_exchange(
713    symbol_map: &PitSymbolMap,
714    symbol_venue_map: &RwLock<AHashMap<Symbol, Venue>>,
715    instrument_id_map: &mut AHashMap<u32, InstrumentId>,
716    raw_instrument_id: u32,
717    exchange: &str,
718) -> anyhow::Result<InstrumentId> {
719    let raw_symbol = symbol_map.get(raw_instrument_id).ok_or_else(|| {
720        anyhow::anyhow!("Cannot resolve raw_symbol for instrument_id {raw_instrument_id}")
721    })?;
722    let symbol = Symbol::from(raw_symbol.as_str());
723    let venue = Venue::from_code(exchange)
724        .map_err(|e| anyhow::anyhow!("Invalid venue code '{exchange}': {e}"))?;
725    let instrument_id = InstrumentId::new(symbol, venue);
726    let mut map = symbol_venue_map
727        .write()
728        .map_err(|e| anyhow::anyhow!("symbol_venue_map lock poisoned: {e}"))?;
729    map.entry(symbol).or_insert(venue);
730    instrument_id_map.insert(raw_instrument_id, instrument_id);
731    Ok(instrument_id)
732}
733
734fn update_instrument_id_map(
735    record: &dbn::RecordRef,
736    symbol_map: &PitSymbolMap,
737    publisher_venue_map: &IndexMap<PublisherId, Venue>,
738    symbol_venue_map: &AHashMap<Symbol, Venue>,
739    instrument_id_map: &mut AHashMap<u32, InstrumentId>,
740) -> anyhow::Result<InstrumentId> {
741    let header = record.header();
742
743    // Check if instrument ID is already in the map
744    if let Some(&instrument_id) = instrument_id_map.get(&header.instrument_id) {
745        return Ok(instrument_id);
746    }
747
748    let raw_symbol = symbol_map.get_for_rec(record).ok_or_else(|| {
749        anyhow::anyhow!(
750            "Cannot resolve `raw_symbol` from `symbol_map` for instrument_id {}",
751            header.instrument_id
752        )
753    })?;
754
755    let symbol = Symbol::from_str_unchecked(raw_symbol);
756
757    let publisher_id = header.publisher_id;
758    let venue = if let Some(venue) = symbol_venue_map.get(&symbol) {
759        *venue
760    } else {
761        let venue = publisher_venue_map
762            .get(&publisher_id)
763            .ok_or_else(|| anyhow::anyhow!("No venue found for `publisher_id` {publisher_id}"))?;
764        *venue
765    };
766    let instrument_id = InstrumentId::new(symbol, venue);
767
768    instrument_id_map.insert(header.instrument_id, instrument_id);
769    Ok(instrument_id)
770}
771
772/// Handles instrument definition messages and decodes them into Nautilus instruments.
773///
774/// # Errors
775///
776/// Returns an error if instrument decoding fails.
777fn handle_instrument_def_msg(
778    msg: &dbn::InstrumentDefMsg,
779    record: &dbn::RecordRef,
780    symbol_map: &PitSymbolMap,
781    publisher_venue_map: &IndexMap<PublisherId, Venue>,
782    symbol_venue_map: &AHashMap<Symbol, Venue>,
783    instrument_id_map: &mut AHashMap<u32, InstrumentId>,
784    ts_init: UnixNanos,
785) -> anyhow::Result<InstrumentAny> {
786    let instrument_id = update_instrument_id_map(
787        record,
788        symbol_map,
789        publisher_venue_map,
790        symbol_venue_map,
791        instrument_id_map,
792    )?;
793
794    decode_instrument_def_msg(msg, instrument_id, Some(ts_init))
795}
796
797fn handle_status_msg(
798    msg: &dbn::StatusMsg,
799    record: &dbn::RecordRef,
800    symbol_map: &PitSymbolMap,
801    publisher_venue_map: &IndexMap<PublisherId, Venue>,
802    symbol_venue_map: &AHashMap<Symbol, Venue>,
803    instrument_id_map: &mut AHashMap<u32, InstrumentId>,
804    ts_init: UnixNanos,
805) -> anyhow::Result<InstrumentStatus> {
806    let instrument_id = update_instrument_id_map(
807        record,
808        symbol_map,
809        publisher_venue_map,
810        symbol_venue_map,
811        instrument_id_map,
812    )?;
813
814    decode_status_msg(msg, instrument_id, Some(ts_init))
815}
816
817fn handle_imbalance_msg(
818    msg: &dbn::ImbalanceMsg,
819    record: &dbn::RecordRef,
820    symbol_map: &PitSymbolMap,
821    publisher_venue_map: &IndexMap<PublisherId, Venue>,
822    symbol_venue_map: &AHashMap<Symbol, Venue>,
823    instrument_id_map: &mut AHashMap<u32, InstrumentId>,
824    ts_init: UnixNanos,
825) -> anyhow::Result<DatabentoImbalance> {
826    let instrument_id = update_instrument_id_map(
827        record,
828        symbol_map,
829        publisher_venue_map,
830        symbol_venue_map,
831        instrument_id_map,
832    )?;
833
834    let price_precision = 2; // Hardcoded for now
835
836    decode_imbalance_msg(msg, instrument_id, price_precision, Some(ts_init))
837}
838
839fn handle_statistics_msg(
840    msg: &dbn::StatMsg,
841    record: &dbn::RecordRef,
842    symbol_map: &PitSymbolMap,
843    publisher_venue_map: &IndexMap<PublisherId, Venue>,
844    symbol_venue_map: &AHashMap<Symbol, Venue>,
845    instrument_id_map: &mut AHashMap<u32, InstrumentId>,
846    ts_init: UnixNanos,
847) -> anyhow::Result<DatabentoStatistics> {
848    let instrument_id = update_instrument_id_map(
849        record,
850        symbol_map,
851        publisher_venue_map,
852        symbol_venue_map,
853        instrument_id_map,
854    )?;
855
856    let price_precision = 2; // Hardcoded for now
857
858    decode_statistics_msg(msg, instrument_id, price_precision, Some(ts_init))
859}
860
861#[allow(clippy::too_many_arguments)]
862fn handle_record(
863    record: dbn::RecordRef,
864    symbol_map: &PitSymbolMap,
865    publisher_venue_map: &IndexMap<PublisherId, Venue>,
866    symbol_venue_map: &AHashMap<Symbol, Venue>,
867    instrument_id_map: &mut AHashMap<u32, InstrumentId>,
868    ts_init: UnixNanos,
869    initialized_books: &HashSet<InstrumentId>,
870    bars_timestamp_on_close: bool,
871) -> anyhow::Result<(Option<Data>, Option<Data>)> {
872    let instrument_id = update_instrument_id_map(
873        &record,
874        symbol_map,
875        publisher_venue_map,
876        symbol_venue_map,
877        instrument_id_map,
878    )?;
879
880    let price_precision = 2; // Hardcoded for now
881
882    // For MBP-1 and quote-based schemas, always include trades since they're integral to the data
883    // For MBO, only include trades after the book is initialized to maintain consistency
884    let include_trades = if record.get::<dbn::Mbp1Msg>().is_some()
885        || record.get::<dbn::TbboMsg>().is_some()
886        || record.get::<dbn::Cmbp1Msg>().is_some()
887    {
888        true // These schemas include trade information directly
889    } else {
890        initialized_books.contains(&instrument_id) // MBO requires initialized book
891    };
892
893    decode_record(
894        &record,
895        instrument_id,
896        price_precision,
897        Some(ts_init),
898        include_trades,
899        bars_timestamp_on_close,
900    )
901}
902
903#[cfg(test)]
904mod tests {
905    use databento::live::Subscription;
906    use indexmap::IndexMap;
907    use rstest::*;
908    use time::macros::datetime;
909
910    use super::*;
911
912    fn create_test_handler(reconnect_timeout_mins: Option<u64>) -> DatabentoFeedHandler {
913        let (_cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel();
914        let (msg_tx, _msg_rx) = tokio::sync::mpsc::channel(100);
915
916        DatabentoFeedHandler::new(
917            "test_key".to_string(),
918            "GLBX.MDP3".to_string(),
919            cmd_rx,
920            msg_tx,
921            IndexMap::new(),
922            Arc::new(RwLock::new(AHashMap::new())),
923            false,
924            false,
925            reconnect_timeout_mins,
926        )
927    }
928
929    #[rstest]
930    #[case(Some(10))]
931    #[case(None)]
932    fn test_backoff_initialization(#[case] reconnect_timeout_mins: Option<u64>) {
933        let handler = create_test_handler(reconnect_timeout_mins);
934
935        assert_eq!(handler.reconnect_timeout_mins, reconnect_timeout_mins);
936        assert!(handler.subscriptions.is_empty());
937        assert!(handler.buffered_commands.is_empty());
938    }
939
940    #[rstest]
941    fn test_subscription_with_and_without_start() {
942        let start_time = datetime!(2024-01-01 00:00:00 UTC);
943        let sub_with_start = Subscription::builder()
944            .symbols("ES.FUT")
945            .schema(databento::dbn::Schema::Mbp1)
946            .start(start_time)
947            .build();
948
949        let mut sub_without_start = sub_with_start.clone();
950        sub_without_start.start = None;
951
952        assert!(sub_with_start.start.is_some());
953        assert!(sub_without_start.start.is_none());
954        assert_eq!(sub_with_start.schema, sub_without_start.schema);
955        assert_eq!(sub_with_start.symbols, sub_without_start.symbols);
956    }
957
958    #[rstest]
959    fn test_handler_initialization_state() {
960        let handler = create_test_handler(Some(10));
961
962        assert!(!handler.replay);
963        assert_eq!(handler.dataset, "GLBX.MDP3");
964        assert_eq!(handler.key, "test_key");
965        assert!(handler.subscriptions.is_empty());
966        assert!(handler.buffered_commands.is_empty());
967    }
968
969    #[rstest]
970    fn test_handler_with_no_timeout() {
971        let handler = create_test_handler(None);
972
973        assert_eq!(handler.reconnect_timeout_mins, None);
974        assert!(!handler.replay);
975    }
976
977    #[rstest]
978    fn test_handler_with_zero_timeout() {
979        let handler = create_test_handler(Some(0));
980
981        assert_eq!(handler.reconnect_timeout_mins, Some(0));
982        assert!(!handler.replay);
983    }
984}