nautilus_serialization/arrow/
mod.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Defines the Apache Arrow schema for Nautilus types.
17
18pub mod bar;
19pub mod delta;
20pub mod depth;
21pub mod quote;
22pub mod trade;
23
24use std::{
25    collections::HashMap,
26    io::{self, Write},
27};
28
29use arrow::{
30    array::{Array, ArrayRef},
31    datatypes::{DataType, Schema},
32    error::ArrowError,
33    ipc::writer::StreamWriter,
34    record_batch::RecordBatch,
35};
36use nautilus_model::{
37    data::{
38        Data, bar::Bar, delta::OrderBookDelta, depth::OrderBookDepth10, quote::QuoteTick,
39        trade::TradeTick,
40    },
41    types::{price::PriceRaw, quantity::QuantityRaw},
42};
43use pyo3::prelude::*;
44
45// Define metadata key constants constants
46const KEY_BAR_TYPE: &str = "bar_type";
47pub const KEY_INSTRUMENT_ID: &str = "instrument_id";
48const KEY_PRICE_PRECISION: &str = "price_precision";
49const KEY_SIZE_PRECISION: &str = "size_precision";
50
51#[derive(thiserror::Error, Debug)]
52pub enum DataStreamingError {
53    #[error("Arrow error: {0}")]
54    ArrowError(#[from] arrow::error::ArrowError),
55    #[error("I/O error: {0}")]
56    IoError(#[from] io::Error),
57    #[error("Python error: {0}")]
58    PythonError(#[from] PyErr),
59}
60
61#[derive(thiserror::Error, Debug)]
62pub enum EncodingError {
63    #[error("Empty data")]
64    EmptyData,
65    #[error("Missing metadata key: `{0}`")]
66    MissingMetadata(&'static str),
67    #[error("Missing data column: `{0}` at index {1}")]
68    MissingColumn(&'static str, usize),
69    #[error("Error parsing `{0}`: {1}")]
70    ParseError(&'static str, String),
71    #[error("Invalid column type `{0}` at index {1}: expected {2}, found {3}")]
72    InvalidColumnType(&'static str, usize, DataType, DataType),
73    #[error("Arrow error: {0}")]
74    ArrowError(#[from] arrow::error::ArrowError),
75}
76
77#[inline]
78fn get_raw_price(bytes: &[u8]) -> PriceRaw {
79    PriceRaw::from_le_bytes(bytes.try_into().unwrap())
80}
81
82#[inline]
83fn get_raw_quantity(bytes: &[u8]) -> QuantityRaw {
84    QuantityRaw::from_le_bytes(bytes.try_into().unwrap())
85}
86
87pub trait ArrowSchemaProvider {
88    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema;
89
90    #[must_use]
91    fn get_schema_map() -> HashMap<String, String> {
92        let schema = Self::get_schema(None);
93        let mut map = HashMap::new();
94        for field in schema.fields() {
95            let name = field.name().to_string();
96            let data_type = format!("{:?}", field.data_type());
97            map.insert(name, data_type);
98        }
99        map
100    }
101}
102
103pub trait EncodeToRecordBatch
104where
105    Self: Sized + ArrowSchemaProvider,
106{
107    fn encode_batch(
108        metadata: &HashMap<String, String>,
109        data: &[Self],
110    ) -> Result<RecordBatch, ArrowError>;
111
112    fn metadata(&self) -> HashMap<String, String>;
113    fn chunk_metadata(chunk: &[Self]) -> HashMap<String, String> {
114        chunk
115            .first()
116            .map(|elem| elem.metadata())
117            .expect("Chunk must have atleast one element to encode")
118    }
119}
120
121pub trait DecodeFromRecordBatch
122where
123    Self: Sized + Into<Data> + ArrowSchemaProvider,
124{
125    fn decode_batch(
126        metadata: &HashMap<String, String>,
127        record_batch: RecordBatch,
128    ) -> Result<Vec<Self>, EncodingError>;
129}
130
131pub trait DecodeDataFromRecordBatch
132where
133    Self: Sized + Into<Data> + ArrowSchemaProvider,
134{
135    fn decode_data_batch(
136        metadata: &HashMap<String, String>,
137        record_batch: RecordBatch,
138    ) -> Result<Vec<Data>, EncodingError>;
139}
140
141pub trait WriteStream {
142    fn write(&mut self, record_batch: &RecordBatch) -> Result<(), DataStreamingError>;
143}
144
145impl<T: EncodeToRecordBatch + Write> WriteStream for T {
146    fn write(&mut self, record_batch: &RecordBatch) -> Result<(), DataStreamingError> {
147        let mut writer = StreamWriter::try_new(self, &record_batch.schema())?;
148        writer.write(record_batch)?;
149        writer.finish()?;
150        Ok(())
151    }
152}
153
154pub fn extract_column<'a, T: Array + 'static>(
155    cols: &'a [ArrayRef],
156    column_key: &'static str,
157    column_index: usize,
158    expected_type: DataType,
159) -> Result<&'a T, EncodingError> {
160    let column_values = cols
161        .get(column_index)
162        .ok_or(EncodingError::MissingColumn(column_key, column_index))?;
163    let downcasted_values =
164        column_values
165            .as_any()
166            .downcast_ref::<T>()
167            .ok_or(EncodingError::InvalidColumnType(
168                column_key,
169                column_index,
170                expected_type,
171                column_values.data_type().clone(),
172            ))?;
173    Ok(downcasted_values)
174}
175
176pub fn order_book_deltas_to_arrow_record_batch_bytes(
177    data: Vec<OrderBookDelta>,
178) -> Result<RecordBatch, EncodingError> {
179    if data.is_empty() {
180        return Err(EncodingError::EmptyData);
181    }
182
183    // Extract metadata from chunk
184    let metadata = OrderBookDelta::chunk_metadata(&data);
185    OrderBookDelta::encode_batch(&metadata, &data).map_err(EncodingError::ArrowError)
186}
187
188pub fn order_book_depth10_to_arrow_record_batch_bytes(
189    data: Vec<OrderBookDepth10>,
190) -> Result<RecordBatch, EncodingError> {
191    if data.is_empty() {
192        return Err(EncodingError::EmptyData);
193    }
194
195    // Take first element and extract metadata
196    // SAFETY: Unwrap safe as already checked that `data` not empty
197    let first = data.first().unwrap();
198    let metadata = first.metadata();
199    OrderBookDepth10::encode_batch(&metadata, &data).map_err(EncodingError::ArrowError)
200}
201
202pub fn quote_ticks_to_arrow_record_batch_bytes(
203    data: Vec<QuoteTick>,
204) -> Result<RecordBatch, EncodingError> {
205    if data.is_empty() {
206        return Err(EncodingError::EmptyData);
207    }
208
209    // Take first element and extract metadata
210    // SAFETY: Unwrap safe as already checked that `data` not empty
211    let first = data.first().unwrap();
212    let metadata = first.metadata();
213    QuoteTick::encode_batch(&metadata, &data).map_err(EncodingError::ArrowError)
214}
215
216pub fn trade_ticks_to_arrow_record_batch_bytes(
217    data: Vec<TradeTick>,
218) -> Result<RecordBatch, EncodingError> {
219    if data.is_empty() {
220        return Err(EncodingError::EmptyData);
221    }
222
223    // Take first element and extract metadata
224    // SAFETY: Unwrap safe as already checked that `data` not empty
225    let first = data.first().unwrap();
226    let metadata = first.metadata();
227    TradeTick::encode_batch(&metadata, &data).map_err(EncodingError::ArrowError)
228}
229
230pub fn bars_to_arrow_record_batch_bytes(data: Vec<Bar>) -> Result<RecordBatch, EncodingError> {
231    if data.is_empty() {
232        return Err(EncodingError::EmptyData);
233    }
234
235    // Take first element and extract metadata
236    // SAFETY: Unwrap safe as already checked that `data` not empty
237    let first = data.first().unwrap();
238    let metadata = first.metadata();
239    Bar::encode_batch(&metadata, &data).map_err(EncodingError::ArrowError)
240}