nautilus_serialization/arrow/
mark_price.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19    array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20    datatypes::{DataType, Field, Schema},
21    error::ArrowError,
22    record_batch::RecordBatch,
23};
24use nautilus_model::{
25    data::prices::MarkPriceUpdate,
26    identifiers::InstrumentId,
27    types::{Price, fixed::PRECISION_BYTES},
28};
29
30use super::{
31    DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
32    extract_column,
33};
34use crate::arrow::{
35    ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
36};
37
38impl ArrowSchemaProvider for MarkPriceUpdate {
39    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
40        let fields = vec![
41            Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
42            Field::new("ts_event", DataType::UInt64, false),
43            Field::new("ts_init", DataType::UInt64, false),
44        ];
45
46        match metadata {
47            Some(metadata) => Schema::new_with_metadata(fields, metadata),
48            None => Schema::new(fields),
49        }
50    }
51}
52
53fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
54    let instrument_id_str = metadata
55        .get(KEY_INSTRUMENT_ID)
56        .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
57    let instrument_id = InstrumentId::from_str(instrument_id_str)
58        .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
59
60    let price_precision = metadata
61        .get(KEY_PRICE_PRECISION)
62        .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
63        .parse::<u8>()
64        .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
65
66    Ok((instrument_id, price_precision))
67}
68
69impl EncodeToRecordBatch for MarkPriceUpdate {
70    fn encode_batch(
71        metadata: &HashMap<String, String>,
72        data: &[Self],
73    ) -> Result<RecordBatch, ArrowError> {
74        let mut value_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
75        let mut ts_event_builder = UInt64Array::builder(data.len());
76        let mut ts_init_builder = UInt64Array::builder(data.len());
77
78        for update in data {
79            value_builder
80                .append_value(update.value.raw.to_le_bytes())
81                .unwrap();
82            ts_event_builder.append_value(update.ts_event.as_u64());
83            ts_init_builder.append_value(update.ts_init.as_u64());
84        }
85
86        RecordBatch::try_new(
87            Self::get_schema(Some(metadata.clone())).into(),
88            vec![
89                Arc::new(value_builder.finish()),
90                Arc::new(ts_event_builder.finish()),
91                Arc::new(ts_init_builder.finish()),
92            ],
93        )
94    }
95
96    fn metadata(&self) -> HashMap<String, String> {
97        let mut metadata = HashMap::new();
98        metadata.insert(
99            KEY_INSTRUMENT_ID.to_string(),
100            self.instrument_id.to_string(),
101        );
102        metadata.insert(
103            KEY_PRICE_PRECISION.to_string(),
104            self.value.precision.to_string(),
105        );
106        metadata
107    }
108}
109
110impl DecodeFromRecordBatch for MarkPriceUpdate {
111    fn decode_batch(
112        metadata: &HashMap<String, String>,
113        record_batch: RecordBatch,
114    ) -> Result<Vec<Self>, EncodingError> {
115        let (instrument_id, price_precision) = parse_metadata(metadata)?;
116        let cols = record_batch.columns();
117
118        let value_values = extract_column::<FixedSizeBinaryArray>(
119            cols,
120            "value",
121            0,
122            DataType::FixedSizeBinary(PRECISION_BYTES),
123        )?;
124        let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 1, DataType::UInt64)?;
125        let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 2, DataType::UInt64)?;
126
127        if value_values.value_length() != PRECISION_BYTES {
128            return Err(EncodingError::ParseError(
129                "value",
130                format!(
131                    "Invalid value length: expected {PRECISION_BYTES}, found {}",
132                    value_values.value_length()
133                ),
134            ));
135        }
136
137        let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
138            .map(|row| {
139                Ok(Self {
140                    instrument_id,
141                    value: Price::from_raw(get_raw_price(value_values.value(row)), price_precision),
142                    ts_event: ts_event_values.value(row).into(),
143                    ts_init: ts_init_values.value(row).into(),
144                })
145            })
146            .collect();
147
148        result
149    }
150}
151
152impl DecodeDataFromRecordBatch for MarkPriceUpdate {
153    fn decode_data_batch(
154        metadata: &HashMap<String, String>,
155        record_batch: RecordBatch,
156    ) -> Result<Vec<Data>, EncodingError> {
157        let updates: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
158        Ok(updates.into_iter().map(Data::from).collect())
159    }
160}
161
162////////////////////////////////////////////////////////////////////////////////
163// Tests
164////////////////////////////////////////////////////////////////////////////////
165#[cfg(test)]
166mod tests {
167    use std::sync::Arc;
168
169    use arrow::{array::Array, record_batch::RecordBatch};
170    use nautilus_model::types::price::PriceRaw;
171    use rstest::rstest;
172    use rust_decimal_macros::dec;
173
174    use super::*;
175    use crate::arrow::get_raw_price;
176
177    #[rstest]
178    fn test_get_schema() {
179        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
180        let metadata = HashMap::from([
181            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
182            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
183        ]);
184        let schema = MarkPriceUpdate::get_schema(Some(metadata.clone()));
185
186        let expected_fields = vec![
187            Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
188            Field::new("ts_event", DataType::UInt64, false),
189            Field::new("ts_init", DataType::UInt64, false),
190        ];
191
192        let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
193        assert_eq!(schema, expected_schema);
194    }
195
196    #[rstest]
197    fn test_get_schema_map() {
198        let schema_map = MarkPriceUpdate::get_schema_map();
199        let mut expected_map = HashMap::new();
200
201        let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
202        expected_map.insert("value".to_string(), fixed_size_binary);
203        expected_map.insert("ts_event".to_string(), "UInt64".to_string());
204        expected_map.insert("ts_init".to_string(), "UInt64".to_string());
205        assert_eq!(schema_map, expected_map);
206    }
207
208    #[rstest]
209    fn test_encode_batch() {
210        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
211        let metadata = HashMap::from([
212            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
213            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
214        ]);
215
216        let update1 = MarkPriceUpdate {
217            instrument_id,
218            value: Price::from("50200.00"),
219            ts_event: 1.into(),
220            ts_init: 3.into(),
221        };
222
223        let update2 = MarkPriceUpdate {
224            instrument_id,
225            value: Price::from("50300.00"),
226            ts_event: 2.into(),
227            ts_init: 4.into(),
228        };
229
230        let data = vec![update1, update2];
231        let record_batch = MarkPriceUpdate::encode_batch(&metadata, &data).unwrap();
232
233        let columns = record_batch.columns();
234        let value_values = columns[0]
235            .as_any()
236            .downcast_ref::<FixedSizeBinaryArray>()
237            .unwrap();
238        let ts_event_values = columns[1].as_any().downcast_ref::<UInt64Array>().unwrap();
239        let ts_init_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
240
241        assert_eq!(columns.len(), 3);
242        assert_eq!(value_values.len(), 2);
243        assert_eq!(
244            get_raw_price(value_values.value(0)),
245            Price::from(dec!(50200.00).to_string()).raw
246        );
247        assert_eq!(
248            get_raw_price(value_values.value(1)),
249            Price::from(dec!(50300.00).to_string()).raw
250        );
251        assert_eq!(ts_event_values.len(), 2);
252        assert_eq!(ts_event_values.value(0), 1);
253        assert_eq!(ts_event_values.value(1), 2);
254        assert_eq!(ts_init_values.len(), 2);
255        assert_eq!(ts_init_values.value(0), 3);
256        assert_eq!(ts_init_values.value(1), 4);
257    }
258
259    #[rstest]
260    fn test_decode_batch() {
261        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
262        let metadata = HashMap::from([
263            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
264            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
265        ]);
266
267        let value = FixedSizeBinaryArray::from(vec![
268            &(5020000 as PriceRaw).to_le_bytes(),
269            &(5030000 as PriceRaw).to_le_bytes(),
270        ]);
271        let ts_event = UInt64Array::from(vec![1, 2]);
272        let ts_init = UInt64Array::from(vec![3, 4]);
273
274        let record_batch = RecordBatch::try_new(
275            MarkPriceUpdate::get_schema(Some(metadata.clone())).into(),
276            vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
277        )
278        .unwrap();
279
280        let decoded_data = MarkPriceUpdate::decode_batch(&metadata, record_batch).unwrap();
281
282        assert_eq!(decoded_data.len(), 2);
283        assert_eq!(decoded_data[0].instrument_id, instrument_id);
284        assert_eq!(decoded_data[0].value, Price::from_raw(5020000, 2));
285        assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
286        assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
287
288        assert_eq!(decoded_data[1].instrument_id, instrument_id);
289        assert_eq!(decoded_data[1].value, Price::from_raw(5030000, 2));
290        assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
291        assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
292    }
293}