1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt8Array, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::{BookOrder, OrderBookDelta},
26 enums::{BookAction, FromU8, OrderSide},
27 identifiers::InstrumentId,
28 types::{
29 Price, Quantity, fixed::PRECISION_BYTES, price::PRICE_UNDEF, quantity::QUANTITY_UNDEF,
30 },
31};
32
33use super::{
34 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
35 KEY_SIZE_PRECISION, extract_column, get_corrected_raw_price, get_corrected_raw_quantity,
36 get_raw_price, get_raw_quantity,
37};
38use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
39
40impl ArrowSchemaProvider for OrderBookDelta {
41 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
42 let fields = vec![
43 Field::new("action", DataType::UInt8, false),
44 Field::new("side", DataType::UInt8, false),
45 Field::new("price", DataType::FixedSizeBinary(PRECISION_BYTES), false),
46 Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
47 Field::new("order_id", DataType::UInt64, false),
48 Field::new("flags", DataType::UInt8, false),
49 Field::new("sequence", DataType::UInt64, false),
50 Field::new("ts_event", DataType::UInt64, false),
51 Field::new("ts_init", DataType::UInt64, false),
52 ];
53
54 match metadata {
55 Some(metadata) => Schema::new_with_metadata(fields, metadata),
56 None => Schema::new(fields),
57 }
58 }
59}
60
61fn parse_metadata(
62 metadata: &HashMap<String, String>,
63) -> Result<(InstrumentId, u8, u8), EncodingError> {
64 let instrument_id_str = metadata
65 .get(KEY_INSTRUMENT_ID)
66 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
67 let instrument_id = InstrumentId::from_str(instrument_id_str)
68 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
69
70 let price_precision = metadata
71 .get(KEY_PRICE_PRECISION)
72 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
73 .parse::<u8>()
74 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
75
76 let size_precision = metadata
77 .get(KEY_SIZE_PRECISION)
78 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
79 .parse::<u8>()
80 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
81
82 Ok((instrument_id, price_precision, size_precision))
83}
84
85impl EncodeToRecordBatch for OrderBookDelta {
86 fn encode_batch(
87 metadata: &HashMap<String, String>,
88 data: &[Self],
89 ) -> Result<RecordBatch, ArrowError> {
90 let mut action_builder = UInt8Array::builder(data.len());
91 let mut side_builder = UInt8Array::builder(data.len());
92 let mut price_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
93 let mut size_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
94 let mut order_id_builder = UInt64Array::builder(data.len());
95 let mut flags_builder = UInt8Array::builder(data.len());
96 let mut sequence_builder = UInt64Array::builder(data.len());
97 let mut ts_event_builder = UInt64Array::builder(data.len());
98 let mut ts_init_builder = UInt64Array::builder(data.len());
99
100 for delta in data {
101 action_builder.append_value(delta.action as u8);
102 side_builder.append_value(delta.order.side as u8);
103 price_builder
104 .append_value(delta.order.price.raw.to_le_bytes())
105 .unwrap();
106 size_builder
107 .append_value(delta.order.size.raw.to_le_bytes())
108 .unwrap();
109 order_id_builder.append_value(delta.order.order_id);
110 flags_builder.append_value(delta.flags);
111 sequence_builder.append_value(delta.sequence);
112 ts_event_builder.append_value(delta.ts_event.as_u64());
113 ts_init_builder.append_value(delta.ts_init.as_u64());
114 }
115
116 let action_array = action_builder.finish();
117 let side_array = side_builder.finish();
118 let price_array = price_builder.finish();
119 let size_array = size_builder.finish();
120 let order_id_array = order_id_builder.finish();
121 let flags_array = flags_builder.finish();
122 let sequence_array = sequence_builder.finish();
123 let ts_event_array = ts_event_builder.finish();
124 let ts_init_array = ts_init_builder.finish();
125
126 RecordBatch::try_new(
127 Self::get_schema(Some(metadata.clone())).into(),
128 vec![
129 Arc::new(action_array),
130 Arc::new(side_array),
131 Arc::new(price_array),
132 Arc::new(size_array),
133 Arc::new(order_id_array),
134 Arc::new(flags_array),
135 Arc::new(sequence_array),
136 Arc::new(ts_event_array),
137 Arc::new(ts_init_array),
138 ],
139 )
140 }
141
142 fn metadata(&self) -> HashMap<String, String> {
143 Self::get_metadata(
144 &self.instrument_id,
145 self.order.price.precision,
146 self.order.size.precision,
147 )
148 }
149
150 fn chunk_metadata(chunk: &[Self]) -> HashMap<String, String> {
154 let delta = chunk
155 .first()
156 .expect("Chunk should have at least one element to encode");
157
158 if delta.order.price.precision == 0
159 && delta.order.size.precision == 0
160 && let Some(delta) = chunk.get(1)
161 {
162 return EncodeToRecordBatch::metadata(delta);
163 }
164
165 EncodeToRecordBatch::metadata(delta)
166 }
167}
168
169impl DecodeFromRecordBatch for OrderBookDelta {
170 fn decode_batch(
171 metadata: &HashMap<String, String>,
172 record_batch: RecordBatch,
173 ) -> Result<Vec<Self>, EncodingError> {
174 let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
175 let cols = record_batch.columns();
176
177 let action_values = extract_column::<UInt8Array>(cols, "action", 0, DataType::UInt8)?;
178 let side_values = extract_column::<UInt8Array>(cols, "side", 1, DataType::UInt8)?;
179 let price_values = extract_column::<FixedSizeBinaryArray>(
180 cols,
181 "price",
182 2,
183 DataType::FixedSizeBinary(PRECISION_BYTES),
184 )?;
185 let size_values = extract_column::<FixedSizeBinaryArray>(
186 cols,
187 "size",
188 3,
189 DataType::FixedSizeBinary(PRECISION_BYTES),
190 )?;
191 let order_id_values = extract_column::<UInt64Array>(cols, "order_id", 4, DataType::UInt64)?;
192 let flags_values = extract_column::<UInt8Array>(cols, "flags", 5, DataType::UInt8)?;
193 let sequence_values = extract_column::<UInt64Array>(cols, "sequence", 6, DataType::UInt64)?;
194 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 7, DataType::UInt64)?;
195 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 8, DataType::UInt64)?;
196
197 if price_values.value_length() != PRECISION_BYTES {
198 return Err(EncodingError::ParseError(
199 "price",
200 format!(
201 "Invalid value length: expected {PRECISION_BYTES}, found {}",
202 price_values.value_length()
203 ),
204 ));
205 }
206
207 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
208 .map(|i| {
209 let action_value = action_values.value(i);
210 let action = BookAction::from_u8(action_value).ok_or_else(|| {
211 EncodingError::ParseError(
212 stringify!(BookAction),
213 format!("Invalid enum value, was {action_value}"),
214 )
215 })?;
216 let side_value = side_values.value(i);
217 let side = OrderSide::from_u8(side_value).ok_or_else(|| {
218 EncodingError::ParseError(
219 stringify!(OrderSide),
220 format!("Invalid enum value, was {side_value}"),
221 )
222 })?;
223 let raw_price = get_raw_price(price_values.value(i));
226 let price = if raw_price == PRICE_UNDEF {
227 Price::from_raw(raw_price, 0)
228 } else {
229 Price::from_raw(
230 get_corrected_raw_price(price_values.value(i), price_precision),
231 price_precision,
232 )
233 };
234
235 let raw_size = get_raw_quantity(size_values.value(i));
236 let size = if raw_size == QUANTITY_UNDEF {
237 Quantity::from_raw(raw_size, 0)
238 } else {
239 Quantity::from_raw(
240 get_corrected_raw_quantity(size_values.value(i), size_precision),
241 size_precision,
242 )
243 };
244 let order_id = order_id_values.value(i);
245 let flags = flags_values.value(i);
246 let sequence = sequence_values.value(i);
247 let ts_event = ts_event_values.value(i).into();
248 let ts_init = ts_init_values.value(i).into();
249
250 Ok(Self {
251 instrument_id,
252 action,
253 order: BookOrder {
254 side,
255 price,
256 size,
257 order_id,
258 },
259 flags,
260 sequence,
261 ts_event,
262 ts_init,
263 })
264 })
265 .collect();
266
267 result
268 }
269}
270
271impl DecodeDataFromRecordBatch for OrderBookDelta {
272 fn decode_data_batch(
273 metadata: &HashMap<String, String>,
274 record_batch: RecordBatch,
275 ) -> Result<Vec<Data>, EncodingError> {
276 let deltas: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
277 Ok(deltas.into_iter().map(Data::from).collect())
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use std::sync::Arc;
284
285 use arrow::{array::Array, record_batch::RecordBatch};
286 use nautilus_model::types::{
287 fixed::FIXED_SCALAR,
288 price::{PRICE_UNDEF, PriceRaw},
289 quantity::QUANTITY_UNDEF,
290 };
291 use pretty_assertions::assert_eq;
292 use rstest::rstest;
293
294 use super::*;
295 use crate::arrow::get_raw_price;
296
297 #[rstest]
298 fn test_get_schema() {
299 let instrument_id = InstrumentId::from("AAPL.XNAS");
300 let metadata = OrderBookDelta::get_metadata(&instrument_id, 2, 0);
301 let schema = OrderBookDelta::get_schema(Some(metadata.clone()));
302
303 let expected_fields = vec![
304 Field::new("action", DataType::UInt8, false),
305 Field::new("side", DataType::UInt8, false),
306 Field::new("price", DataType::FixedSizeBinary(PRECISION_BYTES), false),
307 Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
308 Field::new("order_id", DataType::UInt64, false),
309 Field::new("flags", DataType::UInt8, false),
310 Field::new("sequence", DataType::UInt64, false),
311 Field::new("ts_event", DataType::UInt64, false),
312 Field::new("ts_init", DataType::UInt64, false),
313 ];
314
315 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
316 assert_eq!(schema, expected_schema);
317 }
318
319 #[rstest]
320 fn test_get_schema_map() {
321 let schema_map = OrderBookDelta::get_schema_map();
322 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
323
324 assert_eq!(schema_map.get("action").unwrap(), "UInt8");
325 assert_eq!(schema_map.get("side").unwrap(), "UInt8");
326 assert_eq!(*schema_map.get("price").unwrap(), fixed_size_binary);
327 assert_eq!(*schema_map.get("size").unwrap(), fixed_size_binary);
328 assert_eq!(schema_map.get("order_id").unwrap(), "UInt64");
329 assert_eq!(schema_map.get("flags").unwrap(), "UInt8");
330 assert_eq!(schema_map.get("sequence").unwrap(), "UInt64");
331 assert_eq!(schema_map.get("ts_event").unwrap(), "UInt64");
332 assert_eq!(schema_map.get("ts_init").unwrap(), "UInt64");
333 }
334
335 #[rstest]
336 fn test_encode_batch() {
337 let instrument_id = InstrumentId::from("AAPL.XNAS");
338 let metadata = OrderBookDelta::get_metadata(&instrument_id, 2, 0);
339
340 let delta1 = OrderBookDelta {
341 instrument_id,
342 action: BookAction::Add,
343 order: BookOrder {
344 side: OrderSide::Buy,
345 price: Price::from("100.10"),
346 size: Quantity::from(100),
347 order_id: 1,
348 },
349 flags: 0,
350 sequence: 1,
351 ts_event: 1.into(),
352 ts_init: 3.into(),
353 };
354
355 let delta2 = OrderBookDelta {
356 instrument_id,
357 action: BookAction::Update,
358 order: BookOrder {
359 side: OrderSide::Sell,
360 price: Price::from("101.20"),
361 size: Quantity::from(200),
362 order_id: 2,
363 },
364 flags: 1,
365 sequence: 2,
366 ts_event: 2.into(),
367 ts_init: 4.into(),
368 };
369
370 let data = vec![delta1, delta2];
371 let record_batch = OrderBookDelta::encode_batch(&metadata, &data).unwrap();
372
373 let columns = record_batch.columns();
374 let action_values = columns[0].as_any().downcast_ref::<UInt8Array>().unwrap();
375 let side_values = columns[1].as_any().downcast_ref::<UInt8Array>().unwrap();
376 let price_values = columns[2]
377 .as_any()
378 .downcast_ref::<FixedSizeBinaryArray>()
379 .unwrap();
380 let size_values = columns[3]
381 .as_any()
382 .downcast_ref::<FixedSizeBinaryArray>()
383 .unwrap();
384 let order_id_values = columns[4].as_any().downcast_ref::<UInt64Array>().unwrap();
385 let flags_values = columns[5].as_any().downcast_ref::<UInt8Array>().unwrap();
386 let sequence_values = columns[6].as_any().downcast_ref::<UInt64Array>().unwrap();
387 let ts_event_values = columns[7].as_any().downcast_ref::<UInt64Array>().unwrap();
388 let ts_init_values = columns[8].as_any().downcast_ref::<UInt64Array>().unwrap();
389
390 assert_eq!(columns.len(), 9);
391 assert_eq!(action_values.len(), 2);
392 assert_eq!(action_values.value(0), 1);
393 assert_eq!(action_values.value(1), 2);
394 assert_eq!(side_values.len(), 2);
395 assert_eq!(side_values.value(0), 1);
396 assert_eq!(side_values.value(1), 2);
397
398 assert_eq!(price_values.len(), 2);
399 assert_eq!(
400 get_raw_price(price_values.value(0)),
401 (100.10 * FIXED_SCALAR) as PriceRaw
402 );
403 assert_eq!(
404 get_raw_price(price_values.value(1)),
405 (101.20 * FIXED_SCALAR) as PriceRaw
406 );
407
408 assert_eq!(size_values.len(), 2);
409 assert_eq!(
410 get_raw_price(size_values.value(0)),
411 (100.0 * FIXED_SCALAR) as PriceRaw
412 );
413 assert_eq!(
414 get_raw_price(size_values.value(1)),
415 (200.0 * FIXED_SCALAR) as PriceRaw
416 );
417 assert_eq!(order_id_values.len(), 2);
418 assert_eq!(order_id_values.value(0), 1);
419 assert_eq!(order_id_values.value(1), 2);
420 assert_eq!(flags_values.len(), 2);
421 assert_eq!(flags_values.value(0), 0);
422 assert_eq!(flags_values.value(1), 1);
423 assert_eq!(sequence_values.len(), 2);
424 assert_eq!(sequence_values.value(0), 1);
425 assert_eq!(sequence_values.value(1), 2);
426 assert_eq!(ts_event_values.len(), 2);
427 assert_eq!(ts_event_values.value(0), 1);
428 assert_eq!(ts_event_values.value(1), 2);
429 assert_eq!(ts_init_values.len(), 2);
430 assert_eq!(ts_init_values.value(0), 3);
431 assert_eq!(ts_init_values.value(1), 4);
432 }
433
434 #[rstest]
435 fn test_decode_batch() {
436 let instrument_id = InstrumentId::from("AAPL.XNAS");
437 let metadata = OrderBookDelta::get_metadata(&instrument_id, 2, 0);
438
439 let action = UInt8Array::from(vec![1, 2]);
440 let side = UInt8Array::from(vec![1, 1]);
441 let price = FixedSizeBinaryArray::from(vec![
442 &((101.10 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
443 &((101.20 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
444 ]);
445 let size = FixedSizeBinaryArray::from(vec![
446 &((10000.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
447 &((9000.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
448 ]);
449 let order_id = UInt64Array::from(vec![1, 2]);
450 let flags = UInt8Array::from(vec![0, 0]);
451 let sequence = UInt64Array::from(vec![1, 2]);
452 let ts_event = UInt64Array::from(vec![1, 2]);
453 let ts_init = UInt64Array::from(vec![3, 4]);
454
455 let record_batch = RecordBatch::try_new(
456 OrderBookDelta::get_schema(Some(metadata.clone())).into(),
457 vec![
458 Arc::new(action),
459 Arc::new(side),
460 Arc::new(price),
461 Arc::new(size),
462 Arc::new(order_id),
463 Arc::new(flags),
464 Arc::new(sequence),
465 Arc::new(ts_event),
466 Arc::new(ts_init),
467 ],
468 )
469 .unwrap();
470
471 let decoded_data = OrderBookDelta::decode_batch(&metadata, record_batch).unwrap();
472 assert_eq!(decoded_data.len(), 2);
473 }
474
475 #[rstest]
476 fn test_decode_batch_with_undef_values() {
477 let instrument_id = InstrumentId::from("PLTR.XNAS");
478 let metadata = OrderBookDelta::get_metadata(&instrument_id, 2, 0);
479
480 let action = UInt8Array::from(vec![4, 1]); let side = UInt8Array::from(vec![0, 1]); let price = FixedSizeBinaryArray::from(vec![
484 &PRICE_UNDEF.to_le_bytes(),
485 &((100.50 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
486 ]);
487 let size = FixedSizeBinaryArray::from(vec![
488 &QUANTITY_UNDEF.to_le_bytes(),
489 &((1000.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
490 ]);
491 let order_id = UInt64Array::from(vec![0, 1]);
492 let flags = UInt8Array::from(vec![0, 0]);
493 let sequence = UInt64Array::from(vec![1, 2]);
494 let ts_event = UInt64Array::from(vec![1, 2]);
495 let ts_init = UInt64Array::from(vec![3, 4]);
496
497 let record_batch = RecordBatch::try_new(
498 OrderBookDelta::get_schema(Some(metadata.clone())).into(),
499 vec![
500 Arc::new(action),
501 Arc::new(side),
502 Arc::new(price),
503 Arc::new(size),
504 Arc::new(order_id),
505 Arc::new(flags),
506 Arc::new(sequence),
507 Arc::new(ts_event),
508 Arc::new(ts_init),
509 ],
510 )
511 .unwrap();
512
513 let decoded_data = OrderBookDelta::decode_batch(&metadata, record_batch).unwrap();
514 assert_eq!(decoded_data.len(), 2);
515 assert_eq!(decoded_data[0].order.price.raw, PRICE_UNDEF);
516 assert_eq!(decoded_data[0].order.price.precision, 0);
517 assert_eq!(decoded_data[0].order.size.raw, QUANTITY_UNDEF);
518 assert_eq!(decoded_data[0].order.size.precision, 0);
519 assert_eq!(decoded_data[1].order.price.precision, 2);
520 assert_eq!(decoded_data[1].order.size.precision, 0);
521 }
522}