1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{
20 FixedSizeBinaryArray, FixedSizeBinaryBuilder, StringArray, StringBuilder, StringViewArray,
21 UInt8Array, UInt64Array,
22 },
23 datatypes::{DataType, Field, Schema},
24 error::ArrowError,
25 record_batch::RecordBatch,
26};
27use nautilus_model::{
28 data::TradeTick,
29 enums::AggressorSide,
30 identifiers::{InstrumentId, TradeId},
31 types::{Price, Quantity, fixed::PRECISION_BYTES},
32};
33
34use super::{
35 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
36 KEY_SIZE_PRECISION, extract_column, get_raw_price, get_raw_quantity,
37};
38use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
39
40impl ArrowSchemaProvider for TradeTick {
41 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
42 let fields = vec![
43 Field::new("price", DataType::FixedSizeBinary(PRECISION_BYTES), false),
44 Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
45 Field::new("aggressor_side", DataType::UInt8, false),
46 Field::new("trade_id", DataType::Utf8, false),
47 Field::new("ts_event", DataType::UInt64, false),
48 Field::new("ts_init", DataType::UInt64, false),
49 ];
50
51 match metadata {
52 Some(metadata) => Schema::new_with_metadata(fields, metadata),
53 None => Schema::new(fields),
54 }
55 }
56}
57
58fn parse_metadata(
59 metadata: &HashMap<String, String>,
60) -> Result<(InstrumentId, u8, u8), EncodingError> {
61 let instrument_id_str = metadata
62 .get(KEY_INSTRUMENT_ID)
63 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
64 let instrument_id = InstrumentId::from_str(instrument_id_str)
65 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
66
67 let price_precision = metadata
68 .get(KEY_PRICE_PRECISION)
69 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
70 .parse::<u8>()
71 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
72
73 let size_precision = metadata
74 .get(KEY_SIZE_PRECISION)
75 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
76 .parse::<u8>()
77 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
78
79 Ok((instrument_id, price_precision, size_precision))
80}
81
82impl EncodeToRecordBatch for TradeTick {
83 fn encode_batch(
84 metadata: &HashMap<String, String>,
85 data: &[Self],
86 ) -> Result<RecordBatch, ArrowError> {
87 let mut price_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
88 let mut size_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
89
90 let mut aggressor_side_builder = UInt8Array::builder(data.len());
91 let mut trade_id_builder = StringBuilder::new();
92 let mut ts_event_builder = UInt64Array::builder(data.len());
93 let mut ts_init_builder = UInt64Array::builder(data.len());
94
95 for tick in data {
96 price_builder
97 .append_value(tick.price.raw.to_le_bytes())
98 .unwrap();
99 size_builder
100 .append_value(tick.size.raw.to_le_bytes())
101 .unwrap();
102 aggressor_side_builder.append_value(tick.aggressor_side as u8);
103 trade_id_builder.append_value(tick.trade_id.to_string());
104 ts_event_builder.append_value(tick.ts_event.as_u64());
105 ts_init_builder.append_value(tick.ts_init.as_u64());
106 }
107
108 let price_array = Arc::new(price_builder.finish());
109 let size_array = Arc::new(size_builder.finish());
110 let aggressor_side_array = Arc::new(aggressor_side_builder.finish());
111 let trade_id_array = Arc::new(trade_id_builder.finish());
112 let ts_event_array = Arc::new(ts_event_builder.finish());
113 let ts_init_array = Arc::new(ts_init_builder.finish());
114
115 RecordBatch::try_new(
116 Self::get_schema(Some(metadata.clone())).into(),
117 vec![
118 price_array,
119 size_array,
120 aggressor_side_array,
121 trade_id_array,
122 ts_event_array,
123 ts_init_array,
124 ],
125 )
126 }
127
128 fn metadata(&self) -> HashMap<String, String> {
129 TradeTick::get_metadata(
130 &self.instrument_id,
131 self.price.precision,
132 self.size.precision,
133 )
134 }
135}
136
137impl DecodeFromRecordBatch for TradeTick {
138 fn decode_batch(
139 metadata: &HashMap<String, String>,
140 record_batch: RecordBatch,
141 ) -> Result<Vec<Self>, EncodingError> {
142 let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
143 let cols = record_batch.columns();
144
145 let price_values = extract_column::<FixedSizeBinaryArray>(
146 cols,
147 "price",
148 0,
149 DataType::FixedSizeBinary(PRECISION_BYTES),
150 )?;
151
152 let size_values = extract_column::<FixedSizeBinaryArray>(
153 cols,
154 "size",
155 1,
156 DataType::FixedSizeBinary(PRECISION_BYTES),
157 )?;
158 let aggressor_side_values =
159 extract_column::<UInt8Array>(cols, "aggressor_side", 2, DataType::UInt8)?;
160 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 4, DataType::UInt64)?;
161 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 5, DataType::UInt64)?;
162
163 let trade_id_values: Vec<TradeId> = if record_batch
165 .schema()
166 .field_with_name("trade_id")?
167 .data_type()
168 == &DataType::Utf8View
169 {
170 extract_column::<StringViewArray>(cols, "trade_id", 3, DataType::Utf8View)?
171 .iter()
172 .map(|id| TradeId::from(id.unwrap()))
173 .collect()
174 } else {
175 extract_column::<StringArray>(cols, "trade_id", 3, DataType::Utf8)?
176 .iter()
177 .map(|id| TradeId::from(id.unwrap()))
178 .collect()
179 };
180
181 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
182 .map(|i| {
183 let price = Price::from_raw(get_raw_price(price_values.value(i)), price_precision);
184
185 let size =
186 Quantity::from_raw(get_raw_quantity(size_values.value(i)), size_precision);
187 let aggressor_side_value = aggressor_side_values.value(i);
188 let aggressor_side = AggressorSide::from_repr(aggressor_side_value as usize)
189 .ok_or_else(|| {
190 EncodingError::ParseError(
191 stringify!(AggressorSide),
192 format!("Invalid enum value, was {aggressor_side_value}"),
193 )
194 })?;
195 let trade_id = trade_id_values[i];
196 let ts_event = ts_event_values.value(i).into();
197 let ts_init = ts_init_values.value(i).into();
198
199 Ok(Self {
200 instrument_id,
201 price,
202 size,
203 aggressor_side,
204 trade_id,
205 ts_event,
206 ts_init,
207 })
208 })
209 .collect();
210
211 result
212 }
213}
214
215impl DecodeDataFromRecordBatch for TradeTick {
216 fn decode_data_batch(
217 metadata: &HashMap<String, String>,
218 record_batch: RecordBatch,
219 ) -> Result<Vec<Data>, EncodingError> {
220 let ticks: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
221 Ok(ticks.into_iter().map(Data::from).collect())
222 }
223}
224
225#[cfg(test)]
229mod tests {
230 use std::sync::Arc;
231
232 use arrow::{
233 array::{Array, FixedSizeBinaryArray, UInt8Array, UInt64Array},
234 record_batch::RecordBatch,
235 };
236 use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw};
237 use rstest::rstest;
238
239 use super::*;
240 use crate::arrow::{get_raw_price, get_raw_quantity};
241
242 #[rstest]
243 fn test_get_schema() {
244 let instrument_id = InstrumentId::from("AAPL.XNAS");
245 let metadata = TradeTick::get_metadata(&instrument_id, 2, 0);
246 let schema = TradeTick::get_schema(Some(metadata.clone()));
247
248 let mut expected_fields = Vec::with_capacity(6);
249
250 expected_fields.push(Field::new(
251 "price",
252 DataType::FixedSizeBinary(PRECISION_BYTES),
253 false,
254 ));
255
256 expected_fields.extend(vec![
257 Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
258 Field::new("aggressor_side", DataType::UInt8, false),
259 Field::new("trade_id", DataType::Utf8, false),
260 Field::new("ts_event", DataType::UInt64, false),
261 Field::new("ts_init", DataType::UInt64, false),
262 ]);
263
264 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
265 assert_eq!(schema, expected_schema);
266 }
267
268 #[rstest]
269 fn test_get_schema_map() {
270 let schema_map = TradeTick::get_schema_map();
271 let mut expected_map = HashMap::new();
272
273 let precision_bytes = format!("FixedSizeBinary({PRECISION_BYTES})");
274 expected_map.insert("price".to_string(), precision_bytes.clone());
275 expected_map.insert("size".to_string(), precision_bytes);
276 expected_map.insert("aggressor_side".to_string(), "UInt8".to_string());
277 expected_map.insert("trade_id".to_string(), "Utf8".to_string());
278 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
279 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
280 assert_eq!(schema_map, expected_map);
281 }
282
283 #[rstest]
284 fn test_encode_trade_tick() {
285 let instrument_id = InstrumentId::from("AAPL.XNAS");
286 let metadata = TradeTick::get_metadata(&instrument_id, 2, 0);
287
288 let tick1 = TradeTick {
289 instrument_id,
290 price: Price::from("100.10"),
291 size: Quantity::from(1000),
292 aggressor_side: AggressorSide::Buyer,
293 trade_id: TradeId::new("1"),
294 ts_event: 1.into(),
295 ts_init: 3.into(),
296 };
297
298 let tick2 = TradeTick {
299 instrument_id,
300 price: Price::from("100.50"),
301 size: Quantity::from(500),
302 aggressor_side: AggressorSide::Seller,
303 trade_id: TradeId::new("2"),
304 ts_event: 2.into(),
305 ts_init: 4.into(),
306 };
307
308 let data = vec![tick1, tick2];
309 let record_batch = TradeTick::encode_batch(&metadata, &data).unwrap();
310 let columns = record_batch.columns();
311
312 let price_values = columns[0]
313 .as_any()
314 .downcast_ref::<FixedSizeBinaryArray>()
315 .unwrap();
316 assert_eq!(
317 get_raw_price(price_values.value(0)),
318 (100.10 * FIXED_SCALAR) as PriceRaw
319 );
320 assert_eq!(
321 get_raw_price(price_values.value(1)),
322 (100.50 * FIXED_SCALAR) as PriceRaw
323 );
324
325 let size_values = columns[1]
326 .as_any()
327 .downcast_ref::<FixedSizeBinaryArray>()
328 .unwrap();
329 assert_eq!(
330 get_raw_quantity(size_values.value(0)),
331 (1000.0 * FIXED_SCALAR) as QuantityRaw
332 );
333 assert_eq!(
334 get_raw_quantity(size_values.value(1)),
335 (500.0 * FIXED_SCALAR) as QuantityRaw
336 );
337
338 let aggressor_side_values = columns[2].as_any().downcast_ref::<UInt8Array>().unwrap();
339 let trade_id_values = columns[3].as_any().downcast_ref::<StringArray>().unwrap();
340 let ts_event_values = columns[4].as_any().downcast_ref::<UInt64Array>().unwrap();
341 let ts_init_values = columns[5].as_any().downcast_ref::<UInt64Array>().unwrap();
342
343 assert_eq!(columns.len(), 6);
344 assert_eq!(size_values.len(), 2);
345 assert_eq!(
346 get_raw_quantity(size_values.value(0)),
347 (1000.0 * FIXED_SCALAR) as QuantityRaw
348 );
349 assert_eq!(
350 get_raw_quantity(size_values.value(1)),
351 (500.0 * FIXED_SCALAR) as QuantityRaw
352 );
353 assert_eq!(aggressor_side_values.len(), 2);
354 assert_eq!(aggressor_side_values.value(0), 1);
355 assert_eq!(aggressor_side_values.value(1), 2);
356 assert_eq!(trade_id_values.len(), 2);
357 assert_eq!(trade_id_values.value(0), "1");
358 assert_eq!(trade_id_values.value(1), "2");
359 assert_eq!(ts_event_values.len(), 2);
360 assert_eq!(ts_event_values.value(0), 1);
361 assert_eq!(ts_event_values.value(1), 2);
362 assert_eq!(ts_init_values.len(), 2);
363 assert_eq!(ts_init_values.value(0), 3);
364 assert_eq!(ts_init_values.value(1), 4);
365 }
366
367 #[rstest]
368 fn test_decode_batch() {
369 let instrument_id = InstrumentId::from("AAPL.XNAS");
370 let metadata = TradeTick::get_metadata(&instrument_id, 2, 0);
371
372 let price = FixedSizeBinaryArray::from(vec![
373 &(1_000_000_000_000 as PriceRaw).to_le_bytes(),
374 &(1_010_000_000_000 as PriceRaw).to_le_bytes(),
375 ]);
376
377 let size = FixedSizeBinaryArray::from(vec![
378 &((1000.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
379 &((900.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
380 ]);
381 let aggressor_side = UInt8Array::from(vec![0, 1]); let trade_id = StringArray::from(vec!["1", "2"]);
383 let ts_event = UInt64Array::from(vec![1, 2]);
384 let ts_init = UInt64Array::from(vec![3, 4]);
385
386 let record_batch = RecordBatch::try_new(
387 TradeTick::get_schema(Some(metadata.clone())).into(),
388 vec![
389 Arc::new(price),
390 Arc::new(size),
391 Arc::new(aggressor_side),
392 Arc::new(trade_id),
393 Arc::new(ts_event),
394 Arc::new(ts_init),
395 ],
396 )
397 .unwrap();
398
399 let decoded_data = TradeTick::decode_batch(&metadata, record_batch).unwrap();
400 assert_eq!(decoded_data.len(), 2);
401 assert_eq!(decoded_data[0].price, Price::from_raw(1_000_000_000_000, 2));
402 assert_eq!(decoded_data[1].price, Price::from_raw(1_010_000_000_000, 2));
403 }
404}