nautilus_serialization/arrow/
index_price.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19    array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20    datatypes::{DataType, Field, Schema},
21    error::ArrowError,
22    record_batch::RecordBatch,
23};
24use nautilus_model::{
25    data::prices::IndexPriceUpdate,
26    identifiers::InstrumentId,
27    types::{Price, fixed::PRECISION_BYTES},
28};
29
30use super::{
31    DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
32    extract_column, get_corrected_raw_price,
33};
34use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
35
36impl ArrowSchemaProvider for IndexPriceUpdate {
37    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
38        let fields = vec![
39            Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
40            Field::new("ts_event", DataType::UInt64, false),
41            Field::new("ts_init", DataType::UInt64, false),
42        ];
43
44        match metadata {
45            Some(metadata) => Schema::new_with_metadata(fields, metadata),
46            None => Schema::new(fields),
47        }
48    }
49}
50
51fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
52    let instrument_id_str = metadata
53        .get(KEY_INSTRUMENT_ID)
54        .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
55    let instrument_id = InstrumentId::from_str(instrument_id_str)
56        .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
57
58    let price_precision = metadata
59        .get(KEY_PRICE_PRECISION)
60        .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
61        .parse::<u8>()
62        .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
63
64    Ok((instrument_id, price_precision))
65}
66
67impl EncodeToRecordBatch for IndexPriceUpdate {
68    fn encode_batch(
69        metadata: &HashMap<String, String>,
70        data: &[Self],
71    ) -> Result<RecordBatch, ArrowError> {
72        let mut value_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
73        let mut ts_event_builder = UInt64Array::builder(data.len());
74        let mut ts_init_builder = UInt64Array::builder(data.len());
75
76        for update in data {
77            value_builder
78                .append_value(update.value.raw.to_le_bytes())
79                .unwrap();
80            ts_event_builder.append_value(update.ts_event.as_u64());
81            ts_init_builder.append_value(update.ts_init.as_u64());
82        }
83
84        RecordBatch::try_new(
85            Self::get_schema(Some(metadata.clone())).into(),
86            vec![
87                Arc::new(value_builder.finish()),
88                Arc::new(ts_event_builder.finish()),
89                Arc::new(ts_init_builder.finish()),
90            ],
91        )
92    }
93
94    fn metadata(&self) -> HashMap<String, String> {
95        let mut metadata = HashMap::new();
96        metadata.insert(
97            KEY_INSTRUMENT_ID.to_string(),
98            self.instrument_id.to_string(),
99        );
100        metadata.insert(
101            KEY_PRICE_PRECISION.to_string(),
102            self.value.precision.to_string(),
103        );
104        metadata
105    }
106}
107
108impl DecodeFromRecordBatch for IndexPriceUpdate {
109    fn decode_batch(
110        metadata: &HashMap<String, String>,
111        record_batch: RecordBatch,
112    ) -> Result<Vec<Self>, EncodingError> {
113        let (instrument_id, price_precision) = parse_metadata(metadata)?;
114        let cols = record_batch.columns();
115
116        let value_values = extract_column::<FixedSizeBinaryArray>(
117            cols,
118            "value",
119            0,
120            DataType::FixedSizeBinary(PRECISION_BYTES),
121        )?;
122        let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 1, DataType::UInt64)?;
123        let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 2, DataType::UInt64)?;
124
125        if value_values.value_length() != PRECISION_BYTES {
126            return Err(EncodingError::ParseError(
127                "value",
128                format!(
129                    "Invalid value length: expected {PRECISION_BYTES}, found {}",
130                    value_values.value_length()
131                ),
132            ));
133        }
134
135        let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
136            .map(|row| {
137                Ok(Self {
138                    // Use corrected raw value to handle floating-point precision errors in stored data
139                    instrument_id,
140                    value: Price::from_raw(
141                        get_corrected_raw_price(value_values.value(row), price_precision),
142                        price_precision,
143                    ),
144                    ts_event: ts_event_values.value(row).into(),
145                    ts_init: ts_init_values.value(row).into(),
146                })
147            })
148            .collect();
149
150        result
151    }
152}
153
154impl DecodeDataFromRecordBatch for IndexPriceUpdate {
155    fn decode_data_batch(
156        metadata: &HashMap<String, String>,
157        record_batch: RecordBatch,
158    ) -> Result<Vec<Data>, EncodingError> {
159        let updates: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
160        Ok(updates.into_iter().map(Data::from).collect())
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use std::sync::Arc;
167
168    use arrow::{array::Array, record_batch::RecordBatch};
169    use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw};
170    use rstest::rstest;
171    use rust_decimal_macros::dec;
172
173    use super::*;
174    use crate::arrow::get_raw_price;
175
176    #[rstest]
177    fn test_get_schema() {
178        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
179        let metadata = HashMap::from([
180            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
181            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
182        ]);
183        let schema = IndexPriceUpdate::get_schema(Some(metadata.clone()));
184
185        let expected_fields = vec![
186            Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
187            Field::new("ts_event", DataType::UInt64, false),
188            Field::new("ts_init", DataType::UInt64, false),
189        ];
190
191        let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
192        assert_eq!(schema, expected_schema);
193    }
194
195    #[rstest]
196    fn test_get_schema_map() {
197        let schema_map = IndexPriceUpdate::get_schema_map();
198        let mut expected_map = HashMap::new();
199
200        let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
201        expected_map.insert("value".to_string(), fixed_size_binary);
202        expected_map.insert("ts_event".to_string(), "UInt64".to_string());
203        expected_map.insert("ts_init".to_string(), "UInt64".to_string());
204        assert_eq!(schema_map, expected_map);
205    }
206
207    #[rstest]
208    fn test_encode_batch() {
209        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
210        let metadata = HashMap::from([
211            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
212            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
213        ]);
214
215        let update1 = IndexPriceUpdate {
216            instrument_id,
217            value: Price::from("50000.00"),
218            ts_event: 1.into(),
219            ts_init: 3.into(),
220        };
221
222        let update2 = IndexPriceUpdate {
223            instrument_id,
224            value: Price::from("51000.00"),
225            ts_event: 2.into(),
226            ts_init: 4.into(),
227        };
228
229        let data = vec![update1, update2];
230        let record_batch = IndexPriceUpdate::encode_batch(&metadata, &data).unwrap();
231
232        let columns = record_batch.columns();
233        let value_values = columns[0]
234            .as_any()
235            .downcast_ref::<FixedSizeBinaryArray>()
236            .unwrap();
237        let ts_event_values = columns[1].as_any().downcast_ref::<UInt64Array>().unwrap();
238        let ts_init_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
239
240        assert_eq!(columns.len(), 3);
241        assert_eq!(value_values.len(), 2);
242        assert_eq!(
243            get_raw_price(value_values.value(0)),
244            Price::from(dec!(50000.00).to_string()).raw
245        );
246        assert_eq!(
247            get_raw_price(value_values.value(1)),
248            Price::from(dec!(51000.00).to_string()).raw
249        );
250        assert_eq!(ts_event_values.len(), 2);
251        assert_eq!(ts_event_values.value(0), 1);
252        assert_eq!(ts_event_values.value(1), 2);
253        assert_eq!(ts_init_values.len(), 2);
254        assert_eq!(ts_init_values.value(0), 3);
255        assert_eq!(ts_init_values.value(1), 4);
256    }
257
258    #[rstest]
259    fn test_decode_batch() {
260        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
261        let metadata = HashMap::from([
262            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
263            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
264        ]);
265
266        let raw_price1 = (50.00 * FIXED_SCALAR) as PriceRaw;
267        let raw_price2 = (51.00 * FIXED_SCALAR) as PriceRaw;
268        let value =
269            FixedSizeBinaryArray::from(vec![&raw_price1.to_le_bytes(), &raw_price2.to_le_bytes()]);
270        let ts_event = UInt64Array::from(vec![1, 2]);
271        let ts_init = UInt64Array::from(vec![3, 4]);
272
273        let record_batch = RecordBatch::try_new(
274            IndexPriceUpdate::get_schema(Some(metadata.clone())).into(),
275            vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
276        )
277        .unwrap();
278
279        let decoded_data = IndexPriceUpdate::decode_batch(&metadata, record_batch).unwrap();
280
281        assert_eq!(decoded_data.len(), 2);
282        assert_eq!(decoded_data[0].instrument_id, instrument_id);
283        assert_eq!(decoded_data[0].value, Price::from_raw(raw_price1, 2));
284        assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
285        assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
286
287        assert_eq!(decoded_data[1].instrument_id, instrument_id);
288        assert_eq!(decoded_data[1].value, Price::from_raw(raw_price2, 2));
289        assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
290        assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
291    }
292}