nautilus_serialization/arrow/
close.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19    array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt8Array, UInt64Array},
20    datatypes::{DataType, Field, Schema},
21    error::ArrowError,
22    record_batch::RecordBatch,
23};
24use nautilus_model::{
25    data::close::InstrumentClose,
26    enums::{FromU8, InstrumentCloseType},
27    identifiers::InstrumentId,
28    types::{Price, fixed::PRECISION_BYTES},
29};
30
31use super::{
32    DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
33    extract_column,
34};
35use crate::arrow::{
36    ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
37};
38
39impl ArrowSchemaProvider for InstrumentClose {
40    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
41        let fields = vec![
42            Field::new(
43                "close_price",
44                DataType::FixedSizeBinary(PRECISION_BYTES),
45                false,
46            ),
47            Field::new("close_type", DataType::UInt8, false),
48            Field::new("ts_event", DataType::UInt64, false),
49            Field::new("ts_init", DataType::UInt64, false),
50        ];
51
52        match metadata {
53            Some(metadata) => Schema::new_with_metadata(fields, metadata),
54            None => Schema::new(fields),
55        }
56    }
57}
58
59fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
60    let instrument_id_str = metadata
61        .get(KEY_INSTRUMENT_ID)
62        .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
63    let instrument_id = InstrumentId::from_str(instrument_id_str)
64        .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
65
66    let price_precision = metadata
67        .get(KEY_PRICE_PRECISION)
68        .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
69        .parse::<u8>()
70        .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
71
72    Ok((instrument_id, price_precision))
73}
74
75impl EncodeToRecordBatch for InstrumentClose {
76    fn encode_batch(
77        metadata: &HashMap<String, String>,
78        data: &[Self],
79    ) -> Result<RecordBatch, ArrowError> {
80        let mut close_price_builder =
81            FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
82        let mut close_type_builder = UInt8Array::builder(data.len());
83        let mut ts_event_builder = UInt64Array::builder(data.len());
84        let mut ts_init_builder = UInt64Array::builder(data.len());
85
86        for item in data {
87            close_price_builder
88                .append_value(item.close_price.raw.to_le_bytes())
89                .unwrap();
90            close_type_builder.append_value(item.close_type as u8);
91            ts_event_builder.append_value(item.ts_event.as_u64());
92            ts_init_builder.append_value(item.ts_init.as_u64());
93        }
94
95        RecordBatch::try_new(
96            Self::get_schema(Some(metadata.clone())).into(),
97            vec![
98                Arc::new(close_price_builder.finish()),
99                Arc::new(close_type_builder.finish()),
100                Arc::new(ts_event_builder.finish()),
101                Arc::new(ts_init_builder.finish()),
102            ],
103        )
104    }
105
106    fn metadata(&self) -> HashMap<String, String> {
107        InstrumentClose::get_metadata(&self.instrument_id, self.close_price.precision)
108    }
109}
110
111impl DecodeFromRecordBatch for InstrumentClose {
112    fn decode_batch(
113        metadata: &HashMap<String, String>,
114        record_batch: RecordBatch,
115    ) -> Result<Vec<Self>, EncodingError> {
116        let (instrument_id, price_precision) = parse_metadata(metadata)?;
117        let cols = record_batch.columns();
118
119        let close_price_values = extract_column::<FixedSizeBinaryArray>(
120            cols,
121            "close_price",
122            0,
123            DataType::FixedSizeBinary(PRECISION_BYTES),
124        )?;
125        let close_type_values =
126            extract_column::<UInt8Array>(cols, "close_type", 1, DataType::UInt8)?;
127        let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 2, DataType::UInt64)?;
128        let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 3, DataType::UInt64)?;
129
130        // Validate value length
131        if close_price_values.value_length() != PRECISION_BYTES {
132            return Err(EncodingError::ParseError(
133                "close_price",
134                format!(
135                    "Invalid value length: expected {PRECISION_BYTES}, found {}",
136                    close_price_values.value_length()
137                ),
138            ));
139        }
140
141        let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
142            .map(|row| {
143                let close_type_value = close_type_values.value(row);
144                let close_type =
145                    InstrumentCloseType::from_u8(close_type_value).ok_or_else(|| {
146                        EncodingError::ParseError(
147                            stringify!(InstrumentCloseType),
148                            format!("Invalid enum value, was {close_type_value}"),
149                        )
150                    })?;
151
152                Ok(Self {
153                    instrument_id,
154                    close_price: Price::from_raw(
155                        get_raw_price(close_price_values.value(row)),
156                        price_precision,
157                    ),
158                    close_type,
159                    ts_event: ts_event_values.value(row).into(),
160                    ts_init: ts_init_values.value(row).into(),
161                })
162            })
163            .collect();
164
165        result
166    }
167}
168
169impl DecodeDataFromRecordBatch for InstrumentClose {
170    fn decode_data_batch(
171        metadata: &HashMap<String, String>,
172        record_batch: RecordBatch,
173    ) -> Result<Vec<Data>, EncodingError> {
174        let items: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
175        Ok(items.into_iter().map(Data::from).collect())
176    }
177}
178
179////////////////////////////////////////////////////////////////////////////////
180// Tests
181////////////////////////////////////////////////////////////////////////////////
182#[cfg(test)]
183mod tests {
184    use std::sync::Arc;
185
186    use arrow::{array::Array, record_batch::RecordBatch};
187    use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw};
188    use rstest::rstest;
189
190    use super::*;
191    use crate::arrow::get_raw_price;
192
193    #[rstest]
194    fn test_get_schema() {
195        let instrument_id = InstrumentId::from("AAPL.XNAS");
196        let metadata = HashMap::from([
197            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
198            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
199        ]);
200        let schema = InstrumentClose::get_schema(Some(metadata.clone()));
201
202        let expected_fields = vec![
203            Field::new(
204                "close_price",
205                DataType::FixedSizeBinary(PRECISION_BYTES),
206                false,
207            ),
208            Field::new("close_type", DataType::UInt8, false),
209            Field::new("ts_event", DataType::UInt64, false),
210            Field::new("ts_init", DataType::UInt64, false),
211        ];
212
213        let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
214        assert_eq!(schema, expected_schema);
215    }
216
217    #[rstest]
218    fn test_get_schema_map() {
219        let schema_map = InstrumentClose::get_schema_map();
220        let mut expected_map = HashMap::new();
221
222        let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
223        expected_map.insert("close_price".to_string(), fixed_size_binary);
224        expected_map.insert("close_type".to_string(), "UInt8".to_string());
225        expected_map.insert("ts_event".to_string(), "UInt64".to_string());
226        expected_map.insert("ts_init".to_string(), "UInt64".to_string());
227        assert_eq!(schema_map, expected_map);
228    }
229
230    #[rstest]
231    fn test_encode_batch() {
232        let instrument_id = InstrumentId::from("AAPL.XNAS");
233        let metadata = HashMap::from([
234            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
235            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
236        ]);
237
238        let close1 = InstrumentClose {
239            instrument_id,
240            close_price: Price::from("150.50"),
241            close_type: InstrumentCloseType::EndOfSession,
242            ts_event: 1.into(),
243            ts_init: 3.into(),
244        };
245
246        let close2 = InstrumentClose {
247            instrument_id,
248            close_price: Price::from("151.25"),
249            close_type: InstrumentCloseType::ContractExpired,
250            ts_event: 2.into(),
251            ts_init: 4.into(),
252        };
253
254        let data = vec![close1, close2];
255        let record_batch = InstrumentClose::encode_batch(&metadata, &data).unwrap();
256
257        let columns = record_batch.columns();
258        let close_price_values = columns[0]
259            .as_any()
260            .downcast_ref::<FixedSizeBinaryArray>()
261            .unwrap();
262        let close_type_values = columns[1].as_any().downcast_ref::<UInt8Array>().unwrap();
263        let ts_event_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
264        let ts_init_values = columns[3].as_any().downcast_ref::<UInt64Array>().unwrap();
265
266        assert_eq!(columns.len(), 4);
267        assert_eq!(close_price_values.len(), 2);
268        assert_eq!(
269            get_raw_price(close_price_values.value(0)),
270            (150.50 * FIXED_SCALAR) as PriceRaw
271        );
272        assert_eq!(
273            get_raw_price(close_price_values.value(1)),
274            (151.25 * FIXED_SCALAR) as PriceRaw
275        );
276        assert_eq!(close_type_values.len(), 2);
277        assert_eq!(
278            close_type_values.value(0),
279            InstrumentCloseType::EndOfSession as u8
280        );
281        assert_eq!(
282            close_type_values.value(1),
283            InstrumentCloseType::ContractExpired as u8
284        );
285        assert_eq!(ts_event_values.len(), 2);
286        assert_eq!(ts_event_values.value(0), 1);
287        assert_eq!(ts_event_values.value(1), 2);
288        assert_eq!(ts_init_values.len(), 2);
289        assert_eq!(ts_init_values.value(0), 3);
290        assert_eq!(ts_init_values.value(1), 4);
291    }
292
293    #[rstest]
294    fn test_decode_batch() {
295        let instrument_id = InstrumentId::from("AAPL.XNAS");
296        let metadata = HashMap::from([
297            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
298            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
299        ]);
300
301        let close_price = FixedSizeBinaryArray::from(vec![
302            &(15050 as PriceRaw).to_le_bytes(),
303            &(15125 as PriceRaw).to_le_bytes(),
304        ]);
305        let close_type = UInt8Array::from(vec![
306            InstrumentCloseType::EndOfSession as u8,
307            InstrumentCloseType::ContractExpired as u8,
308        ]);
309        let ts_event = UInt64Array::from(vec![1, 2]);
310        let ts_init = UInt64Array::from(vec![3, 4]);
311
312        let record_batch = RecordBatch::try_new(
313            InstrumentClose::get_schema(Some(metadata.clone())).into(),
314            vec![
315                Arc::new(close_price),
316                Arc::new(close_type),
317                Arc::new(ts_event),
318                Arc::new(ts_init),
319            ],
320        )
321        .unwrap();
322
323        let decoded_data = InstrumentClose::decode_batch(&metadata, record_batch).unwrap();
324
325        assert_eq!(decoded_data.len(), 2);
326        assert_eq!(decoded_data[0].instrument_id, instrument_id);
327        assert_eq!(decoded_data[0].close_price, Price::from_raw(15050, 2));
328        assert_eq!(
329            decoded_data[0].close_type,
330            InstrumentCloseType::EndOfSession
331        );
332        assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
333        assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
334
335        assert_eq!(decoded_data[1].instrument_id, instrument_id);
336        assert_eq!(decoded_data[1].close_price, Price::from_raw(15125, 2));
337        assert_eq!(
338            decoded_data[1].close_type,
339            InstrumentCloseType::ContractExpired
340        );
341        assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
342        assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
343    }
344}