nautilus_serialization/arrow/
close.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19    array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt8Array, UInt64Array},
20    datatypes::{DataType, Field, Schema},
21    error::ArrowError,
22    record_batch::RecordBatch,
23};
24use nautilus_model::{
25    data::close::InstrumentClose,
26    enums::{FromU8, InstrumentCloseType},
27    identifiers::InstrumentId,
28    types::{Price, fixed::PRECISION_BYTES},
29};
30
31use super::{
32    DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
33    extract_column,
34};
35use crate::arrow::{
36    ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
37};
38
39impl ArrowSchemaProvider for InstrumentClose {
40    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
41        let fields = vec![
42            Field::new(
43                "close_price",
44                DataType::FixedSizeBinary(PRECISION_BYTES),
45                false,
46            ),
47            Field::new("close_type", DataType::UInt8, false),
48            Field::new("ts_event", DataType::UInt64, false),
49            Field::new("ts_init", DataType::UInt64, false),
50        ];
51
52        match metadata {
53            Some(metadata) => Schema::new_with_metadata(fields, metadata),
54            None => Schema::new(fields),
55        }
56    }
57}
58
59fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
60    let instrument_id_str = metadata
61        .get(KEY_INSTRUMENT_ID)
62        .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
63    let instrument_id = InstrumentId::from_str(instrument_id_str)
64        .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
65
66    let price_precision = metadata
67        .get(KEY_PRICE_PRECISION)
68        .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
69        .parse::<u8>()
70        .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
71
72    Ok((instrument_id, price_precision))
73}
74
75impl EncodeToRecordBatch for InstrumentClose {
76    fn encode_batch(
77        metadata: &HashMap<String, String>,
78        data: &[Self],
79    ) -> Result<RecordBatch, ArrowError> {
80        let mut close_price_builder =
81            FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
82        let mut close_type_builder = UInt8Array::builder(data.len());
83        let mut ts_event_builder = UInt64Array::builder(data.len());
84        let mut ts_init_builder = UInt64Array::builder(data.len());
85
86        for item in data {
87            close_price_builder
88                .append_value(item.close_price.raw.to_le_bytes())
89                .unwrap();
90            close_type_builder.append_value(item.close_type as u8);
91            ts_event_builder.append_value(item.ts_event.as_u64());
92            ts_init_builder.append_value(item.ts_init.as_u64());
93        }
94
95        RecordBatch::try_new(
96            Self::get_schema(Some(metadata.clone())).into(),
97            vec![
98                Arc::new(close_price_builder.finish()),
99                Arc::new(close_type_builder.finish()),
100                Arc::new(ts_event_builder.finish()),
101                Arc::new(ts_init_builder.finish()),
102            ],
103        )
104    }
105
106    fn metadata(&self) -> HashMap<String, String> {
107        InstrumentClose::get_metadata(&self.instrument_id, self.close_price.precision)
108    }
109}
110
111impl DecodeFromRecordBatch for InstrumentClose {
112    fn decode_batch(
113        metadata: &HashMap<String, String>,
114        record_batch: RecordBatch,
115    ) -> Result<Vec<Self>, EncodingError> {
116        let (instrument_id, price_precision) = parse_metadata(metadata)?;
117        let cols = record_batch.columns();
118
119        let close_price_values = extract_column::<FixedSizeBinaryArray>(
120            cols,
121            "close_price",
122            0,
123            DataType::FixedSizeBinary(PRECISION_BYTES),
124        )?;
125        let close_type_values =
126            extract_column::<UInt8Array>(cols, "close_type", 1, DataType::UInt8)?;
127        let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 2, DataType::UInt64)?;
128        let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 3, DataType::UInt64)?;
129
130        assert_eq!(
131            close_price_values.value_length(),
132            PRECISION_BYTES,
133            "Price precision uses {PRECISION_BYTES} byte value"
134        );
135
136        let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
137            .map(|row| {
138                let close_type_value = close_type_values.value(row);
139                let close_type =
140                    InstrumentCloseType::from_u8(close_type_value).ok_or_else(|| {
141                        EncodingError::ParseError(
142                            stringify!(InstrumentCloseType),
143                            format!("Invalid enum value, was {close_type_value}"),
144                        )
145                    })?;
146
147                Ok(Self {
148                    instrument_id,
149                    close_price: Price::from_raw(
150                        get_raw_price(close_price_values.value(row)),
151                        price_precision,
152                    ),
153                    close_type,
154                    ts_event: ts_event_values.value(row).into(),
155                    ts_init: ts_init_values.value(row).into(),
156                })
157            })
158            .collect();
159
160        result
161    }
162}
163
164impl DecodeDataFromRecordBatch for InstrumentClose {
165    fn decode_data_batch(
166        metadata: &HashMap<String, String>,
167        record_batch: RecordBatch,
168    ) -> Result<Vec<Data>, EncodingError> {
169        let items: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
170        Ok(items.into_iter().map(Data::from).collect())
171    }
172}
173
174////////////////////////////////////////////////////////////////////////////////
175// Tests
176////////////////////////////////////////////////////////////////////////////////
177#[cfg(test)]
178mod tests {
179    use std::sync::Arc;
180
181    use arrow::{array::Array, record_batch::RecordBatch};
182    use nautilus_model::{
183        enums::InstrumentCloseType,
184        types::{fixed::FIXED_SCALAR, price::PriceRaw},
185    };
186    use rstest::rstest;
187
188    use super::*;
189    use crate::arrow::get_raw_price;
190
191    #[rstest]
192    fn test_get_schema() {
193        let instrument_id = InstrumentId::from("AAPL.XNAS");
194        let metadata = HashMap::from([
195            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
196            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
197        ]);
198        let schema = InstrumentClose::get_schema(Some(metadata.clone()));
199
200        let expected_fields = vec![
201            Field::new(
202                "close_price",
203                DataType::FixedSizeBinary(PRECISION_BYTES),
204                false,
205            ),
206            Field::new("close_type", DataType::UInt8, false),
207            Field::new("ts_event", DataType::UInt64, false),
208            Field::new("ts_init", DataType::UInt64, false),
209        ];
210
211        let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
212        assert_eq!(schema, expected_schema);
213    }
214
215    #[rstest]
216    fn test_get_schema_map() {
217        let schema_map = InstrumentClose::get_schema_map();
218        let mut expected_map = HashMap::new();
219
220        let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
221        expected_map.insert("close_price".to_string(), fixed_size_binary);
222        expected_map.insert("close_type".to_string(), "UInt8".to_string());
223        expected_map.insert("ts_event".to_string(), "UInt64".to_string());
224        expected_map.insert("ts_init".to_string(), "UInt64".to_string());
225        assert_eq!(schema_map, expected_map);
226    }
227
228    #[rstest]
229    fn test_encode_batch() {
230        let instrument_id = InstrumentId::from("AAPL.XNAS");
231        let metadata = HashMap::from([
232            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
233            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
234        ]);
235
236        let close1 = InstrumentClose {
237            instrument_id,
238            close_price: Price::from("150.50"),
239            close_type: InstrumentCloseType::EndOfSession,
240            ts_event: 1.into(),
241            ts_init: 3.into(),
242        };
243
244        let close2 = InstrumentClose {
245            instrument_id,
246            close_price: Price::from("151.25"),
247            close_type: InstrumentCloseType::ContractExpired,
248            ts_event: 2.into(),
249            ts_init: 4.into(),
250        };
251
252        let data = vec![close1, close2];
253        let record_batch = InstrumentClose::encode_batch(&metadata, &data).unwrap();
254
255        let columns = record_batch.columns();
256        let close_price_values = columns[0]
257            .as_any()
258            .downcast_ref::<FixedSizeBinaryArray>()
259            .unwrap();
260        let close_type_values = columns[1].as_any().downcast_ref::<UInt8Array>().unwrap();
261        let ts_event_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
262        let ts_init_values = columns[3].as_any().downcast_ref::<UInt64Array>().unwrap();
263
264        assert_eq!(columns.len(), 4);
265        assert_eq!(close_price_values.len(), 2);
266        assert_eq!(
267            get_raw_price(close_price_values.value(0)),
268            (150.50 * FIXED_SCALAR) as PriceRaw
269        );
270        assert_eq!(
271            get_raw_price(close_price_values.value(1)),
272            (151.25 * FIXED_SCALAR) as PriceRaw
273        );
274        assert_eq!(close_type_values.len(), 2);
275        assert_eq!(
276            close_type_values.value(0),
277            InstrumentCloseType::EndOfSession as u8
278        );
279        assert_eq!(
280            close_type_values.value(1),
281            InstrumentCloseType::ContractExpired as u8
282        );
283        assert_eq!(ts_event_values.len(), 2);
284        assert_eq!(ts_event_values.value(0), 1);
285        assert_eq!(ts_event_values.value(1), 2);
286        assert_eq!(ts_init_values.len(), 2);
287        assert_eq!(ts_init_values.value(0), 3);
288        assert_eq!(ts_init_values.value(1), 4);
289    }
290
291    #[rstest]
292    fn test_decode_batch() {
293        let instrument_id = InstrumentId::from("AAPL.XNAS");
294        let metadata = HashMap::from([
295            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
296            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
297        ]);
298
299        let close_price = FixedSizeBinaryArray::from(vec![
300            &(15050 as PriceRaw).to_le_bytes(),
301            &(15125 as PriceRaw).to_le_bytes(),
302        ]);
303        let close_type = UInt8Array::from(vec![
304            InstrumentCloseType::EndOfSession as u8,
305            InstrumentCloseType::ContractExpired as u8,
306        ]);
307        let ts_event = UInt64Array::from(vec![1, 2]);
308        let ts_init = UInt64Array::from(vec![3, 4]);
309
310        let record_batch = RecordBatch::try_new(
311            InstrumentClose::get_schema(Some(metadata.clone())).into(),
312            vec![
313                Arc::new(close_price),
314                Arc::new(close_type),
315                Arc::new(ts_event),
316                Arc::new(ts_init),
317            ],
318        )
319        .unwrap();
320
321        let decoded_data = InstrumentClose::decode_batch(&metadata, record_batch).unwrap();
322
323        assert_eq!(decoded_data.len(), 2);
324        assert_eq!(decoded_data[0].instrument_id, instrument_id);
325        assert_eq!(decoded_data[0].close_price, Price::from_raw(15050, 2));
326        assert_eq!(
327            decoded_data[0].close_type,
328            InstrumentCloseType::EndOfSession
329        );
330        assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
331        assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
332
333        assert_eq!(decoded_data[1].instrument_id, instrument_id);
334        assert_eq!(decoded_data[1].close_price, Price::from_raw(15125, 2));
335        assert_eq!(
336            decoded_data[1].close_type,
337            InstrumentCloseType::ContractExpired
338        );
339        assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
340        assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
341    }
342}