nautilus_serialization/arrow/
bar.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19    array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20    datatypes::{DataType, Field, Schema},
21    error::ArrowError,
22    record_batch::RecordBatch,
23};
24use nautilus_model::{
25    data::{Bar, BarType},
26    types::fixed::PRECISION_BYTES,
27};
28
29use super::{
30    DecodeDataFromRecordBatch, EncodingError, KEY_BAR_TYPE, KEY_PRICE_PRECISION,
31    KEY_SIZE_PRECISION, decode_price, decode_quantity, extract_column,
32};
33use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
34
35impl ArrowSchemaProvider for Bar {
36    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
37        let fields = vec![
38            Field::new("open", DataType::FixedSizeBinary(PRECISION_BYTES), false),
39            Field::new("high", DataType::FixedSizeBinary(PRECISION_BYTES), false),
40            Field::new("low", DataType::FixedSizeBinary(PRECISION_BYTES), false),
41            Field::new("close", DataType::FixedSizeBinary(PRECISION_BYTES), false),
42            Field::new("volume", DataType::FixedSizeBinary(PRECISION_BYTES), false),
43            Field::new("ts_event", DataType::UInt64, false),
44            Field::new("ts_init", DataType::UInt64, false),
45        ];
46
47        match metadata {
48            Some(metadata) => Schema::new_with_metadata(fields, metadata),
49            None => Schema::new(fields),
50        }
51    }
52}
53
54fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(BarType, u8, u8), EncodingError> {
55    let bar_type_str = metadata
56        .get(KEY_BAR_TYPE)
57        .ok_or_else(|| EncodingError::MissingMetadata(KEY_BAR_TYPE))?;
58    let bar_type = BarType::from_str(bar_type_str)
59        .map_err(|e| EncodingError::ParseError(KEY_BAR_TYPE, e.to_string()))?;
60
61    let price_precision = metadata
62        .get(KEY_PRICE_PRECISION)
63        .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
64        .parse::<u8>()
65        .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
66
67    let size_precision = metadata
68        .get(KEY_SIZE_PRECISION)
69        .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
70        .parse::<u8>()
71        .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
72
73    Ok((bar_type, price_precision, size_precision))
74}
75
76impl EncodeToRecordBatch for Bar {
77    fn encode_batch(
78        metadata: &HashMap<String, String>,
79        data: &[Self],
80    ) -> Result<RecordBatch, ArrowError> {
81        let mut open_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
82        let mut high_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
83        let mut low_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
84        let mut close_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
85        let mut volume_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
86        let mut ts_event_builder = UInt64Array::builder(data.len());
87        let mut ts_init_builder = UInt64Array::builder(data.len());
88
89        for bar in data {
90            open_builder
91                .append_value(bar.open.raw.to_le_bytes())
92                .unwrap();
93            high_builder
94                .append_value(bar.high.raw.to_le_bytes())
95                .unwrap();
96            low_builder.append_value(bar.low.raw.to_le_bytes()).unwrap();
97            close_builder
98                .append_value(bar.close.raw.to_le_bytes())
99                .unwrap();
100            volume_builder
101                .append_value(bar.volume.raw.to_le_bytes())
102                .unwrap();
103            ts_event_builder.append_value(bar.ts_event.as_u64());
104            ts_init_builder.append_value(bar.ts_init.as_u64());
105        }
106
107        let open_array = open_builder.finish();
108        let high_array = high_builder.finish();
109        let low_array = low_builder.finish();
110        let close_array = close_builder.finish();
111        let volume_array = volume_builder.finish();
112        let ts_event_array = ts_event_builder.finish();
113        let ts_init_array = ts_init_builder.finish();
114
115        RecordBatch::try_new(
116            Self::get_schema(Some(metadata.clone())).into(),
117            vec![
118                Arc::new(open_array),
119                Arc::new(high_array),
120                Arc::new(low_array),
121                Arc::new(close_array),
122                Arc::new(volume_array),
123                Arc::new(ts_event_array),
124                Arc::new(ts_init_array),
125            ],
126        )
127    }
128
129    fn metadata(&self) -> HashMap<String, String> {
130        Self::get_metadata(&self.bar_type, self.open.precision, self.volume.precision)
131    }
132}
133
134impl DecodeFromRecordBatch for Bar {
135    fn decode_batch(
136        metadata: &HashMap<String, String>,
137        record_batch: RecordBatch,
138    ) -> Result<Vec<Self>, EncodingError> {
139        let (bar_type, price_precision, size_precision) = parse_metadata(metadata)?;
140        let cols = record_batch.columns();
141
142        let open_values = extract_column::<FixedSizeBinaryArray>(
143            cols,
144            "open",
145            0,
146            DataType::FixedSizeBinary(PRECISION_BYTES),
147        )?;
148        let high_values = extract_column::<FixedSizeBinaryArray>(
149            cols,
150            "high",
151            1,
152            DataType::FixedSizeBinary(PRECISION_BYTES),
153        )?;
154        let low_values = extract_column::<FixedSizeBinaryArray>(
155            cols,
156            "low",
157            2,
158            DataType::FixedSizeBinary(PRECISION_BYTES),
159        )?;
160        let close_values = extract_column::<FixedSizeBinaryArray>(
161            cols,
162            "close",
163            3,
164            DataType::FixedSizeBinary(PRECISION_BYTES),
165        )?;
166        let volume_values = extract_column::<FixedSizeBinaryArray>(
167            cols,
168            "volume",
169            4,
170            DataType::FixedSizeBinary(PRECISION_BYTES),
171        )?;
172        let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 5, DataType::UInt64)?;
173        let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 6, DataType::UInt64)?;
174
175        let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
176            .map(|i| {
177                let open = decode_price(open_values.value(i), price_precision, "open", i)?;
178                let high = decode_price(high_values.value(i), price_precision, "high", i)?;
179                let low = decode_price(low_values.value(i), price_precision, "low", i)?;
180                let close = decode_price(close_values.value(i), price_precision, "close", i)?;
181                let volume = decode_quantity(volume_values.value(i), size_precision, "volume", i)?;
182                let ts_event = ts_event_values.value(i).into();
183                let ts_init = ts_init_values.value(i).into();
184
185                Ok(Self {
186                    bar_type,
187                    open,
188                    high,
189                    low,
190                    close,
191                    volume,
192                    ts_event,
193                    ts_init,
194                })
195            })
196            .collect();
197
198        result
199    }
200}
201
202impl DecodeDataFromRecordBatch for Bar {
203    fn decode_data_batch(
204        metadata: &HashMap<String, String>,
205        record_batch: RecordBatch,
206    ) -> Result<Vec<Data>, EncodingError> {
207        let bars: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
208        Ok(bars.into_iter().map(Data::from).collect())
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use std::sync::Arc;
215
216    use arrow::{array::Array, record_batch::RecordBatch};
217    use nautilus_model::types::{
218        Price, Quantity, fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw,
219    };
220    use rstest::rstest;
221
222    use super::*;
223    use crate::arrow::{get_raw_price, get_raw_quantity};
224
225    #[rstest]
226    fn test_get_schema() {
227        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
228        let metadata = Bar::get_metadata(&bar_type, 2, 0);
229        let schema = Bar::get_schema(Some(metadata.clone()));
230        let expected_fields = vec![
231            Field::new("open", DataType::FixedSizeBinary(PRECISION_BYTES), false),
232            Field::new("high", DataType::FixedSizeBinary(PRECISION_BYTES), false),
233            Field::new("low", DataType::FixedSizeBinary(PRECISION_BYTES), false),
234            Field::new("close", DataType::FixedSizeBinary(PRECISION_BYTES), false),
235            Field::new("volume", DataType::FixedSizeBinary(PRECISION_BYTES), false),
236            Field::new("ts_event", DataType::UInt64, false),
237            Field::new("ts_init", DataType::UInt64, false),
238        ];
239        let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
240        assert_eq!(schema, expected_schema);
241    }
242
243    #[rstest]
244    fn test_get_schema_map() {
245        let schema_map = Bar::get_schema_map();
246        let mut expected_map = HashMap::new();
247        let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
248        expected_map.insert("open".to_string(), fixed_size_binary.clone());
249        expected_map.insert("high".to_string(), fixed_size_binary.clone());
250        expected_map.insert("low".to_string(), fixed_size_binary.clone());
251        expected_map.insert("close".to_string(), fixed_size_binary.clone());
252        expected_map.insert("volume".to_string(), fixed_size_binary);
253        expected_map.insert("ts_event".to_string(), "UInt64".to_string());
254        expected_map.insert("ts_init".to_string(), "UInt64".to_string());
255        assert_eq!(schema_map, expected_map);
256    }
257
258    #[rstest]
259    fn test_encode_batch() {
260        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
261        let metadata = Bar::get_metadata(&bar_type, 2, 0);
262
263        let bar1 = Bar::new(
264            bar_type,
265            Price::from("100.10"),
266            Price::from("102.00"),
267            Price::from("100.00"),
268            Price::from("101.00"),
269            Quantity::from(1100),
270            1.into(),
271            3.into(),
272        );
273        let bar2 = Bar::new(
274            bar_type,
275            Price::from("100.00"),
276            Price::from("100.10"),
277            Price::from("100.00"),
278            Price::from("100.10"),
279            Quantity::from(1110),
280            2.into(),
281            4.into(),
282        );
283
284        let data = vec![bar1, bar2];
285        let record_batch = Bar::encode_batch(&metadata, &data).unwrap();
286
287        let columns = record_batch.columns();
288        let open_values = columns[0]
289            .as_any()
290            .downcast_ref::<FixedSizeBinaryArray>()
291            .unwrap();
292        let high_values = columns[1]
293            .as_any()
294            .downcast_ref::<FixedSizeBinaryArray>()
295            .unwrap();
296        let low_values = columns[2]
297            .as_any()
298            .downcast_ref::<FixedSizeBinaryArray>()
299            .unwrap();
300        let close_values = columns[3]
301            .as_any()
302            .downcast_ref::<FixedSizeBinaryArray>()
303            .unwrap();
304        let volume_values = columns[4]
305            .as_any()
306            .downcast_ref::<FixedSizeBinaryArray>()
307            .unwrap();
308        let ts_event_values = columns[5].as_any().downcast_ref::<UInt64Array>().unwrap();
309        let ts_init_values = columns[6].as_any().downcast_ref::<UInt64Array>().unwrap();
310
311        assert_eq!(columns.len(), 7);
312        assert_eq!(open_values.len(), 2);
313        assert_eq!(
314            get_raw_price(open_values.value(0)),
315            (100.10 * FIXED_SCALAR) as PriceRaw
316        );
317        assert_eq!(
318            get_raw_price(open_values.value(1)),
319            (100.00 * FIXED_SCALAR) as PriceRaw
320        );
321        assert_eq!(high_values.len(), 2);
322        assert_eq!(
323            get_raw_price(high_values.value(0)),
324            (102.00 * FIXED_SCALAR) as PriceRaw
325        );
326        assert_eq!(
327            get_raw_price(high_values.value(1)),
328            (100.10 * FIXED_SCALAR) as PriceRaw
329        );
330        assert_eq!(low_values.len(), 2);
331        assert_eq!(
332            get_raw_price(low_values.value(0)),
333            (100.00 * FIXED_SCALAR) as PriceRaw
334        );
335        assert_eq!(
336            get_raw_price(low_values.value(1)),
337            (100.00 * FIXED_SCALAR) as PriceRaw
338        );
339        assert_eq!(close_values.len(), 2);
340        assert_eq!(
341            get_raw_price(close_values.value(0)),
342            (101.00 * FIXED_SCALAR) as PriceRaw
343        );
344        assert_eq!(
345            get_raw_price(close_values.value(1)),
346            (100.10 * FIXED_SCALAR) as PriceRaw
347        );
348        assert_eq!(volume_values.len(), 2);
349        assert_eq!(
350            get_raw_quantity(volume_values.value(0)),
351            (1100.0 * FIXED_SCALAR) as QuantityRaw
352        );
353        assert_eq!(
354            get_raw_quantity(volume_values.value(1)),
355            (1110.0 * FIXED_SCALAR) as QuantityRaw
356        );
357        assert_eq!(ts_event_values.len(), 2);
358        assert_eq!(ts_event_values.value(0), 1);
359        assert_eq!(ts_event_values.value(1), 2);
360        assert_eq!(ts_init_values.len(), 2);
361        assert_eq!(ts_init_values.value(0), 3);
362        assert_eq!(ts_init_values.value(1), 4);
363    }
364
365    #[rstest]
366    fn test_decode_batch() {
367        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
368        let metadata = Bar::get_metadata(&bar_type, 2, 0);
369
370        let open = FixedSizeBinaryArray::from(vec![
371            &((100.10 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
372            &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
373        ]);
374        let high = FixedSizeBinaryArray::from(vec![
375            &((102.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
376            &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
377        ]);
378        let low = FixedSizeBinaryArray::from(vec![
379            &((100.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
380            &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
381        ]);
382        let close = FixedSizeBinaryArray::from(vec![
383            &((101.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
384            &((10.01 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
385        ]);
386        let volume = FixedSizeBinaryArray::from(vec![
387            &((11.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
388            &((10.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
389        ]);
390        let ts_event = UInt64Array::from(vec![1, 2]);
391        let ts_init = UInt64Array::from(vec![3, 4]);
392
393        let record_batch = RecordBatch::try_new(
394            Bar::get_schema(Some(metadata.clone())).into(),
395            vec![
396                Arc::new(open),
397                Arc::new(high),
398                Arc::new(low),
399                Arc::new(close),
400                Arc::new(volume),
401                Arc::new(ts_event),
402                Arc::new(ts_init),
403            ],
404        )
405        .unwrap();
406
407        let decoded_data = Bar::decode_batch(&metadata, record_batch).unwrap();
408        assert_eq!(decoded_data.len(), 2);
409    }
410
411    #[rstest]
412    fn test_decode_batch_invalid_price_returns_error() {
413        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
414        let metadata = Bar::get_metadata(&bar_type, 2, 0);
415
416        let invalid_price: PriceRaw = PriceRaw::MAX - 1000;
417        let valid_price = (100.00 * FIXED_SCALAR) as PriceRaw;
418
419        let open = FixedSizeBinaryArray::from(vec![&invalid_price.to_le_bytes()]);
420        let high = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
421        let low = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
422        let close = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
423        let volume = FixedSizeBinaryArray::from(vec![
424            &((100.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
425        ]);
426        let ts_event = UInt64Array::from(vec![1]);
427        let ts_init = UInt64Array::from(vec![2]);
428
429        let record_batch = RecordBatch::try_new(
430            Bar::get_schema(Some(metadata.clone())).into(),
431            vec![
432                Arc::new(open),
433                Arc::new(high),
434                Arc::new(low),
435                Arc::new(close),
436                Arc::new(volume),
437                Arc::new(ts_event),
438                Arc::new(ts_init),
439            ],
440        )
441        .unwrap();
442
443        let result = Bar::decode_batch(&metadata, record_batch);
444        assert!(result.is_err());
445        let err = result.unwrap_err();
446        assert!(
447            err.to_string().contains("open") && err.to_string().contains("row 0"),
448            "Expected open error at row 0, got: {err}"
449        );
450    }
451
452    #[rstest]
453    fn test_decode_batch_missing_bar_type_returns_error() {
454        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
455        let mut metadata = Bar::get_metadata(&bar_type, 2, 0);
456
457        let valid_price = (100.00 * FIXED_SCALAR) as PriceRaw;
458        let open = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
459        let high = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
460        let low = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
461        let close = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
462        let volume = FixedSizeBinaryArray::from(vec![
463            &((100.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
464        ]);
465        let ts_event = UInt64Array::from(vec![1]);
466        let ts_init = UInt64Array::from(vec![2]);
467
468        let record_batch = RecordBatch::try_new(
469            Bar::get_schema(Some(metadata.clone())).into(),
470            vec![
471                Arc::new(open),
472                Arc::new(high),
473                Arc::new(low),
474                Arc::new(close),
475                Arc::new(volume),
476                Arc::new(ts_event),
477                Arc::new(ts_init),
478            ],
479        )
480        .unwrap();
481
482        metadata.remove(KEY_BAR_TYPE);
483
484        let result = Bar::decode_batch(&metadata, record_batch);
485        assert!(result.is_err());
486        let err = result.unwrap_err();
487        assert!(
488            err.to_string().contains("bar_type"),
489            "Expected missing bar_type error, got: {err}"
490        );
491    }
492
493    #[rstest]
494    fn test_encode_decode_round_trip() {
495        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
496        let metadata = Bar::get_metadata(&bar_type, 2, 0);
497
498        let bar1 = Bar::new(
499            bar_type,
500            Price::from("100.10"),
501            Price::from("102.00"),
502            Price::from("100.00"),
503            Price::from("101.00"),
504            Quantity::from(1100),
505            1_000_000_000.into(),
506            1_000_000_001.into(),
507        );
508
509        let bar2 = Bar::new(
510            bar_type,
511            Price::from("101.00"),
512            Price::from("103.00"),
513            Price::from("100.50"),
514            Price::from("102.50"),
515            Quantity::from(2200),
516            2_000_000_000.into(),
517            2_000_000_001.into(),
518        );
519
520        let original = vec![bar1, bar2];
521        let record_batch = Bar::encode_batch(&metadata, &original).unwrap();
522        let decoded = Bar::decode_batch(&metadata, record_batch).unwrap();
523
524        assert_eq!(decoded.len(), original.len());
525        for (orig, dec) in original.iter().zip(decoded.iter()) {
526            assert_eq!(dec.bar_type, orig.bar_type);
527            assert_eq!(dec.open, orig.open);
528            assert_eq!(dec.high, orig.high);
529            assert_eq!(dec.low, orig.low);
530            assert_eq!(dec.close, orig.close);
531            assert_eq!(dec.volume, orig.volume);
532            assert_eq!(dec.ts_event, orig.ts_event);
533            assert_eq!(dec.ts_init, orig.ts_init);
534        }
535    }
536}