nautilus_serialization/arrow/
bar.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19    array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20    datatypes::{DataType, Field, Schema},
21    error::ArrowError,
22    record_batch::RecordBatch,
23};
24use nautilus_model::{
25    data::{Bar, BarType},
26    types::{Price, Quantity, fixed::PRECISION_BYTES},
27};
28
29use super::{
30    DecodeDataFromRecordBatch, EncodingError, KEY_BAR_TYPE, KEY_PRICE_PRECISION,
31    KEY_SIZE_PRECISION, extract_column, get_corrected_raw_price, get_corrected_raw_quantity,
32};
33use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
34
35impl ArrowSchemaProvider for Bar {
36    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
37        let fields = vec![
38            Field::new("open", DataType::FixedSizeBinary(PRECISION_BYTES), false),
39            Field::new("high", DataType::FixedSizeBinary(PRECISION_BYTES), false),
40            Field::new("low", DataType::FixedSizeBinary(PRECISION_BYTES), false),
41            Field::new("close", DataType::FixedSizeBinary(PRECISION_BYTES), false),
42            Field::new("volume", DataType::FixedSizeBinary(PRECISION_BYTES), false),
43            Field::new("ts_event", DataType::UInt64, false),
44            Field::new("ts_init", DataType::UInt64, false),
45        ];
46
47        match metadata {
48            Some(metadata) => Schema::new_with_metadata(fields, metadata),
49            None => Schema::new(fields),
50        }
51    }
52}
53
54fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(BarType, u8, u8), EncodingError> {
55    let bar_type_str = metadata
56        .get(KEY_BAR_TYPE)
57        .ok_or_else(|| EncodingError::MissingMetadata(KEY_BAR_TYPE))?;
58    let bar_type = BarType::from_str(bar_type_str)
59        .map_err(|e| EncodingError::ParseError(KEY_BAR_TYPE, e.to_string()))?;
60
61    let price_precision = metadata
62        .get(KEY_PRICE_PRECISION)
63        .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
64        .parse::<u8>()
65        .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
66
67    let size_precision = metadata
68        .get(KEY_SIZE_PRECISION)
69        .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
70        .parse::<u8>()
71        .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
72
73    Ok((bar_type, price_precision, size_precision))
74}
75
76impl EncodeToRecordBatch for Bar {
77    fn encode_batch(
78        metadata: &HashMap<String, String>,
79        data: &[Self],
80    ) -> Result<RecordBatch, ArrowError> {
81        let mut open_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
82        let mut high_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
83        let mut low_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
84        let mut close_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
85        let mut volume_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
86        let mut ts_event_builder = UInt64Array::builder(data.len());
87        let mut ts_init_builder = UInt64Array::builder(data.len());
88
89        for bar in data {
90            open_builder
91                .append_value(bar.open.raw.to_le_bytes())
92                .unwrap();
93            high_builder
94                .append_value(bar.high.raw.to_le_bytes())
95                .unwrap();
96            low_builder.append_value(bar.low.raw.to_le_bytes()).unwrap();
97            close_builder
98                .append_value(bar.close.raw.to_le_bytes())
99                .unwrap();
100            volume_builder
101                .append_value(bar.volume.raw.to_le_bytes())
102                .unwrap();
103            ts_event_builder.append_value(bar.ts_event.as_u64());
104            ts_init_builder.append_value(bar.ts_init.as_u64());
105        }
106
107        let open_array = open_builder.finish();
108        let high_array = high_builder.finish();
109        let low_array = low_builder.finish();
110        let close_array = close_builder.finish();
111        let volume_array = volume_builder.finish();
112        let ts_event_array = ts_event_builder.finish();
113        let ts_init_array = ts_init_builder.finish();
114
115        RecordBatch::try_new(
116            Self::get_schema(Some(metadata.clone())).into(),
117            vec![
118                Arc::new(open_array),
119                Arc::new(high_array),
120                Arc::new(low_array),
121                Arc::new(close_array),
122                Arc::new(volume_array),
123                Arc::new(ts_event_array),
124                Arc::new(ts_init_array),
125            ],
126        )
127    }
128
129    fn metadata(&self) -> HashMap<String, String> {
130        Self::get_metadata(&self.bar_type, self.open.precision, self.volume.precision)
131    }
132}
133
134impl DecodeFromRecordBatch for Bar {
135    fn decode_batch(
136        metadata: &HashMap<String, String>,
137        record_batch: RecordBatch,
138    ) -> Result<Vec<Self>, EncodingError> {
139        let (bar_type, price_precision, size_precision) = parse_metadata(metadata)?;
140        let cols = record_batch.columns();
141
142        let open_values = extract_column::<FixedSizeBinaryArray>(
143            cols,
144            "open",
145            0,
146            DataType::FixedSizeBinary(PRECISION_BYTES),
147        )?;
148        let high_values = extract_column::<FixedSizeBinaryArray>(
149            cols,
150            "high",
151            1,
152            DataType::FixedSizeBinary(PRECISION_BYTES),
153        )?;
154        let low_values = extract_column::<FixedSizeBinaryArray>(
155            cols,
156            "low",
157            2,
158            DataType::FixedSizeBinary(PRECISION_BYTES),
159        )?;
160        let close_values = extract_column::<FixedSizeBinaryArray>(
161            cols,
162            "close",
163            3,
164            DataType::FixedSizeBinary(PRECISION_BYTES),
165        )?;
166        let volume_values = extract_column::<FixedSizeBinaryArray>(
167            cols,
168            "volume",
169            4,
170            DataType::FixedSizeBinary(PRECISION_BYTES),
171        )?;
172        let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 5, DataType::UInt64)?;
173        let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 6, DataType::UInt64)?;
174
175        let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
176            .map(|i| {
177                // Use corrected raw values to handle floating-point precision errors in stored data
178                let open = Price::from_raw(
179                    get_corrected_raw_price(open_values.value(i), price_precision),
180                    price_precision,
181                );
182                let high = Price::from_raw(
183                    get_corrected_raw_price(high_values.value(i), price_precision),
184                    price_precision,
185                );
186                let low = Price::from_raw(
187                    get_corrected_raw_price(low_values.value(i), price_precision),
188                    price_precision,
189                );
190                let close = Price::from_raw(
191                    get_corrected_raw_price(close_values.value(i), price_precision),
192                    price_precision,
193                );
194                let volume = Quantity::from_raw(
195                    get_corrected_raw_quantity(volume_values.value(i), size_precision),
196                    size_precision,
197                );
198                let ts_event = ts_event_values.value(i).into();
199                let ts_init = ts_init_values.value(i).into();
200
201                Ok(Self {
202                    bar_type,
203                    open,
204                    high,
205                    low,
206                    close,
207                    volume,
208                    ts_event,
209                    ts_init,
210                })
211            })
212            .collect();
213
214        result
215    }
216}
217
218impl DecodeDataFromRecordBatch for Bar {
219    fn decode_data_batch(
220        metadata: &HashMap<String, String>,
221        record_batch: RecordBatch,
222    ) -> Result<Vec<Data>, EncodingError> {
223        let bars: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
224        Ok(bars.into_iter().map(Data::from).collect())
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use std::sync::Arc;
231
232    use arrow::{array::Array, record_batch::RecordBatch};
233    use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw};
234    use rstest::rstest;
235
236    use super::*;
237    use crate::arrow::{get_raw_price, get_raw_quantity};
238
239    #[rstest]
240    fn test_get_schema() {
241        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
242        let metadata = Bar::get_metadata(&bar_type, 2, 0);
243        let schema = Bar::get_schema(Some(metadata.clone()));
244        let expected_fields = vec![
245            Field::new("open", DataType::FixedSizeBinary(PRECISION_BYTES), false),
246            Field::new("high", DataType::FixedSizeBinary(PRECISION_BYTES), false),
247            Field::new("low", DataType::FixedSizeBinary(PRECISION_BYTES), false),
248            Field::new("close", DataType::FixedSizeBinary(PRECISION_BYTES), false),
249            Field::new("volume", DataType::FixedSizeBinary(PRECISION_BYTES), false),
250            Field::new("ts_event", DataType::UInt64, false),
251            Field::new("ts_init", DataType::UInt64, false),
252        ];
253        let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
254        assert_eq!(schema, expected_schema);
255    }
256
257    #[rstest]
258    fn test_get_schema_map() {
259        let schema_map = Bar::get_schema_map();
260        let mut expected_map = HashMap::new();
261        let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
262        expected_map.insert("open".to_string(), fixed_size_binary.clone());
263        expected_map.insert("high".to_string(), fixed_size_binary.clone());
264        expected_map.insert("low".to_string(), fixed_size_binary.clone());
265        expected_map.insert("close".to_string(), fixed_size_binary.clone());
266        expected_map.insert("volume".to_string(), fixed_size_binary);
267        expected_map.insert("ts_event".to_string(), "UInt64".to_string());
268        expected_map.insert("ts_init".to_string(), "UInt64".to_string());
269        assert_eq!(schema_map, expected_map);
270    }
271
272    #[rstest]
273    fn test_encode_batch() {
274        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
275        let metadata = Bar::get_metadata(&bar_type, 2, 0);
276
277        let bar1 = Bar::new(
278            bar_type,
279            Price::from("100.10"),
280            Price::from("102.00"),
281            Price::from("100.00"),
282            Price::from("101.00"),
283            Quantity::from(1100),
284            1.into(),
285            3.into(),
286        );
287        let bar2 = Bar::new(
288            bar_type,
289            Price::from("100.00"),
290            Price::from("100.10"),
291            Price::from("100.00"),
292            Price::from("100.10"),
293            Quantity::from(1110),
294            2.into(),
295            4.into(),
296        );
297
298        let data = vec![bar1, bar2];
299        let record_batch = Bar::encode_batch(&metadata, &data).unwrap();
300
301        let columns = record_batch.columns();
302        let open_values = columns[0]
303            .as_any()
304            .downcast_ref::<FixedSizeBinaryArray>()
305            .unwrap();
306        let high_values = columns[1]
307            .as_any()
308            .downcast_ref::<FixedSizeBinaryArray>()
309            .unwrap();
310        let low_values = columns[2]
311            .as_any()
312            .downcast_ref::<FixedSizeBinaryArray>()
313            .unwrap();
314        let close_values = columns[3]
315            .as_any()
316            .downcast_ref::<FixedSizeBinaryArray>()
317            .unwrap();
318        let volume_values = columns[4]
319            .as_any()
320            .downcast_ref::<FixedSizeBinaryArray>()
321            .unwrap();
322        let ts_event_values = columns[5].as_any().downcast_ref::<UInt64Array>().unwrap();
323        let ts_init_values = columns[6].as_any().downcast_ref::<UInt64Array>().unwrap();
324
325        assert_eq!(columns.len(), 7);
326        assert_eq!(open_values.len(), 2);
327        assert_eq!(
328            get_raw_price(open_values.value(0)),
329            (100.10 * FIXED_SCALAR) as PriceRaw
330        );
331        assert_eq!(
332            get_raw_price(open_values.value(1)),
333            (100.00 * FIXED_SCALAR) as PriceRaw
334        );
335        assert_eq!(high_values.len(), 2);
336        assert_eq!(
337            get_raw_price(high_values.value(0)),
338            (102.00 * FIXED_SCALAR) as PriceRaw
339        );
340        assert_eq!(
341            get_raw_price(high_values.value(1)),
342            (100.10 * FIXED_SCALAR) as PriceRaw
343        );
344        assert_eq!(low_values.len(), 2);
345        assert_eq!(
346            get_raw_price(low_values.value(0)),
347            (100.00 * FIXED_SCALAR) as PriceRaw
348        );
349        assert_eq!(
350            get_raw_price(low_values.value(1)),
351            (100.00 * FIXED_SCALAR) as PriceRaw
352        );
353        assert_eq!(close_values.len(), 2);
354        assert_eq!(
355            get_raw_price(close_values.value(0)),
356            (101.00 * FIXED_SCALAR) as PriceRaw
357        );
358        assert_eq!(
359            get_raw_price(close_values.value(1)),
360            (100.10 * FIXED_SCALAR) as PriceRaw
361        );
362        assert_eq!(volume_values.len(), 2);
363        assert_eq!(
364            get_raw_quantity(volume_values.value(0)),
365            (1100.0 * FIXED_SCALAR) as QuantityRaw
366        );
367        assert_eq!(
368            get_raw_quantity(volume_values.value(1)),
369            (1110.0 * FIXED_SCALAR) as QuantityRaw
370        );
371        assert_eq!(ts_event_values.len(), 2);
372        assert_eq!(ts_event_values.value(0), 1);
373        assert_eq!(ts_event_values.value(1), 2);
374        assert_eq!(ts_init_values.len(), 2);
375        assert_eq!(ts_init_values.value(0), 3);
376        assert_eq!(ts_init_values.value(1), 4);
377    }
378
379    #[rstest]
380    fn test_decode_batch() {
381        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
382        let metadata = Bar::get_metadata(&bar_type, 2, 0);
383
384        let open = FixedSizeBinaryArray::from(vec![
385            &((100.10 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
386            &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
387        ]);
388        let high = FixedSizeBinaryArray::from(vec![
389            &((102.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
390            &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
391        ]);
392        let low = FixedSizeBinaryArray::from(vec![
393            &((100.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
394            &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
395        ]);
396        let close = FixedSizeBinaryArray::from(vec![
397            &((101.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
398            &((10.01 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
399        ]);
400        let volume = FixedSizeBinaryArray::from(vec![
401            &((11.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
402            &((10.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
403        ]);
404        let ts_event = UInt64Array::from(vec![1, 2]);
405        let ts_init = UInt64Array::from(vec![3, 4]);
406
407        let record_batch = RecordBatch::try_new(
408            Bar::get_schema(Some(metadata.clone())).into(),
409            vec![
410                Arc::new(open),
411                Arc::new(high),
412                Arc::new(low),
413                Arc::new(close),
414                Arc::new(volume),
415                Arc::new(ts_event),
416                Arc::new(ts_init),
417            ],
418        )
419        .unwrap();
420
421        let decoded_data = Bar::decode_batch(&metadata, record_batch).unwrap();
422        assert_eq!(decoded_data.len(), 2);
423    }
424}