1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{
20 FixedSizeBinaryArray, FixedSizeBinaryBuilder, StringArray, StringBuilder, StringViewArray,
21 UInt8Array, UInt64Array,
22 },
23 datatypes::{DataType, Field, Schema},
24 error::ArrowError,
25 record_batch::RecordBatch,
26};
27use nautilus_model::{
28 data::TradeTick,
29 enums::AggressorSide,
30 identifiers::{InstrumentId, TradeId},
31 types::{Price, Quantity, fixed::PRECISION_BYTES},
32};
33
34use super::{
35 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
36 KEY_SIZE_PRECISION, extract_column, get_corrected_raw_price, get_corrected_raw_quantity,
37};
38use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
39
40impl ArrowSchemaProvider for TradeTick {
41 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
42 let fields = vec![
43 Field::new("price", DataType::FixedSizeBinary(PRECISION_BYTES), false),
44 Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
45 Field::new("aggressor_side", DataType::UInt8, false),
46 Field::new("trade_id", DataType::Utf8, false),
47 Field::new("ts_event", DataType::UInt64, false),
48 Field::new("ts_init", DataType::UInt64, false),
49 ];
50
51 match metadata {
52 Some(metadata) => Schema::new_with_metadata(fields, metadata),
53 None => Schema::new(fields),
54 }
55 }
56}
57
58fn parse_metadata(
59 metadata: &HashMap<String, String>,
60) -> Result<(InstrumentId, u8, u8), EncodingError> {
61 let instrument_id_str = metadata
62 .get(KEY_INSTRUMENT_ID)
63 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
64 let instrument_id = InstrumentId::from_str(instrument_id_str)
65 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
66
67 let price_precision = metadata
68 .get(KEY_PRICE_PRECISION)
69 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
70 .parse::<u8>()
71 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
72
73 let size_precision = metadata
74 .get(KEY_SIZE_PRECISION)
75 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
76 .parse::<u8>()
77 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
78
79 Ok((instrument_id, price_precision, size_precision))
80}
81
82impl EncodeToRecordBatch for TradeTick {
83 fn encode_batch(
84 metadata: &HashMap<String, String>,
85 data: &[Self],
86 ) -> Result<RecordBatch, ArrowError> {
87 let mut price_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
88 let mut size_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
89
90 let mut aggressor_side_builder = UInt8Array::builder(data.len());
91 let mut trade_id_builder = StringBuilder::new();
92 let mut ts_event_builder = UInt64Array::builder(data.len());
93 let mut ts_init_builder = UInt64Array::builder(data.len());
94
95 for tick in data {
96 price_builder
97 .append_value(tick.price.raw.to_le_bytes())
98 .unwrap();
99 size_builder
100 .append_value(tick.size.raw.to_le_bytes())
101 .unwrap();
102 aggressor_side_builder.append_value(tick.aggressor_side as u8);
103 trade_id_builder.append_value(tick.trade_id.to_string());
104 ts_event_builder.append_value(tick.ts_event.as_u64());
105 ts_init_builder.append_value(tick.ts_init.as_u64());
106 }
107
108 let price_array = Arc::new(price_builder.finish());
109 let size_array = Arc::new(size_builder.finish());
110 let aggressor_side_array = Arc::new(aggressor_side_builder.finish());
111 let trade_id_array = Arc::new(trade_id_builder.finish());
112 let ts_event_array = Arc::new(ts_event_builder.finish());
113 let ts_init_array = Arc::new(ts_init_builder.finish());
114
115 RecordBatch::try_new(
116 Self::get_schema(Some(metadata.clone())).into(),
117 vec![
118 price_array,
119 size_array,
120 aggressor_side_array,
121 trade_id_array,
122 ts_event_array,
123 ts_init_array,
124 ],
125 )
126 }
127
128 fn metadata(&self) -> HashMap<String, String> {
129 Self::get_metadata(
130 &self.instrument_id,
131 self.price.precision,
132 self.size.precision,
133 )
134 }
135}
136
137impl DecodeFromRecordBatch for TradeTick {
138 fn decode_batch(
139 metadata: &HashMap<String, String>,
140 record_batch: RecordBatch,
141 ) -> Result<Vec<Self>, EncodingError> {
142 let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
143 let cols = record_batch.columns();
144
145 let price_values = extract_column::<FixedSizeBinaryArray>(
146 cols,
147 "price",
148 0,
149 DataType::FixedSizeBinary(PRECISION_BYTES),
150 )?;
151
152 let size_values = extract_column::<FixedSizeBinaryArray>(
153 cols,
154 "size",
155 1,
156 DataType::FixedSizeBinary(PRECISION_BYTES),
157 )?;
158 let aggressor_side_values =
159 extract_column::<UInt8Array>(cols, "aggressor_side", 2, DataType::UInt8)?;
160 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 4, DataType::UInt64)?;
161 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 5, DataType::UInt64)?;
162
163 let trade_id_values: Vec<TradeId> = if record_batch
165 .schema()
166 .field_with_name("trade_id")?
167 .data_type()
168 == &DataType::Utf8View
169 {
170 extract_column::<StringViewArray>(cols, "trade_id", 3, DataType::Utf8View)?
171 .iter()
172 .map(|id| TradeId::from(id.unwrap()))
173 .collect()
174 } else {
175 extract_column::<StringArray>(cols, "trade_id", 3, DataType::Utf8)?
176 .iter()
177 .map(|id| TradeId::from(id.unwrap()))
178 .collect()
179 };
180
181 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
182 .map(|i| {
183 let price = Price::from_raw(
185 get_corrected_raw_price(price_values.value(i), price_precision),
186 price_precision,
187 );
188 let size = Quantity::from_raw(
189 get_corrected_raw_quantity(size_values.value(i), size_precision),
190 size_precision,
191 );
192 let aggressor_side_value = aggressor_side_values.value(i);
193 let aggressor_side = AggressorSide::from_repr(aggressor_side_value as usize)
194 .ok_or_else(|| {
195 EncodingError::ParseError(
196 stringify!(AggressorSide),
197 format!("Invalid enum value, was {aggressor_side_value}"),
198 )
199 })?;
200 let trade_id = trade_id_values[i];
201 let ts_event = ts_event_values.value(i).into();
202 let ts_init = ts_init_values.value(i).into();
203
204 Ok(Self {
205 instrument_id,
206 price,
207 size,
208 aggressor_side,
209 trade_id,
210 ts_event,
211 ts_init,
212 })
213 })
214 .collect();
215
216 result
217 }
218}
219
220impl DecodeDataFromRecordBatch for TradeTick {
221 fn decode_data_batch(
222 metadata: &HashMap<String, String>,
223 record_batch: RecordBatch,
224 ) -> Result<Vec<Data>, EncodingError> {
225 let ticks: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
226 Ok(ticks.into_iter().map(Data::from).collect())
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use std::sync::Arc;
233
234 use arrow::{
235 array::{Array, FixedSizeBinaryArray, UInt8Array, UInt64Array},
236 record_batch::RecordBatch,
237 };
238 use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw};
239 use rstest::rstest;
240
241 use super::*;
242 use crate::arrow::{get_raw_price, get_raw_quantity};
243
244 #[rstest]
245 fn test_get_schema() {
246 let instrument_id = InstrumentId::from("AAPL.XNAS");
247 let metadata = TradeTick::get_metadata(&instrument_id, 2, 0);
248 let schema = TradeTick::get_schema(Some(metadata.clone()));
249
250 let mut expected_fields = Vec::with_capacity(6);
251
252 expected_fields.push(Field::new(
253 "price",
254 DataType::FixedSizeBinary(PRECISION_BYTES),
255 false,
256 ));
257
258 expected_fields.extend(vec![
259 Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
260 Field::new("aggressor_side", DataType::UInt8, false),
261 Field::new("trade_id", DataType::Utf8, false),
262 Field::new("ts_event", DataType::UInt64, false),
263 Field::new("ts_init", DataType::UInt64, false),
264 ]);
265
266 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
267 assert_eq!(schema, expected_schema);
268 }
269
270 #[rstest]
271 fn test_get_schema_map() {
272 let schema_map = TradeTick::get_schema_map();
273 let mut expected_map = HashMap::new();
274
275 let precision_bytes = format!("FixedSizeBinary({PRECISION_BYTES})");
276 expected_map.insert("price".to_string(), precision_bytes.clone());
277 expected_map.insert("size".to_string(), precision_bytes);
278 expected_map.insert("aggressor_side".to_string(), "UInt8".to_string());
279 expected_map.insert("trade_id".to_string(), "Utf8".to_string());
280 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
281 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
282 assert_eq!(schema_map, expected_map);
283 }
284
285 #[rstest]
286 fn test_encode_trade_tick() {
287 let instrument_id = InstrumentId::from("AAPL.XNAS");
288 let metadata = TradeTick::get_metadata(&instrument_id, 2, 0);
289
290 let tick1 = TradeTick {
291 instrument_id,
292 price: Price::from("100.10"),
293 size: Quantity::from(1000),
294 aggressor_side: AggressorSide::Buyer,
295 trade_id: TradeId::new("1"),
296 ts_event: 1.into(),
297 ts_init: 3.into(),
298 };
299
300 let tick2 = TradeTick {
301 instrument_id,
302 price: Price::from("100.50"),
303 size: Quantity::from(500),
304 aggressor_side: AggressorSide::Seller,
305 trade_id: TradeId::new("2"),
306 ts_event: 2.into(),
307 ts_init: 4.into(),
308 };
309
310 let data = vec![tick1, tick2];
311 let record_batch = TradeTick::encode_batch(&metadata, &data).unwrap();
312 let columns = record_batch.columns();
313
314 let price_values = columns[0]
315 .as_any()
316 .downcast_ref::<FixedSizeBinaryArray>()
317 .unwrap();
318 assert_eq!(
319 get_raw_price(price_values.value(0)),
320 (100.10 * FIXED_SCALAR) as PriceRaw
321 );
322 assert_eq!(
323 get_raw_price(price_values.value(1)),
324 (100.50 * FIXED_SCALAR) as PriceRaw
325 );
326
327 let size_values = columns[1]
328 .as_any()
329 .downcast_ref::<FixedSizeBinaryArray>()
330 .unwrap();
331 assert_eq!(
332 get_raw_quantity(size_values.value(0)),
333 (1000.0 * FIXED_SCALAR) as QuantityRaw
334 );
335 assert_eq!(
336 get_raw_quantity(size_values.value(1)),
337 (500.0 * FIXED_SCALAR) as QuantityRaw
338 );
339
340 let aggressor_side_values = columns[2].as_any().downcast_ref::<UInt8Array>().unwrap();
341 let trade_id_values = columns[3].as_any().downcast_ref::<StringArray>().unwrap();
342 let ts_event_values = columns[4].as_any().downcast_ref::<UInt64Array>().unwrap();
343 let ts_init_values = columns[5].as_any().downcast_ref::<UInt64Array>().unwrap();
344
345 assert_eq!(columns.len(), 6);
346 assert_eq!(size_values.len(), 2);
347 assert_eq!(
348 get_raw_quantity(size_values.value(0)),
349 (1000.0 * FIXED_SCALAR) as QuantityRaw
350 );
351 assert_eq!(
352 get_raw_quantity(size_values.value(1)),
353 (500.0 * FIXED_SCALAR) as QuantityRaw
354 );
355 assert_eq!(aggressor_side_values.len(), 2);
356 assert_eq!(aggressor_side_values.value(0), 1);
357 assert_eq!(aggressor_side_values.value(1), 2);
358 assert_eq!(trade_id_values.len(), 2);
359 assert_eq!(trade_id_values.value(0), "1");
360 assert_eq!(trade_id_values.value(1), "2");
361 assert_eq!(ts_event_values.len(), 2);
362 assert_eq!(ts_event_values.value(0), 1);
363 assert_eq!(ts_event_values.value(1), 2);
364 assert_eq!(ts_init_values.len(), 2);
365 assert_eq!(ts_init_values.value(0), 3);
366 assert_eq!(ts_init_values.value(1), 4);
367 }
368
369 #[rstest]
370 fn test_decode_batch() {
371 let instrument_id = InstrumentId::from("AAPL.XNAS");
372 let metadata = TradeTick::get_metadata(&instrument_id, 2, 0);
373
374 let raw_price1 = (100.00 * FIXED_SCALAR) as PriceRaw;
375 let raw_price2 = (101.00 * FIXED_SCALAR) as PriceRaw;
376 let price =
377 FixedSizeBinaryArray::from(vec![&raw_price1.to_le_bytes(), &raw_price2.to_le_bytes()]);
378
379 let size = FixedSizeBinaryArray::from(vec![
380 &((1000.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
381 &((900.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
382 ]);
383 let aggressor_side = UInt8Array::from(vec![0, 1]); let trade_id = StringArray::from(vec!["1", "2"]);
385 let ts_event = UInt64Array::from(vec![1, 2]);
386 let ts_init = UInt64Array::from(vec![3, 4]);
387
388 let record_batch = RecordBatch::try_new(
389 TradeTick::get_schema(Some(metadata.clone())).into(),
390 vec![
391 Arc::new(price),
392 Arc::new(size),
393 Arc::new(aggressor_side),
394 Arc::new(trade_id),
395 Arc::new(ts_event),
396 Arc::new(ts_init),
397 ],
398 )
399 .unwrap();
400
401 let decoded_data = TradeTick::decode_batch(&metadata, record_batch).unwrap();
402 assert_eq!(decoded_data.len(), 2);
403 assert_eq!(decoded_data[0].price, Price::from_raw(raw_price1, 2));
404 assert_eq!(decoded_data[1].price, Price::from_raw(raw_price2, 2));
405 }
406}