1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{
20 Array, FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt8Array, UInt32Array, UInt64Array,
21 },
22 datatypes::{DataType, Field, Schema},
23 error::ArrowError,
24 record_batch::RecordBatch,
25};
26use nautilus_model::{
27 data::{
28 depth::{DEPTH10_LEN, OrderBookDepth10},
29 order::BookOrder,
30 },
31 enums::OrderSide,
32 identifiers::InstrumentId,
33 types::{Price, Quantity, fixed::PRECISION_BYTES},
34};
35
36use super::{
37 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
38 KEY_SIZE_PRECISION, extract_column, get_corrected_raw_price, get_corrected_raw_quantity,
39};
40use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
41
42fn get_field_data() -> Vec<(&'static str, DataType)> {
43 vec![
44 ("bid_price", DataType::FixedSizeBinary(PRECISION_BYTES)),
45 ("ask_price", DataType::FixedSizeBinary(PRECISION_BYTES)),
46 ("bid_size", DataType::FixedSizeBinary(PRECISION_BYTES)),
47 ("ask_size", DataType::FixedSizeBinary(PRECISION_BYTES)),
48 ("bid_count", DataType::UInt32),
49 ("ask_count", DataType::UInt32),
50 ]
51}
52
53impl ArrowSchemaProvider for OrderBookDepth10 {
54 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
55 let mut fields = Vec::new();
56 let field_data = get_field_data();
57
58 for (name, data_type) in field_data {
61 for i in 0..DEPTH10_LEN {
62 fields.push(Field::new(format!("{name}_{i}"), data_type.clone(), false));
63 }
64 }
65
66 fields.push(Field::new("flags", DataType::UInt8, false));
67 fields.push(Field::new("sequence", DataType::UInt64, false));
68 fields.push(Field::new("ts_event", DataType::UInt64, false));
69 fields.push(Field::new("ts_init", DataType::UInt64, false));
70
71 match metadata {
72 Some(metadata) => Schema::new_with_metadata(fields, metadata),
73 None => Schema::new(fields),
74 }
75 }
76}
77
78fn parse_metadata(
79 metadata: &HashMap<String, String>,
80) -> Result<(InstrumentId, u8, u8), EncodingError> {
81 let instrument_id_str = metadata
82 .get(KEY_INSTRUMENT_ID)
83 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
84 let instrument_id = InstrumentId::from_str(instrument_id_str)
85 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
86
87 let price_precision = metadata
88 .get(KEY_PRICE_PRECISION)
89 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
90 .parse::<u8>()
91 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
92
93 let size_precision = metadata
94 .get(KEY_SIZE_PRECISION)
95 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
96 .parse::<u8>()
97 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
98
99 Ok((instrument_id, price_precision, size_precision))
100}
101
102impl EncodeToRecordBatch for OrderBookDepth10 {
103 fn encode_batch(
104 metadata: &HashMap<String, String>,
105 data: &[Self],
106 ) -> Result<RecordBatch, ArrowError> {
107 let mut bid_price_builders = Vec::with_capacity(DEPTH10_LEN);
108 let mut ask_price_builders = Vec::with_capacity(DEPTH10_LEN);
109 let mut bid_size_builders = Vec::with_capacity(DEPTH10_LEN);
110 let mut ask_size_builders = Vec::with_capacity(DEPTH10_LEN);
111 let mut bid_count_builders = Vec::with_capacity(DEPTH10_LEN);
112 let mut ask_count_builders = Vec::with_capacity(DEPTH10_LEN);
113
114 for _ in 0..DEPTH10_LEN {
115 bid_price_builders.push(FixedSizeBinaryBuilder::with_capacity(
116 data.len(),
117 PRECISION_BYTES,
118 ));
119 ask_price_builders.push(FixedSizeBinaryBuilder::with_capacity(
120 data.len(),
121 PRECISION_BYTES,
122 ));
123 bid_size_builders.push(FixedSizeBinaryBuilder::with_capacity(
124 data.len(),
125 PRECISION_BYTES,
126 ));
127 ask_size_builders.push(FixedSizeBinaryBuilder::with_capacity(
128 data.len(),
129 PRECISION_BYTES,
130 ));
131 bid_count_builders.push(UInt32Array::builder(data.len()));
132 ask_count_builders.push(UInt32Array::builder(data.len()));
133 }
134
135 let mut flags_builder = UInt8Array::builder(data.len());
136 let mut sequence_builder = UInt64Array::builder(data.len());
137 let mut ts_event_builder = UInt64Array::builder(data.len());
138 let mut ts_init_builder = UInt64Array::builder(data.len());
139
140 for depth in data {
141 for i in 0..DEPTH10_LEN {
142 bid_price_builders[i]
143 .append_value(depth.bids[i].price.raw.to_le_bytes())
144 .unwrap();
145 ask_price_builders[i]
146 .append_value(depth.asks[i].price.raw.to_le_bytes())
147 .unwrap();
148 bid_size_builders[i]
149 .append_value(depth.bids[i].size.raw.to_le_bytes())
150 .unwrap();
151 ask_size_builders[i]
152 .append_value(depth.asks[i].size.raw.to_le_bytes())
153 .unwrap();
154 bid_count_builders[i].append_value(depth.bid_counts[i]);
155 ask_count_builders[i].append_value(depth.ask_counts[i]);
156 }
157
158 flags_builder.append_value(depth.flags);
159 sequence_builder.append_value(depth.sequence);
160 ts_event_builder.append_value(depth.ts_event.as_u64());
161 ts_init_builder.append_value(depth.ts_init.as_u64());
162 }
163
164 let bid_price_arrays = bid_price_builders
165 .into_iter()
166 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
167 .collect::<Vec<_>>();
168 let ask_price_arrays = ask_price_builders
169 .into_iter()
170 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
171 .collect::<Vec<_>>();
172 let bid_size_arrays = bid_size_builders
173 .into_iter()
174 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
175 .collect::<Vec<_>>();
176 let ask_size_arrays = ask_size_builders
177 .into_iter()
178 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
179 .collect::<Vec<_>>();
180 let bid_count_arrays = bid_count_builders
181 .into_iter()
182 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
183 .collect::<Vec<_>>();
184 let ask_count_arrays = ask_count_builders
185 .into_iter()
186 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
187 .collect::<Vec<_>>();
188
189 let flags_array = Arc::new(flags_builder.finish()) as Arc<dyn Array>;
190 let sequence_array = Arc::new(sequence_builder.finish()) as Arc<dyn Array>;
191 let ts_event_array = Arc::new(ts_event_builder.finish()) as Arc<dyn Array>;
192 let ts_init_array = Arc::new(ts_init_builder.finish()) as Arc<dyn Array>;
193
194 let mut columns = Vec::new();
195 columns.extend(bid_price_arrays);
196 columns.extend(ask_price_arrays);
197 columns.extend(bid_size_arrays);
198 columns.extend(ask_size_arrays);
199 columns.extend(bid_count_arrays);
200 columns.extend(ask_count_arrays);
201 columns.push(flags_array);
202 columns.push(sequence_array);
203 columns.push(ts_event_array);
204 columns.push(ts_init_array);
205
206 RecordBatch::try_new(Self::get_schema(Some(metadata.clone())).into(), columns)
207 }
208
209 fn metadata(&self) -> HashMap<String, String> {
210 Self::get_metadata(
211 &self.instrument_id,
212 self.bids[0].price.precision,
213 self.bids[0].size.precision,
214 )
215 }
216}
217
218impl DecodeFromRecordBatch for OrderBookDepth10 {
219 fn decode_batch(
220 metadata: &HashMap<String, String>,
221 record_batch: RecordBatch,
222 ) -> Result<Vec<Self>, EncodingError> {
223 let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
224 let cols = record_batch.columns();
225
226 let mut bid_prices = Vec::with_capacity(DEPTH10_LEN);
227 let mut ask_prices = Vec::with_capacity(DEPTH10_LEN);
228 let mut bid_sizes = Vec::with_capacity(DEPTH10_LEN);
229 let mut ask_sizes = Vec::with_capacity(DEPTH10_LEN);
230 let mut bid_counts = Vec::with_capacity(DEPTH10_LEN);
231 let mut ask_counts = Vec::with_capacity(DEPTH10_LEN);
232
233 macro_rules! extract_depth_column {
234 ($array:ty, $name:literal, $i:expr, $offset:expr, $type:expr) => {
235 extract_column::<$array>(cols, concat!($name, "_", stringify!($i)), $offset, $type)?
236 };
237 }
238
239 for i in 0..DEPTH10_LEN {
240 bid_prices.push(extract_depth_column!(
241 FixedSizeBinaryArray,
242 "bid_price",
243 i,
244 i,
245 DataType::FixedSizeBinary(PRECISION_BYTES)
246 ));
247 ask_prices.push(extract_depth_column!(
248 FixedSizeBinaryArray,
249 "ask_price",
250 i,
251 DEPTH10_LEN + i,
252 DataType::FixedSizeBinary(PRECISION_BYTES)
253 ));
254 bid_sizes.push(extract_depth_column!(
255 FixedSizeBinaryArray,
256 "bid_size",
257 i,
258 2 * DEPTH10_LEN + i,
259 DataType::FixedSizeBinary(PRECISION_BYTES)
260 ));
261 ask_sizes.push(extract_depth_column!(
262 FixedSizeBinaryArray,
263 "ask_size",
264 i,
265 3 * DEPTH10_LEN + i,
266 DataType::FixedSizeBinary(PRECISION_BYTES)
267 ));
268 bid_counts.push(extract_depth_column!(
269 UInt32Array,
270 "bid_count",
271 i,
272 4 * DEPTH10_LEN + i,
273 DataType::UInt32
274 ));
275 ask_counts.push(extract_depth_column!(
276 UInt32Array,
277 "ask_count",
278 i,
279 5 * DEPTH10_LEN + i,
280 DataType::UInt32
281 ));
282 }
283
284 for i in 0..DEPTH10_LEN {
285 if bid_prices[i].value_length() != PRECISION_BYTES {
286 return Err(EncodingError::ParseError(
287 "bid_price",
288 format!(
289 "Invalid value length at index {i}: expected {PRECISION_BYTES}, found {}",
290 bid_prices[i].value_length()
291 ),
292 ));
293 }
294 if ask_prices[i].value_length() != PRECISION_BYTES {
295 return Err(EncodingError::ParseError(
296 "ask_price",
297 format!(
298 "Invalid value length at index {i}: expected {PRECISION_BYTES}, found {}",
299 ask_prices[i].value_length()
300 ),
301 ));
302 }
303 if bid_sizes[i].value_length() != PRECISION_BYTES {
304 return Err(EncodingError::ParseError(
305 "bid_size",
306 format!(
307 "Invalid value length at index {i}: expected {PRECISION_BYTES}, found {}",
308 bid_sizes[i].value_length()
309 ),
310 ));
311 }
312 if ask_sizes[i].value_length() != PRECISION_BYTES {
313 return Err(EncodingError::ParseError(
314 "ask_size",
315 format!(
316 "Invalid value length at index {i}: expected {PRECISION_BYTES}, found {}",
317 ask_sizes[i].value_length()
318 ),
319 ));
320 }
321 }
322
323 let flags = extract_column::<UInt8Array>(cols, "flags", 6 * DEPTH10_LEN, DataType::UInt8)?;
324 let sequence =
325 extract_column::<UInt64Array>(cols, "sequence", 6 * DEPTH10_LEN + 1, DataType::UInt64)?;
326 let ts_event =
327 extract_column::<UInt64Array>(cols, "ts_event", 6 * DEPTH10_LEN + 2, DataType::UInt64)?;
328 let ts_init =
329 extract_column::<UInt64Array>(cols, "ts_init", 6 * DEPTH10_LEN + 3, DataType::UInt64)?;
330
331 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
333 .map(|row| {
334 let mut bids = [BookOrder::default(); DEPTH10_LEN];
335 let mut asks = [BookOrder::default(); DEPTH10_LEN];
336 let mut bid_count_arr = [0u32; DEPTH10_LEN];
337 let mut ask_count_arr = [0u32; DEPTH10_LEN];
338
339 for i in 0..DEPTH10_LEN {
341 bids[i] = BookOrder::new(
342 OrderSide::Buy,
343 Price::from_raw(
344 get_corrected_raw_price(bid_prices[i].value(row), price_precision),
345 price_precision,
346 ),
347 Quantity::from_raw(
348 get_corrected_raw_quantity(bid_sizes[i].value(row), size_precision),
349 size_precision,
350 ),
351 0, );
353 asks[i] = BookOrder::new(
354 OrderSide::Sell,
355 Price::from_raw(
356 get_corrected_raw_price(ask_prices[i].value(row), price_precision),
357 price_precision,
358 ),
359 Quantity::from_raw(
360 get_corrected_raw_quantity(ask_sizes[i].value(row), size_precision),
361 size_precision,
362 ),
363 0, );
365 bid_count_arr[i] = bid_counts[i].value(row);
366 ask_count_arr[i] = ask_counts[i].value(row);
367 }
368
369 Ok(Self {
370 instrument_id,
371 bids,
372 asks,
373 bid_counts: bid_count_arr,
374 ask_counts: ask_count_arr,
375 flags: flags.value(row),
376 sequence: sequence.value(row),
377 ts_event: ts_event.value(row).into(),
378 ts_init: ts_init.value(row).into(),
379 })
380 })
381 .collect();
382
383 result
384 }
385}
386
387impl DecodeDataFromRecordBatch for OrderBookDepth10 {
388 fn decode_data_batch(
389 metadata: &HashMap<String, String>,
390 record_batch: RecordBatch,
391 ) -> Result<Vec<Data>, EncodingError> {
392 let depths: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
393 Ok(depths.into_iter().map(Data::from).collect())
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use arrow::datatypes::{DataType, Field};
400 use nautilus_model::{
401 data::stubs::stub_depth10,
402 types::{fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw},
403 };
404 use pretty_assertions::assert_eq;
405 use rstest::rstest;
406
407 use super::*;
408 use crate::arrow::{get_raw_price, get_raw_quantity};
409
410 #[rstest]
411 fn test_get_schema() {
412 let instrument_id = InstrumentId::from("AAPL.XNAS");
413 let metadata = OrderBookDepth10::get_metadata(&instrument_id, 2, 0);
414 let schema = OrderBookDepth10::get_schema(Some(metadata));
415
416 let mut group_count = 0;
417 let field_data = get_field_data();
418 for (name, data_type) in field_data {
419 for i in 0..DEPTH10_LEN {
420 let field = schema.field(i + group_count * DEPTH10_LEN).clone();
421 assert_eq!(
422 field,
423 Field::new(format!("{name}_{i}"), data_type.clone(), false)
424 );
425 }
426
427 group_count += 1;
428 }
429
430 let flags_field = schema.field(group_count * DEPTH10_LEN).clone();
431 assert_eq!(flags_field, Field::new("flags", DataType::UInt8, false));
432 let sequence_field = schema.field(group_count * DEPTH10_LEN + 1).clone();
433 assert_eq!(
434 sequence_field,
435 Field::new("sequence", DataType::UInt64, false)
436 );
437 let ts_event_field = schema.field(group_count * DEPTH10_LEN + 2).clone();
438 assert_eq!(
439 ts_event_field,
440 Field::new("ts_event", DataType::UInt64, false)
441 );
442 let ts_init_field = schema.field(group_count * DEPTH10_LEN + 3).clone();
443 assert_eq!(
444 ts_init_field,
445 Field::new("ts_init", DataType::UInt64, false)
446 );
447
448 assert_eq!(schema.metadata()["instrument_id"], "AAPL.XNAS");
449 assert_eq!(schema.metadata()["price_precision"], "2");
450 assert_eq!(schema.metadata()["size_precision"], "0");
451 }
452
453 #[rstest]
454 fn test_get_schema_map() {
455 let schema_map = OrderBookDepth10::get_schema_map();
456
457 let field_data = get_field_data();
458 for (name, data_type) in field_data {
459 for i in 0..DEPTH10_LEN {
460 let field = schema_map.get(&format!("{name}_{i}")).map(String::as_str);
461 assert_eq!(field, Some(format!("{data_type:?}").as_str()));
462 }
463 }
464
465 assert_eq!(schema_map.get("flags").map(String::as_str), Some("UInt8"));
466 assert_eq!(
467 schema_map.get("sequence").map(String::as_str),
468 Some("UInt64")
469 );
470 assert_eq!(
471 schema_map.get("ts_event").map(String::as_str),
472 Some("UInt64")
473 );
474 assert_eq!(
475 schema_map.get("ts_init").map(String::as_str),
476 Some("UInt64")
477 );
478 }
479
480 #[rstest]
481 fn test_encode_batch(stub_depth10: OrderBookDepth10) {
482 let instrument_id = InstrumentId::from("AAPL.XNAS");
483 let price_precision = 2;
484 let metadata = OrderBookDepth10::get_metadata(&instrument_id, price_precision, 0);
485
486 let data = vec![stub_depth10];
487 let record_batch = OrderBookDepth10::encode_batch(&metadata, &data).unwrap();
488 let columns = record_batch.columns();
489
490 assert_eq!(columns.len(), DEPTH10_LEN * 6 + 4);
491
492 let bid_prices: Vec<_> = (0..DEPTH10_LEN)
494 .map(|i| {
495 columns[i]
496 .as_any()
497 .downcast_ref::<FixedSizeBinaryArray>()
498 .unwrap()
499 })
500 .collect();
501
502 let expected_bid_prices: Vec<f64> =
503 vec![99.0, 98.0, 97.0, 96.0, 95.0, 94.0, 93.0, 92.0, 91.0, 90.0];
504
505 for (i, bid_price) in bid_prices.iter().enumerate() {
506 assert_eq!(bid_price.len(), 1);
507 assert_eq!(
508 get_raw_price(bid_price.value(0)),
509 (expected_bid_prices[i] * FIXED_SCALAR) as PriceRaw
510 );
511 assert_eq!(
512 Price::from_raw(get_raw_price(bid_price.value(0)), price_precision).as_f64(),
513 expected_bid_prices[i]
514 );
515 }
516
517 let ask_prices: Vec<_> = (0..DEPTH10_LEN)
519 .map(|i| {
520 columns[DEPTH10_LEN + i]
521 .as_any()
522 .downcast_ref::<FixedSizeBinaryArray>()
523 .unwrap()
524 })
525 .collect();
526
527 let expected_ask_prices: Vec<f64> = vec![
528 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0,
529 ];
530
531 for (i, ask_price) in ask_prices.iter().enumerate() {
532 assert_eq!(ask_price.len(), 1);
533 assert_eq!(
534 get_raw_price(ask_price.value(0)),
535 (expected_ask_prices[i] * FIXED_SCALAR) as PriceRaw
536 );
537 assert_eq!(
538 Price::from_raw(get_raw_price(ask_price.value(0)), price_precision).as_f64(),
539 expected_ask_prices[i]
540 );
541 }
542
543 let bid_sizes: Vec<_> = (0..DEPTH10_LEN)
545 .map(|i| {
546 columns[2 * DEPTH10_LEN + i]
547 .as_any()
548 .downcast_ref::<FixedSizeBinaryArray>()
549 .unwrap()
550 })
551 .collect();
552
553 for (i, bid_size) in bid_sizes.iter().enumerate() {
554 assert_eq!(bid_size.len(), 1);
555 assert_eq!(
556 get_raw_quantity(bid_size.value(0)),
557 ((100.0 * FIXED_SCALAR * (i + 1) as f64) as QuantityRaw)
558 );
559 }
560
561 let ask_sizes: Vec<_> = (0..DEPTH10_LEN)
563 .map(|i| {
564 columns[3 * DEPTH10_LEN + i]
565 .as_any()
566 .downcast_ref::<FixedSizeBinaryArray>()
567 .unwrap()
568 })
569 .collect();
570
571 for (i, ask_size) in ask_sizes.iter().enumerate() {
572 assert_eq!(ask_size.len(), 1);
573 assert_eq!(
574 get_raw_quantity(ask_size.value(0)),
575 ((100.0 * FIXED_SCALAR * ((i + 1) as f64)) as QuantityRaw)
576 );
577 }
578
579 let bid_counts: Vec<_> = (0..DEPTH10_LEN)
581 .map(|i| {
582 columns[4 * DEPTH10_LEN + i]
583 .as_any()
584 .downcast_ref::<UInt32Array>()
585 .unwrap()
586 })
587 .collect();
588
589 for count_values in bid_counts {
590 assert_eq!(count_values.len(), 1);
591 assert_eq!(count_values.value(0), 1);
592 }
593
594 let ask_counts: Vec<_> = (0..DEPTH10_LEN)
596 .map(|i| {
597 columns[5 * DEPTH10_LEN + i]
598 .as_any()
599 .downcast_ref::<UInt32Array>()
600 .unwrap()
601 })
602 .collect();
603
604 for count_values in ask_counts {
605 assert_eq!(count_values.len(), 1);
606 assert_eq!(count_values.value(0), 1);
607 }
608
609 let flags_values = columns[6 * DEPTH10_LEN]
611 .as_any()
612 .downcast_ref::<UInt8Array>()
613 .unwrap();
614 let sequence_values = columns[6 * DEPTH10_LEN + 1]
615 .as_any()
616 .downcast_ref::<UInt64Array>()
617 .unwrap();
618 let ts_event_values = columns[6 * DEPTH10_LEN + 2]
619 .as_any()
620 .downcast_ref::<UInt64Array>()
621 .unwrap();
622 let ts_init_values = columns[6 * DEPTH10_LEN + 3]
623 .as_any()
624 .downcast_ref::<UInt64Array>()
625 .unwrap();
626
627 assert_eq!(flags_values.len(), 1);
628 assert_eq!(flags_values.value(0), 0);
629 assert_eq!(sequence_values.len(), 1);
630 assert_eq!(sequence_values.value(0), 0);
631 assert_eq!(ts_event_values.len(), 1);
632 assert_eq!(ts_event_values.value(0), 1);
633 assert_eq!(ts_init_values.len(), 1);
634 assert_eq!(ts_init_values.value(0), 2);
635 }
636
637 #[rstest]
638 fn test_decode_batch(stub_depth10: OrderBookDepth10) {
639 let instrument_id = InstrumentId::from("AAPL.XNAS");
640 let metadata = OrderBookDepth10::get_metadata(&instrument_id, 2, 0);
641
642 let data = vec![stub_depth10];
643 let record_batch = OrderBookDepth10::encode_batch(&metadata, &data).unwrap();
644 let decoded_data = OrderBookDepth10::decode_batch(&metadata, record_batch).unwrap();
645
646 assert_eq!(decoded_data.len(), 1);
647 }
648}