1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt8Array, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::{BookOrder, OrderBookDelta},
26 enums::{BookAction, FromU8, OrderSide},
27 identifiers::InstrumentId,
28 types::{
29 Price, Quantity, fixed::PRECISION_BYTES, price::PRICE_UNDEF, quantity::QUANTITY_UNDEF,
30 },
31};
32
33use super::{
34 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
35 KEY_SIZE_PRECISION, extract_column,
36};
37use crate::arrow::{
38 ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
39 get_raw_quantity,
40};
41
42impl ArrowSchemaProvider for OrderBookDelta {
43 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
44 let fields = vec![
45 Field::new("action", DataType::UInt8, false),
46 Field::new("side", DataType::UInt8, false),
47 Field::new("price", DataType::FixedSizeBinary(PRECISION_BYTES), false),
48 Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
49 Field::new("order_id", DataType::UInt64, false),
50 Field::new("flags", DataType::UInt8, false),
51 Field::new("sequence", DataType::UInt64, false),
52 Field::new("ts_event", DataType::UInt64, false),
53 Field::new("ts_init", DataType::UInt64, false),
54 ];
55
56 match metadata {
57 Some(metadata) => Schema::new_with_metadata(fields, metadata),
58 None => Schema::new(fields),
59 }
60 }
61}
62
63fn parse_metadata(
64 metadata: &HashMap<String, String>,
65) -> Result<(InstrumentId, u8, u8), EncodingError> {
66 let instrument_id_str = metadata
67 .get(KEY_INSTRUMENT_ID)
68 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
69 let instrument_id = InstrumentId::from_str(instrument_id_str)
70 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
71
72 let price_precision = metadata
73 .get(KEY_PRICE_PRECISION)
74 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
75 .parse::<u8>()
76 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
77
78 let size_precision = metadata
79 .get(KEY_SIZE_PRECISION)
80 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
81 .parse::<u8>()
82 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
83
84 Ok((instrument_id, price_precision, size_precision))
85}
86
87impl EncodeToRecordBatch for OrderBookDelta {
88 fn encode_batch(
89 metadata: &HashMap<String, String>,
90 data: &[Self],
91 ) -> Result<RecordBatch, ArrowError> {
92 let mut action_builder = UInt8Array::builder(data.len());
93 let mut side_builder = UInt8Array::builder(data.len());
94 let mut price_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
95 let mut size_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
96 let mut order_id_builder = UInt64Array::builder(data.len());
97 let mut flags_builder = UInt8Array::builder(data.len());
98 let mut sequence_builder = UInt64Array::builder(data.len());
99 let mut ts_event_builder = UInt64Array::builder(data.len());
100 let mut ts_init_builder = UInt64Array::builder(data.len());
101
102 for delta in data {
103 action_builder.append_value(delta.action as u8);
104 side_builder.append_value(delta.order.side as u8);
105 price_builder
106 .append_value(delta.order.price.raw.to_le_bytes())
107 .unwrap();
108 size_builder
109 .append_value(delta.order.size.raw.to_le_bytes())
110 .unwrap();
111 order_id_builder.append_value(delta.order.order_id);
112 flags_builder.append_value(delta.flags);
113 sequence_builder.append_value(delta.sequence);
114 ts_event_builder.append_value(delta.ts_event.as_u64());
115 ts_init_builder.append_value(delta.ts_init.as_u64());
116 }
117
118 let action_array = action_builder.finish();
119 let side_array = side_builder.finish();
120 let price_array = price_builder.finish();
121 let size_array = size_builder.finish();
122 let order_id_array = order_id_builder.finish();
123 let flags_array = flags_builder.finish();
124 let sequence_array = sequence_builder.finish();
125 let ts_event_array = ts_event_builder.finish();
126 let ts_init_array = ts_init_builder.finish();
127
128 RecordBatch::try_new(
129 Self::get_schema(Some(metadata.clone())).into(),
130 vec![
131 Arc::new(action_array),
132 Arc::new(side_array),
133 Arc::new(price_array),
134 Arc::new(size_array),
135 Arc::new(order_id_array),
136 Arc::new(flags_array),
137 Arc::new(sequence_array),
138 Arc::new(ts_event_array),
139 Arc::new(ts_init_array),
140 ],
141 )
142 }
143
144 fn metadata(&self) -> HashMap<String, String> {
145 Self::get_metadata(
146 &self.instrument_id,
147 self.order.price.precision,
148 self.order.size.precision,
149 )
150 }
151
152 fn chunk_metadata(chunk: &[Self]) -> HashMap<String, String> {
156 let delta = chunk
157 .first()
158 .expect("Chunk should have at least one element to encode");
159
160 if delta.order.price.precision == 0
161 && delta.order.size.precision == 0
162 && let Some(delta) = chunk.get(1)
163 {
164 return EncodeToRecordBatch::metadata(delta);
165 }
166
167 EncodeToRecordBatch::metadata(delta)
168 }
169}
170
171impl DecodeFromRecordBatch for OrderBookDelta {
172 fn decode_batch(
173 metadata: &HashMap<String, String>,
174 record_batch: RecordBatch,
175 ) -> Result<Vec<Self>, EncodingError> {
176 let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
177 let cols = record_batch.columns();
178
179 let action_values = extract_column::<UInt8Array>(cols, "action", 0, DataType::UInt8)?;
180 let side_values = extract_column::<UInt8Array>(cols, "side", 1, DataType::UInt8)?;
181 let price_values = extract_column::<FixedSizeBinaryArray>(
182 cols,
183 "price",
184 2,
185 DataType::FixedSizeBinary(PRECISION_BYTES),
186 )?;
187 let size_values = extract_column::<FixedSizeBinaryArray>(
188 cols,
189 "size",
190 3,
191 DataType::FixedSizeBinary(PRECISION_BYTES),
192 )?;
193 let order_id_values = extract_column::<UInt64Array>(cols, "order_id", 4, DataType::UInt64)?;
194 let flags_values = extract_column::<UInt8Array>(cols, "flags", 5, DataType::UInt8)?;
195 let sequence_values = extract_column::<UInt64Array>(cols, "sequence", 6, DataType::UInt64)?;
196 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 7, DataType::UInt64)?;
197 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 8, DataType::UInt64)?;
198
199 if price_values.value_length() != PRECISION_BYTES {
200 return Err(EncodingError::ParseError(
201 "price",
202 format!(
203 "Invalid value length: expected {PRECISION_BYTES}, found {}",
204 price_values.value_length()
205 ),
206 ));
207 }
208
209 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
210 .map(|i| {
211 let action_value = action_values.value(i);
212 let action = BookAction::from_u8(action_value).ok_or_else(|| {
213 EncodingError::ParseError(
214 stringify!(BookAction),
215 format!("Invalid enum value, was {action_value}"),
216 )
217 })?;
218 let side_value = side_values.value(i);
219 let side = OrderSide::from_u8(side_value).ok_or_else(|| {
220 EncodingError::ParseError(
221 stringify!(OrderSide),
222 format!("Invalid enum value, was {side_value}"),
223 )
224 })?;
225 let raw_price = get_raw_price(price_values.value(i));
226 let price_prec = if raw_price == PRICE_UNDEF {
227 0
228 } else {
229 price_precision
230 };
231 let price = Price::from_raw(raw_price, price_prec);
232
233 let raw_size = get_raw_quantity(size_values.value(i));
234 let size_prec = if raw_size == QUANTITY_UNDEF {
235 0
236 } else {
237 size_precision
238 };
239 let size = Quantity::from_raw(raw_size, size_prec);
240 let order_id = order_id_values.value(i);
241 let flags = flags_values.value(i);
242 let sequence = sequence_values.value(i);
243 let ts_event = ts_event_values.value(i).into();
244 let ts_init = ts_init_values.value(i).into();
245
246 Ok(Self {
247 instrument_id,
248 action,
249 order: BookOrder {
250 side,
251 price,
252 size,
253 order_id,
254 },
255 flags,
256 sequence,
257 ts_event,
258 ts_init,
259 })
260 })
261 .collect();
262
263 result
264 }
265}
266
267impl DecodeDataFromRecordBatch for OrderBookDelta {
268 fn decode_data_batch(
269 metadata: &HashMap<String, String>,
270 record_batch: RecordBatch,
271 ) -> Result<Vec<Data>, EncodingError> {
272 let deltas: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
273 Ok(deltas.into_iter().map(Data::from).collect())
274 }
275}
276
277#[cfg(test)]
281mod tests {
282 use std::sync::Arc;
283
284 use arrow::{array::Array, record_batch::RecordBatch};
285 use nautilus_model::types::{
286 fixed::FIXED_SCALAR,
287 price::{PRICE_UNDEF, PriceRaw},
288 quantity::QUANTITY_UNDEF,
289 };
290 use pretty_assertions::assert_eq;
291 use rstest::rstest;
292
293 use super::*;
294 use crate::arrow::get_raw_price;
295
296 #[rstest]
297 fn test_get_schema() {
298 let instrument_id = InstrumentId::from("AAPL.XNAS");
299 let metadata = OrderBookDelta::get_metadata(&instrument_id, 2, 0);
300 let schema = OrderBookDelta::get_schema(Some(metadata.clone()));
301
302 let expected_fields = vec![
303 Field::new("action", DataType::UInt8, false),
304 Field::new("side", DataType::UInt8, false),
305 Field::new("price", DataType::FixedSizeBinary(PRECISION_BYTES), false),
306 Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
307 Field::new("order_id", DataType::UInt64, false),
308 Field::new("flags", DataType::UInt8, false),
309 Field::new("sequence", DataType::UInt64, false),
310 Field::new("ts_event", DataType::UInt64, false),
311 Field::new("ts_init", DataType::UInt64, false),
312 ];
313
314 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
315 assert_eq!(schema, expected_schema);
316 }
317
318 #[rstest]
319 fn test_get_schema_map() {
320 let schema_map = OrderBookDelta::get_schema_map();
321 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
322
323 assert_eq!(schema_map.get("action").unwrap(), "UInt8");
324 assert_eq!(schema_map.get("side").unwrap(), "UInt8");
325 assert_eq!(*schema_map.get("price").unwrap(), fixed_size_binary);
326 assert_eq!(*schema_map.get("size").unwrap(), fixed_size_binary);
327 assert_eq!(schema_map.get("order_id").unwrap(), "UInt64");
328 assert_eq!(schema_map.get("flags").unwrap(), "UInt8");
329 assert_eq!(schema_map.get("sequence").unwrap(), "UInt64");
330 assert_eq!(schema_map.get("ts_event").unwrap(), "UInt64");
331 assert_eq!(schema_map.get("ts_init").unwrap(), "UInt64");
332 }
333
334 #[rstest]
335 fn test_encode_batch() {
336 let instrument_id = InstrumentId::from("AAPL.XNAS");
337 let metadata = OrderBookDelta::get_metadata(&instrument_id, 2, 0);
338
339 let delta1 = OrderBookDelta {
340 instrument_id,
341 action: BookAction::Add,
342 order: BookOrder {
343 side: OrderSide::Buy,
344 price: Price::from("100.10"),
345 size: Quantity::from(100),
346 order_id: 1,
347 },
348 flags: 0,
349 sequence: 1,
350 ts_event: 1.into(),
351 ts_init: 3.into(),
352 };
353
354 let delta2 = OrderBookDelta {
355 instrument_id,
356 action: BookAction::Update,
357 order: BookOrder {
358 side: OrderSide::Sell,
359 price: Price::from("101.20"),
360 size: Quantity::from(200),
361 order_id: 2,
362 },
363 flags: 1,
364 sequence: 2,
365 ts_event: 2.into(),
366 ts_init: 4.into(),
367 };
368
369 let data = vec![delta1, delta2];
370 let record_batch = OrderBookDelta::encode_batch(&metadata, &data).unwrap();
371
372 let columns = record_batch.columns();
373 let action_values = columns[0].as_any().downcast_ref::<UInt8Array>().unwrap();
374 let side_values = columns[1].as_any().downcast_ref::<UInt8Array>().unwrap();
375 let price_values = columns[2]
376 .as_any()
377 .downcast_ref::<FixedSizeBinaryArray>()
378 .unwrap();
379 let size_values = columns[3]
380 .as_any()
381 .downcast_ref::<FixedSizeBinaryArray>()
382 .unwrap();
383 let order_id_values = columns[4].as_any().downcast_ref::<UInt64Array>().unwrap();
384 let flags_values = columns[5].as_any().downcast_ref::<UInt8Array>().unwrap();
385 let sequence_values = columns[6].as_any().downcast_ref::<UInt64Array>().unwrap();
386 let ts_event_values = columns[7].as_any().downcast_ref::<UInt64Array>().unwrap();
387 let ts_init_values = columns[8].as_any().downcast_ref::<UInt64Array>().unwrap();
388
389 assert_eq!(columns.len(), 9);
390 assert_eq!(action_values.len(), 2);
391 assert_eq!(action_values.value(0), 1);
392 assert_eq!(action_values.value(1), 2);
393 assert_eq!(side_values.len(), 2);
394 assert_eq!(side_values.value(0), 1);
395 assert_eq!(side_values.value(1), 2);
396
397 assert_eq!(price_values.len(), 2);
398 assert_eq!(
399 get_raw_price(price_values.value(0)),
400 (100.10 * FIXED_SCALAR) as PriceRaw
401 );
402 assert_eq!(
403 get_raw_price(price_values.value(1)),
404 (101.20 * FIXED_SCALAR) as PriceRaw
405 );
406
407 assert_eq!(size_values.len(), 2);
408 assert_eq!(
409 get_raw_price(size_values.value(0)),
410 (100.0 * FIXED_SCALAR) as PriceRaw
411 );
412 assert_eq!(
413 get_raw_price(size_values.value(1)),
414 (200.0 * FIXED_SCALAR) as PriceRaw
415 );
416 assert_eq!(order_id_values.len(), 2);
417 assert_eq!(order_id_values.value(0), 1);
418 assert_eq!(order_id_values.value(1), 2);
419 assert_eq!(flags_values.len(), 2);
420 assert_eq!(flags_values.value(0), 0);
421 assert_eq!(flags_values.value(1), 1);
422 assert_eq!(sequence_values.len(), 2);
423 assert_eq!(sequence_values.value(0), 1);
424 assert_eq!(sequence_values.value(1), 2);
425 assert_eq!(ts_event_values.len(), 2);
426 assert_eq!(ts_event_values.value(0), 1);
427 assert_eq!(ts_event_values.value(1), 2);
428 assert_eq!(ts_init_values.len(), 2);
429 assert_eq!(ts_init_values.value(0), 3);
430 assert_eq!(ts_init_values.value(1), 4);
431 }
432
433 #[rstest]
434 fn test_decode_batch() {
435 let instrument_id = InstrumentId::from("AAPL.XNAS");
436 let metadata = OrderBookDelta::get_metadata(&instrument_id, 2, 0);
437
438 let action = UInt8Array::from(vec![1, 2]);
439 let side = UInt8Array::from(vec![1, 1]);
440 let price = FixedSizeBinaryArray::from(vec![
441 &((101.10 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
442 &((101.20 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
443 ]);
444 let size = FixedSizeBinaryArray::from(vec![
445 &((10000.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
446 &((9000.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
447 ]);
448 let order_id = UInt64Array::from(vec![1, 2]);
449 let flags = UInt8Array::from(vec![0, 0]);
450 let sequence = UInt64Array::from(vec![1, 2]);
451 let ts_event = UInt64Array::from(vec![1, 2]);
452 let ts_init = UInt64Array::from(vec![3, 4]);
453
454 let record_batch = RecordBatch::try_new(
455 OrderBookDelta::get_schema(Some(metadata.clone())).into(),
456 vec![
457 Arc::new(action),
458 Arc::new(side),
459 Arc::new(price),
460 Arc::new(size),
461 Arc::new(order_id),
462 Arc::new(flags),
463 Arc::new(sequence),
464 Arc::new(ts_event),
465 Arc::new(ts_init),
466 ],
467 )
468 .unwrap();
469
470 let decoded_data = OrderBookDelta::decode_batch(&metadata, record_batch).unwrap();
471 assert_eq!(decoded_data.len(), 2);
472 }
473
474 #[rstest]
475 fn test_decode_batch_with_undef_values() {
476 let instrument_id = InstrumentId::from("PLTR.XNAS");
477 let metadata = OrderBookDelta::get_metadata(&instrument_id, 2, 0);
478
479 let action = UInt8Array::from(vec![4, 1]); let side = UInt8Array::from(vec![0, 1]); let price = FixedSizeBinaryArray::from(vec![
483 &PRICE_UNDEF.to_le_bytes(),
484 &((100.50 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
485 ]);
486 let size = FixedSizeBinaryArray::from(vec![
487 &QUANTITY_UNDEF.to_le_bytes(),
488 &((1000.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
489 ]);
490 let order_id = UInt64Array::from(vec![0, 1]);
491 let flags = UInt8Array::from(vec![0, 0]);
492 let sequence = UInt64Array::from(vec![1, 2]);
493 let ts_event = UInt64Array::from(vec![1, 2]);
494 let ts_init = UInt64Array::from(vec![3, 4]);
495
496 let record_batch = RecordBatch::try_new(
497 OrderBookDelta::get_schema(Some(metadata.clone())).into(),
498 vec![
499 Arc::new(action),
500 Arc::new(side),
501 Arc::new(price),
502 Arc::new(size),
503 Arc::new(order_id),
504 Arc::new(flags),
505 Arc::new(sequence),
506 Arc::new(ts_event),
507 Arc::new(ts_init),
508 ],
509 )
510 .unwrap();
511
512 let decoded_data = OrderBookDelta::decode_batch(&metadata, record_batch).unwrap();
513 assert_eq!(decoded_data.len(), 2);
514 assert_eq!(decoded_data[0].order.price.raw, PRICE_UNDEF);
515 assert_eq!(decoded_data[0].order.price.precision, 0);
516 assert_eq!(decoded_data[0].order.size.raw, QUANTITY_UNDEF);
517 assert_eq!(decoded_data[0].order.size.precision, 0);
518 assert_eq!(decoded_data[1].order.price.precision, 2);
519 assert_eq!(decoded_data[1].order.size.precision, 0);
520 }
521}