Skip to main content

nautilus_serialization/arrow/
bar.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19    array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20    datatypes::{DataType, Field, Schema},
21    error::ArrowError,
22    record_batch::RecordBatch,
23};
24use nautilus_model::{
25    data::{Bar, BarType},
26    types::fixed::PRECISION_BYTES,
27};
28
29use super::{
30    DecodeDataFromRecordBatch, EncodingError, KEY_BAR_TYPE, KEY_PRICE_PRECISION,
31    KEY_SIZE_PRECISION, decode_price, decode_quantity, extract_column, validate_precision_bytes,
32};
33use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
34
35impl ArrowSchemaProvider for Bar {
36    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
37        let fields = vec![
38            Field::new("open", DataType::FixedSizeBinary(PRECISION_BYTES), false),
39            Field::new("high", DataType::FixedSizeBinary(PRECISION_BYTES), false),
40            Field::new("low", DataType::FixedSizeBinary(PRECISION_BYTES), false),
41            Field::new("close", DataType::FixedSizeBinary(PRECISION_BYTES), false),
42            Field::new("volume", DataType::FixedSizeBinary(PRECISION_BYTES), false),
43            Field::new("ts_event", DataType::UInt64, false),
44            Field::new("ts_init", DataType::UInt64, false),
45        ];
46
47        match metadata {
48            Some(metadata) => Schema::new_with_metadata(fields, metadata),
49            None => Schema::new(fields),
50        }
51    }
52}
53
54fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(BarType, u8, u8), EncodingError> {
55    let bar_type_str = metadata
56        .get(KEY_BAR_TYPE)
57        .ok_or_else(|| EncodingError::MissingMetadata(KEY_BAR_TYPE))?;
58    let bar_type = BarType::from_str(bar_type_str)
59        .map_err(|e| EncodingError::ParseError(KEY_BAR_TYPE, e.to_string()))?;
60
61    let price_precision = metadata
62        .get(KEY_PRICE_PRECISION)
63        .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
64        .parse::<u8>()
65        .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
66
67    let size_precision = metadata
68        .get(KEY_SIZE_PRECISION)
69        .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
70        .parse::<u8>()
71        .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
72
73    Ok((bar_type, price_precision, size_precision))
74}
75
76impl EncodeToRecordBatch for Bar {
77    fn encode_batch(
78        metadata: &HashMap<String, String>,
79        data: &[Self],
80    ) -> Result<RecordBatch, ArrowError> {
81        let mut open_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
82        let mut high_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
83        let mut low_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
84        let mut close_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
85        let mut volume_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
86        let mut ts_event_builder = UInt64Array::builder(data.len());
87        let mut ts_init_builder = UInt64Array::builder(data.len());
88
89        for bar in data {
90            open_builder
91                .append_value(bar.open.raw.to_le_bytes())
92                .unwrap();
93            high_builder
94                .append_value(bar.high.raw.to_le_bytes())
95                .unwrap();
96            low_builder.append_value(bar.low.raw.to_le_bytes()).unwrap();
97            close_builder
98                .append_value(bar.close.raw.to_le_bytes())
99                .unwrap();
100            volume_builder
101                .append_value(bar.volume.raw.to_le_bytes())
102                .unwrap();
103            ts_event_builder.append_value(bar.ts_event.as_u64());
104            ts_init_builder.append_value(bar.ts_init.as_u64());
105        }
106
107        let open_array = open_builder.finish();
108        let high_array = high_builder.finish();
109        let low_array = low_builder.finish();
110        let close_array = close_builder.finish();
111        let volume_array = volume_builder.finish();
112        let ts_event_array = ts_event_builder.finish();
113        let ts_init_array = ts_init_builder.finish();
114
115        RecordBatch::try_new(
116            Self::get_schema(Some(metadata.clone())).into(),
117            vec![
118                Arc::new(open_array),
119                Arc::new(high_array),
120                Arc::new(low_array),
121                Arc::new(close_array),
122                Arc::new(volume_array),
123                Arc::new(ts_event_array),
124                Arc::new(ts_init_array),
125            ],
126        )
127    }
128
129    fn metadata(&self) -> HashMap<String, String> {
130        Self::get_metadata(&self.bar_type, self.open.precision, self.volume.precision)
131    }
132}
133
134impl DecodeFromRecordBatch for Bar {
135    fn decode_batch(
136        metadata: &HashMap<String, String>,
137        record_batch: RecordBatch,
138    ) -> Result<Vec<Self>, EncodingError> {
139        let (bar_type, price_precision, size_precision) = parse_metadata(metadata)?;
140        let cols = record_batch.columns();
141
142        let open_values = extract_column::<FixedSizeBinaryArray>(
143            cols,
144            "open",
145            0,
146            DataType::FixedSizeBinary(PRECISION_BYTES),
147        )?;
148        let high_values = extract_column::<FixedSizeBinaryArray>(
149            cols,
150            "high",
151            1,
152            DataType::FixedSizeBinary(PRECISION_BYTES),
153        )?;
154        let low_values = extract_column::<FixedSizeBinaryArray>(
155            cols,
156            "low",
157            2,
158            DataType::FixedSizeBinary(PRECISION_BYTES),
159        )?;
160        let close_values = extract_column::<FixedSizeBinaryArray>(
161            cols,
162            "close",
163            3,
164            DataType::FixedSizeBinary(PRECISION_BYTES),
165        )?;
166        let volume_values = extract_column::<FixedSizeBinaryArray>(
167            cols,
168            "volume",
169            4,
170            DataType::FixedSizeBinary(PRECISION_BYTES),
171        )?;
172        let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 5, DataType::UInt64)?;
173        let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 6, DataType::UInt64)?;
174
175        validate_precision_bytes(open_values, "open")?;
176        validate_precision_bytes(high_values, "high")?;
177        validate_precision_bytes(low_values, "low")?;
178        validate_precision_bytes(close_values, "close")?;
179        validate_precision_bytes(volume_values, "volume")?;
180
181        let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
182            .map(|i| {
183                let open = decode_price(open_values.value(i), price_precision, "open", i)?;
184                let high = decode_price(high_values.value(i), price_precision, "high", i)?;
185                let low = decode_price(low_values.value(i), price_precision, "low", i)?;
186                let close = decode_price(close_values.value(i), price_precision, "close", i)?;
187                let volume = decode_quantity(volume_values.value(i), size_precision, "volume", i)?;
188                let ts_event = ts_event_values.value(i).into();
189                let ts_init = ts_init_values.value(i).into();
190
191                Ok(Self {
192                    bar_type,
193                    open,
194                    high,
195                    low,
196                    close,
197                    volume,
198                    ts_event,
199                    ts_init,
200                })
201            })
202            .collect();
203
204        result
205    }
206}
207
208impl DecodeDataFromRecordBatch for Bar {
209    fn decode_data_batch(
210        metadata: &HashMap<String, String>,
211        record_batch: RecordBatch,
212    ) -> Result<Vec<Data>, EncodingError> {
213        let bars: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
214        Ok(bars.into_iter().map(Data::from).collect())
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use std::sync::Arc;
221
222    use arrow::{array::Array, record_batch::RecordBatch};
223    use nautilus_model::types::{
224        Price, Quantity, fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw,
225    };
226    use rstest::rstest;
227
228    use super::*;
229    use crate::arrow::{get_raw_price, get_raw_quantity};
230
231    #[rstest]
232    fn test_get_schema() {
233        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
234        let metadata = Bar::get_metadata(&bar_type, 2, 0);
235        let schema = Bar::get_schema(Some(metadata.clone()));
236        let expected_fields = vec![
237            Field::new("open", DataType::FixedSizeBinary(PRECISION_BYTES), false),
238            Field::new("high", DataType::FixedSizeBinary(PRECISION_BYTES), false),
239            Field::new("low", DataType::FixedSizeBinary(PRECISION_BYTES), false),
240            Field::new("close", DataType::FixedSizeBinary(PRECISION_BYTES), false),
241            Field::new("volume", DataType::FixedSizeBinary(PRECISION_BYTES), false),
242            Field::new("ts_event", DataType::UInt64, false),
243            Field::new("ts_init", DataType::UInt64, false),
244        ];
245        let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
246        assert_eq!(schema, expected_schema);
247    }
248
249    #[rstest]
250    fn test_get_schema_map() {
251        let schema_map = Bar::get_schema_map();
252        let mut expected_map = HashMap::new();
253        let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
254        expected_map.insert("open".to_string(), fixed_size_binary.clone());
255        expected_map.insert("high".to_string(), fixed_size_binary.clone());
256        expected_map.insert("low".to_string(), fixed_size_binary.clone());
257        expected_map.insert("close".to_string(), fixed_size_binary.clone());
258        expected_map.insert("volume".to_string(), fixed_size_binary);
259        expected_map.insert("ts_event".to_string(), "UInt64".to_string());
260        expected_map.insert("ts_init".to_string(), "UInt64".to_string());
261        assert_eq!(schema_map, expected_map);
262    }
263
264    #[rstest]
265    fn test_encode_batch() {
266        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
267        let metadata = Bar::get_metadata(&bar_type, 2, 0);
268
269        let bar1 = Bar::new(
270            bar_type,
271            Price::from("100.10"),
272            Price::from("102.00"),
273            Price::from("100.00"),
274            Price::from("101.00"),
275            Quantity::from(1100),
276            1.into(),
277            3.into(),
278        );
279        let bar2 = Bar::new(
280            bar_type,
281            Price::from("100.00"),
282            Price::from("100.10"),
283            Price::from("100.00"),
284            Price::from("100.10"),
285            Quantity::from(1110),
286            2.into(),
287            4.into(),
288        );
289
290        let data = vec![bar1, bar2];
291        let record_batch = Bar::encode_batch(&metadata, &data).unwrap();
292
293        let columns = record_batch.columns();
294        let open_values = columns[0]
295            .as_any()
296            .downcast_ref::<FixedSizeBinaryArray>()
297            .unwrap();
298        let high_values = columns[1]
299            .as_any()
300            .downcast_ref::<FixedSizeBinaryArray>()
301            .unwrap();
302        let low_values = columns[2]
303            .as_any()
304            .downcast_ref::<FixedSizeBinaryArray>()
305            .unwrap();
306        let close_values = columns[3]
307            .as_any()
308            .downcast_ref::<FixedSizeBinaryArray>()
309            .unwrap();
310        let volume_values = columns[4]
311            .as_any()
312            .downcast_ref::<FixedSizeBinaryArray>()
313            .unwrap();
314        let ts_event_values = columns[5].as_any().downcast_ref::<UInt64Array>().unwrap();
315        let ts_init_values = columns[6].as_any().downcast_ref::<UInt64Array>().unwrap();
316
317        assert_eq!(columns.len(), 7);
318        assert_eq!(open_values.len(), 2);
319        assert_eq!(
320            get_raw_price(open_values.value(0)),
321            (100.10 * FIXED_SCALAR) as PriceRaw
322        );
323        assert_eq!(
324            get_raw_price(open_values.value(1)),
325            (100.00 * FIXED_SCALAR) as PriceRaw
326        );
327        assert_eq!(high_values.len(), 2);
328        assert_eq!(
329            get_raw_price(high_values.value(0)),
330            (102.00 * FIXED_SCALAR) as PriceRaw
331        );
332        assert_eq!(
333            get_raw_price(high_values.value(1)),
334            (100.10 * FIXED_SCALAR) as PriceRaw
335        );
336        assert_eq!(low_values.len(), 2);
337        assert_eq!(
338            get_raw_price(low_values.value(0)),
339            (100.00 * FIXED_SCALAR) as PriceRaw
340        );
341        assert_eq!(
342            get_raw_price(low_values.value(1)),
343            (100.00 * FIXED_SCALAR) as PriceRaw
344        );
345        assert_eq!(close_values.len(), 2);
346        assert_eq!(
347            get_raw_price(close_values.value(0)),
348            (101.00 * FIXED_SCALAR) as PriceRaw
349        );
350        assert_eq!(
351            get_raw_price(close_values.value(1)),
352            (100.10 * FIXED_SCALAR) as PriceRaw
353        );
354        assert_eq!(volume_values.len(), 2);
355        assert_eq!(
356            get_raw_quantity(volume_values.value(0)),
357            (1100.0 * FIXED_SCALAR) as QuantityRaw
358        );
359        assert_eq!(
360            get_raw_quantity(volume_values.value(1)),
361            (1110.0 * FIXED_SCALAR) as QuantityRaw
362        );
363        assert_eq!(ts_event_values.len(), 2);
364        assert_eq!(ts_event_values.value(0), 1);
365        assert_eq!(ts_event_values.value(1), 2);
366        assert_eq!(ts_init_values.len(), 2);
367        assert_eq!(ts_init_values.value(0), 3);
368        assert_eq!(ts_init_values.value(1), 4);
369    }
370
371    #[rstest]
372    fn test_decode_batch() {
373        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
374        let metadata = Bar::get_metadata(&bar_type, 2, 0);
375
376        let open = FixedSizeBinaryArray::from(vec![
377            &((100.10 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
378            &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
379        ]);
380        let high = FixedSizeBinaryArray::from(vec![
381            &((102.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
382            &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
383        ]);
384        let low = FixedSizeBinaryArray::from(vec![
385            &((100.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
386            &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
387        ]);
388        let close = FixedSizeBinaryArray::from(vec![
389            &((101.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
390            &((10.01 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
391        ]);
392        let volume = FixedSizeBinaryArray::from(vec![
393            &((11.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
394            &((10.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
395        ]);
396        let ts_event = UInt64Array::from(vec![1, 2]);
397        let ts_init = UInt64Array::from(vec![3, 4]);
398
399        let record_batch = RecordBatch::try_new(
400            Bar::get_schema(Some(metadata.clone())).into(),
401            vec![
402                Arc::new(open),
403                Arc::new(high),
404                Arc::new(low),
405                Arc::new(close),
406                Arc::new(volume),
407                Arc::new(ts_event),
408                Arc::new(ts_init),
409            ],
410        )
411        .unwrap();
412
413        let decoded_data = Bar::decode_batch(&metadata, record_batch).unwrap();
414        assert_eq!(decoded_data.len(), 2);
415    }
416
417    #[rstest]
418    fn test_decode_batch_invalid_price_returns_error() {
419        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
420        let metadata = Bar::get_metadata(&bar_type, 2, 0);
421
422        let invalid_price: PriceRaw = PriceRaw::MAX - 1000;
423        let valid_price = (100.00 * FIXED_SCALAR) as PriceRaw;
424
425        let open = FixedSizeBinaryArray::from(vec![&invalid_price.to_le_bytes()]);
426        let high = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
427        let low = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
428        let close = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
429        let volume = FixedSizeBinaryArray::from(vec![
430            &((100.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
431        ]);
432        let ts_event = UInt64Array::from(vec![1]);
433        let ts_init = UInt64Array::from(vec![2]);
434
435        let record_batch = RecordBatch::try_new(
436            Bar::get_schema(Some(metadata.clone())).into(),
437            vec![
438                Arc::new(open),
439                Arc::new(high),
440                Arc::new(low),
441                Arc::new(close),
442                Arc::new(volume),
443                Arc::new(ts_event),
444                Arc::new(ts_init),
445            ],
446        )
447        .unwrap();
448
449        let result = Bar::decode_batch(&metadata, record_batch);
450        assert!(result.is_err());
451        let err = result.unwrap_err();
452        assert!(
453            err.to_string().contains("open") && err.to_string().contains("row 0"),
454            "Expected open error at row 0, was: {err}"
455        );
456    }
457
458    #[rstest]
459    fn test_decode_batch_missing_bar_type_returns_error() {
460        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
461        let mut metadata = Bar::get_metadata(&bar_type, 2, 0);
462
463        let valid_price = (100.00 * FIXED_SCALAR) as PriceRaw;
464        let open = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
465        let high = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
466        let low = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
467        let close = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
468        let volume = FixedSizeBinaryArray::from(vec![
469            &((100.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
470        ]);
471        let ts_event = UInt64Array::from(vec![1]);
472        let ts_init = UInt64Array::from(vec![2]);
473
474        let record_batch = RecordBatch::try_new(
475            Bar::get_schema(Some(metadata.clone())).into(),
476            vec![
477                Arc::new(open),
478                Arc::new(high),
479                Arc::new(low),
480                Arc::new(close),
481                Arc::new(volume),
482                Arc::new(ts_event),
483                Arc::new(ts_init),
484            ],
485        )
486        .unwrap();
487
488        metadata.remove(KEY_BAR_TYPE);
489
490        let result = Bar::decode_batch(&metadata, record_batch);
491        assert!(result.is_err());
492        let err = result.unwrap_err();
493        assert!(
494            err.to_string().contains("bar_type"),
495            "Expected missing bar_type error, was: {err}"
496        );
497    }
498
499    #[rstest]
500    fn test_encode_decode_round_trip() {
501        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
502        let metadata = Bar::get_metadata(&bar_type, 2, 0);
503
504        let bar1 = Bar::new(
505            bar_type,
506            Price::from("100.10"),
507            Price::from("102.00"),
508            Price::from("100.00"),
509            Price::from("101.00"),
510            Quantity::from(1100),
511            1_000_000_000.into(),
512            1_000_000_001.into(),
513        );
514
515        let bar2 = Bar::new(
516            bar_type,
517            Price::from("101.00"),
518            Price::from("103.00"),
519            Price::from("100.50"),
520            Price::from("102.50"),
521            Quantity::from(2200),
522            2_000_000_000.into(),
523            2_000_000_001.into(),
524        );
525
526        let original = vec![bar1, bar2];
527        let record_batch = Bar::encode_batch(&metadata, &original).unwrap();
528        let decoded = Bar::decode_batch(&metadata, record_batch).unwrap();
529
530        assert_eq!(decoded.len(), original.len());
531        for (orig, dec) in original.iter().zip(decoded.iter()) {
532            assert_eq!(dec.bar_type, orig.bar_type);
533            assert_eq!(dec.open, orig.open);
534            assert_eq!(dec.high, orig.high);
535            assert_eq!(dec.low, orig.low);
536            assert_eq!(dec.close, orig.close);
537            assert_eq!(dec.volume, orig.volume);
538            assert_eq!(dec.ts_event, orig.ts_event);
539            assert_eq!(dec.ts_init, orig.ts_init);
540        }
541    }
542}