nautilus_serialization/arrow/
index_price.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19    array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20    datatypes::{DataType, Field, Schema},
21    error::ArrowError,
22    record_batch::RecordBatch,
23};
24use nautilus_model::{
25    data::prices::IndexPriceUpdate,
26    identifiers::InstrumentId,
27    types::{Price, fixed::PRECISION_BYTES},
28};
29
30use super::{
31    DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
32    extract_column,
33};
34use crate::arrow::{
35    ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
36};
37
38impl ArrowSchemaProvider for IndexPriceUpdate {
39    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
40        let fields = vec![
41            Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
42            Field::new("ts_event", DataType::UInt64, false),
43            Field::new("ts_init", DataType::UInt64, false),
44        ];
45
46        match metadata {
47            Some(metadata) => Schema::new_with_metadata(fields, metadata),
48            None => Schema::new(fields),
49        }
50    }
51}
52
53fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
54    let instrument_id_str = metadata
55        .get(KEY_INSTRUMENT_ID)
56        .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
57    let instrument_id = InstrumentId::from_str(instrument_id_str)
58        .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
59
60    let price_precision = metadata
61        .get(KEY_PRICE_PRECISION)
62        .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
63        .parse::<u8>()
64        .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
65
66    Ok((instrument_id, price_precision))
67}
68
69impl EncodeToRecordBatch for IndexPriceUpdate {
70    fn encode_batch(
71        metadata: &HashMap<String, String>,
72        data: &[Self],
73    ) -> Result<RecordBatch, ArrowError> {
74        let mut value_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
75        let mut ts_event_builder = UInt64Array::builder(data.len());
76        let mut ts_init_builder = UInt64Array::builder(data.len());
77
78        for update in data {
79            value_builder
80                .append_value(update.value.raw.to_le_bytes())
81                .unwrap();
82            ts_event_builder.append_value(update.ts_event.as_u64());
83            ts_init_builder.append_value(update.ts_init.as_u64());
84        }
85
86        RecordBatch::try_new(
87            Self::get_schema(Some(metadata.clone())).into(),
88            vec![
89                Arc::new(value_builder.finish()),
90                Arc::new(ts_event_builder.finish()),
91                Arc::new(ts_init_builder.finish()),
92            ],
93        )
94    }
95
96    fn metadata(&self) -> HashMap<String, String> {
97        let mut metadata = HashMap::new();
98        metadata.insert(
99            KEY_INSTRUMENT_ID.to_string(),
100            self.instrument_id.to_string(),
101        );
102        metadata.insert(
103            KEY_PRICE_PRECISION.to_string(),
104            self.value.precision.to_string(),
105        );
106        metadata
107    }
108}
109
110impl DecodeFromRecordBatch for IndexPriceUpdate {
111    fn decode_batch(
112        metadata: &HashMap<String, String>,
113        record_batch: RecordBatch,
114    ) -> Result<Vec<Self>, EncodingError> {
115        let (instrument_id, price_precision) = parse_metadata(metadata)?;
116        let cols = record_batch.columns();
117
118        let value_values = extract_column::<FixedSizeBinaryArray>(
119            cols,
120            "value",
121            0,
122            DataType::FixedSizeBinary(PRECISION_BYTES),
123        )?;
124        let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 1, DataType::UInt64)?;
125        let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 2, DataType::UInt64)?;
126
127        assert_eq!(
128            value_values.value_length(),
129            PRECISION_BYTES,
130            "Price precision uses {PRECISION_BYTES} byte value"
131        );
132
133        let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
134            .map(|row| {
135                Ok(Self {
136                    instrument_id,
137                    value: Price::from_raw(get_raw_price(value_values.value(row)), price_precision),
138                    ts_event: ts_event_values.value(row).into(),
139                    ts_init: ts_init_values.value(row).into(),
140                })
141            })
142            .collect();
143
144        result
145    }
146}
147
148impl DecodeDataFromRecordBatch for IndexPriceUpdate {
149    fn decode_data_batch(
150        metadata: &HashMap<String, String>,
151        record_batch: RecordBatch,
152    ) -> Result<Vec<Data>, EncodingError> {
153        let updates: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
154        Ok(updates.into_iter().map(Data::from).collect())
155    }
156}
157
158////////////////////////////////////////////////////////////////////////////////
159// Tests
160////////////////////////////////////////////////////////////////////////////////
161#[cfg(test)]
162mod tests {
163    use std::sync::Arc;
164
165    use arrow::{array::Array, record_batch::RecordBatch};
166    use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw};
167    use rstest::rstest;
168
169    use super::*;
170    use crate::arrow::get_raw_price;
171
172    #[rstest]
173    fn test_get_schema() {
174        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
175        let metadata = HashMap::from([
176            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
177            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
178        ]);
179        let schema = IndexPriceUpdate::get_schema(Some(metadata.clone()));
180
181        let expected_fields = vec![
182            Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
183            Field::new("ts_event", DataType::UInt64, false),
184            Field::new("ts_init", DataType::UInt64, false),
185        ];
186
187        let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
188        assert_eq!(schema, expected_schema);
189    }
190
191    #[rstest]
192    fn test_get_schema_map() {
193        let schema_map = IndexPriceUpdate::get_schema_map();
194        let mut expected_map = HashMap::new();
195
196        let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
197        expected_map.insert("value".to_string(), fixed_size_binary);
198        expected_map.insert("ts_event".to_string(), "UInt64".to_string());
199        expected_map.insert("ts_init".to_string(), "UInt64".to_string());
200        assert_eq!(schema_map, expected_map);
201    }
202
203    #[rstest]
204    fn test_encode_batch() {
205        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
206        let metadata = HashMap::from([
207            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
208            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
209        ]);
210
211        let update1 = IndexPriceUpdate {
212            instrument_id,
213            value: Price::from("50000.00"),
214            ts_event: 1.into(),
215            ts_init: 3.into(),
216        };
217
218        let update2 = IndexPriceUpdate {
219            instrument_id,
220            value: Price::from("51000.00"),
221            ts_event: 2.into(),
222            ts_init: 4.into(),
223        };
224
225        let data = vec![update1, update2];
226        let record_batch = IndexPriceUpdate::encode_batch(&metadata, &data).unwrap();
227
228        let columns = record_batch.columns();
229        let value_values = columns[0]
230            .as_any()
231            .downcast_ref::<FixedSizeBinaryArray>()
232            .unwrap();
233        let ts_event_values = columns[1].as_any().downcast_ref::<UInt64Array>().unwrap();
234        let ts_init_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
235
236        assert_eq!(columns.len(), 3);
237        assert_eq!(value_values.len(), 2);
238        assert_eq!(
239            get_raw_price(value_values.value(0)),
240            (50000.00 * FIXED_SCALAR) as PriceRaw
241        );
242        assert_eq!(
243            get_raw_price(value_values.value(1)),
244            (51000.00 * FIXED_SCALAR) as PriceRaw
245        );
246        assert_eq!(ts_event_values.len(), 2);
247        assert_eq!(ts_event_values.value(0), 1);
248        assert_eq!(ts_event_values.value(1), 2);
249        assert_eq!(ts_init_values.len(), 2);
250        assert_eq!(ts_init_values.value(0), 3);
251        assert_eq!(ts_init_values.value(1), 4);
252    }
253
254    #[rstest]
255    fn test_decode_batch() {
256        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
257        let metadata = HashMap::from([
258            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
259            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
260        ]);
261
262        let value = FixedSizeBinaryArray::from(vec![
263            &(5000000 as PriceRaw).to_le_bytes(),
264            &(5100000 as PriceRaw).to_le_bytes(),
265        ]);
266        let ts_event = UInt64Array::from(vec![1, 2]);
267        let ts_init = UInt64Array::from(vec![3, 4]);
268
269        let record_batch = RecordBatch::try_new(
270            IndexPriceUpdate::get_schema(Some(metadata.clone())).into(),
271            vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
272        )
273        .unwrap();
274
275        let decoded_data = IndexPriceUpdate::decode_batch(&metadata, record_batch).unwrap();
276
277        assert_eq!(decoded_data.len(), 2);
278        assert_eq!(decoded_data[0].instrument_id, instrument_id);
279        assert_eq!(decoded_data[0].value, Price::from_raw(5000000, 2));
280        assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
281        assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
282
283        assert_eq!(decoded_data[1].instrument_id, instrument_id);
284        assert_eq!(decoded_data[1].value, Price::from_raw(5100000, 2));
285        assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
286        assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
287    }
288}