Skip to main content

nautilus_serialization/arrow/
mark_price.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19    array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20    datatypes::{DataType, Field, Schema},
21    error::ArrowError,
22    record_batch::RecordBatch,
23};
24use nautilus_model::{
25    data::prices::MarkPriceUpdate, identifiers::InstrumentId, types::fixed::PRECISION_BYTES,
26};
27
28use super::{
29    DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION, decode_price,
30    extract_column, validate_precision_bytes,
31};
32use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
33
34impl ArrowSchemaProvider for MarkPriceUpdate {
35    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
36        let fields = vec![
37            Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
38            Field::new("ts_event", DataType::UInt64, false),
39            Field::new("ts_init", DataType::UInt64, false),
40        ];
41
42        match metadata {
43            Some(metadata) => Schema::new_with_metadata(fields, metadata),
44            None => Schema::new(fields),
45        }
46    }
47}
48
49fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
50    let instrument_id_str = metadata
51        .get(KEY_INSTRUMENT_ID)
52        .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
53    let instrument_id = InstrumentId::from_str(instrument_id_str)
54        .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
55
56    let price_precision = metadata
57        .get(KEY_PRICE_PRECISION)
58        .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
59        .parse::<u8>()
60        .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
61
62    Ok((instrument_id, price_precision))
63}
64
65impl EncodeToRecordBatch for MarkPriceUpdate {
66    fn encode_batch(
67        metadata: &HashMap<String, String>,
68        data: &[Self],
69    ) -> Result<RecordBatch, ArrowError> {
70        let mut value_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
71        let mut ts_event_builder = UInt64Array::builder(data.len());
72        let mut ts_init_builder = UInt64Array::builder(data.len());
73
74        for update in data {
75            value_builder
76                .append_value(update.value.raw.to_le_bytes())
77                .unwrap();
78            ts_event_builder.append_value(update.ts_event.as_u64());
79            ts_init_builder.append_value(update.ts_init.as_u64());
80        }
81
82        RecordBatch::try_new(
83            Self::get_schema(Some(metadata.clone())).into(),
84            vec![
85                Arc::new(value_builder.finish()),
86                Arc::new(ts_event_builder.finish()),
87                Arc::new(ts_init_builder.finish()),
88            ],
89        )
90    }
91
92    fn metadata(&self) -> HashMap<String, String> {
93        let mut metadata = HashMap::new();
94        metadata.insert(
95            KEY_INSTRUMENT_ID.to_string(),
96            self.instrument_id.to_string(),
97        );
98        metadata.insert(
99            KEY_PRICE_PRECISION.to_string(),
100            self.value.precision.to_string(),
101        );
102        metadata
103    }
104}
105
106impl DecodeFromRecordBatch for MarkPriceUpdate {
107    fn decode_batch(
108        metadata: &HashMap<String, String>,
109        record_batch: RecordBatch,
110    ) -> Result<Vec<Self>, EncodingError> {
111        let (instrument_id, price_precision) = parse_metadata(metadata)?;
112        let cols = record_batch.columns();
113
114        let value_values = extract_column::<FixedSizeBinaryArray>(
115            cols,
116            "value",
117            0,
118            DataType::FixedSizeBinary(PRECISION_BYTES),
119        )?;
120        let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 1, DataType::UInt64)?;
121        let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 2, DataType::UInt64)?;
122
123        validate_precision_bytes(value_values, "value")?;
124
125        let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
126            .map(|row| {
127                let value = decode_price(value_values.value(row), price_precision, "value", row)?;
128                Ok(Self {
129                    instrument_id,
130                    value,
131                    ts_event: ts_event_values.value(row).into(),
132                    ts_init: ts_init_values.value(row).into(),
133                })
134            })
135            .collect();
136
137        result
138    }
139}
140
141impl DecodeDataFromRecordBatch for MarkPriceUpdate {
142    fn decode_data_batch(
143        metadata: &HashMap<String, String>,
144        record_batch: RecordBatch,
145    ) -> Result<Vec<Data>, EncodingError> {
146        let updates: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
147        Ok(updates.into_iter().map(Data::from).collect())
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use std::sync::Arc;
154
155    use arrow::{array::Array, record_batch::RecordBatch};
156    use nautilus_model::types::{Price, fixed::FIXED_SCALAR, price::PriceRaw};
157    use rstest::rstest;
158    use rust_decimal_macros::dec;
159
160    use super::*;
161    use crate::arrow::get_raw_price;
162
163    #[rstest]
164    fn test_get_schema() {
165        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
166        let metadata = HashMap::from([
167            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
168            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
169        ]);
170        let schema = MarkPriceUpdate::get_schema(Some(metadata.clone()));
171
172        let expected_fields = vec![
173            Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
174            Field::new("ts_event", DataType::UInt64, false),
175            Field::new("ts_init", DataType::UInt64, false),
176        ];
177
178        let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
179        assert_eq!(schema, expected_schema);
180    }
181
182    #[rstest]
183    fn test_get_schema_map() {
184        let schema_map = MarkPriceUpdate::get_schema_map();
185        let mut expected_map = HashMap::new();
186
187        let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
188        expected_map.insert("value".to_string(), fixed_size_binary);
189        expected_map.insert("ts_event".to_string(), "UInt64".to_string());
190        expected_map.insert("ts_init".to_string(), "UInt64".to_string());
191        assert_eq!(schema_map, expected_map);
192    }
193
194    #[rstest]
195    fn test_encode_batch() {
196        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
197        let metadata = HashMap::from([
198            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
199            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
200        ]);
201
202        let update1 = MarkPriceUpdate {
203            instrument_id,
204            value: Price::from("50200.00"),
205            ts_event: 1.into(),
206            ts_init: 3.into(),
207        };
208
209        let update2 = MarkPriceUpdate {
210            instrument_id,
211            value: Price::from("50300.00"),
212            ts_event: 2.into(),
213            ts_init: 4.into(),
214        };
215
216        let data = vec![update1, update2];
217        let record_batch = MarkPriceUpdate::encode_batch(&metadata, &data).unwrap();
218
219        let columns = record_batch.columns();
220        let value_values = columns[0]
221            .as_any()
222            .downcast_ref::<FixedSizeBinaryArray>()
223            .unwrap();
224        let ts_event_values = columns[1].as_any().downcast_ref::<UInt64Array>().unwrap();
225        let ts_init_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
226
227        assert_eq!(columns.len(), 3);
228        assert_eq!(value_values.len(), 2);
229        assert_eq!(
230            get_raw_price(value_values.value(0)),
231            Price::from(dec!(50200.00).to_string()).raw
232        );
233        assert_eq!(
234            get_raw_price(value_values.value(1)),
235            Price::from(dec!(50300.00).to_string()).raw
236        );
237        assert_eq!(ts_event_values.len(), 2);
238        assert_eq!(ts_event_values.value(0), 1);
239        assert_eq!(ts_event_values.value(1), 2);
240        assert_eq!(ts_init_values.len(), 2);
241        assert_eq!(ts_init_values.value(0), 3);
242        assert_eq!(ts_init_values.value(1), 4);
243    }
244
245    #[rstest]
246    fn test_decode_batch() {
247        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
248        let metadata = HashMap::from([
249            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
250            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
251        ]);
252
253        let raw_price1 = (50.20 * FIXED_SCALAR) as PriceRaw;
254        let raw_price2 = (50.30 * FIXED_SCALAR) as PriceRaw;
255        let value =
256            FixedSizeBinaryArray::from(vec![&raw_price1.to_le_bytes(), &raw_price2.to_le_bytes()]);
257        let ts_event = UInt64Array::from(vec![1, 2]);
258        let ts_init = UInt64Array::from(vec![3, 4]);
259
260        let record_batch = RecordBatch::try_new(
261            MarkPriceUpdate::get_schema(Some(metadata.clone())).into(),
262            vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
263        )
264        .unwrap();
265
266        let decoded_data = MarkPriceUpdate::decode_batch(&metadata, record_batch).unwrap();
267
268        assert_eq!(decoded_data.len(), 2);
269        assert_eq!(decoded_data[0].instrument_id, instrument_id);
270        assert_eq!(decoded_data[0].value, Price::from_raw(raw_price1, 2));
271        assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
272        assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
273
274        assert_eq!(decoded_data[1].instrument_id, instrument_id);
275        assert_eq!(decoded_data[1].value, Price::from_raw(raw_price2, 2));
276        assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
277        assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
278    }
279
280    #[rstest]
281    fn test_decode_batch_invalid_value_returns_error() {
282        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
283        let metadata = HashMap::from([
284            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
285            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
286        ]);
287
288        let invalid_price: PriceRaw = PriceRaw::MAX - 1000;
289        let value = FixedSizeBinaryArray::from(vec![&invalid_price.to_le_bytes()]);
290        let ts_event = UInt64Array::from(vec![1]);
291        let ts_init = UInt64Array::from(vec![2]);
292
293        let record_batch = RecordBatch::try_new(
294            MarkPriceUpdate::get_schema(Some(metadata.clone())).into(),
295            vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
296        )
297        .unwrap();
298
299        let result = MarkPriceUpdate::decode_batch(&metadata, record_batch);
300        assert!(result.is_err());
301        let err = result.unwrap_err();
302        assert!(
303            err.to_string().contains("value") && err.to_string().contains("row 0"),
304            "Expected value error at row 0, was: {err}"
305        );
306    }
307
308    #[rstest]
309    fn test_decode_batch_missing_instrument_id_returns_error() {
310        let mut metadata = HashMap::from([
311            (
312                KEY_INSTRUMENT_ID.to_string(),
313                "BTC-USDT.BINANCE".to_string(),
314            ),
315            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
316        ]);
317
318        let raw_price = (50.20 * FIXED_SCALAR) as PriceRaw;
319        let value = FixedSizeBinaryArray::from(vec![&raw_price.to_le_bytes()]);
320        let ts_event = UInt64Array::from(vec![1]);
321        let ts_init = UInt64Array::from(vec![2]);
322
323        let record_batch = RecordBatch::try_new(
324            MarkPriceUpdate::get_schema(Some(metadata.clone())).into(),
325            vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
326        )
327        .unwrap();
328
329        metadata.remove(KEY_INSTRUMENT_ID);
330
331        let result = MarkPriceUpdate::decode_batch(&metadata, record_batch);
332        assert!(result.is_err());
333        let err = result.unwrap_err();
334        assert!(
335            err.to_string().contains("instrument_id"),
336            "Expected missing instrument_id error, was: {err}"
337        );
338    }
339
340    #[rstest]
341    fn test_encode_decode_round_trip() {
342        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
343        let metadata = HashMap::from([
344            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
345            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
346        ]);
347
348        let update1 = MarkPriceUpdate {
349            instrument_id,
350            value: Price::from("50200.00"),
351            ts_event: 1_000_000_000.into(),
352            ts_init: 1_000_000_001.into(),
353        };
354
355        let update2 = MarkPriceUpdate {
356            instrument_id,
357            value: Price::from("50300.00"),
358            ts_event: 2_000_000_000.into(),
359            ts_init: 2_000_000_001.into(),
360        };
361
362        let original = vec![update1, update2];
363        let record_batch = MarkPriceUpdate::encode_batch(&metadata, &original).unwrap();
364        let decoded = MarkPriceUpdate::decode_batch(&metadata, record_batch).unwrap();
365
366        assert_eq!(decoded.len(), original.len());
367        for (orig, dec) in original.iter().zip(decoded.iter()) {
368            assert_eq!(dec.instrument_id, orig.instrument_id);
369            assert_eq!(dec.value, orig.value);
370            assert_eq!(dec.ts_event, orig.ts_event);
371            assert_eq!(dec.ts_init, orig.ts_init);
372        }
373    }
374}