nautilus_serialization/arrow/
trade.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19    array::{
20        FixedSizeBinaryArray, FixedSizeBinaryBuilder, StringArray, StringBuilder, StringViewArray,
21        UInt8Array, UInt64Array,
22    },
23    datatypes::{DataType, Field, Schema},
24    error::ArrowError,
25    record_batch::RecordBatch,
26};
27use nautilus_model::{
28    data::TradeTick,
29    enums::AggressorSide,
30    identifiers::{InstrumentId, TradeId},
31    types::{Price, Quantity, fixed::PRECISION_BYTES},
32};
33
34use super::{
35    DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
36    KEY_SIZE_PRECISION, extract_column, get_corrected_raw_price, get_corrected_raw_quantity,
37};
38use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
39
40impl ArrowSchemaProvider for TradeTick {
41    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
42        let fields = vec![
43            Field::new("price", DataType::FixedSizeBinary(PRECISION_BYTES), false),
44            Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
45            Field::new("aggressor_side", DataType::UInt8, false),
46            Field::new("trade_id", DataType::Utf8, false),
47            Field::new("ts_event", DataType::UInt64, false),
48            Field::new("ts_init", DataType::UInt64, false),
49        ];
50
51        match metadata {
52            Some(metadata) => Schema::new_with_metadata(fields, metadata),
53            None => Schema::new(fields),
54        }
55    }
56}
57
58fn parse_metadata(
59    metadata: &HashMap<String, String>,
60) -> Result<(InstrumentId, u8, u8), EncodingError> {
61    let instrument_id_str = metadata
62        .get(KEY_INSTRUMENT_ID)
63        .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
64    let instrument_id = InstrumentId::from_str(instrument_id_str)
65        .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
66
67    let price_precision = metadata
68        .get(KEY_PRICE_PRECISION)
69        .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
70        .parse::<u8>()
71        .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
72
73    let size_precision = metadata
74        .get(KEY_SIZE_PRECISION)
75        .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
76        .parse::<u8>()
77        .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
78
79    Ok((instrument_id, price_precision, size_precision))
80}
81
82impl EncodeToRecordBatch for TradeTick {
83    fn encode_batch(
84        metadata: &HashMap<String, String>,
85        data: &[Self],
86    ) -> Result<RecordBatch, ArrowError> {
87        let mut price_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
88        let mut size_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
89
90        let mut aggressor_side_builder = UInt8Array::builder(data.len());
91        let mut trade_id_builder = StringBuilder::new();
92        let mut ts_event_builder = UInt64Array::builder(data.len());
93        let mut ts_init_builder = UInt64Array::builder(data.len());
94
95        for tick in data {
96            price_builder
97                .append_value(tick.price.raw.to_le_bytes())
98                .unwrap();
99            size_builder
100                .append_value(tick.size.raw.to_le_bytes())
101                .unwrap();
102            aggressor_side_builder.append_value(tick.aggressor_side as u8);
103            trade_id_builder.append_value(tick.trade_id.to_string());
104            ts_event_builder.append_value(tick.ts_event.as_u64());
105            ts_init_builder.append_value(tick.ts_init.as_u64());
106        }
107
108        let price_array = Arc::new(price_builder.finish());
109        let size_array = Arc::new(size_builder.finish());
110        let aggressor_side_array = Arc::new(aggressor_side_builder.finish());
111        let trade_id_array = Arc::new(trade_id_builder.finish());
112        let ts_event_array = Arc::new(ts_event_builder.finish());
113        let ts_init_array = Arc::new(ts_init_builder.finish());
114
115        RecordBatch::try_new(
116            Self::get_schema(Some(metadata.clone())).into(),
117            vec![
118                price_array,
119                size_array,
120                aggressor_side_array,
121                trade_id_array,
122                ts_event_array,
123                ts_init_array,
124            ],
125        )
126    }
127
128    fn metadata(&self) -> HashMap<String, String> {
129        Self::get_metadata(
130            &self.instrument_id,
131            self.price.precision,
132            self.size.precision,
133        )
134    }
135}
136
137impl DecodeFromRecordBatch for TradeTick {
138    fn decode_batch(
139        metadata: &HashMap<String, String>,
140        record_batch: RecordBatch,
141    ) -> Result<Vec<Self>, EncodingError> {
142        let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
143        let cols = record_batch.columns();
144
145        let price_values = extract_column::<FixedSizeBinaryArray>(
146            cols,
147            "price",
148            0,
149            DataType::FixedSizeBinary(PRECISION_BYTES),
150        )?;
151
152        let size_values = extract_column::<FixedSizeBinaryArray>(
153            cols,
154            "size",
155            1,
156            DataType::FixedSizeBinary(PRECISION_BYTES),
157        )?;
158        let aggressor_side_values =
159            extract_column::<UInt8Array>(cols, "aggressor_side", 2, DataType::UInt8)?;
160        let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 4, DataType::UInt64)?;
161        let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 5, DataType::UInt64)?;
162
163        // Datafusion reads trade_ids as StringView
164        let trade_id_values: Vec<TradeId> = if record_batch
165            .schema()
166            .field_with_name("trade_id")?
167            .data_type()
168            == &DataType::Utf8View
169        {
170            extract_column::<StringViewArray>(cols, "trade_id", 3, DataType::Utf8View)?
171                .iter()
172                .map(|id| TradeId::from(id.unwrap()))
173                .collect()
174        } else {
175            extract_column::<StringArray>(cols, "trade_id", 3, DataType::Utf8)?
176                .iter()
177                .map(|id| TradeId::from(id.unwrap()))
178                .collect()
179        };
180
181        let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
182            .map(|i| {
183                // Use corrected raw values to handle floating-point precision errors in stored data
184                let price = Price::from_raw(
185                    get_corrected_raw_price(price_values.value(i), price_precision),
186                    price_precision,
187                );
188                let size = Quantity::from_raw(
189                    get_corrected_raw_quantity(size_values.value(i), size_precision),
190                    size_precision,
191                );
192                let aggressor_side_value = aggressor_side_values.value(i);
193                let aggressor_side = AggressorSide::from_repr(aggressor_side_value as usize)
194                    .ok_or_else(|| {
195                        EncodingError::ParseError(
196                            stringify!(AggressorSide),
197                            format!("Invalid enum value, was {aggressor_side_value}"),
198                        )
199                    })?;
200                let trade_id = trade_id_values[i];
201                let ts_event = ts_event_values.value(i).into();
202                let ts_init = ts_init_values.value(i).into();
203
204                Ok(Self {
205                    instrument_id,
206                    price,
207                    size,
208                    aggressor_side,
209                    trade_id,
210                    ts_event,
211                    ts_init,
212                })
213            })
214            .collect();
215
216        result
217    }
218}
219
220impl DecodeDataFromRecordBatch for TradeTick {
221    fn decode_data_batch(
222        metadata: &HashMap<String, String>,
223        record_batch: RecordBatch,
224    ) -> Result<Vec<Data>, EncodingError> {
225        let ticks: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
226        Ok(ticks.into_iter().map(Data::from).collect())
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use std::sync::Arc;
233
234    use arrow::{
235        array::{Array, FixedSizeBinaryArray, UInt8Array, UInt64Array},
236        record_batch::RecordBatch,
237    };
238    use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw};
239    use rstest::rstest;
240
241    use super::*;
242    use crate::arrow::{get_raw_price, get_raw_quantity};
243
244    #[rstest]
245    fn test_get_schema() {
246        let instrument_id = InstrumentId::from("AAPL.XNAS");
247        let metadata = TradeTick::get_metadata(&instrument_id, 2, 0);
248        let schema = TradeTick::get_schema(Some(metadata.clone()));
249
250        let mut expected_fields = Vec::with_capacity(6);
251
252        expected_fields.push(Field::new(
253            "price",
254            DataType::FixedSizeBinary(PRECISION_BYTES),
255            false,
256        ));
257
258        expected_fields.extend(vec![
259            Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
260            Field::new("aggressor_side", DataType::UInt8, false),
261            Field::new("trade_id", DataType::Utf8, false),
262            Field::new("ts_event", DataType::UInt64, false),
263            Field::new("ts_init", DataType::UInt64, false),
264        ]);
265
266        let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
267        assert_eq!(schema, expected_schema);
268    }
269
270    #[rstest]
271    fn test_get_schema_map() {
272        let schema_map = TradeTick::get_schema_map();
273        let mut expected_map = HashMap::new();
274
275        let precision_bytes = format!("FixedSizeBinary({PRECISION_BYTES})");
276        expected_map.insert("price".to_string(), precision_bytes.clone());
277        expected_map.insert("size".to_string(), precision_bytes);
278        expected_map.insert("aggressor_side".to_string(), "UInt8".to_string());
279        expected_map.insert("trade_id".to_string(), "Utf8".to_string());
280        expected_map.insert("ts_event".to_string(), "UInt64".to_string());
281        expected_map.insert("ts_init".to_string(), "UInt64".to_string());
282        assert_eq!(schema_map, expected_map);
283    }
284
285    #[rstest]
286    fn test_encode_trade_tick() {
287        let instrument_id = InstrumentId::from("AAPL.XNAS");
288        let metadata = TradeTick::get_metadata(&instrument_id, 2, 0);
289
290        let tick1 = TradeTick {
291            instrument_id,
292            price: Price::from("100.10"),
293            size: Quantity::from(1000),
294            aggressor_side: AggressorSide::Buyer,
295            trade_id: TradeId::new("1"),
296            ts_event: 1.into(),
297            ts_init: 3.into(),
298        };
299
300        let tick2 = TradeTick {
301            instrument_id,
302            price: Price::from("100.50"),
303            size: Quantity::from(500),
304            aggressor_side: AggressorSide::Seller,
305            trade_id: TradeId::new("2"),
306            ts_event: 2.into(),
307            ts_init: 4.into(),
308        };
309
310        let data = vec![tick1, tick2];
311        let record_batch = TradeTick::encode_batch(&metadata, &data).unwrap();
312        let columns = record_batch.columns();
313
314        let price_values = columns[0]
315            .as_any()
316            .downcast_ref::<FixedSizeBinaryArray>()
317            .unwrap();
318        assert_eq!(
319            get_raw_price(price_values.value(0)),
320            (100.10 * FIXED_SCALAR) as PriceRaw
321        );
322        assert_eq!(
323            get_raw_price(price_values.value(1)),
324            (100.50 * FIXED_SCALAR) as PriceRaw
325        );
326
327        let size_values = columns[1]
328            .as_any()
329            .downcast_ref::<FixedSizeBinaryArray>()
330            .unwrap();
331        assert_eq!(
332            get_raw_quantity(size_values.value(0)),
333            (1000.0 * FIXED_SCALAR) as QuantityRaw
334        );
335        assert_eq!(
336            get_raw_quantity(size_values.value(1)),
337            (500.0 * FIXED_SCALAR) as QuantityRaw
338        );
339
340        let aggressor_side_values = columns[2].as_any().downcast_ref::<UInt8Array>().unwrap();
341        let trade_id_values = columns[3].as_any().downcast_ref::<StringArray>().unwrap();
342        let ts_event_values = columns[4].as_any().downcast_ref::<UInt64Array>().unwrap();
343        let ts_init_values = columns[5].as_any().downcast_ref::<UInt64Array>().unwrap();
344
345        assert_eq!(columns.len(), 6);
346        assert_eq!(size_values.len(), 2);
347        assert_eq!(
348            get_raw_quantity(size_values.value(0)),
349            (1000.0 * FIXED_SCALAR) as QuantityRaw
350        );
351        assert_eq!(
352            get_raw_quantity(size_values.value(1)),
353            (500.0 * FIXED_SCALAR) as QuantityRaw
354        );
355        assert_eq!(aggressor_side_values.len(), 2);
356        assert_eq!(aggressor_side_values.value(0), 1);
357        assert_eq!(aggressor_side_values.value(1), 2);
358        assert_eq!(trade_id_values.len(), 2);
359        assert_eq!(trade_id_values.value(0), "1");
360        assert_eq!(trade_id_values.value(1), "2");
361        assert_eq!(ts_event_values.len(), 2);
362        assert_eq!(ts_event_values.value(0), 1);
363        assert_eq!(ts_event_values.value(1), 2);
364        assert_eq!(ts_init_values.len(), 2);
365        assert_eq!(ts_init_values.value(0), 3);
366        assert_eq!(ts_init_values.value(1), 4);
367    }
368
369    #[rstest]
370    fn test_decode_batch() {
371        let instrument_id = InstrumentId::from("AAPL.XNAS");
372        let metadata = TradeTick::get_metadata(&instrument_id, 2, 0);
373
374        let raw_price1 = (100.00 * FIXED_SCALAR) as PriceRaw;
375        let raw_price2 = (101.00 * FIXED_SCALAR) as PriceRaw;
376        let price =
377            FixedSizeBinaryArray::from(vec![&raw_price1.to_le_bytes(), &raw_price2.to_le_bytes()]);
378
379        let size = FixedSizeBinaryArray::from(vec![
380            &((1000.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
381            &((900.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
382        ]);
383        let aggressor_side = UInt8Array::from(vec![0, 1]); // 0 for BUY, 1 for SELL
384        let trade_id = StringArray::from(vec!["1", "2"]);
385        let ts_event = UInt64Array::from(vec![1, 2]);
386        let ts_init = UInt64Array::from(vec![3, 4]);
387
388        let record_batch = RecordBatch::try_new(
389            TradeTick::get_schema(Some(metadata.clone())).into(),
390            vec![
391                Arc::new(price),
392                Arc::new(size),
393                Arc::new(aggressor_side),
394                Arc::new(trade_id),
395                Arc::new(ts_event),
396                Arc::new(ts_init),
397            ],
398        )
399        .unwrap();
400
401        let decoded_data = TradeTick::decode_batch(&metadata, record_batch).unwrap();
402        assert_eq!(decoded_data.len(), 2);
403        assert_eq!(decoded_data[0].price, Price::from_raw(raw_price1, 2));
404        assert_eq!(decoded_data[1].price, Price::from_raw(raw_price2, 2));
405    }
406}