use std::{fs, num::NonZeroU64, path::PathBuf, str::FromStr, sync::Arc};
use databento::{
dbn::{self, SType},
historical::timeseries::GetRangeParams,
};
use indexmap::IndexMap;
use nautilus_core::{
python::to_pyvalue_err,
time::{get_atomic_clock_realtime, AtomicTime},
};
use nautilus_model::{
data::{bar::Bar, quote::QuoteTick, status::InstrumentStatus, trade::TradeTick, Data},
enums::BarAggregation,
identifiers::{InstrumentId, Symbol, Venue},
python::instruments::instrument_any_to_pyobject,
types::currency::Currency,
};
use pyo3::{
exceptions::PyException,
prelude::*,
types::{PyDict, PyList},
};
use tokio::sync::Mutex;
use crate::{
common::get_date_time_range,
decode::{
decode_imbalance_msg, decode_instrument_def_msg, decode_record, decode_statistics_msg,
decode_status_msg,
},
symbology::{check_consistent_symbology, decode_nautilus_instrument_id, infer_symbology_type},
types::{DatabentoImbalance, DatabentoPublisher, DatabentoStatistics, PublisherId},
};
#[cfg_attr(
feature = "python",
pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.databento")
)]
pub struct DatabentoHistoricalClient {
#[pyo3(get)]
pub key: String,
clock: &'static AtomicTime,
inner: Arc<Mutex<databento::HistoricalClient>>,
publisher_venue_map: Arc<IndexMap<PublisherId, Venue>>,
}
#[pymethods]
impl DatabentoHistoricalClient {
#[new]
fn py_new(key: String, publishers_filepath: PathBuf) -> PyResult<Self> {
let client = databento::HistoricalClient::builder()
.key(key.clone())
.map_err(to_pyvalue_err)?
.build()
.map_err(to_pyvalue_err)?;
let file_content = fs::read_to_string(publishers_filepath)?;
let publishers_vec: Vec<DatabentoPublisher> =
serde_json::from_str(&file_content).map_err(to_pyvalue_err)?;
let publisher_venue_map = publishers_vec
.into_iter()
.map(|p| (p.publisher_id, Venue::from(p.venue.as_str())))
.collect::<IndexMap<u16, Venue>>();
Ok(Self {
clock: get_atomic_clock_realtime(),
inner: Arc::new(Mutex::new(client)),
publisher_venue_map: Arc::new(publisher_venue_map),
key,
})
}
#[pyo3(name = "get_dataset_range")]
fn py_get_dataset_range<'py>(
&self,
py: Python<'py>,
dataset: String,
) -> PyResult<Bound<'py, PyAny>> {
let client = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut client = client.lock().await; let response = client.metadata().get_dataset_range(&dataset).await;
match response {
Ok(res) => Python::with_gil(|py| {
let dict = PyDict::new_bound(py);
dict.set_item("start", res.start.to_string())?;
dict.set_item("end", res.end.to_string())?;
Ok(dict.to_object(py))
}),
Err(e) => Err(PyErr::new::<PyException, _>(format!(
"Error handling response: {e}"
))),
}
})
}
#[pyo3(name = "get_range_instruments")]
#[pyo3(signature = (dataset, symbols, start, end=None, limit=None))]
fn py_get_range_instruments<'py>(
&self,
py: Python<'py>,
dataset: String,
symbols: Vec<String>,
start: u64,
end: Option<u64>,
limit: Option<u64>,
) -> PyResult<Bound<'py, PyAny>> {
let client = self.inner.clone();
let stype_in = infer_symbology_type(symbols.first().unwrap());
let symbols: Vec<&str> = symbols.iter().map(std::string::String::as_str).collect();
check_consistent_symbology(symbols.as_slice()).map_err(to_pyvalue_err)?;
let end = end.unwrap_or(self.clock.get_time_ns().as_u64());
let time_range = get_date_time_range(start.into(), end.into()).map_err(to_pyvalue_err)?;
let params = GetRangeParams::builder()
.dataset(dataset)
.date_time_range(time_range)
.symbols(symbols)
.stype_in(SType::from_str(&stype_in).map_err(to_pyvalue_err)?)
.schema(dbn::Schema::Definition)
.limit(limit.and_then(NonZeroU64::new))
.build();
let publisher_venue_map = self.publisher_venue_map.clone();
let ts_init = self.clock.get_time_ns();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut client = client.lock().await; let mut decoder = client
.timeseries()
.get_range(¶ms)
.await
.map_err(to_pyvalue_err)?;
decoder.set_upgrade_policy(dbn::VersionUpgradePolicy::Upgrade);
let mut instruments = Vec::new();
while let Ok(Some(msg)) = decoder.decode_record::<dbn::InstrumentDefMsg>().await {
let raw_symbol = msg.raw_symbol().expect("Error decoding `raw_symbol`");
let symbol = Symbol::from(raw_symbol);
let publisher = msg.hd.publisher().expect("Invalid `publisher` for record");
let venue = publisher_venue_map
.get(&msg.hd.publisher_id)
.unwrap_or_else(|| panic!("`Venue` not found for `publisher` {publisher}"));
let instrument_id = InstrumentId::new(symbol, *venue);
let result = decode_instrument_def_msg(msg, instrument_id, ts_init);
match result {
Ok(instrument) => instruments.push(instrument),
Err(e) => tracing::error!("{e:?}"),
};
}
Python::with_gil(|py| {
let py_results: PyResult<Vec<PyObject>> = instruments
.into_iter()
.map(|result| instrument_any_to_pyobject(py, result))
.collect();
py_results.map(|objs| PyList::new_bound(py, &objs).to_object(py))
})
})
}
#[pyo3(name = "get_range_quotes")]
#[pyo3(signature = (dataset, symbols, start, end=None, limit=None, price_precision=None, schema=None))]
#[allow(clippy::too_many_arguments)]
fn py_get_range_quotes<'py>(
&self,
py: Python<'py>,
dataset: String,
symbols: Vec<String>,
start: u64,
end: Option<u64>,
limit: Option<u64>,
price_precision: Option<u8>,
schema: Option<String>,
) -> PyResult<Bound<'py, PyAny>> {
let client = self.inner.clone();
let stype_in = infer_symbology_type(symbols.first().unwrap());
let symbols: Vec<&str> = symbols.iter().map(std::string::String::as_str).collect();
check_consistent_symbology(symbols.as_slice()).map_err(to_pyvalue_err)?;
let end = end.unwrap_or(self.clock.get_time_ns().as_u64());
let time_range = get_date_time_range(start.into(), end.into()).map_err(to_pyvalue_err)?;
let schema = schema.unwrap_or_else(|| "mbp-1".to_string());
let dbn_schema = dbn::Schema::from_str(&schema).map_err(to_pyvalue_err)?;
match dbn_schema {
dbn::Schema::Mbp1 | dbn::Schema::Bbo1S | dbn::Schema::Bbo1M => (),
_ => {
return Err(to_pyvalue_err(
"Invalid schema. Must be one of: mbp-1, bbo-1s, bbo-1m",
))
}
};
let params = GetRangeParams::builder()
.dataset(dataset)
.date_time_range(time_range)
.symbols(symbols)
.stype_in(SType::from_str(&stype_in).map_err(to_pyvalue_err)?)
.schema(dbn_schema)
.limit(limit.and_then(NonZeroU64::new))
.build();
let price_precision = price_precision.unwrap_or(Currency::USD().precision);
let publisher_venue_map = self.publisher_venue_map.clone();
let ts_init = self.clock.get_time_ns();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut client = client.lock().await; let mut decoder = client
.timeseries()
.get_range(¶ms)
.await
.map_err(to_pyvalue_err)?;
let metadata = decoder.metadata().clone();
let mut result: Vec<QuoteTick> = Vec::new();
let mut process_record = |record: dbn::RecordRef| -> PyResult<()> {
let instrument_id =
decode_nautilus_instrument_id(&record, &metadata, &publisher_venue_map)
.map_err(to_pyvalue_err)?;
let (data, _) = decode_record(
&record,
instrument_id,
price_precision,
Some(ts_init),
false, )
.map_err(to_pyvalue_err)?;
match data {
Some(Data::Quote(quote)) => {
result.push(quote);
Ok(())
}
_ => panic!("Invalid data element not `QuoteTick`, was {data:?}"),
}
};
match dbn_schema {
dbn::Schema::Mbp1 => {
while let Ok(Some(msg)) = decoder.decode_record::<dbn::Mbp1Msg>().await {
process_record(dbn::RecordRef::from(msg))?;
}
}
dbn::Schema::Bbo1M => {
while let Ok(Some(msg)) = decoder.decode_record::<dbn::Bbo1MMsg>().await {
process_record(dbn::RecordRef::from(msg))?;
}
}
dbn::Schema::Bbo1S => {
while let Ok(Some(msg)) = decoder.decode_record::<dbn::Bbo1SMsg>().await {
process_record(dbn::RecordRef::from(msg))?;
}
}
_ => panic!("Invalid schema {dbn_schema}"),
}
Python::with_gil(|py| Ok(result.into_py(py)))
})
}
#[pyo3(name = "get_range_trades")]
#[pyo3(signature = (dataset, symbols, start, end=None, limit=None, price_precision=None))]
#[allow(clippy::too_many_arguments)]
fn py_get_range_trades<'py>(
&self,
py: Python<'py>,
dataset: String,
symbols: Vec<String>,
start: u64,
end: Option<u64>,
limit: Option<u64>,
price_precision: Option<u8>,
) -> PyResult<Bound<'py, PyAny>> {
let client = self.inner.clone();
let stype_in = infer_symbology_type(symbols.first().unwrap());
let symbols: Vec<&str> = symbols.iter().map(std::string::String::as_str).collect();
check_consistent_symbology(symbols.as_slice()).map_err(to_pyvalue_err)?;
let end = end.unwrap_or(self.clock.get_time_ns().as_u64());
let time_range = get_date_time_range(start.into(), end.into()).map_err(to_pyvalue_err)?;
let params = GetRangeParams::builder()
.dataset(dataset)
.date_time_range(time_range)
.symbols(symbols)
.stype_in(SType::from_str(&stype_in).map_err(to_pyvalue_err)?)
.schema(dbn::Schema::Trades)
.limit(limit.and_then(NonZeroU64::new))
.build();
let price_precision = price_precision.unwrap_or(Currency::USD().precision);
let publisher_venue_map = self.publisher_venue_map.clone();
let ts_init = self.clock.get_time_ns();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut client = client.lock().await; let mut decoder = client
.timeseries()
.get_range(¶ms)
.await
.map_err(to_pyvalue_err)?;
let metadata = decoder.metadata().clone();
let mut result: Vec<TradeTick> = Vec::new();
while let Ok(Some(msg)) = decoder.decode_record::<dbn::TradeMsg>().await {
let record = dbn::RecordRef::from(msg);
let instrument_id =
decode_nautilus_instrument_id(&record, &metadata, &publisher_venue_map)
.map_err(to_pyvalue_err)?;
let (data, _) = decode_record(
&record,
instrument_id,
price_precision,
Some(ts_init),
false, )
.map_err(to_pyvalue_err)?;
match data {
Some(Data::Trade(trade)) => {
result.push(trade);
}
_ => panic!("Invalid data element not `TradeTick`, was {data:?}"),
}
}
Python::with_gil(|py| Ok(result.into_py(py)))
})
}
#[pyo3(name = "get_range_bars")]
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (dataset, symbols, aggregation, start, end=None, limit=None, price_precision=None))]
fn py_get_range_bars<'py>(
&self,
py: Python<'py>,
dataset: String,
symbols: Vec<String>,
aggregation: BarAggregation,
start: u64,
end: Option<u64>,
limit: Option<u64>,
price_precision: Option<u8>,
) -> PyResult<Bound<'py, PyAny>> {
let client = self.inner.clone();
let stype_in = infer_symbology_type(symbols.first().unwrap());
let symbols: Vec<&str> = symbols.iter().map(std::string::String::as_str).collect();
check_consistent_symbology(symbols.as_slice()).map_err(to_pyvalue_err)?;
let schema = match aggregation {
BarAggregation::Second => dbn::Schema::Ohlcv1S,
BarAggregation::Minute => dbn::Schema::Ohlcv1M,
BarAggregation::Hour => dbn::Schema::Ohlcv1H,
BarAggregation::Day => dbn::Schema::Ohlcv1D,
_ => panic!("Invalid `BarAggregation` for request, was {aggregation}"),
};
let end = end.unwrap_or(self.clock.get_time_ns().as_u64());
let time_range = get_date_time_range(start.into(), end.into()).map_err(to_pyvalue_err)?;
let params = GetRangeParams::builder()
.dataset(dataset)
.date_time_range(time_range)
.symbols(symbols)
.stype_in(SType::from_str(&stype_in).map_err(to_pyvalue_err)?)
.schema(schema)
.limit(limit.and_then(NonZeroU64::new))
.build();
let price_precision = price_precision.unwrap_or(Currency::USD().precision);
let publisher_venue_map = self.publisher_venue_map.clone();
let ts_init = self.clock.get_time_ns();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut client = client.lock().await; let mut decoder = client
.timeseries()
.get_range(¶ms)
.await
.map_err(to_pyvalue_err)?;
let metadata = decoder.metadata().clone();
let mut result: Vec<Bar> = Vec::new();
while let Ok(Some(msg)) = decoder.decode_record::<dbn::OhlcvMsg>().await {
let record = dbn::RecordRef::from(msg);
let instrument_id =
decode_nautilus_instrument_id(&record, &metadata, &publisher_venue_map)
.map_err(to_pyvalue_err)?;
let (data, _) = decode_record(
&record,
instrument_id,
price_precision,
Some(ts_init),
false, )
.map_err(to_pyvalue_err)?;
match data {
Some(Data::Bar(bar)) => {
result.push(bar);
}
_ => panic!("Invalid data element not `Bar`, was {data:?}"),
}
}
Python::with_gil(|py| Ok(result.into_py(py)))
})
}
#[pyo3(name = "get_range_imbalance")]
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (dataset, symbols, start, end=None, limit=None, price_precision=None))]
fn py_get_range_imbalance<'py>(
&self,
py: Python<'py>,
dataset: String,
symbols: Vec<String>,
start: u64,
end: Option<u64>,
limit: Option<u64>,
price_precision: Option<u8>,
) -> PyResult<Bound<'py, PyAny>> {
let client = self.inner.clone();
let stype_in = infer_symbology_type(symbols.first().unwrap());
let symbols: Vec<&str> = symbols.iter().map(std::string::String::as_str).collect();
check_consistent_symbology(symbols.as_slice()).map_err(to_pyvalue_err)?;
let end = end.unwrap_or(self.clock.get_time_ns().as_u64());
let time_range = get_date_time_range(start.into(), end.into()).map_err(to_pyvalue_err)?;
let params = GetRangeParams::builder()
.dataset(dataset)
.date_time_range(time_range)
.symbols(symbols)
.stype_in(SType::from_str(&stype_in).map_err(to_pyvalue_err)?)
.schema(dbn::Schema::Imbalance)
.limit(limit.and_then(NonZeroU64::new))
.build();
let price_precision = price_precision.unwrap_or(Currency::USD().precision);
let publisher_venue_map = self.publisher_venue_map.clone();
let ts_init = self.clock.get_time_ns();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut client = client.lock().await; let mut decoder = client
.timeseries()
.get_range(¶ms)
.await
.map_err(to_pyvalue_err)?;
let metadata = decoder.metadata().clone();
let mut result: Vec<DatabentoImbalance> = Vec::new();
while let Ok(Some(msg)) = decoder.decode_record::<dbn::ImbalanceMsg>().await {
let record = dbn::RecordRef::from(msg);
let instrument_id =
decode_nautilus_instrument_id(&record, &metadata, &publisher_venue_map)
.map_err(to_pyvalue_err)?;
let imbalance = decode_imbalance_msg(msg, instrument_id, price_precision, ts_init)
.map_err(to_pyvalue_err)?;
result.push(imbalance);
}
Python::with_gil(|py| Ok(result.into_py(py)))
})
}
#[pyo3(name = "get_range_statistics")]
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (dataset, symbols, start, end=None, limit=None, price_precision=None))]
fn py_get_range_statistics<'py>(
&self,
py: Python<'py>,
dataset: String,
symbols: Vec<String>,
start: u64,
end: Option<u64>,
limit: Option<u64>,
price_precision: Option<u8>,
) -> PyResult<Bound<'py, PyAny>> {
let client = self.inner.clone();
let stype_in = infer_symbology_type(symbols.first().unwrap());
let symbols: Vec<&str> = symbols.iter().map(std::string::String::as_str).collect();
check_consistent_symbology(symbols.as_slice()).map_err(to_pyvalue_err)?;
let end = end.unwrap_or(self.clock.get_time_ns().as_u64());
let time_range = get_date_time_range(start.into(), end.into()).map_err(to_pyvalue_err)?;
let params = GetRangeParams::builder()
.dataset(dataset)
.date_time_range(time_range)
.symbols(symbols)
.stype_in(SType::from_str(&stype_in).map_err(to_pyvalue_err)?)
.schema(dbn::Schema::Statistics)
.limit(limit.and_then(NonZeroU64::new))
.build();
let price_precision = price_precision.unwrap_or(Currency::USD().precision);
let publisher_venue_map = self.publisher_venue_map.clone();
let ts_init = self.clock.get_time_ns();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut client = client.lock().await; let mut decoder = client
.timeseries()
.get_range(¶ms)
.await
.map_err(to_pyvalue_err)?;
let metadata = decoder.metadata().clone();
let mut result: Vec<DatabentoStatistics> = Vec::new();
while let Ok(Some(msg)) = decoder.decode_record::<dbn::StatMsg>().await {
let record = dbn::RecordRef::from(msg);
let instrument_id =
decode_nautilus_instrument_id(&record, &metadata, &publisher_venue_map)
.map_err(to_pyvalue_err)?;
let statistics =
decode_statistics_msg(msg, instrument_id, price_precision, ts_init)
.map_err(to_pyvalue_err)?;
result.push(statistics);
}
Python::with_gil(|py| Ok(result.into_py(py)))
})
}
#[pyo3(name = "get_range_status")]
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (dataset, symbols, start, end=None, limit=None))]
fn py_get_range_status<'py>(
&self,
py: Python<'py>,
dataset: String,
symbols: Vec<String>,
start: u64,
end: Option<u64>,
limit: Option<u64>,
) -> PyResult<Bound<'py, PyAny>> {
let client = self.inner.clone();
let stype_in = infer_symbology_type(symbols.first().unwrap());
let symbols: Vec<&str> = symbols.iter().map(std::string::String::as_str).collect();
check_consistent_symbology(symbols.as_slice()).map_err(to_pyvalue_err)?;
let end = end.unwrap_or(self.clock.get_time_ns().as_u64());
let time_range = get_date_time_range(start.into(), end.into()).map_err(to_pyvalue_err)?;
let params = GetRangeParams::builder()
.dataset(dataset)
.date_time_range(time_range)
.symbols(symbols)
.stype_in(SType::from_str(&stype_in).map_err(to_pyvalue_err)?)
.schema(dbn::Schema::Status)
.limit(limit.and_then(NonZeroU64::new))
.build();
let publisher_venue_map = self.publisher_venue_map.clone();
let ts_init = self.clock.get_time_ns();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut client = client.lock().await; let mut decoder = client
.timeseries()
.get_range(¶ms)
.await
.map_err(to_pyvalue_err)?;
let metadata = decoder.metadata().clone();
let mut result: Vec<InstrumentStatus> = Vec::new();
while let Ok(Some(msg)) = decoder.decode_record::<dbn::StatusMsg>().await {
let record = dbn::RecordRef::from(msg);
let instrument_id =
decode_nautilus_instrument_id(&record, &metadata, &publisher_venue_map)
.map_err(to_pyvalue_err)?;
let status =
decode_status_msg(msg, instrument_id, ts_init).map_err(to_pyvalue_err)?;
result.push(status);
}
Python::with_gil(|py| Ok(result.into_py(py)))
})
}
}