nautilus_serialization/arrow/
quote.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19    array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20    datatypes::{DataType, Field, Schema},
21    error::ArrowError,
22    record_batch::RecordBatch,
23};
24use nautilus_model::{
25    data::QuoteTick,
26    identifiers::InstrumentId,
27    types::{Price, Quantity, fixed::PRECISION_BYTES},
28};
29
30use super::{
31    DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
32    KEY_SIZE_PRECISION, extract_column, get_corrected_raw_price, get_corrected_raw_quantity,
33};
34use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
35
36impl ArrowSchemaProvider for QuoteTick {
37    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
38        let fields = vec![
39            Field::new(
40                "bid_price",
41                DataType::FixedSizeBinary(PRECISION_BYTES),
42                false,
43            ),
44            Field::new(
45                "ask_price",
46                DataType::FixedSizeBinary(PRECISION_BYTES),
47                false,
48            ),
49            Field::new(
50                "bid_size",
51                DataType::FixedSizeBinary(PRECISION_BYTES),
52                false,
53            ),
54            Field::new(
55                "ask_size",
56                DataType::FixedSizeBinary(PRECISION_BYTES),
57                false,
58            ),
59            Field::new("ts_event", DataType::UInt64, false),
60            Field::new("ts_init", DataType::UInt64, false),
61        ];
62
63        match metadata {
64            Some(metadata) => Schema::new_with_metadata(fields, metadata),
65            None => Schema::new(fields),
66        }
67    }
68}
69
70fn parse_metadata(
71    metadata: &HashMap<String, String>,
72) -> Result<(InstrumentId, u8, u8), EncodingError> {
73    let instrument_id_str = metadata
74        .get(KEY_INSTRUMENT_ID)
75        .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
76    let instrument_id = InstrumentId::from_str(instrument_id_str)
77        .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
78
79    let price_precision = metadata
80        .get(KEY_PRICE_PRECISION)
81        .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
82        .parse::<u8>()
83        .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
84
85    let size_precision = metadata
86        .get(KEY_SIZE_PRECISION)
87        .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
88        .parse::<u8>()
89        .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
90
91    Ok((instrument_id, price_precision, size_precision))
92}
93
94impl EncodeToRecordBatch for QuoteTick {
95    fn encode_batch(
96        metadata: &HashMap<String, String>,
97        data: &[Self],
98    ) -> Result<RecordBatch, ArrowError> {
99        let mut bid_price_builder =
100            FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
101        let mut ask_price_builder =
102            FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
103        let mut bid_size_builder =
104            FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
105        let mut ask_size_builder =
106            FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
107        let mut ts_event_builder = UInt64Array::builder(data.len());
108        let mut ts_init_builder = UInt64Array::builder(data.len());
109
110        for quote in data {
111            bid_price_builder
112                .append_value(quote.bid_price.raw.to_le_bytes())
113                .unwrap();
114            ask_price_builder
115                .append_value(quote.ask_price.raw.to_le_bytes())
116                .unwrap();
117            bid_size_builder
118                .append_value(quote.bid_size.raw.to_le_bytes())
119                .unwrap();
120            ask_size_builder
121                .append_value(quote.ask_size.raw.to_le_bytes())
122                .unwrap();
123            ts_event_builder.append_value(quote.ts_event.as_u64());
124            ts_init_builder.append_value(quote.ts_init.as_u64());
125        }
126
127        RecordBatch::try_new(
128            Self::get_schema(Some(metadata.clone())).into(),
129            vec![
130                Arc::new(bid_price_builder.finish()),
131                Arc::new(ask_price_builder.finish()),
132                Arc::new(bid_size_builder.finish()),
133                Arc::new(ask_size_builder.finish()),
134                Arc::new(ts_event_builder.finish()),
135                Arc::new(ts_init_builder.finish()),
136            ],
137        )
138    }
139
140    fn metadata(&self) -> HashMap<String, String> {
141        Self::get_metadata(
142            &self.instrument_id,
143            self.bid_price.precision,
144            self.bid_size.precision,
145        )
146    }
147}
148
149impl DecodeFromRecordBatch for QuoteTick {
150    fn decode_batch(
151        metadata: &HashMap<String, String>,
152        record_batch: RecordBatch,
153    ) -> Result<Vec<Self>, EncodingError> {
154        let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
155        let cols = record_batch.columns();
156
157        let bid_price_values = extract_column::<FixedSizeBinaryArray>(
158            cols,
159            "bid_price",
160            0,
161            DataType::FixedSizeBinary(PRECISION_BYTES),
162        )?;
163        let ask_price_values = extract_column::<FixedSizeBinaryArray>(
164            cols,
165            "ask_price",
166            1,
167            DataType::FixedSizeBinary(PRECISION_BYTES),
168        )?;
169        let bid_size_values = extract_column::<FixedSizeBinaryArray>(
170            cols,
171            "bid_size",
172            2,
173            DataType::FixedSizeBinary(PRECISION_BYTES),
174        )?;
175        let ask_size_values = extract_column::<FixedSizeBinaryArray>(
176            cols,
177            "ask_size",
178            3,
179            DataType::FixedSizeBinary(PRECISION_BYTES),
180        )?;
181        let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 4, DataType::UInt64)?;
182        let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 5, DataType::UInt64)?;
183
184        if bid_price_values.value_length() != PRECISION_BYTES {
185            return Err(EncodingError::ParseError(
186                "bid_price",
187                format!(
188                    "Invalid value length: expected {PRECISION_BYTES}, found {}",
189                    bid_price_values.value_length()
190                ),
191            ));
192        }
193        if ask_price_values.value_length() != PRECISION_BYTES {
194            return Err(EncodingError::ParseError(
195                "ask_price",
196                format!(
197                    "Invalid value length: expected {PRECISION_BYTES}, found {}",
198                    ask_price_values.value_length()
199                ),
200            ));
201        }
202
203        // Use corrected raw values to handle floating-point precision errors in stored data
204        let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
205            .map(|row| {
206                Ok(Self {
207                    instrument_id,
208                    bid_price: Price::from_raw(
209                        get_corrected_raw_price(bid_price_values.value(row), price_precision),
210                        price_precision,
211                    ),
212                    ask_price: Price::from_raw(
213                        get_corrected_raw_price(ask_price_values.value(row), price_precision),
214                        price_precision,
215                    ),
216                    bid_size: Quantity::from_raw(
217                        get_corrected_raw_quantity(bid_size_values.value(row), size_precision),
218                        size_precision,
219                    ),
220                    ask_size: Quantity::from_raw(
221                        get_corrected_raw_quantity(ask_size_values.value(row), size_precision),
222                        size_precision,
223                    ),
224                    ts_event: ts_event_values.value(row).into(),
225                    ts_init: ts_init_values.value(row).into(),
226                })
227            })
228            .collect();
229
230        result
231    }
232}
233
234impl DecodeDataFromRecordBatch for QuoteTick {
235    fn decode_data_batch(
236        metadata: &HashMap<String, String>,
237        record_batch: RecordBatch,
238    ) -> Result<Vec<Data>, EncodingError> {
239        let ticks: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
240        Ok(ticks.into_iter().map(Data::from).collect())
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use std::{collections::HashMap, sync::Arc};
247
248    use arrow::{array::Array, record_batch::RecordBatch};
249    use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw};
250    use rstest::rstest;
251
252    use super::*;
253    use crate::arrow::{get_raw_price, get_raw_quantity};
254
255    #[rstest]
256    fn test_get_schema() {
257        let instrument_id = InstrumentId::from("AAPL.XNAS");
258        let metadata = QuoteTick::get_metadata(&instrument_id, 2, 0);
259        let schema = QuoteTick::get_schema(Some(metadata.clone()));
260
261        let mut expected_fields = Vec::with_capacity(6);
262
263        expected_fields.push(Field::new(
264            "bid_price",
265            DataType::FixedSizeBinary(PRECISION_BYTES),
266            false,
267        ));
268        expected_fields.push(Field::new(
269            "ask_price",
270            DataType::FixedSizeBinary(PRECISION_BYTES),
271            false,
272        ));
273
274        expected_fields.extend(vec![
275            Field::new(
276                "bid_size",
277                DataType::FixedSizeBinary(PRECISION_BYTES),
278                false,
279            ),
280            Field::new(
281                "ask_size",
282                DataType::FixedSizeBinary(PRECISION_BYTES),
283                false,
284            ),
285            Field::new("ts_event", DataType::UInt64, false),
286            Field::new("ts_init", DataType::UInt64, false),
287        ]);
288
289        let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
290        assert_eq!(schema, expected_schema);
291    }
292
293    #[rstest]
294    fn test_get_schema_map() {
295        let arrow_schema = QuoteTick::get_schema_map();
296        let mut expected_map = HashMap::new();
297
298        let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
299        expected_map.insert("bid_price".to_string(), fixed_size_binary.clone());
300        expected_map.insert("ask_price".to_string(), fixed_size_binary.clone());
301        expected_map.insert("bid_size".to_string(), fixed_size_binary.clone());
302        expected_map.insert("ask_size".to_string(), fixed_size_binary);
303        expected_map.insert("ts_event".to_string(), "UInt64".to_string());
304        expected_map.insert("ts_init".to_string(), "UInt64".to_string());
305        assert_eq!(arrow_schema, expected_map);
306    }
307
308    #[rstest]
309    fn test_encode_quote_tick() {
310        // Create test data
311        let instrument_id = InstrumentId::from("AAPL.XNAS");
312        let tick1 = QuoteTick {
313            instrument_id,
314            bid_price: Price::from("100.10"),
315            ask_price: Price::from("101.50"),
316            bid_size: Quantity::from(1000),
317            ask_size: Quantity::from(500),
318            ts_event: 1.into(),
319            ts_init: 3.into(),
320        };
321
322        let tick2 = QuoteTick {
323            instrument_id,
324            bid_price: Price::from("100.75"),
325            ask_price: Price::from("100.20"),
326            bid_size: Quantity::from(750),
327            ask_size: Quantity::from(300),
328            ts_event: 2.into(),
329            ts_init: 4.into(),
330        };
331
332        let data = vec![tick1, tick2];
333        let metadata = QuoteTick::get_metadata(&instrument_id, 2, 0);
334        let record_batch = QuoteTick::encode_batch(&metadata, &data).unwrap();
335
336        // Verify the encoded data
337        let columns = record_batch.columns();
338
339        let bid_price_values = columns[0]
340            .as_any()
341            .downcast_ref::<FixedSizeBinaryArray>()
342            .unwrap();
343        let ask_price_values = columns[1]
344            .as_any()
345            .downcast_ref::<FixedSizeBinaryArray>()
346            .unwrap();
347        assert_eq!(
348            get_raw_price(bid_price_values.value(0)),
349            (100.10 * FIXED_SCALAR) as PriceRaw
350        );
351        assert_eq!(
352            get_raw_price(bid_price_values.value(1)),
353            (100.75 * FIXED_SCALAR) as PriceRaw
354        );
355        assert_eq!(
356            get_raw_price(ask_price_values.value(0)),
357            (101.50 * FIXED_SCALAR) as PriceRaw
358        );
359        assert_eq!(
360            get_raw_price(ask_price_values.value(1)),
361            (100.20 * FIXED_SCALAR) as PriceRaw
362        );
363
364        let bid_size_values = columns[2]
365            .as_any()
366            .downcast_ref::<FixedSizeBinaryArray>()
367            .unwrap();
368        let ask_size_values = columns[3]
369            .as_any()
370            .downcast_ref::<FixedSizeBinaryArray>()
371            .unwrap();
372        let ts_event_values = columns[4].as_any().downcast_ref::<UInt64Array>().unwrap();
373        let ts_init_values = columns[5].as_any().downcast_ref::<UInt64Array>().unwrap();
374
375        assert_eq!(columns.len(), 6);
376        assert_eq!(bid_size_values.len(), 2);
377        assert_eq!(
378            get_raw_quantity(bid_size_values.value(0)),
379            (1000.0 * FIXED_SCALAR) as QuantityRaw
380        );
381        assert_eq!(
382            get_raw_quantity(bid_size_values.value(1)),
383            (750.0 * FIXED_SCALAR) as QuantityRaw
384        );
385        assert_eq!(ask_size_values.len(), 2);
386        assert_eq!(
387            get_raw_quantity(ask_size_values.value(0)),
388            (500.0 * FIXED_SCALAR) as QuantityRaw
389        );
390        assert_eq!(
391            get_raw_quantity(ask_size_values.value(1)),
392            (300.0 * FIXED_SCALAR) as QuantityRaw
393        );
394        assert_eq!(ts_event_values.len(), 2);
395        assert_eq!(ts_event_values.value(0), 1);
396        assert_eq!(ts_event_values.value(1), 2);
397        assert_eq!(ts_init_values.len(), 2);
398        assert_eq!(ts_init_values.value(0), 3);
399        assert_eq!(ts_init_values.value(1), 4);
400    }
401
402    #[rstest]
403    fn test_decode_batch() {
404        let instrument_id = InstrumentId::from("AAPL.XNAS");
405        let metadata = QuoteTick::get_metadata(&instrument_id, 2, 0);
406
407        let raw_bid1 = (100.00 * FIXED_SCALAR) as PriceRaw;
408        let raw_bid2 = (99.00 * FIXED_SCALAR) as PriceRaw;
409        let raw_ask1 = (101.00 * FIXED_SCALAR) as PriceRaw;
410        let raw_ask2 = (100.00 * FIXED_SCALAR) as PriceRaw;
411
412        let (bid_price, ask_price) = (
413            FixedSizeBinaryArray::from(vec![&raw_bid1.to_le_bytes(), &raw_bid2.to_le_bytes()]),
414            FixedSizeBinaryArray::from(vec![&raw_ask1.to_le_bytes(), &raw_ask2.to_le_bytes()]),
415        );
416
417        let bid_size = FixedSizeBinaryArray::from(vec![
418            &((100.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
419            &((90.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
420        ]);
421        let ask_size = FixedSizeBinaryArray::from(vec![
422            &((110.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
423            &((100.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
424        ]);
425        let ts_event = UInt64Array::from(vec![1, 2]);
426        let ts_init = UInt64Array::from(vec![3, 4]);
427
428        let record_batch = RecordBatch::try_new(
429            QuoteTick::get_schema(Some(metadata.clone())).into(),
430            vec![
431                Arc::new(bid_price),
432                Arc::new(ask_price),
433                Arc::new(bid_size),
434                Arc::new(ask_size),
435                Arc::new(ts_event),
436                Arc::new(ts_init),
437            ],
438        )
439        .unwrap();
440
441        let decoded_data = QuoteTick::decode_batch(&metadata, record_batch).unwrap();
442        assert_eq!(decoded_data.len(), 2);
443
444        // Verify decoded values
445        assert_eq!(decoded_data[0].bid_price, Price::from_raw(raw_bid1, 2));
446        assert_eq!(decoded_data[0].ask_price, Price::from_raw(raw_ask1, 2));
447        assert_eq!(decoded_data[1].bid_price, Price::from_raw(raw_bid2, 2));
448        assert_eq!(decoded_data[1].ask_price, Price::from_raw(raw_ask2, 2));
449    }
450}