use std::{collections::HashMap, ffi::CStr};
use databento::{
dbn,
dbn::{PitSymbolMap, Record, SymbolIndex, VersionUpgradePolicy},
live::Subscription,
};
use indexmap::IndexMap;
use nautilus_core::{
python::{to_pyruntime_err, to_pyvalue_err},
time::{get_atomic_clock_realtime, AtomicTime},
};
use nautilus_model::{
data::{
delta::OrderBookDelta,
deltas::{OrderBookDeltas, OrderBookDeltas_API},
status::InstrumentStatus,
Data,
},
enums::RecordFlag,
identifiers::{InstrumentId, Symbol, Venue},
instruments::any::InstrumentAny,
};
use tokio::{
sync::mpsc::error::TryRecvError,
time::{timeout, Duration},
};
use super::{
decode::{decode_imbalance_msg, decode_statistics_msg, decode_status_msg},
types::{DatabentoImbalance, DatabentoStatistics},
};
use crate::{
decode::{decode_instrument_def_msg, decode_record},
types::PublisherId,
};
#[derive(Debug)]
pub enum LiveCommand {
Subscribe(Subscription),
Start,
Close,
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)] pub enum LiveMessage {
Data(Data),
Instrument(InstrumentAny),
Status(InstrumentStatus),
Imbalance(DatabentoImbalance),
Statistics(DatabentoStatistics),
Error(anyhow::Error),
Close,
}
pub struct DatabentoFeedHandler {
key: String,
dataset: String,
cmd_rx: tokio::sync::mpsc::UnboundedReceiver<LiveCommand>,
msg_tx: tokio::sync::mpsc::Sender<LiveMessage>,
publisher_venue_map: IndexMap<PublisherId, Venue>,
replay: bool,
}
impl DatabentoFeedHandler {
#[must_use]
pub const fn new(
key: String,
dataset: String,
rx: tokio::sync::mpsc::UnboundedReceiver<LiveCommand>,
tx: tokio::sync::mpsc::Sender<LiveMessage>,
publisher_venue_map: IndexMap<PublisherId, Venue>,
) -> Self {
Self {
key,
dataset,
cmd_rx: rx,
msg_tx: tx,
publisher_venue_map,
replay: false,
}
}
pub async fn run(&mut self) -> anyhow::Result<()> {
tracing::debug!("Running feed handler");
let clock = get_atomic_clock_realtime();
let mut symbol_map = PitSymbolMap::new();
let mut instrument_id_map: HashMap<u32, InstrumentId> = HashMap::new();
let mut buffering_start = None;
let mut buffered_deltas: HashMap<InstrumentId, Vec<OrderBookDelta>> = HashMap::new();
let mut deltas_count = 0_u64;
let result = timeout(
Duration::from_secs(5), databento::LiveClient::builder()
.key(self.key.clone())?
.dataset(self.dataset.clone())
.upgrade_policy(VersionUpgradePolicy::Upgrade)
.build(),
)
.await?;
tracing::info!("Connected");
let mut client = if let Ok(client) = result {
client
} else {
self.msg_tx.send(LiveMessage::Close).await?;
self.cmd_rx.close();
return Err(anyhow::anyhow!("Timeout connecting to LSG"));
};
let timeout_duration = Duration::from_millis(10);
let mut running = false;
loop {
if self.msg_tx.is_closed() {
tracing::debug!("Message channel was closed: stopping");
break;
};
match self.cmd_rx.try_recv() {
Ok(cmd) => {
tracing::debug!("Received command: {cmd:?}");
match cmd {
LiveCommand::Subscribe(sub) => {
if !self.replay & sub.start.is_some() {
self.replay = true;
}
client.subscribe(&sub).await.map_err(to_pyruntime_err)?;
}
LiveCommand::Start => {
buffering_start = match self.replay {
true => Some(clock.get_time_ns()),
false => None,
};
client.start().await.map_err(to_pyruntime_err)?;
running = true;
tracing::debug!("Started");
}
LiveCommand::Close => {
self.msg_tx.send(LiveMessage::Close).await?;
if running {
client.close().await.map_err(to_pyruntime_err)?;
tracing::debug!("Closed inner client");
}
break;
}
}
}
Err(TryRecvError::Empty) => {} Err(TryRecvError::Disconnected) => {
tracing::debug!("Disconnected");
break;
}
}
if !running {
continue;
};
let result = timeout(timeout_duration, client.next_record()).await;
let record_opt = match result {
Ok(record_opt) => record_opt,
Err(_) => continue, };
let record = match record_opt {
Ok(Some(record)) => record,
Ok(None) => break, Err(e) => {
self.send_msg(LiveMessage::Error(anyhow::anyhow!(e))).await;
break;
}
};
if let Some(msg) = record.get::<dbn::ErrorMsg>() {
handle_error_msg(msg);
} else if let Some(msg) = record.get::<dbn::SystemMsg>() {
handle_system_msg(msg);
} else if let Some(msg) = record.get::<dbn::SymbolMappingMsg>() {
instrument_id_map.remove(&msg.hd.instrument_id);
handle_symbol_mapping_msg(msg, &mut symbol_map, &mut instrument_id_map);
} else if let Some(msg) = record.get::<dbn::InstrumentDefMsg>() {
let data = handle_instrument_def_msg(msg, &self.publisher_venue_map, clock)?;
self.send_msg(LiveMessage::Instrument(data)).await;
} else if let Some(msg) = record.get::<dbn::StatusMsg>() {
let data = handle_status_msg(
msg,
&record,
&symbol_map,
&self.publisher_venue_map,
&mut instrument_id_map,
clock,
)?;
self.send_msg(LiveMessage::Status(data)).await;
} else if let Some(msg) = record.get::<dbn::ImbalanceMsg>() {
let data = handle_imbalance_msg(
msg,
&record,
&symbol_map,
&self.publisher_venue_map,
&mut instrument_id_map,
clock,
)?;
self.send_msg(LiveMessage::Imbalance(data)).await;
} else if let Some(msg) = record.get::<dbn::StatMsg>() {
let data = handle_statistics_msg(
msg,
&record,
&symbol_map,
&self.publisher_venue_map,
&mut instrument_id_map,
clock,
)?;
self.send_msg(LiveMessage::Statistics(data)).await;
} else {
let (mut data1, data2) = match handle_record(
record,
&symbol_map,
&self.publisher_venue_map,
&mut instrument_id_map,
clock,
) {
Ok(decoded) => decoded,
Err(e) => {
tracing::error!("Error decoding record: {e}");
continue;
}
};
if let Some(msg) = record.get::<dbn::MboMsg>() {
if let Data::Delta(delta) = data1.clone().expect("MBO should decode a delta") {
let buffer = buffered_deltas.entry(delta.instrument_id).or_default();
buffer.push(delta);
deltas_count += 1;
tracing::trace!(
"Buffering delta: {deltas_count} {} {buffering_start:?} flags={}",
delta.ts_event,
msg.flags.raw(),
);
if !RecordFlag::F_LAST.matches(msg.flags.raw()) {
continue; }
if RecordFlag::F_SNAPSHOT.matches(msg.flags.raw()) {
continue; }
if let Some(start_ns) = buffering_start {
if delta.ts_event <= start_ns {
continue; }
buffering_start = None;
}
let buffer = buffered_deltas.remove(&delta.instrument_id).unwrap();
let deltas = OrderBookDeltas::new(delta.instrument_id, buffer);
let deltas = OrderBookDeltas_API::new(deltas);
data1 = Some(Data::Deltas(deltas));
}
};
if let Some(data) = data1 {
self.send_msg(LiveMessage::Data(data)).await;
};
if let Some(data) = data2 {
self.send_msg(LiveMessage::Data(data)).await;
};
}
}
self.cmd_rx.close();
tracing::debug!("Closed command receiver");
Ok(())
}
async fn send_msg(&mut self, msg: LiveMessage) {
tracing::trace!("Sending {msg:?}");
match self.msg_tx.send(msg).await {
Ok(()) => {}
Err(e) => tracing::error!("Error sending message: {e}"),
}
}
}
fn handle_error_msg(msg: &dbn::ErrorMsg) {
tracing::error!("{msg:?}");
}
fn handle_system_msg(msg: &dbn::SystemMsg) {
tracing::info!("{msg:?}");
}
fn handle_symbol_mapping_msg(
msg: &dbn::SymbolMappingMsg,
symbol_map: &mut PitSymbolMap,
instrument_id_map: &mut HashMap<u32, InstrumentId>,
) {
symbol_map
.on_symbol_mapping(msg)
.unwrap_or_else(|_| panic!("Error updating `symbol_map` with {msg:?}"));
instrument_id_map.remove(&msg.header().instrument_id);
}
fn update_instrument_id_map(
record: &dbn::RecordRef,
symbol_map: &PitSymbolMap,
publisher_venue_map: &IndexMap<PublisherId, Venue>,
instrument_id_map: &mut HashMap<u32, InstrumentId>,
) -> InstrumentId {
let header = record.header();
if let Some(&instrument_id) = instrument_id_map.get(&header.instrument_id) {
return instrument_id;
}
let raw_symbol = symbol_map
.get_for_rec(record)
.expect("Cannot resolve `raw_symbol` from `symbol_map`");
let symbol = Symbol::from_str_unchecked(raw_symbol);
let publisher_id = header.publisher_id;
let venue = publisher_venue_map
.get(&publisher_id)
.unwrap_or_else(|| panic!("No venue found for `publisher_id` {publisher_id}"));
let instrument_id = InstrumentId::new(symbol, *venue);
instrument_id_map.insert(header.instrument_id, instrument_id);
instrument_id
}
fn handle_instrument_def_msg(
msg: &dbn::InstrumentDefMsg,
publisher_venue_map: &IndexMap<PublisherId, Venue>,
clock: &AtomicTime,
) -> anyhow::Result<InstrumentAny> {
let c_str: &CStr = unsafe { CStr::from_ptr(msg.raw_symbol.as_ptr()) };
let raw_symbol: &str = c_str.to_str().map_err(to_pyvalue_err)?;
let symbol = Symbol::from(raw_symbol);
let publisher_id = msg.header().publisher_id;
let venue = publisher_venue_map
.get(&publisher_id)
.unwrap_or_else(|| panic!("No venue found for `publisher_id` {publisher_id}"));
let instrument_id = InstrumentId::new(symbol, *venue);
let ts_init = clock.get_time_ns();
decode_instrument_def_msg(msg, instrument_id, ts_init)
}
fn handle_status_msg(
msg: &dbn::StatusMsg,
record: &dbn::RecordRef,
symbol_map: &PitSymbolMap,
publisher_venue_map: &IndexMap<PublisherId, Venue>,
instrument_id_map: &mut HashMap<u32, InstrumentId>,
clock: &AtomicTime,
) -> anyhow::Result<InstrumentStatus> {
let instrument_id =
update_instrument_id_map(record, symbol_map, publisher_venue_map, instrument_id_map);
let ts_init = clock.get_time_ns();
decode_status_msg(msg, instrument_id, ts_init)
}
fn handle_imbalance_msg(
msg: &dbn::ImbalanceMsg,
record: &dbn::RecordRef,
symbol_map: &PitSymbolMap,
publisher_venue_map: &IndexMap<PublisherId, Venue>,
instrument_id_map: &mut HashMap<u32, InstrumentId>,
clock: &AtomicTime,
) -> anyhow::Result<DatabentoImbalance> {
let instrument_id =
update_instrument_id_map(record, symbol_map, publisher_venue_map, instrument_id_map);
let price_precision = 2; let ts_init = clock.get_time_ns();
decode_imbalance_msg(msg, instrument_id, price_precision, ts_init)
}
fn handle_statistics_msg(
msg: &dbn::StatMsg,
record: &dbn::RecordRef,
symbol_map: &PitSymbolMap,
publisher_venue_map: &IndexMap<PublisherId, Venue>,
instrument_id_map: &mut HashMap<u32, InstrumentId>,
clock: &AtomicTime,
) -> anyhow::Result<DatabentoStatistics> {
let instrument_id =
update_instrument_id_map(record, symbol_map, publisher_venue_map, instrument_id_map);
let price_precision = 2; let ts_init = clock.get_time_ns();
decode_statistics_msg(msg, instrument_id, price_precision, ts_init)
}
fn handle_record(
record: dbn::RecordRef,
symbol_map: &PitSymbolMap,
publisher_venue_map: &IndexMap<PublisherId, Venue>,
instrument_id_map: &mut HashMap<u32, InstrumentId>,
clock: &AtomicTime,
) -> anyhow::Result<(Option<Data>, Option<Data>)> {
let instrument_id =
update_instrument_id_map(&record, symbol_map, publisher_venue_map, instrument_id_map);
let price_precision = 2; let ts_init = clock.get_time_ns();
decode_record(
&record,
instrument_id,
price_precision,
Some(ts_init),
true, )
}