1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{
20 Array, FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt8Array, UInt32Array, UInt64Array,
21 },
22 datatypes::{DataType, Field, Schema},
23 error::ArrowError,
24 record_batch::RecordBatch,
25};
26use nautilus_model::{
27 data::{
28 depth::{DEPTH10_LEN, OrderBookDepth10},
29 order::BookOrder,
30 },
31 enums::OrderSide,
32 identifiers::InstrumentId,
33 types::{Price, Quantity, fixed::PRECISION_BYTES},
34};
35
36use super::{
37 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
38 KEY_SIZE_PRECISION, extract_column,
39};
40use crate::arrow::{
41 ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
42 get_raw_quantity,
43};
44
45fn get_field_data() -> Vec<(&'static str, DataType)> {
46 vec![
47 ("bid_price", DataType::FixedSizeBinary(PRECISION_BYTES)),
48 ("ask_price", DataType::FixedSizeBinary(PRECISION_BYTES)),
49 ("bid_size", DataType::FixedSizeBinary(PRECISION_BYTES)),
50 ("ask_size", DataType::FixedSizeBinary(PRECISION_BYTES)),
51 ("bid_count", DataType::UInt32),
52 ("ask_count", DataType::UInt32),
53 ]
54}
55
56impl ArrowSchemaProvider for OrderBookDepth10 {
57 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
58 let mut fields = Vec::new();
59 let field_data = get_field_data();
60
61 for (name, data_type) in field_data {
64 for i in 0..DEPTH10_LEN {
65 fields.push(Field::new(format!("{name}_{i}"), data_type.clone(), false));
66 }
67 }
68
69 fields.push(Field::new("flags", DataType::UInt8, false));
70 fields.push(Field::new("sequence", DataType::UInt64, false));
71 fields.push(Field::new("ts_event", DataType::UInt64, false));
72 fields.push(Field::new("ts_init", DataType::UInt64, false));
73
74 match metadata {
75 Some(metadata) => Schema::new_with_metadata(fields, metadata),
76 None => Schema::new(fields),
77 }
78 }
79}
80
81fn parse_metadata(
82 metadata: &HashMap<String, String>,
83) -> Result<(InstrumentId, u8, u8), EncodingError> {
84 let instrument_id_str = metadata
85 .get(KEY_INSTRUMENT_ID)
86 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
87 let instrument_id = InstrumentId::from_str(instrument_id_str)
88 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
89
90 let price_precision = metadata
91 .get(KEY_PRICE_PRECISION)
92 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
93 .parse::<u8>()
94 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
95
96 let size_precision = metadata
97 .get(KEY_SIZE_PRECISION)
98 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
99 .parse::<u8>()
100 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
101
102 Ok((instrument_id, price_precision, size_precision))
103}
104
105impl EncodeToRecordBatch for OrderBookDepth10 {
106 fn encode_batch(
107 metadata: &HashMap<String, String>,
108 data: &[Self],
109 ) -> Result<RecordBatch, ArrowError> {
110 let mut bid_price_builders = Vec::with_capacity(DEPTH10_LEN);
111 let mut ask_price_builders = Vec::with_capacity(DEPTH10_LEN);
112 let mut bid_size_builders = Vec::with_capacity(DEPTH10_LEN);
113 let mut ask_size_builders = Vec::with_capacity(DEPTH10_LEN);
114 let mut bid_count_builders = Vec::with_capacity(DEPTH10_LEN);
115 let mut ask_count_builders = Vec::with_capacity(DEPTH10_LEN);
116
117 for _ in 0..DEPTH10_LEN {
118 bid_price_builders.push(FixedSizeBinaryBuilder::with_capacity(
119 data.len(),
120 PRECISION_BYTES,
121 ));
122 ask_price_builders.push(FixedSizeBinaryBuilder::with_capacity(
123 data.len(),
124 PRECISION_BYTES,
125 ));
126 bid_size_builders.push(FixedSizeBinaryBuilder::with_capacity(
127 data.len(),
128 PRECISION_BYTES,
129 ));
130 ask_size_builders.push(FixedSizeBinaryBuilder::with_capacity(
131 data.len(),
132 PRECISION_BYTES,
133 ));
134 bid_count_builders.push(UInt32Array::builder(data.len()));
135 ask_count_builders.push(UInt32Array::builder(data.len()));
136 }
137
138 let mut flags_builder = UInt8Array::builder(data.len());
139 let mut sequence_builder = UInt64Array::builder(data.len());
140 let mut ts_event_builder = UInt64Array::builder(data.len());
141 let mut ts_init_builder = UInt64Array::builder(data.len());
142
143 for depth in data {
144 for i in 0..DEPTH10_LEN {
145 bid_price_builders[i]
146 .append_value(depth.bids[i].price.raw.to_le_bytes())
147 .unwrap();
148 ask_price_builders[i]
149 .append_value(depth.asks[i].price.raw.to_le_bytes())
150 .unwrap();
151 bid_size_builders[i]
152 .append_value(depth.bids[i].size.raw.to_le_bytes())
153 .unwrap();
154 ask_size_builders[i]
155 .append_value(depth.asks[i].size.raw.to_le_bytes())
156 .unwrap();
157 bid_count_builders[i].append_value(depth.bid_counts[i]);
158 ask_count_builders[i].append_value(depth.ask_counts[i]);
159 }
160
161 flags_builder.append_value(depth.flags);
162 sequence_builder.append_value(depth.sequence);
163 ts_event_builder.append_value(depth.ts_event.as_u64());
164 ts_init_builder.append_value(depth.ts_init.as_u64());
165 }
166
167 let bid_price_arrays = bid_price_builders
168 .into_iter()
169 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
170 .collect::<Vec<_>>();
171 let ask_price_arrays = ask_price_builders
172 .into_iter()
173 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
174 .collect::<Vec<_>>();
175 let bid_size_arrays = bid_size_builders
176 .into_iter()
177 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
178 .collect::<Vec<_>>();
179 let ask_size_arrays = ask_size_builders
180 .into_iter()
181 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
182 .collect::<Vec<_>>();
183 let bid_count_arrays = bid_count_builders
184 .into_iter()
185 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
186 .collect::<Vec<_>>();
187 let ask_count_arrays = ask_count_builders
188 .into_iter()
189 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
190 .collect::<Vec<_>>();
191
192 let flags_array = Arc::new(flags_builder.finish()) as Arc<dyn Array>;
193 let sequence_array = Arc::new(sequence_builder.finish()) as Arc<dyn Array>;
194 let ts_event_array = Arc::new(ts_event_builder.finish()) as Arc<dyn Array>;
195 let ts_init_array = Arc::new(ts_init_builder.finish()) as Arc<dyn Array>;
196
197 let mut columns = Vec::new();
198 columns.extend(bid_price_arrays);
199 columns.extend(ask_price_arrays);
200 columns.extend(bid_size_arrays);
201 columns.extend(ask_size_arrays);
202 columns.extend(bid_count_arrays);
203 columns.extend(ask_count_arrays);
204 columns.push(flags_array);
205 columns.push(sequence_array);
206 columns.push(ts_event_array);
207 columns.push(ts_init_array);
208
209 RecordBatch::try_new(Self::get_schema(Some(metadata.clone())).into(), columns)
210 }
211
212 fn metadata(&self) -> HashMap<String, String> {
213 OrderBookDepth10::get_metadata(
214 &self.instrument_id,
215 self.bids[0].price.precision,
216 self.bids[0].size.precision,
217 )
218 }
219}
220
221impl DecodeFromRecordBatch for OrderBookDepth10 {
222 fn decode_batch(
223 metadata: &HashMap<String, String>,
224 record_batch: RecordBatch,
225 ) -> Result<Vec<Self>, EncodingError> {
226 let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
227 let cols = record_batch.columns();
228
229 let mut bid_prices = Vec::with_capacity(DEPTH10_LEN);
230 let mut ask_prices = Vec::with_capacity(DEPTH10_LEN);
231 let mut bid_sizes = Vec::with_capacity(DEPTH10_LEN);
232 let mut ask_sizes = Vec::with_capacity(DEPTH10_LEN);
233 let mut bid_counts = Vec::with_capacity(DEPTH10_LEN);
234 let mut ask_counts = Vec::with_capacity(DEPTH10_LEN);
235
236 macro_rules! extract_depth_column {
237 ($array:ty, $name:literal, $i:expr, $offset:expr, $type:expr) => {
238 extract_column::<$array>(cols, concat!($name, "_", stringify!($i)), $offset, $type)?
239 };
240 }
241
242 for i in 0..DEPTH10_LEN {
243 bid_prices.push(extract_depth_column!(
244 FixedSizeBinaryArray,
245 "bid_price",
246 i,
247 i,
248 DataType::FixedSizeBinary(PRECISION_BYTES)
249 ));
250 ask_prices.push(extract_depth_column!(
251 FixedSizeBinaryArray,
252 "ask_price",
253 i,
254 DEPTH10_LEN + i,
255 DataType::FixedSizeBinary(PRECISION_BYTES)
256 ));
257 bid_sizes.push(extract_depth_column!(
258 FixedSizeBinaryArray,
259 "bid_size",
260 i,
261 2 * DEPTH10_LEN + i,
262 DataType::FixedSizeBinary(PRECISION_BYTES)
263 ));
264 ask_sizes.push(extract_depth_column!(
265 FixedSizeBinaryArray,
266 "ask_size",
267 i,
268 3 * DEPTH10_LEN + i,
269 DataType::FixedSizeBinary(PRECISION_BYTES)
270 ));
271 bid_counts.push(extract_depth_column!(
272 UInt32Array,
273 "bid_count",
274 i,
275 4 * DEPTH10_LEN + i,
276 DataType::UInt32
277 ));
278 ask_counts.push(extract_depth_column!(
279 UInt32Array,
280 "ask_count",
281 i,
282 5 * DEPTH10_LEN + i,
283 DataType::UInt32
284 ));
285 }
286
287 for i in 0..DEPTH10_LEN {
288 if bid_prices[i].value_length() != PRECISION_BYTES {
289 return Err(EncodingError::ParseError(
290 "bid_price",
291 format!(
292 "Invalid value length at index {i}: expected {PRECISION_BYTES}, found {}",
293 bid_prices[i].value_length()
294 ),
295 ));
296 }
297 if ask_prices[i].value_length() != PRECISION_BYTES {
298 return Err(EncodingError::ParseError(
299 "ask_price",
300 format!(
301 "Invalid value length at index {i}: expected {PRECISION_BYTES}, found {}",
302 ask_prices[i].value_length()
303 ),
304 ));
305 }
306 if bid_sizes[i].value_length() != PRECISION_BYTES {
307 return Err(EncodingError::ParseError(
308 "bid_size",
309 format!(
310 "Invalid value length at index {i}: expected {PRECISION_BYTES}, found {}",
311 bid_sizes[i].value_length()
312 ),
313 ));
314 }
315 if ask_sizes[i].value_length() != PRECISION_BYTES {
316 return Err(EncodingError::ParseError(
317 "ask_size",
318 format!(
319 "Invalid value length at index {i}: expected {PRECISION_BYTES}, found {}",
320 ask_sizes[i].value_length()
321 ),
322 ));
323 }
324 }
325
326 let flags = extract_column::<UInt8Array>(cols, "flags", 6 * DEPTH10_LEN, DataType::UInt8)?;
327 let sequence =
328 extract_column::<UInt64Array>(cols, "sequence", 6 * DEPTH10_LEN + 1, DataType::UInt64)?;
329 let ts_event =
330 extract_column::<UInt64Array>(cols, "ts_event", 6 * DEPTH10_LEN + 2, DataType::UInt64)?;
331 let ts_init =
332 extract_column::<UInt64Array>(cols, "ts_init", 6 * DEPTH10_LEN + 3, DataType::UInt64)?;
333
334 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
336 .map(|row| {
337 let mut bids = [BookOrder::default(); DEPTH10_LEN];
338 let mut asks = [BookOrder::default(); DEPTH10_LEN];
339 let mut bid_count_arr = [0u32; DEPTH10_LEN];
340 let mut ask_count_arr = [0u32; DEPTH10_LEN];
341
342 for i in 0..DEPTH10_LEN {
343 bids[i] = BookOrder::new(
344 OrderSide::Buy,
345 Price::from_raw(get_raw_price(bid_prices[i].value(row)), price_precision),
346 Quantity::from_raw(
347 get_raw_quantity(bid_sizes[i].value(row)),
348 size_precision,
349 ),
350 0, );
352 asks[i] = BookOrder::new(
353 OrderSide::Sell,
354 Price::from_raw(get_raw_price(ask_prices[i].value(row)), price_precision),
355 Quantity::from_raw(
356 get_raw_quantity(ask_sizes[i].value(row)),
357 size_precision,
358 ),
359 0, );
361 bid_count_arr[i] = bid_counts[i].value(row);
362 ask_count_arr[i] = ask_counts[i].value(row);
363 }
364
365 Ok(Self {
366 instrument_id,
367 bids,
368 asks,
369 bid_counts: bid_count_arr,
370 ask_counts: ask_count_arr,
371 flags: flags.value(row),
372 sequence: sequence.value(row),
373 ts_event: ts_event.value(row).into(),
374 ts_init: ts_init.value(row).into(),
375 })
376 })
377 .collect();
378
379 result
380 }
381}
382
383impl DecodeDataFromRecordBatch for OrderBookDepth10 {
384 fn decode_data_batch(
385 metadata: &HashMap<String, String>,
386 record_batch: RecordBatch,
387 ) -> Result<Vec<Data>, EncodingError> {
388 let depths: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
389 Ok(depths.into_iter().map(Data::from).collect())
390 }
391}
392
393#[cfg(test)]
397mod tests {
398 use arrow::datatypes::{DataType, Field};
399 use nautilus_model::{
400 data::stubs::stub_depth10,
401 types::{fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw},
402 };
403 use pretty_assertions::assert_eq;
404 use rstest::rstest;
405
406 use super::*;
407
408 #[rstest]
409 fn test_get_schema() {
410 let instrument_id = InstrumentId::from("AAPL.XNAS");
411 let metadata = OrderBookDepth10::get_metadata(&instrument_id, 2, 0);
412 let schema = OrderBookDepth10::get_schema(Some(metadata.clone()));
413
414 let mut group_count = 0;
415 let field_data = get_field_data();
416 for (name, data_type) in field_data {
417 for i in 0..DEPTH10_LEN {
418 let field = schema.field(i + group_count * DEPTH10_LEN).clone();
419 assert_eq!(
420 field,
421 Field::new(format!("{name}_{i}"), data_type.clone(), false)
422 );
423 }
424
425 group_count += 1;
426 }
427
428 let flags_field = schema.field(group_count * DEPTH10_LEN).clone();
429 assert_eq!(flags_field, Field::new("flags", DataType::UInt8, false));
430 let sequence_field = schema.field(group_count * DEPTH10_LEN + 1).clone();
431 assert_eq!(
432 sequence_field,
433 Field::new("sequence", DataType::UInt64, false)
434 );
435 let ts_event_field = schema.field(group_count * DEPTH10_LEN + 2).clone();
436 assert_eq!(
437 ts_event_field,
438 Field::new("ts_event", DataType::UInt64, false)
439 );
440 let ts_init_field = schema.field(group_count * DEPTH10_LEN + 3).clone();
441 assert_eq!(
442 ts_init_field,
443 Field::new("ts_init", DataType::UInt64, false)
444 );
445
446 assert_eq!(schema.metadata()["instrument_id"], "AAPL.XNAS");
447 assert_eq!(schema.metadata()["price_precision"], "2");
448 assert_eq!(schema.metadata()["size_precision"], "0");
449 }
450
451 #[rstest]
452 fn test_get_schema_map() {
453 let schema_map = OrderBookDepth10::get_schema_map();
454
455 let field_data = get_field_data();
456 for (name, data_type) in field_data {
457 for i in 0..DEPTH10_LEN {
458 let field = schema_map.get(&format!("{name}_{i}")).map(String::as_str);
459 assert_eq!(field, Some(format!("{data_type:?}").as_str()));
460 }
461 }
462
463 assert_eq!(schema_map.get("flags").map(String::as_str), Some("UInt8"));
464 assert_eq!(
465 schema_map.get("sequence").map(String::as_str),
466 Some("UInt64")
467 );
468 assert_eq!(
469 schema_map.get("ts_event").map(String::as_str),
470 Some("UInt64")
471 );
472 assert_eq!(
473 schema_map.get("ts_init").map(String::as_str),
474 Some("UInt64")
475 );
476 }
477
478 #[rstest]
479 fn test_encode_batch(stub_depth10: OrderBookDepth10) {
480 let instrument_id = InstrumentId::from("AAPL.XNAS");
481 let price_precision = 2;
482 let metadata = OrderBookDepth10::get_metadata(&instrument_id, price_precision, 0);
483
484 let data = vec![stub_depth10];
485 let record_batch = OrderBookDepth10::encode_batch(&metadata, &data).unwrap();
486 let columns = record_batch.columns();
487
488 assert_eq!(columns.len(), DEPTH10_LEN * 6 + 4);
489
490 let bid_prices: Vec<_> = (0..DEPTH10_LEN)
492 .map(|i| {
493 columns[i]
494 .as_any()
495 .downcast_ref::<FixedSizeBinaryArray>()
496 .unwrap()
497 })
498 .collect();
499
500 let expected_bid_prices: Vec<f64> =
501 vec![99.0, 98.0, 97.0, 96.0, 95.0, 94.0, 93.0, 92.0, 91.0, 90.0];
502
503 for (i, bid_price) in bid_prices.iter().enumerate() {
504 assert_eq!(bid_price.len(), 1);
505 assert_eq!(
506 get_raw_price(bid_price.value(0)),
507 (expected_bid_prices[i] * FIXED_SCALAR) as PriceRaw
508 );
509 assert_eq!(
510 Price::from_raw(get_raw_price(bid_price.value(0)), price_precision).as_f64(),
511 expected_bid_prices[i]
512 );
513 }
514
515 let ask_prices: Vec<_> = (0..DEPTH10_LEN)
517 .map(|i| {
518 columns[DEPTH10_LEN + i]
519 .as_any()
520 .downcast_ref::<FixedSizeBinaryArray>()
521 .unwrap()
522 })
523 .collect();
524
525 let expected_ask_prices: Vec<f64> = vec![
526 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0,
527 ];
528
529 for (i, ask_price) in ask_prices.iter().enumerate() {
530 assert_eq!(ask_price.len(), 1);
531 assert_eq!(
532 get_raw_price(ask_price.value(0)),
533 (expected_ask_prices[i] * FIXED_SCALAR) as PriceRaw
534 );
535 assert_eq!(
536 Price::from_raw(get_raw_price(ask_price.value(0)), price_precision).as_f64(),
537 expected_ask_prices[i]
538 );
539 }
540
541 let bid_sizes: Vec<_> = (0..DEPTH10_LEN)
543 .map(|i| {
544 columns[2 * DEPTH10_LEN + i]
545 .as_any()
546 .downcast_ref::<FixedSizeBinaryArray>()
547 .unwrap()
548 })
549 .collect();
550
551 for (i, bid_size) in bid_sizes.iter().enumerate() {
552 assert_eq!(bid_size.len(), 1);
553 assert_eq!(
554 get_raw_quantity(bid_size.value(0)),
555 ((100.0 * FIXED_SCALAR * (i + 1) as f64) as QuantityRaw)
556 );
557 }
558
559 let ask_sizes: Vec<_> = (0..DEPTH10_LEN)
561 .map(|i| {
562 columns[3 * DEPTH10_LEN + i]
563 .as_any()
564 .downcast_ref::<FixedSizeBinaryArray>()
565 .unwrap()
566 })
567 .collect();
568
569 for (i, ask_size) in ask_sizes.iter().enumerate() {
570 assert_eq!(ask_size.len(), 1);
571 assert_eq!(
572 get_raw_quantity(ask_size.value(0)),
573 ((100.0 * FIXED_SCALAR * ((i + 1) as f64)) as QuantityRaw)
574 );
575 }
576
577 let bid_counts: Vec<_> = (0..DEPTH10_LEN)
579 .map(|i| {
580 columns[4 * DEPTH10_LEN + i]
581 .as_any()
582 .downcast_ref::<UInt32Array>()
583 .unwrap()
584 })
585 .collect();
586
587 for count_values in bid_counts {
588 assert_eq!(count_values.len(), 1);
589 assert_eq!(count_values.value(0), 1);
590 }
591
592 let ask_counts: Vec<_> = (0..DEPTH10_LEN)
594 .map(|i| {
595 columns[5 * DEPTH10_LEN + i]
596 .as_any()
597 .downcast_ref::<UInt32Array>()
598 .unwrap()
599 })
600 .collect();
601
602 for count_values in ask_counts {
603 assert_eq!(count_values.len(), 1);
604 assert_eq!(count_values.value(0), 1);
605 }
606
607 let flags_values = columns[6 * DEPTH10_LEN]
609 .as_any()
610 .downcast_ref::<UInt8Array>()
611 .unwrap();
612 let sequence_values = columns[6 * DEPTH10_LEN + 1]
613 .as_any()
614 .downcast_ref::<UInt64Array>()
615 .unwrap();
616 let ts_event_values = columns[6 * DEPTH10_LEN + 2]
617 .as_any()
618 .downcast_ref::<UInt64Array>()
619 .unwrap();
620 let ts_init_values = columns[6 * DEPTH10_LEN + 3]
621 .as_any()
622 .downcast_ref::<UInt64Array>()
623 .unwrap();
624
625 assert_eq!(flags_values.len(), 1);
626 assert_eq!(flags_values.value(0), 0);
627 assert_eq!(sequence_values.len(), 1);
628 assert_eq!(sequence_values.value(0), 0);
629 assert_eq!(ts_event_values.len(), 1);
630 assert_eq!(ts_event_values.value(0), 1);
631 assert_eq!(ts_init_values.len(), 1);
632 assert_eq!(ts_init_values.value(0), 2);
633 }
634
635 #[rstest]
636 fn test_decode_batch(stub_depth10: OrderBookDepth10) {
637 let instrument_id = InstrumentId::from("AAPL.XNAS");
638 let metadata = OrderBookDepth10::get_metadata(&instrument_id, 2, 0);
639
640 let data = vec![stub_depth10];
641 let record_batch = OrderBookDepth10::encode_batch(&metadata, &data).unwrap();
642 let decoded_data = OrderBookDepth10::decode_batch(&metadata, record_batch).unwrap();
643
644 assert_eq!(decoded_data.len(), 1);
645 }
646}