nautilus_serialization/arrow/
trade.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19    array::{
20        FixedSizeBinaryArray, FixedSizeBinaryBuilder, StringArray, StringBuilder, StringViewArray,
21        UInt8Array, UInt64Array,
22    },
23    datatypes::{DataType, Field, Schema},
24    error::ArrowError,
25    record_batch::RecordBatch,
26};
27use nautilus_model::{
28    data::TradeTick,
29    enums::AggressorSide,
30    identifiers::{InstrumentId, TradeId},
31    types::{Price, Quantity, fixed::PRECISION_BYTES},
32};
33
34use super::{
35    DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
36    KEY_SIZE_PRECISION, extract_column, get_raw_price, get_raw_quantity,
37};
38use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
39
40impl ArrowSchemaProvider for TradeTick {
41    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
42        let fields = vec![
43            Field::new("price", DataType::FixedSizeBinary(PRECISION_BYTES), false),
44            Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
45            Field::new("aggressor_side", DataType::UInt8, false),
46            Field::new("trade_id", DataType::Utf8, false),
47            Field::new("ts_event", DataType::UInt64, false),
48            Field::new("ts_init", DataType::UInt64, false),
49        ];
50
51        match metadata {
52            Some(metadata) => Schema::new_with_metadata(fields, metadata),
53            None => Schema::new(fields),
54        }
55    }
56}
57
58fn parse_metadata(
59    metadata: &HashMap<String, String>,
60) -> Result<(InstrumentId, u8, u8), EncodingError> {
61    let instrument_id_str = metadata
62        .get(KEY_INSTRUMENT_ID)
63        .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
64    let instrument_id = InstrumentId::from_str(instrument_id_str)
65        .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
66
67    let price_precision = metadata
68        .get(KEY_PRICE_PRECISION)
69        .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
70        .parse::<u8>()
71        .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
72
73    let size_precision = metadata
74        .get(KEY_SIZE_PRECISION)
75        .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
76        .parse::<u8>()
77        .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
78
79    Ok((instrument_id, price_precision, size_precision))
80}
81
82impl EncodeToRecordBatch for TradeTick {
83    fn encode_batch(
84        metadata: &HashMap<String, String>,
85        data: &[Self],
86    ) -> Result<RecordBatch, ArrowError> {
87        let mut price_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
88        let mut size_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
89
90        let mut aggressor_side_builder = UInt8Array::builder(data.len());
91        let mut trade_id_builder = StringBuilder::new();
92        let mut ts_event_builder = UInt64Array::builder(data.len());
93        let mut ts_init_builder = UInt64Array::builder(data.len());
94
95        for tick in data {
96            price_builder
97                .append_value(tick.price.raw.to_le_bytes())
98                .unwrap();
99            size_builder
100                .append_value(tick.size.raw.to_le_bytes())
101                .unwrap();
102            aggressor_side_builder.append_value(tick.aggressor_side as u8);
103            trade_id_builder.append_value(tick.trade_id.to_string());
104            ts_event_builder.append_value(tick.ts_event.as_u64());
105            ts_init_builder.append_value(tick.ts_init.as_u64());
106        }
107
108        let price_array = Arc::new(price_builder.finish());
109        let size_array = Arc::new(size_builder.finish());
110        let aggressor_side_array = Arc::new(aggressor_side_builder.finish());
111        let trade_id_array = Arc::new(trade_id_builder.finish());
112        let ts_event_array = Arc::new(ts_event_builder.finish());
113        let ts_init_array = Arc::new(ts_init_builder.finish());
114
115        RecordBatch::try_new(
116            Self::get_schema(Some(metadata.clone())).into(),
117            vec![
118                price_array,
119                size_array,
120                aggressor_side_array,
121                trade_id_array,
122                ts_event_array,
123                ts_init_array,
124            ],
125        )
126    }
127
128    fn metadata(&self) -> HashMap<String, String> {
129        TradeTick::get_metadata(
130            &self.instrument_id,
131            self.price.precision,
132            self.size.precision,
133        )
134    }
135}
136
137impl DecodeFromRecordBatch for TradeTick {
138    fn decode_batch(
139        metadata: &HashMap<String, String>,
140        record_batch: RecordBatch,
141    ) -> Result<Vec<Self>, EncodingError> {
142        let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
143        let cols = record_batch.columns();
144
145        let price_values = extract_column::<FixedSizeBinaryArray>(
146            cols,
147            "price",
148            0,
149            DataType::FixedSizeBinary(PRECISION_BYTES),
150        )?;
151
152        let size_values = extract_column::<FixedSizeBinaryArray>(
153            cols,
154            "size",
155            1,
156            DataType::FixedSizeBinary(PRECISION_BYTES),
157        )?;
158        let aggressor_side_values =
159            extract_column::<UInt8Array>(cols, "aggressor_side", 2, DataType::UInt8)?;
160        let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 4, DataType::UInt64)?;
161        let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 5, DataType::UInt64)?;
162
163        // Datafusion reads trade_ids as StringView
164        let trade_id_values: Vec<TradeId> = if record_batch
165            .schema()
166            .field_with_name("trade_id")?
167            .data_type()
168            == &DataType::Utf8View
169        {
170            extract_column::<StringViewArray>(cols, "trade_id", 3, DataType::Utf8View)?
171                .iter()
172                .map(|id| TradeId::from(id.unwrap()))
173                .collect()
174        } else {
175            extract_column::<StringArray>(cols, "trade_id", 3, DataType::Utf8)?
176                .iter()
177                .map(|id| TradeId::from(id.unwrap()))
178                .collect()
179        };
180
181        let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
182            .map(|i| {
183                let price = Price::from_raw(get_raw_price(price_values.value(i)), price_precision);
184
185                let size =
186                    Quantity::from_raw(get_raw_quantity(size_values.value(i)), size_precision);
187                let aggressor_side_value = aggressor_side_values.value(i);
188                let aggressor_side = AggressorSide::from_repr(aggressor_side_value as usize)
189                    .ok_or_else(|| {
190                        EncodingError::ParseError(
191                            stringify!(AggressorSide),
192                            format!("Invalid enum value, was {aggressor_side_value}"),
193                        )
194                    })?;
195                let trade_id = trade_id_values[i];
196                let ts_event = ts_event_values.value(i).into();
197                let ts_init = ts_init_values.value(i).into();
198
199                Ok(Self {
200                    instrument_id,
201                    price,
202                    size,
203                    aggressor_side,
204                    trade_id,
205                    ts_event,
206                    ts_init,
207                })
208            })
209            .collect();
210
211        result
212    }
213}
214
215impl DecodeDataFromRecordBatch for TradeTick {
216    fn decode_data_batch(
217        metadata: &HashMap<String, String>,
218        record_batch: RecordBatch,
219    ) -> Result<Vec<Data>, EncodingError> {
220        let ticks: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
221        Ok(ticks.into_iter().map(Data::from).collect())
222    }
223}
224
225////////////////////////////////////////////////////////////////////////////////
226// Tests
227////////////////////////////////////////////////////////////////////////////////
228#[cfg(test)]
229mod tests {
230    use std::sync::Arc;
231
232    use arrow::{
233        array::{Array, FixedSizeBinaryArray, UInt8Array, UInt64Array},
234        record_batch::RecordBatch,
235    };
236    use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw};
237    use rstest::rstest;
238
239    use super::*;
240    use crate::arrow::{get_raw_price, get_raw_quantity};
241
242    #[rstest]
243    fn test_get_schema() {
244        let instrument_id = InstrumentId::from("AAPL.XNAS");
245        let metadata = TradeTick::get_metadata(&instrument_id, 2, 0);
246        let schema = TradeTick::get_schema(Some(metadata.clone()));
247
248        let mut expected_fields = Vec::with_capacity(6);
249
250        expected_fields.push(Field::new(
251            "price",
252            DataType::FixedSizeBinary(PRECISION_BYTES),
253            false,
254        ));
255
256        expected_fields.extend(vec![
257            Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
258            Field::new("aggressor_side", DataType::UInt8, false),
259            Field::new("trade_id", DataType::Utf8, false),
260            Field::new("ts_event", DataType::UInt64, false),
261            Field::new("ts_init", DataType::UInt64, false),
262        ]);
263
264        let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
265        assert_eq!(schema, expected_schema);
266    }
267
268    #[rstest]
269    fn test_get_schema_map() {
270        let schema_map = TradeTick::get_schema_map();
271        let mut expected_map = HashMap::new();
272
273        let precision_bytes = format!("FixedSizeBinary({PRECISION_BYTES})");
274        expected_map.insert("price".to_string(), precision_bytes.clone());
275        expected_map.insert("size".to_string(), precision_bytes);
276        expected_map.insert("aggressor_side".to_string(), "UInt8".to_string());
277        expected_map.insert("trade_id".to_string(), "Utf8".to_string());
278        expected_map.insert("ts_event".to_string(), "UInt64".to_string());
279        expected_map.insert("ts_init".to_string(), "UInt64".to_string());
280        assert_eq!(schema_map, expected_map);
281    }
282
283    #[rstest]
284    fn test_encode_trade_tick() {
285        let instrument_id = InstrumentId::from("AAPL.XNAS");
286        let metadata = TradeTick::get_metadata(&instrument_id, 2, 0);
287
288        let tick1 = TradeTick {
289            instrument_id,
290            price: Price::from("100.10"),
291            size: Quantity::from(1000),
292            aggressor_side: AggressorSide::Buyer,
293            trade_id: TradeId::new("1"),
294            ts_event: 1.into(),
295            ts_init: 3.into(),
296        };
297
298        let tick2 = TradeTick {
299            instrument_id,
300            price: Price::from("100.50"),
301            size: Quantity::from(500),
302            aggressor_side: AggressorSide::Seller,
303            trade_id: TradeId::new("2"),
304            ts_event: 2.into(),
305            ts_init: 4.into(),
306        };
307
308        let data = vec![tick1, tick2];
309        let record_batch = TradeTick::encode_batch(&metadata, &data).unwrap();
310        let columns = record_batch.columns();
311
312        let price_values = columns[0]
313            .as_any()
314            .downcast_ref::<FixedSizeBinaryArray>()
315            .unwrap();
316        assert_eq!(
317            get_raw_price(price_values.value(0)),
318            (100.10 * FIXED_SCALAR) as PriceRaw
319        );
320        assert_eq!(
321            get_raw_price(price_values.value(1)),
322            (100.50 * FIXED_SCALAR) as PriceRaw
323        );
324
325        let size_values = columns[1]
326            .as_any()
327            .downcast_ref::<FixedSizeBinaryArray>()
328            .unwrap();
329        assert_eq!(
330            get_raw_quantity(size_values.value(0)),
331            (1000.0 * FIXED_SCALAR) as QuantityRaw
332        );
333        assert_eq!(
334            get_raw_quantity(size_values.value(1)),
335            (500.0 * FIXED_SCALAR) as QuantityRaw
336        );
337
338        let aggressor_side_values = columns[2].as_any().downcast_ref::<UInt8Array>().unwrap();
339        let trade_id_values = columns[3].as_any().downcast_ref::<StringArray>().unwrap();
340        let ts_event_values = columns[4].as_any().downcast_ref::<UInt64Array>().unwrap();
341        let ts_init_values = columns[5].as_any().downcast_ref::<UInt64Array>().unwrap();
342
343        assert_eq!(columns.len(), 6);
344        assert_eq!(size_values.len(), 2);
345        assert_eq!(
346            get_raw_quantity(size_values.value(0)),
347            (1000.0 * FIXED_SCALAR) as QuantityRaw
348        );
349        assert_eq!(
350            get_raw_quantity(size_values.value(1)),
351            (500.0 * FIXED_SCALAR) as QuantityRaw
352        );
353        assert_eq!(aggressor_side_values.len(), 2);
354        assert_eq!(aggressor_side_values.value(0), 1);
355        assert_eq!(aggressor_side_values.value(1), 2);
356        assert_eq!(trade_id_values.len(), 2);
357        assert_eq!(trade_id_values.value(0), "1");
358        assert_eq!(trade_id_values.value(1), "2");
359        assert_eq!(ts_event_values.len(), 2);
360        assert_eq!(ts_event_values.value(0), 1);
361        assert_eq!(ts_event_values.value(1), 2);
362        assert_eq!(ts_init_values.len(), 2);
363        assert_eq!(ts_init_values.value(0), 3);
364        assert_eq!(ts_init_values.value(1), 4);
365    }
366
367    #[rstest]
368    fn test_decode_batch() {
369        let instrument_id = InstrumentId::from("AAPL.XNAS");
370        let metadata = TradeTick::get_metadata(&instrument_id, 2, 0);
371
372        let price = FixedSizeBinaryArray::from(vec![
373            &(1_000_000_000_000 as PriceRaw).to_le_bytes(),
374            &(1_010_000_000_000 as PriceRaw).to_le_bytes(),
375        ]);
376
377        let size = FixedSizeBinaryArray::from(vec![
378            &((1000.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
379            &((900.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
380        ]);
381        let aggressor_side = UInt8Array::from(vec![0, 1]); // 0 for BUY, 1 for SELL
382        let trade_id = StringArray::from(vec!["1", "2"]);
383        let ts_event = UInt64Array::from(vec![1, 2]);
384        let ts_init = UInt64Array::from(vec![3, 4]);
385
386        let record_batch = RecordBatch::try_new(
387            TradeTick::get_schema(Some(metadata.clone())).into(),
388            vec![
389                Arc::new(price),
390                Arc::new(size),
391                Arc::new(aggressor_side),
392                Arc::new(trade_id),
393                Arc::new(ts_event),
394                Arc::new(ts_init),
395            ],
396        )
397        .unwrap();
398
399        let decoded_data = TradeTick::decode_batch(&metadata, record_batch).unwrap();
400        assert_eq!(decoded_data.len(), 2);
401        assert_eq!(decoded_data[0].price, Price::from_raw(1_000_000_000_000, 2));
402        assert_eq!(decoded_data[1].price, Price::from_raw(1_010_000_000_000, 2));
403    }
404}