nautilus_serialization/arrow/
bar.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19    array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20    datatypes::{DataType, Field, Schema},
21    error::ArrowError,
22    record_batch::RecordBatch,
23};
24use nautilus_model::{
25    data::{Bar, BarType},
26    types::{Price, Quantity, fixed::PRECISION_BYTES},
27};
28
29use super::{
30    DecodeDataFromRecordBatch, EncodingError, KEY_BAR_TYPE, KEY_PRICE_PRECISION,
31    KEY_SIZE_PRECISION, extract_column, get_raw_quantity,
32};
33use crate::arrow::{
34    ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
35};
36
37impl ArrowSchemaProvider for Bar {
38    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
39        let fields = vec![
40            Field::new("open", DataType::FixedSizeBinary(PRECISION_BYTES), false),
41            Field::new("high", DataType::FixedSizeBinary(PRECISION_BYTES), false),
42            Field::new("low", DataType::FixedSizeBinary(PRECISION_BYTES), false),
43            Field::new("close", DataType::FixedSizeBinary(PRECISION_BYTES), false),
44            Field::new("volume", DataType::FixedSizeBinary(PRECISION_BYTES), false),
45            Field::new("ts_event", DataType::UInt64, false),
46            Field::new("ts_init", DataType::UInt64, false),
47        ];
48
49        match metadata {
50            Some(metadata) => Schema::new_with_metadata(fields, metadata),
51            None => Schema::new(fields),
52        }
53    }
54}
55
56fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(BarType, u8, u8), EncodingError> {
57    let bar_type_str = metadata
58        .get(KEY_BAR_TYPE)
59        .ok_or_else(|| EncodingError::MissingMetadata(KEY_BAR_TYPE))?;
60    let bar_type = BarType::from_str(bar_type_str)
61        .map_err(|e| EncodingError::ParseError(KEY_BAR_TYPE, e.to_string()))?;
62
63    let price_precision = metadata
64        .get(KEY_PRICE_PRECISION)
65        .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
66        .parse::<u8>()
67        .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
68
69    let size_precision = metadata
70        .get(KEY_SIZE_PRECISION)
71        .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
72        .parse::<u8>()
73        .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
74
75    Ok((bar_type, price_precision, size_precision))
76}
77
78impl EncodeToRecordBatch for Bar {
79    fn encode_batch(
80        metadata: &HashMap<String, String>,
81        data: &[Self],
82    ) -> Result<RecordBatch, ArrowError> {
83        let mut open_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
84        let mut high_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
85        let mut low_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
86        let mut close_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
87        let mut volume_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
88        let mut ts_event_builder = UInt64Array::builder(data.len());
89        let mut ts_init_builder = UInt64Array::builder(data.len());
90
91        for bar in data {
92            open_builder
93                .append_value(bar.open.raw.to_le_bytes())
94                .unwrap();
95            high_builder
96                .append_value(bar.high.raw.to_le_bytes())
97                .unwrap();
98            low_builder.append_value(bar.low.raw.to_le_bytes()).unwrap();
99            close_builder
100                .append_value(bar.close.raw.to_le_bytes())
101                .unwrap();
102            volume_builder
103                .append_value(bar.volume.raw.to_le_bytes())
104                .unwrap();
105            ts_event_builder.append_value(bar.ts_event.as_u64());
106            ts_init_builder.append_value(bar.ts_init.as_u64());
107        }
108
109        let open_array = open_builder.finish();
110        let high_array = high_builder.finish();
111        let low_array = low_builder.finish();
112        let close_array = close_builder.finish();
113        let volume_array = volume_builder.finish();
114        let ts_event_array = ts_event_builder.finish();
115        let ts_init_array = ts_init_builder.finish();
116
117        RecordBatch::try_new(
118            Self::get_schema(Some(metadata.clone())).into(),
119            vec![
120                Arc::new(open_array),
121                Arc::new(high_array),
122                Arc::new(low_array),
123                Arc::new(close_array),
124                Arc::new(volume_array),
125                Arc::new(ts_event_array),
126                Arc::new(ts_init_array),
127            ],
128        )
129    }
130
131    fn metadata(&self) -> HashMap<String, String> {
132        Bar::get_metadata(&self.bar_type, self.open.precision, self.volume.precision)
133    }
134}
135
136impl DecodeFromRecordBatch for Bar {
137    fn decode_batch(
138        metadata: &HashMap<String, String>,
139        record_batch: RecordBatch,
140    ) -> Result<Vec<Self>, EncodingError> {
141        let (bar_type, price_precision, size_precision) = parse_metadata(metadata)?;
142        let cols = record_batch.columns();
143
144        let open_values = extract_column::<FixedSizeBinaryArray>(
145            cols,
146            "open",
147            0,
148            DataType::FixedSizeBinary(PRECISION_BYTES),
149        )?;
150        let high_values = extract_column::<FixedSizeBinaryArray>(
151            cols,
152            "high",
153            1,
154            DataType::FixedSizeBinary(PRECISION_BYTES),
155        )?;
156        let low_values = extract_column::<FixedSizeBinaryArray>(
157            cols,
158            "low",
159            2,
160            DataType::FixedSizeBinary(PRECISION_BYTES),
161        )?;
162        let close_values = extract_column::<FixedSizeBinaryArray>(
163            cols,
164            "close",
165            3,
166            DataType::FixedSizeBinary(PRECISION_BYTES),
167        )?;
168        let volume_values = extract_column::<FixedSizeBinaryArray>(
169            cols,
170            "volume",
171            4,
172            DataType::FixedSizeBinary(PRECISION_BYTES),
173        )?;
174        let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 5, DataType::UInt64)?;
175        let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 6, DataType::UInt64)?;
176
177        let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
178            .map(|i| {
179                let open = Price::from_raw(get_raw_price(open_values.value(i)), price_precision);
180                let high = Price::from_raw(get_raw_price(high_values.value(i)), price_precision);
181                let low = Price::from_raw(get_raw_price(low_values.value(i)), price_precision);
182                let close = Price::from_raw(get_raw_price(close_values.value(i)), price_precision);
183                let volume =
184                    Quantity::from_raw(get_raw_quantity(volume_values.value(i)), size_precision);
185                let ts_event = ts_event_values.value(i).into();
186                let ts_init = ts_init_values.value(i).into();
187
188                Ok(Self {
189                    bar_type,
190                    open,
191                    high,
192                    low,
193                    close,
194                    volume,
195                    ts_event,
196                    ts_init,
197                })
198            })
199            .collect();
200
201        result
202    }
203}
204
205impl DecodeDataFromRecordBatch for Bar {
206    fn decode_data_batch(
207        metadata: &HashMap<String, String>,
208        record_batch: RecordBatch,
209    ) -> Result<Vec<Data>, EncodingError> {
210        let bars: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
211        Ok(bars.into_iter().map(Data::from).collect())
212    }
213}
214
215////////////////////////////////////////////////////////////////////////////////
216// Tests
217////////////////////////////////////////////////////////////////////////////////
218#[cfg(test)]
219mod tests {
220    use std::sync::Arc;
221
222    use arrow::{array::Array, record_batch::RecordBatch};
223    use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw};
224    use rstest::rstest;
225
226    use super::*;
227    use crate::arrow::get_raw_price;
228
229    #[rstest]
230    fn test_get_schema() {
231        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
232        let metadata = Bar::get_metadata(&bar_type, 2, 0);
233        let schema = Bar::get_schema(Some(metadata.clone()));
234        let expected_fields = vec![
235            Field::new("open", DataType::FixedSizeBinary(PRECISION_BYTES), false),
236            Field::new("high", DataType::FixedSizeBinary(PRECISION_BYTES), false),
237            Field::new("low", DataType::FixedSizeBinary(PRECISION_BYTES), false),
238            Field::new("close", DataType::FixedSizeBinary(PRECISION_BYTES), false),
239            Field::new("volume", DataType::FixedSizeBinary(PRECISION_BYTES), false),
240            Field::new("ts_event", DataType::UInt64, false),
241            Field::new("ts_init", DataType::UInt64, false),
242        ];
243        let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
244        assert_eq!(schema, expected_schema);
245    }
246
247    #[rstest]
248    fn test_get_schema_map() {
249        let schema_map = Bar::get_schema_map();
250        let mut expected_map = HashMap::new();
251        let fixed_size_binary = format!("FixedSizeBinary({})", PRECISION_BYTES);
252        expected_map.insert("open".to_string(), fixed_size_binary.clone());
253        expected_map.insert("high".to_string(), fixed_size_binary.clone());
254        expected_map.insert("low".to_string(), fixed_size_binary.clone());
255        expected_map.insert("close".to_string(), fixed_size_binary.clone());
256        expected_map.insert("volume".to_string(), fixed_size_binary.clone());
257        expected_map.insert("ts_event".to_string(), "UInt64".to_string());
258        expected_map.insert("ts_init".to_string(), "UInt64".to_string());
259        assert_eq!(schema_map, expected_map);
260    }
261
262    #[rstest]
263    fn test_encode_batch() {
264        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
265        let metadata = Bar::get_metadata(&bar_type, 2, 0);
266
267        let bar1 = Bar::new(
268            bar_type,
269            Price::from("100.10"),
270            Price::from("102.00"),
271            Price::from("100.00"),
272            Price::from("101.00"),
273            Quantity::from(1100),
274            1.into(),
275            3.into(),
276        );
277        let bar2 = Bar::new(
278            bar_type,
279            Price::from("100.00"),
280            Price::from("100.10"),
281            Price::from("100.00"),
282            Price::from("100.10"),
283            Quantity::from(1110),
284            2.into(),
285            4.into(),
286        );
287
288        let data = vec![bar1, bar2];
289        let record_batch = Bar::encode_batch(&metadata, &data).unwrap();
290
291        let columns = record_batch.columns();
292        let open_values = columns[0]
293            .as_any()
294            .downcast_ref::<FixedSizeBinaryArray>()
295            .unwrap();
296        let high_values = columns[1]
297            .as_any()
298            .downcast_ref::<FixedSizeBinaryArray>()
299            .unwrap();
300        let low_values = columns[2]
301            .as_any()
302            .downcast_ref::<FixedSizeBinaryArray>()
303            .unwrap();
304        let close_values = columns[3]
305            .as_any()
306            .downcast_ref::<FixedSizeBinaryArray>()
307            .unwrap();
308        let volume_values = columns[4]
309            .as_any()
310            .downcast_ref::<FixedSizeBinaryArray>()
311            .unwrap();
312        let ts_event_values = columns[5].as_any().downcast_ref::<UInt64Array>().unwrap();
313        let ts_init_values = columns[6].as_any().downcast_ref::<UInt64Array>().unwrap();
314
315        assert_eq!(columns.len(), 7);
316        assert_eq!(open_values.len(), 2);
317        assert_eq!(
318            get_raw_price(open_values.value(0)),
319            (100.10 * FIXED_SCALAR) as PriceRaw
320        );
321        assert_eq!(
322            get_raw_price(open_values.value(1)),
323            (100.00 * FIXED_SCALAR) as PriceRaw
324        );
325        assert_eq!(high_values.len(), 2);
326        assert_eq!(
327            get_raw_price(high_values.value(0)),
328            (102.00 * FIXED_SCALAR) as PriceRaw
329        );
330        assert_eq!(
331            get_raw_price(high_values.value(1)),
332            (100.10 * FIXED_SCALAR) as PriceRaw
333        );
334        assert_eq!(low_values.len(), 2);
335        assert_eq!(
336            get_raw_price(low_values.value(0)),
337            (100.00 * FIXED_SCALAR) as PriceRaw
338        );
339        assert_eq!(
340            get_raw_price(low_values.value(1)),
341            (100.00 * FIXED_SCALAR) as PriceRaw
342        );
343        assert_eq!(close_values.len(), 2);
344        assert_eq!(
345            get_raw_price(close_values.value(0)),
346            (101.00 * FIXED_SCALAR) as PriceRaw
347        );
348        assert_eq!(
349            get_raw_price(close_values.value(1)),
350            (100.10 * FIXED_SCALAR) as PriceRaw
351        );
352        assert_eq!(volume_values.len(), 2);
353        assert_eq!(
354            get_raw_quantity(volume_values.value(0)),
355            (1100.0 * FIXED_SCALAR) as QuantityRaw
356        );
357        assert_eq!(
358            get_raw_quantity(volume_values.value(1)),
359            (1110.0 * FIXED_SCALAR) as QuantityRaw
360        );
361        assert_eq!(ts_event_values.len(), 2);
362        assert_eq!(ts_event_values.value(0), 1);
363        assert_eq!(ts_event_values.value(1), 2);
364        assert_eq!(ts_init_values.len(), 2);
365        assert_eq!(ts_init_values.value(0), 3);
366        assert_eq!(ts_init_values.value(1), 4);
367    }
368
369    #[rstest]
370    fn test_decode_batch() {
371        use nautilus_model::types::{price::PriceRaw, quantity::QuantityRaw};
372
373        let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
374        let metadata = Bar::get_metadata(&bar_type, 2, 0);
375
376        let open = FixedSizeBinaryArray::from(vec![
377            &(100_100_000_000 as PriceRaw).to_le_bytes(),
378            &(10_000_000_000 as PriceRaw).to_le_bytes(),
379        ]);
380        let high = FixedSizeBinaryArray::from(vec![
381            &(102_000_000_000 as PriceRaw).to_le_bytes(),
382            &(10_000_000_000 as PriceRaw).to_le_bytes(),
383        ]);
384        let low = FixedSizeBinaryArray::from(vec![
385            &(100_000_000_000 as PriceRaw).to_le_bytes(),
386            &(10_000_000_000 as PriceRaw).to_le_bytes(),
387        ]);
388        let close = FixedSizeBinaryArray::from(vec![
389            &(101_000_000_000 as PriceRaw).to_le_bytes(),
390            &(10_010_000_000 as PriceRaw).to_le_bytes(),
391        ]);
392        let volume = FixedSizeBinaryArray::from(vec![
393            &(11_000_000_000 as QuantityRaw).to_le_bytes(),
394            &(10_000_000_000 as QuantityRaw).to_le_bytes(),
395        ]);
396        let ts_event = UInt64Array::from(vec![1, 2]);
397        let ts_init = UInt64Array::from(vec![3, 4]);
398
399        let record_batch = RecordBatch::try_new(
400            Bar::get_schema(Some(metadata.clone())).into(),
401            vec![
402                Arc::new(open),
403                Arc::new(high),
404                Arc::new(low),
405                Arc::new(close),
406                Arc::new(volume),
407                Arc::new(ts_event),
408                Arc::new(ts_init),
409            ],
410        )
411        .unwrap();
412
413        let decoded_data = Bar::decode_batch(&metadata, record_batch).unwrap();
414        assert_eq!(decoded_data.len(), 2);
415    }
416}