1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{
20 Array, FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt8Array, UInt32Array, UInt64Array,
21 },
22 datatypes::{DataType, Field, Schema},
23 error::ArrowError,
24 record_batch::RecordBatch,
25};
26use nautilus_model::{
27 data::{
28 depth::{DEPTH10_LEN, OrderBookDepth10},
29 order::BookOrder,
30 },
31 enums::OrderSide,
32 identifiers::InstrumentId,
33 types::{Price, Quantity, fixed::PRECISION_BYTES},
34};
35
36use super::{
37 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
38 KEY_SIZE_PRECISION, extract_column,
39};
40use crate::arrow::{
41 ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
42 get_raw_quantity,
43};
44
45fn get_field_data() -> Vec<(&'static str, DataType)> {
46 vec![
47 ("bid_price", DataType::FixedSizeBinary(PRECISION_BYTES)),
48 ("ask_price", DataType::FixedSizeBinary(PRECISION_BYTES)),
49 ("bid_size", DataType::FixedSizeBinary(PRECISION_BYTES)),
50 ("ask_size", DataType::FixedSizeBinary(PRECISION_BYTES)),
51 ("bid_count", DataType::UInt32),
52 ("ask_count", DataType::UInt32),
53 ]
54}
55
56impl ArrowSchemaProvider for OrderBookDepth10 {
57 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
58 let mut fields = Vec::new();
59 let field_data = get_field_data();
60
61 for (name, data_type) in field_data {
64 for i in 0..DEPTH10_LEN {
65 fields.push(Field::new(format!("{name}_{i}"), data_type.clone(), false));
66 }
67 }
68
69 fields.push(Field::new("flags", DataType::UInt8, false));
70 fields.push(Field::new("sequence", DataType::UInt64, false));
71 fields.push(Field::new("ts_event", DataType::UInt64, false));
72 fields.push(Field::new("ts_init", DataType::UInt64, false));
73
74 match metadata {
75 Some(metadata) => Schema::new_with_metadata(fields, metadata),
76 None => Schema::new(fields),
77 }
78 }
79}
80
81fn parse_metadata(
82 metadata: &HashMap<String, String>,
83) -> Result<(InstrumentId, u8, u8), EncodingError> {
84 let instrument_id_str = metadata
85 .get(KEY_INSTRUMENT_ID)
86 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
87 let instrument_id = InstrumentId::from_str(instrument_id_str)
88 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
89
90 let price_precision = metadata
91 .get(KEY_PRICE_PRECISION)
92 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
93 .parse::<u8>()
94 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
95
96 let size_precision = metadata
97 .get(KEY_SIZE_PRECISION)
98 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
99 .parse::<u8>()
100 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
101
102 Ok((instrument_id, price_precision, size_precision))
103}
104
105impl EncodeToRecordBatch for OrderBookDepth10 {
106 fn encode_batch(
107 metadata: &HashMap<String, String>,
108 data: &[Self],
109 ) -> Result<RecordBatch, ArrowError> {
110 let mut bid_price_builders = Vec::with_capacity(DEPTH10_LEN);
111 let mut ask_price_builders = Vec::with_capacity(DEPTH10_LEN);
112 let mut bid_size_builders = Vec::with_capacity(DEPTH10_LEN);
113 let mut ask_size_builders = Vec::with_capacity(DEPTH10_LEN);
114 let mut bid_count_builders = Vec::with_capacity(DEPTH10_LEN);
115 let mut ask_count_builders = Vec::with_capacity(DEPTH10_LEN);
116
117 for _ in 0..DEPTH10_LEN {
118 bid_price_builders.push(FixedSizeBinaryBuilder::with_capacity(
119 data.len(),
120 PRECISION_BYTES,
121 ));
122 ask_price_builders.push(FixedSizeBinaryBuilder::with_capacity(
123 data.len(),
124 PRECISION_BYTES,
125 ));
126 bid_size_builders.push(FixedSizeBinaryBuilder::with_capacity(
127 data.len(),
128 PRECISION_BYTES,
129 ));
130 ask_size_builders.push(FixedSizeBinaryBuilder::with_capacity(
131 data.len(),
132 PRECISION_BYTES,
133 ));
134 bid_count_builders.push(UInt32Array::builder(data.len()));
135 ask_count_builders.push(UInt32Array::builder(data.len()));
136 }
137
138 let mut flags_builder = UInt8Array::builder(data.len());
139 let mut sequence_builder = UInt64Array::builder(data.len());
140 let mut ts_event_builder = UInt64Array::builder(data.len());
141 let mut ts_init_builder = UInt64Array::builder(data.len());
142
143 for depth in data {
144 for i in 0..DEPTH10_LEN {
145 bid_price_builders[i]
146 .append_value(depth.bids[i].price.raw.to_le_bytes())
147 .unwrap();
148 ask_price_builders[i]
149 .append_value(depth.asks[i].price.raw.to_le_bytes())
150 .unwrap();
151 bid_size_builders[i]
152 .append_value(depth.bids[i].size.raw.to_le_bytes())
153 .unwrap();
154 ask_size_builders[i]
155 .append_value(depth.asks[i].size.raw.to_le_bytes())
156 .unwrap();
157 bid_count_builders[i].append_value(depth.bid_counts[i]);
158 ask_count_builders[i].append_value(depth.ask_counts[i]);
159 }
160
161 flags_builder.append_value(depth.flags);
162 sequence_builder.append_value(depth.sequence);
163 ts_event_builder.append_value(depth.ts_event.as_u64());
164 ts_init_builder.append_value(depth.ts_init.as_u64());
165 }
166
167 let bid_price_arrays = bid_price_builders
168 .into_iter()
169 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
170 .collect::<Vec<_>>();
171 let ask_price_arrays = ask_price_builders
172 .into_iter()
173 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
174 .collect::<Vec<_>>();
175 let bid_size_arrays = bid_size_builders
176 .into_iter()
177 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
178 .collect::<Vec<_>>();
179 let ask_size_arrays = ask_size_builders
180 .into_iter()
181 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
182 .collect::<Vec<_>>();
183 let bid_count_arrays = bid_count_builders
184 .into_iter()
185 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
186 .collect::<Vec<_>>();
187 let ask_count_arrays = ask_count_builders
188 .into_iter()
189 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
190 .collect::<Vec<_>>();
191
192 let flags_array = Arc::new(flags_builder.finish()) as Arc<dyn Array>;
193 let sequence_array = Arc::new(sequence_builder.finish()) as Arc<dyn Array>;
194 let ts_event_array = Arc::new(ts_event_builder.finish()) as Arc<dyn Array>;
195 let ts_init_array = Arc::new(ts_init_builder.finish()) as Arc<dyn Array>;
196
197 let mut columns = Vec::new();
198 columns.extend(bid_price_arrays);
199 columns.extend(ask_price_arrays);
200 columns.extend(bid_size_arrays);
201 columns.extend(ask_size_arrays);
202 columns.extend(bid_count_arrays);
203 columns.extend(ask_count_arrays);
204 columns.push(flags_array);
205 columns.push(sequence_array);
206 columns.push(ts_event_array);
207 columns.push(ts_init_array);
208
209 RecordBatch::try_new(Self::get_schema(Some(metadata.clone())).into(), columns)
210 }
211
212 fn metadata(&self) -> HashMap<String, String> {
213 OrderBookDepth10::get_metadata(
214 &self.instrument_id,
215 self.bids[0].price.precision,
216 self.bids[0].size.precision,
217 )
218 }
219}
220
221impl DecodeFromRecordBatch for OrderBookDepth10 {
222 fn decode_batch(
223 metadata: &HashMap<String, String>,
224 record_batch: RecordBatch,
225 ) -> Result<Vec<Self>, EncodingError> {
226 let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
227 let cols = record_batch.columns();
228
229 let mut bid_prices = Vec::with_capacity(DEPTH10_LEN);
230 let mut ask_prices = Vec::with_capacity(DEPTH10_LEN);
231 let mut bid_sizes = Vec::with_capacity(DEPTH10_LEN);
232 let mut ask_sizes = Vec::with_capacity(DEPTH10_LEN);
233 let mut bid_counts = Vec::with_capacity(DEPTH10_LEN);
234 let mut ask_counts = Vec::with_capacity(DEPTH10_LEN);
235
236 macro_rules! extract_depth_column {
237 ($array:ty, $name:literal, $i:expr, $offset:expr, $type:expr) => {
238 extract_column::<$array>(cols, concat!($name, "_", stringify!($i)), $offset, $type)?
239 };
240 }
241
242 for i in 0..DEPTH10_LEN {
243 bid_prices.push(extract_depth_column!(
244 FixedSizeBinaryArray,
245 "bid_price",
246 i,
247 i,
248 DataType::FixedSizeBinary(PRECISION_BYTES)
249 ));
250 ask_prices.push(extract_depth_column!(
251 FixedSizeBinaryArray,
252 "ask_price",
253 i,
254 DEPTH10_LEN + i,
255 DataType::FixedSizeBinary(PRECISION_BYTES)
256 ));
257 bid_sizes.push(extract_depth_column!(
258 FixedSizeBinaryArray,
259 "bid_size",
260 i,
261 2 * DEPTH10_LEN + i,
262 DataType::FixedSizeBinary(PRECISION_BYTES)
263 ));
264 ask_sizes.push(extract_depth_column!(
265 FixedSizeBinaryArray,
266 "ask_size",
267 i,
268 3 * DEPTH10_LEN + i,
269 DataType::FixedSizeBinary(PRECISION_BYTES)
270 ));
271 bid_counts.push(extract_depth_column!(
272 UInt32Array,
273 "bid_count",
274 i,
275 4 * DEPTH10_LEN + i,
276 DataType::UInt32
277 ));
278 ask_counts.push(extract_depth_column!(
279 UInt32Array,
280 "ask_count",
281 i,
282 5 * DEPTH10_LEN + i,
283 DataType::UInt32
284 ));
285 }
286
287 for i in 0..DEPTH10_LEN {
288 assert_eq!(
289 bid_prices[i].value_length(),
290 PRECISION_BYTES,
291 "Price precision uses {PRECISION_BYTES} byte value"
292 );
293 assert_eq!(
294 ask_prices[i].value_length(),
295 PRECISION_BYTES,
296 "Price precision uses {PRECISION_BYTES} byte value"
297 );
298 assert_eq!(
299 bid_sizes[i].value_length(),
300 PRECISION_BYTES,
301 "Size precision uses {PRECISION_BYTES} byte value"
302 );
303 assert_eq!(
304 ask_sizes[i].value_length(),
305 PRECISION_BYTES,
306 "Size precision uses {PRECISION_BYTES} byte value"
307 );
308 }
309
310 let flags = extract_column::<UInt8Array>(cols, "flags", 6 * DEPTH10_LEN, DataType::UInt8)?;
311 let sequence =
312 extract_column::<UInt64Array>(cols, "sequence", 6 * DEPTH10_LEN + 1, DataType::UInt64)?;
313 let ts_event =
314 extract_column::<UInt64Array>(cols, "ts_event", 6 * DEPTH10_LEN + 2, DataType::UInt64)?;
315 let ts_init =
316 extract_column::<UInt64Array>(cols, "ts_init", 6 * DEPTH10_LEN + 3, DataType::UInt64)?;
317
318 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
320 .map(|row| {
321 let mut bids = [BookOrder::default(); DEPTH10_LEN];
322 let mut asks = [BookOrder::default(); DEPTH10_LEN];
323 let mut bid_count_arr = [0u32; DEPTH10_LEN];
324 let mut ask_count_arr = [0u32; DEPTH10_LEN];
325
326 for i in 0..DEPTH10_LEN {
327 bids[i] = BookOrder::new(
328 OrderSide::Buy,
329 Price::from_raw(get_raw_price(bid_prices[i].value(row)), price_precision),
330 Quantity::from_raw(
331 get_raw_quantity(bid_sizes[i].value(row)),
332 size_precision,
333 ),
334 0, );
336 asks[i] = BookOrder::new(
337 OrderSide::Sell,
338 Price::from_raw(get_raw_price(ask_prices[i].value(row)), price_precision),
339 Quantity::from_raw(
340 get_raw_quantity(ask_sizes[i].value(row)),
341 size_precision,
342 ),
343 0, );
345 bid_count_arr[i] = bid_counts[i].value(row);
346 ask_count_arr[i] = ask_counts[i].value(row);
347 }
348
349 Ok(Self {
350 instrument_id,
351 bids,
352 asks,
353 bid_counts: bid_count_arr,
354 ask_counts: ask_count_arr,
355 flags: flags.value(row),
356 sequence: sequence.value(row),
357 ts_event: ts_event.value(row).into(),
358 ts_init: ts_init.value(row).into(),
359 })
360 })
361 .collect();
362
363 result
364 }
365}
366
367impl DecodeDataFromRecordBatch for OrderBookDepth10 {
368 fn decode_data_batch(
369 metadata: &HashMap<String, String>,
370 record_batch: RecordBatch,
371 ) -> Result<Vec<Data>, EncodingError> {
372 let depths: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
373 Ok(depths.into_iter().map(Data::from).collect())
374 }
375}
376
377#[cfg(test)]
381mod tests {
382 use arrow::datatypes::{DataType, Field};
383 use nautilus_model::{
384 data::stubs::stub_depth10,
385 types::{fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw},
386 };
387 use pretty_assertions::assert_eq;
388 use rstest::rstest;
389
390 use super::*;
391
392 #[rstest]
393 fn test_get_schema() {
394 let instrument_id = InstrumentId::from("AAPL.XNAS");
395 let metadata = OrderBookDepth10::get_metadata(&instrument_id, 2, 0);
396 let schema = OrderBookDepth10::get_schema(Some(metadata.clone()));
397
398 let mut group_count = 0;
399 let field_data = get_field_data();
400 for (name, data_type) in field_data {
401 for i in 0..DEPTH10_LEN {
402 let field = schema.field(i + group_count * DEPTH10_LEN).clone();
403 assert_eq!(
404 field,
405 Field::new(format!("{name}_{i}"), data_type.clone(), false)
406 );
407 }
408
409 group_count += 1;
410 }
411
412 let flags_field = schema.field(group_count * DEPTH10_LEN).clone();
413 assert_eq!(flags_field, Field::new("flags", DataType::UInt8, false));
414 let sequence_field = schema.field(group_count * DEPTH10_LEN + 1).clone();
415 assert_eq!(
416 sequence_field,
417 Field::new("sequence", DataType::UInt64, false)
418 );
419 let ts_event_field = schema.field(group_count * DEPTH10_LEN + 2).clone();
420 assert_eq!(
421 ts_event_field,
422 Field::new("ts_event", DataType::UInt64, false)
423 );
424 let ts_init_field = schema.field(group_count * DEPTH10_LEN + 3).clone();
425 assert_eq!(
426 ts_init_field,
427 Field::new("ts_init", DataType::UInt64, false)
428 );
429
430 assert_eq!(schema.metadata()["instrument_id"], "AAPL.XNAS");
431 assert_eq!(schema.metadata()["price_precision"], "2");
432 assert_eq!(schema.metadata()["size_precision"], "0");
433 }
434
435 #[rstest]
436 fn test_get_schema_map() {
437 let schema_map = OrderBookDepth10::get_schema_map();
438
439 let field_data = get_field_data();
440 for (name, data_type) in field_data {
441 for i in 0..DEPTH10_LEN {
442 let field = schema_map.get(&format!("{name}_{i}")).map(String::as_str);
443 assert_eq!(field, Some(format!("{data_type:?}").as_str()));
444 }
445 }
446
447 assert_eq!(schema_map.get("flags").map(String::as_str), Some("UInt8"));
448 assert_eq!(
449 schema_map.get("sequence").map(String::as_str),
450 Some("UInt64")
451 );
452 assert_eq!(
453 schema_map.get("ts_event").map(String::as_str),
454 Some("UInt64")
455 );
456 assert_eq!(
457 schema_map.get("ts_init").map(String::as_str),
458 Some("UInt64")
459 );
460 }
461
462 #[rstest]
463 fn test_encode_batch(stub_depth10: OrderBookDepth10) {
464 let instrument_id = InstrumentId::from("AAPL.XNAS");
465 let price_precision = 2;
466 let metadata = OrderBookDepth10::get_metadata(&instrument_id, price_precision, 0);
467
468 let data = vec![stub_depth10];
469 let record_batch = OrderBookDepth10::encode_batch(&metadata, &data).unwrap();
470 let columns = record_batch.columns();
471
472 assert_eq!(columns.len(), DEPTH10_LEN * 6 + 4);
473
474 let bid_prices: Vec<_> = (0..DEPTH10_LEN)
476 .map(|i| {
477 columns[i]
478 .as_any()
479 .downcast_ref::<FixedSizeBinaryArray>()
480 .unwrap()
481 })
482 .collect();
483
484 let expected_bid_prices: Vec<f64> =
485 vec![99.0, 98.0, 97.0, 96.0, 95.0, 94.0, 93.0, 92.0, 91.0, 90.0];
486
487 for (i, bid_price) in bid_prices.iter().enumerate() {
488 assert_eq!(bid_price.len(), 1);
489 assert_eq!(
490 get_raw_price(bid_price.value(0)),
491 (expected_bid_prices[i] * FIXED_SCALAR) as PriceRaw
492 );
493 assert_eq!(
494 Price::from_raw(get_raw_price(bid_price.value(0)), price_precision).as_f64(),
495 expected_bid_prices[i]
496 );
497 }
498
499 let ask_prices: Vec<_> = (0..DEPTH10_LEN)
501 .map(|i| {
502 columns[DEPTH10_LEN + i]
503 .as_any()
504 .downcast_ref::<FixedSizeBinaryArray>()
505 .unwrap()
506 })
507 .collect();
508
509 let expected_ask_prices: Vec<f64> = vec![
510 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0,
511 ];
512
513 for (i, ask_price) in ask_prices.iter().enumerate() {
514 assert_eq!(ask_price.len(), 1);
515 assert_eq!(
516 get_raw_price(ask_price.value(0)),
517 (expected_ask_prices[i] * FIXED_SCALAR) as PriceRaw
518 );
519 assert_eq!(
520 Price::from_raw(get_raw_price(ask_price.value(0)), price_precision).as_f64(),
521 expected_ask_prices[i]
522 );
523 }
524
525 let bid_sizes: Vec<_> = (0..DEPTH10_LEN)
527 .map(|i| {
528 columns[2 * DEPTH10_LEN + i]
529 .as_any()
530 .downcast_ref::<FixedSizeBinaryArray>()
531 .unwrap()
532 })
533 .collect();
534
535 for (i, bid_size) in bid_sizes.iter().enumerate() {
536 assert_eq!(bid_size.len(), 1);
537 assert_eq!(
538 get_raw_quantity(bid_size.value(0)),
539 ((100.0 * FIXED_SCALAR * (i + 1) as f64) as QuantityRaw)
540 );
541 }
542
543 let ask_sizes: Vec<_> = (0..DEPTH10_LEN)
545 .map(|i| {
546 columns[3 * DEPTH10_LEN + i]
547 .as_any()
548 .downcast_ref::<FixedSizeBinaryArray>()
549 .unwrap()
550 })
551 .collect();
552
553 for (i, ask_size) in ask_sizes.iter().enumerate() {
554 assert_eq!(ask_size.len(), 1);
555 assert_eq!(
556 get_raw_quantity(ask_size.value(0)),
557 ((100.0 * FIXED_SCALAR * ((i + 1) as f64)) as QuantityRaw)
558 );
559 }
560
561 let bid_counts: Vec<_> = (0..DEPTH10_LEN)
563 .map(|i| {
564 columns[4 * DEPTH10_LEN + i]
565 .as_any()
566 .downcast_ref::<UInt32Array>()
567 .unwrap()
568 })
569 .collect();
570
571 for count_values in bid_counts {
572 assert_eq!(count_values.len(), 1);
573 assert_eq!(count_values.value(0), 1);
574 }
575
576 let ask_counts: Vec<_> = (0..DEPTH10_LEN)
578 .map(|i| {
579 columns[5 * DEPTH10_LEN + i]
580 .as_any()
581 .downcast_ref::<UInt32Array>()
582 .unwrap()
583 })
584 .collect();
585
586 for count_values in ask_counts {
587 assert_eq!(count_values.len(), 1);
588 assert_eq!(count_values.value(0), 1);
589 }
590
591 let flags_values = columns[6 * DEPTH10_LEN]
593 .as_any()
594 .downcast_ref::<UInt8Array>()
595 .unwrap();
596 let sequence_values = columns[6 * DEPTH10_LEN + 1]
597 .as_any()
598 .downcast_ref::<UInt64Array>()
599 .unwrap();
600 let ts_event_values = columns[6 * DEPTH10_LEN + 2]
601 .as_any()
602 .downcast_ref::<UInt64Array>()
603 .unwrap();
604 let ts_init_values = columns[6 * DEPTH10_LEN + 3]
605 .as_any()
606 .downcast_ref::<UInt64Array>()
607 .unwrap();
608
609 assert_eq!(flags_values.len(), 1);
610 assert_eq!(flags_values.value(0), 0);
611 assert_eq!(sequence_values.len(), 1);
612 assert_eq!(sequence_values.value(0), 0);
613 assert_eq!(ts_event_values.len(), 1);
614 assert_eq!(ts_event_values.value(0), 1);
615 assert_eq!(ts_init_values.len(), 1);
616 assert_eq!(ts_init_values.value(0), 2);
617 }
618
619 #[rstest]
620 fn test_decode_batch(stub_depth10: OrderBookDepth10) {
621 let instrument_id = InstrumentId::from("AAPL.XNAS");
622 let metadata = OrderBookDepth10::get_metadata(&instrument_id, 2, 0);
623
624 let data = vec![stub_depth10];
625 let record_batch = OrderBookDepth10::encode_batch(&metadata, &data).unwrap();
626 let decoded_data = OrderBookDepth10::decode_batch(&metadata, record_batch).unwrap();
627
628 assert_eq!(decoded_data.len(), 1);
629 }
630}