nautilus_serialization/arrow/
close.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19    array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt8Array, UInt64Array},
20    datatypes::{DataType, Field, Schema},
21    error::ArrowError,
22    record_batch::RecordBatch,
23};
24use nautilus_model::{
25    data::close::InstrumentClose,
26    enums::{FromU8, InstrumentCloseType},
27    identifiers::InstrumentId,
28    types::fixed::PRECISION_BYTES,
29};
30
31use super::{
32    DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION, decode_price,
33    extract_column,
34};
35use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
36
37impl ArrowSchemaProvider for InstrumentClose {
38    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
39        let fields = vec![
40            Field::new(
41                "close_price",
42                DataType::FixedSizeBinary(PRECISION_BYTES),
43                false,
44            ),
45            Field::new("close_type", DataType::UInt8, false),
46            Field::new("ts_event", DataType::UInt64, false),
47            Field::new("ts_init", DataType::UInt64, false),
48        ];
49
50        match metadata {
51            Some(metadata) => Schema::new_with_metadata(fields, metadata),
52            None => Schema::new(fields),
53        }
54    }
55}
56
57fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
58    let instrument_id_str = metadata
59        .get(KEY_INSTRUMENT_ID)
60        .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
61    let instrument_id = InstrumentId::from_str(instrument_id_str)
62        .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
63
64    let price_precision = metadata
65        .get(KEY_PRICE_PRECISION)
66        .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
67        .parse::<u8>()
68        .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
69
70    Ok((instrument_id, price_precision))
71}
72
73impl EncodeToRecordBatch for InstrumentClose {
74    fn encode_batch(
75        metadata: &HashMap<String, String>,
76        data: &[Self],
77    ) -> Result<RecordBatch, ArrowError> {
78        let mut close_price_builder =
79            FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
80        let mut close_type_builder = UInt8Array::builder(data.len());
81        let mut ts_event_builder = UInt64Array::builder(data.len());
82        let mut ts_init_builder = UInt64Array::builder(data.len());
83
84        for item in data {
85            close_price_builder
86                .append_value(item.close_price.raw.to_le_bytes())
87                .unwrap();
88            close_type_builder.append_value(item.close_type as u8);
89            ts_event_builder.append_value(item.ts_event.as_u64());
90            ts_init_builder.append_value(item.ts_init.as_u64());
91        }
92
93        RecordBatch::try_new(
94            Self::get_schema(Some(metadata.clone())).into(),
95            vec![
96                Arc::new(close_price_builder.finish()),
97                Arc::new(close_type_builder.finish()),
98                Arc::new(ts_event_builder.finish()),
99                Arc::new(ts_init_builder.finish()),
100            ],
101        )
102    }
103
104    fn metadata(&self) -> HashMap<String, String> {
105        Self::get_metadata(&self.instrument_id, self.close_price.precision)
106    }
107}
108
109impl DecodeFromRecordBatch for InstrumentClose {
110    fn decode_batch(
111        metadata: &HashMap<String, String>,
112        record_batch: RecordBatch,
113    ) -> Result<Vec<Self>, EncodingError> {
114        let (instrument_id, price_precision) = parse_metadata(metadata)?;
115        let cols = record_batch.columns();
116
117        let close_price_values = extract_column::<FixedSizeBinaryArray>(
118            cols,
119            "close_price",
120            0,
121            DataType::FixedSizeBinary(PRECISION_BYTES),
122        )?;
123        let close_type_values =
124            extract_column::<UInt8Array>(cols, "close_type", 1, DataType::UInt8)?;
125        let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 2, DataType::UInt64)?;
126        let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 3, DataType::UInt64)?;
127
128        // Validate value length
129        if close_price_values.value_length() != PRECISION_BYTES {
130            return Err(EncodingError::ParseError(
131                "close_price",
132                format!(
133                    "Invalid value length: expected {PRECISION_BYTES}, found {}",
134                    close_price_values.value_length()
135                ),
136            ));
137        }
138
139        let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
140            .map(|row| {
141                let close_price = decode_price(
142                    close_price_values.value(row),
143                    price_precision,
144                    "close_price",
145                    row,
146                )?;
147                let close_type_value = close_type_values.value(row);
148                let close_type =
149                    InstrumentCloseType::from_u8(close_type_value).ok_or_else(|| {
150                        EncodingError::ParseError(
151                            stringify!(InstrumentCloseType),
152                            format!("Invalid enum value, was {close_type_value}"),
153                        )
154                    })?;
155                Ok(Self {
156                    instrument_id,
157                    close_price,
158                    close_type,
159                    ts_event: ts_event_values.value(row).into(),
160                    ts_init: ts_init_values.value(row).into(),
161                })
162            })
163            .collect();
164
165        result
166    }
167}
168
169impl DecodeDataFromRecordBatch for InstrumentClose {
170    fn decode_data_batch(
171        metadata: &HashMap<String, String>,
172        record_batch: RecordBatch,
173    ) -> Result<Vec<Data>, EncodingError> {
174        let items: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
175        Ok(items.into_iter().map(Data::from).collect())
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use std::sync::Arc;
182
183    use arrow::{array::Array, record_batch::RecordBatch};
184    use nautilus_model::types::{Price, fixed::FIXED_SCALAR, price::PriceRaw};
185    use rstest::rstest;
186
187    use super::*;
188    use crate::arrow::get_raw_price;
189
190    #[rstest]
191    fn test_get_schema() {
192        let instrument_id = InstrumentId::from("AAPL.XNAS");
193        let metadata = HashMap::from([
194            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
195            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
196        ]);
197        let schema = InstrumentClose::get_schema(Some(metadata.clone()));
198
199        let expected_fields = vec![
200            Field::new(
201                "close_price",
202                DataType::FixedSizeBinary(PRECISION_BYTES),
203                false,
204            ),
205            Field::new("close_type", DataType::UInt8, false),
206            Field::new("ts_event", DataType::UInt64, false),
207            Field::new("ts_init", DataType::UInt64, false),
208        ];
209
210        let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
211        assert_eq!(schema, expected_schema);
212    }
213
214    #[rstest]
215    fn test_get_schema_map() {
216        let schema_map = InstrumentClose::get_schema_map();
217        let mut expected_map = HashMap::new();
218
219        let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
220        expected_map.insert("close_price".to_string(), fixed_size_binary);
221        expected_map.insert("close_type".to_string(), "UInt8".to_string());
222        expected_map.insert("ts_event".to_string(), "UInt64".to_string());
223        expected_map.insert("ts_init".to_string(), "UInt64".to_string());
224        assert_eq!(schema_map, expected_map);
225    }
226
227    #[rstest]
228    fn test_encode_batch() {
229        let instrument_id = InstrumentId::from("AAPL.XNAS");
230        let metadata = HashMap::from([
231            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
232            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
233        ]);
234
235        let close1 = InstrumentClose {
236            instrument_id,
237            close_price: Price::from("150.50"),
238            close_type: InstrumentCloseType::EndOfSession,
239            ts_event: 1.into(),
240            ts_init: 3.into(),
241        };
242
243        let close2 = InstrumentClose {
244            instrument_id,
245            close_price: Price::from("151.25"),
246            close_type: InstrumentCloseType::ContractExpired,
247            ts_event: 2.into(),
248            ts_init: 4.into(),
249        };
250
251        let data = vec![close1, close2];
252        let record_batch = InstrumentClose::encode_batch(&metadata, &data).unwrap();
253
254        let columns = record_batch.columns();
255        let close_price_values = columns[0]
256            .as_any()
257            .downcast_ref::<FixedSizeBinaryArray>()
258            .unwrap();
259        let close_type_values = columns[1].as_any().downcast_ref::<UInt8Array>().unwrap();
260        let ts_event_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
261        let ts_init_values = columns[3].as_any().downcast_ref::<UInt64Array>().unwrap();
262
263        assert_eq!(columns.len(), 4);
264        assert_eq!(close_price_values.len(), 2);
265        assert_eq!(
266            get_raw_price(close_price_values.value(0)),
267            (150.50 * FIXED_SCALAR) as PriceRaw
268        );
269        assert_eq!(
270            get_raw_price(close_price_values.value(1)),
271            (151.25 * FIXED_SCALAR) as PriceRaw
272        );
273        assert_eq!(close_type_values.len(), 2);
274        assert_eq!(
275            close_type_values.value(0),
276            InstrumentCloseType::EndOfSession as u8
277        );
278        assert_eq!(
279            close_type_values.value(1),
280            InstrumentCloseType::ContractExpired as u8
281        );
282        assert_eq!(ts_event_values.len(), 2);
283        assert_eq!(ts_event_values.value(0), 1);
284        assert_eq!(ts_event_values.value(1), 2);
285        assert_eq!(ts_init_values.len(), 2);
286        assert_eq!(ts_init_values.value(0), 3);
287        assert_eq!(ts_init_values.value(1), 4);
288    }
289
290    #[rstest]
291    fn test_decode_batch() {
292        let instrument_id = InstrumentId::from("AAPL.XNAS");
293        let metadata = HashMap::from([
294            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
295            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
296        ]);
297
298        let raw_price1 = (150.50 * FIXED_SCALAR) as PriceRaw;
299        let raw_price2 = (151.25 * FIXED_SCALAR) as PriceRaw;
300        let close_price =
301            FixedSizeBinaryArray::from(vec![&raw_price1.to_le_bytes(), &raw_price2.to_le_bytes()]);
302        let close_type = UInt8Array::from(vec![
303            InstrumentCloseType::EndOfSession as u8,
304            InstrumentCloseType::ContractExpired as u8,
305        ]);
306        let ts_event = UInt64Array::from(vec![1, 2]);
307        let ts_init = UInt64Array::from(vec![3, 4]);
308
309        let record_batch = RecordBatch::try_new(
310            InstrumentClose::get_schema(Some(metadata.clone())).into(),
311            vec![
312                Arc::new(close_price),
313                Arc::new(close_type),
314                Arc::new(ts_event),
315                Arc::new(ts_init),
316            ],
317        )
318        .unwrap();
319
320        let decoded_data = InstrumentClose::decode_batch(&metadata, record_batch).unwrap();
321
322        assert_eq!(decoded_data.len(), 2);
323        assert_eq!(decoded_data[0].instrument_id, instrument_id);
324        assert_eq!(decoded_data[0].close_price, Price::from_raw(raw_price1, 2));
325        assert_eq!(
326            decoded_data[0].close_type,
327            InstrumentCloseType::EndOfSession
328        );
329        assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
330        assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
331
332        assert_eq!(decoded_data[1].instrument_id, instrument_id);
333        assert_eq!(decoded_data[1].close_price, Price::from_raw(raw_price2, 2));
334        assert_eq!(
335            decoded_data[1].close_type,
336            InstrumentCloseType::ContractExpired
337        );
338        assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
339        assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
340    }
341
342    #[rstest]
343    fn test_decode_batch_invalid_close_price_returns_error() {
344        let instrument_id = InstrumentId::from("AAPL.XNAS");
345        let metadata = HashMap::from([
346            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
347            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
348        ]);
349
350        let invalid_price: PriceRaw = PriceRaw::MAX - 1000;
351        let close_price = FixedSizeBinaryArray::from(vec![&invalid_price.to_le_bytes()]);
352        let close_type = UInt8Array::from(vec![InstrumentCloseType::EndOfSession as u8]);
353        let ts_event = UInt64Array::from(vec![1]);
354        let ts_init = UInt64Array::from(vec![2]);
355
356        let record_batch = RecordBatch::try_new(
357            InstrumentClose::get_schema(Some(metadata.clone())).into(),
358            vec![
359                Arc::new(close_price),
360                Arc::new(close_type),
361                Arc::new(ts_event),
362                Arc::new(ts_init),
363            ],
364        )
365        .unwrap();
366
367        let result = InstrumentClose::decode_batch(&metadata, record_batch);
368        assert!(result.is_err());
369        let err = result.unwrap_err();
370        assert!(
371            err.to_string().contains("close_price") && err.to_string().contains("row 0"),
372            "Expected close_price error at row 0, got: {err}"
373        );
374    }
375
376    #[rstest]
377    fn test_decode_batch_invalid_close_type_returns_error() {
378        let instrument_id = InstrumentId::from("AAPL.XNAS");
379        let metadata = HashMap::from([
380            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
381            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
382        ]);
383
384        let raw_price = (150.50 * FIXED_SCALAR) as PriceRaw;
385        let close_price = FixedSizeBinaryArray::from(vec![&raw_price.to_le_bytes()]);
386        let close_type = UInt8Array::from(vec![99]);
387        let ts_event = UInt64Array::from(vec![1]);
388        let ts_init = UInt64Array::from(vec![2]);
389
390        let record_batch = RecordBatch::try_new(
391            InstrumentClose::get_schema(Some(metadata.clone())).into(),
392            vec![
393                Arc::new(close_price),
394                Arc::new(close_type),
395                Arc::new(ts_event),
396                Arc::new(ts_init),
397            ],
398        )
399        .unwrap();
400
401        let result = InstrumentClose::decode_batch(&metadata, record_batch);
402        assert!(result.is_err());
403        let err = result.unwrap_err();
404        assert!(
405            err.to_string().contains("InstrumentCloseType"),
406            "Expected InstrumentCloseType error, got: {err}"
407        );
408    }
409
410    #[rstest]
411    fn test_decode_batch_missing_instrument_id_returns_error() {
412        let instrument_id = InstrumentId::from("AAPL.XNAS");
413        let mut metadata = HashMap::from([
414            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
415            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
416        ]);
417
418        let raw_price = (150.50 * FIXED_SCALAR) as PriceRaw;
419        let close_price = FixedSizeBinaryArray::from(vec![&raw_price.to_le_bytes()]);
420        let close_type = UInt8Array::from(vec![InstrumentCloseType::EndOfSession as u8]);
421        let ts_event = UInt64Array::from(vec![1]);
422        let ts_init = UInt64Array::from(vec![2]);
423
424        let record_batch = RecordBatch::try_new(
425            InstrumentClose::get_schema(Some(metadata.clone())).into(),
426            vec![
427                Arc::new(close_price),
428                Arc::new(close_type),
429                Arc::new(ts_event),
430                Arc::new(ts_init),
431            ],
432        )
433        .unwrap();
434
435        metadata.remove(KEY_INSTRUMENT_ID);
436
437        let result = InstrumentClose::decode_batch(&metadata, record_batch);
438        assert!(result.is_err());
439        let err = result.unwrap_err();
440        assert!(
441            err.to_string().contains("instrument_id"),
442            "Expected missing instrument_id error, got: {err}"
443        );
444    }
445
446    #[rstest]
447    fn test_encode_decode_round_trip() {
448        let instrument_id = InstrumentId::from("AAPL.XNAS");
449        let metadata = HashMap::from([
450            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
451            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
452        ]);
453
454        let close1 = InstrumentClose {
455            instrument_id,
456            close_price: Price::from("150.50"),
457            close_type: InstrumentCloseType::EndOfSession,
458            ts_event: 1_000_000_000.into(),
459            ts_init: 1_000_000_001.into(),
460        };
461
462        let close2 = InstrumentClose {
463            instrument_id,
464            close_price: Price::from("151.25"),
465            close_type: InstrumentCloseType::ContractExpired,
466            ts_event: 2_000_000_000.into(),
467            ts_init: 2_000_000_001.into(),
468        };
469
470        let original = vec![close1, close2];
471        let record_batch = InstrumentClose::encode_batch(&metadata, &original).unwrap();
472        let decoded = InstrumentClose::decode_batch(&metadata, record_batch).unwrap();
473
474        assert_eq!(decoded.len(), original.len());
475        for (orig, dec) in original.iter().zip(decoded.iter()) {
476            assert_eq!(dec.instrument_id, orig.instrument_id);
477            assert_eq!(dec.close_price, orig.close_price);
478            assert_eq!(dec.close_type, orig.close_type);
479            assert_eq!(dec.ts_event, orig.ts_event);
480            assert_eq!(dec.ts_init, orig.ts_init);
481        }
482    }
483}