1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{
20 Array, FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt8Array, UInt32Array, UInt64Array,
21 },
22 datatypes::{DataType, Field, Schema},
23 error::ArrowError,
24 record_batch::RecordBatch,
25};
26use nautilus_model::{
27 data::{
28 depth::{DEPTH10_LEN, OrderBookDepth10},
29 order::BookOrder,
30 },
31 enums::OrderSide,
32 identifiers::InstrumentId,
33 types::{Price, Quantity, fixed::PRECISION_BYTES},
34};
35
36use super::{
37 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
38 KEY_SIZE_PRECISION, extract_column,
39};
40use crate::arrow::{
41 ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
42 get_raw_quantity,
43};
44
45fn get_field_data() -> Vec<(&'static str, DataType)> {
46 vec![
47 ("bid_price", DataType::FixedSizeBinary(PRECISION_BYTES)),
48 ("ask_price", DataType::FixedSizeBinary(PRECISION_BYTES)),
49 ("bid_size", DataType::FixedSizeBinary(PRECISION_BYTES)),
50 ("ask_size", DataType::FixedSizeBinary(PRECISION_BYTES)),
51 ("bid_count", DataType::UInt32),
52 ("ask_count", DataType::UInt32),
53 ]
54}
55
56impl ArrowSchemaProvider for OrderBookDepth10 {
57 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
58 let mut fields = Vec::new();
59 let field_data = get_field_data();
60
61 for (name, data_type) in field_data {
64 for i in 0..DEPTH10_LEN {
65 fields.push(Field::new(
66 format!("{}_{i}", name),
67 data_type.clone(),
68 false,
69 ));
70 }
71 }
72
73 fields.push(Field::new("flags", DataType::UInt8, false));
74 fields.push(Field::new("sequence", DataType::UInt64, false));
75 fields.push(Field::new("ts_event", DataType::UInt64, false));
76 fields.push(Field::new("ts_init", DataType::UInt64, false));
77
78 match metadata {
79 Some(metadata) => Schema::new_with_metadata(fields, metadata),
80 None => Schema::new(fields),
81 }
82 }
83}
84
85fn parse_metadata(
86 metadata: &HashMap<String, String>,
87) -> Result<(InstrumentId, u8, u8), EncodingError> {
88 let instrument_id_str = metadata
89 .get(KEY_INSTRUMENT_ID)
90 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
91 let instrument_id = InstrumentId::from_str(instrument_id_str)
92 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
93
94 let price_precision = metadata
95 .get(KEY_PRICE_PRECISION)
96 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
97 .parse::<u8>()
98 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
99
100 let size_precision = metadata
101 .get(KEY_SIZE_PRECISION)
102 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
103 .parse::<u8>()
104 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
105
106 Ok((instrument_id, price_precision, size_precision))
107}
108
109impl EncodeToRecordBatch for OrderBookDepth10 {
110 fn encode_batch(
111 metadata: &HashMap<String, String>,
112 data: &[Self],
113 ) -> Result<RecordBatch, ArrowError> {
114 let mut bid_price_builders = Vec::with_capacity(DEPTH10_LEN);
115 let mut ask_price_builders = Vec::with_capacity(DEPTH10_LEN);
116 let mut bid_size_builders = Vec::with_capacity(DEPTH10_LEN);
117 let mut ask_size_builders = Vec::with_capacity(DEPTH10_LEN);
118 let mut bid_count_builders = Vec::with_capacity(DEPTH10_LEN);
119 let mut ask_count_builders = Vec::with_capacity(DEPTH10_LEN);
120
121 for _ in 0..DEPTH10_LEN {
122 bid_price_builders.push(FixedSizeBinaryBuilder::with_capacity(
123 data.len(),
124 PRECISION_BYTES,
125 ));
126 ask_price_builders.push(FixedSizeBinaryBuilder::with_capacity(
127 data.len(),
128 PRECISION_BYTES,
129 ));
130 bid_size_builders.push(FixedSizeBinaryBuilder::with_capacity(
131 data.len(),
132 PRECISION_BYTES,
133 ));
134 ask_size_builders.push(FixedSizeBinaryBuilder::with_capacity(
135 data.len(),
136 PRECISION_BYTES,
137 ));
138 bid_count_builders.push(UInt32Array::builder(data.len()));
139 ask_count_builders.push(UInt32Array::builder(data.len()));
140 }
141
142 let mut flags_builder = UInt8Array::builder(data.len());
143 let mut sequence_builder = UInt64Array::builder(data.len());
144 let mut ts_event_builder = UInt64Array::builder(data.len());
145 let mut ts_init_builder = UInt64Array::builder(data.len());
146
147 for depth in data {
148 for i in 0..DEPTH10_LEN {
149 bid_price_builders[i]
150 .append_value(depth.bids[i].price.raw.to_le_bytes())
151 .unwrap();
152 ask_price_builders[i]
153 .append_value(depth.asks[i].price.raw.to_le_bytes())
154 .unwrap();
155 bid_size_builders[i]
156 .append_value(depth.bids[i].size.raw.to_le_bytes())
157 .unwrap();
158 ask_size_builders[i]
159 .append_value(depth.asks[i].size.raw.to_le_bytes())
160 .unwrap();
161 bid_count_builders[i].append_value(depth.bid_counts[i]);
162 ask_count_builders[i].append_value(depth.ask_counts[i]);
163 }
164
165 flags_builder.append_value(depth.flags);
166 sequence_builder.append_value(depth.sequence);
167 ts_event_builder.append_value(depth.ts_event.as_u64());
168 ts_init_builder.append_value(depth.ts_init.as_u64());
169 }
170
171 let bid_price_arrays = bid_price_builders
172 .into_iter()
173 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
174 .collect::<Vec<_>>();
175 let ask_price_arrays = ask_price_builders
176 .into_iter()
177 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
178 .collect::<Vec<_>>();
179 let bid_size_arrays = bid_size_builders
180 .into_iter()
181 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
182 .collect::<Vec<_>>();
183 let ask_size_arrays = ask_size_builders
184 .into_iter()
185 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
186 .collect::<Vec<_>>();
187 let bid_count_arrays = bid_count_builders
188 .into_iter()
189 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
190 .collect::<Vec<_>>();
191 let ask_count_arrays = ask_count_builders
192 .into_iter()
193 .map(|mut b| Arc::new(b.finish()) as Arc<dyn Array>)
194 .collect::<Vec<_>>();
195
196 let flags_array = Arc::new(flags_builder.finish()) as Arc<dyn Array>;
197 let sequence_array = Arc::new(sequence_builder.finish()) as Arc<dyn Array>;
198 let ts_event_array = Arc::new(ts_event_builder.finish()) as Arc<dyn Array>;
199 let ts_init_array = Arc::new(ts_init_builder.finish()) as Arc<dyn Array>;
200
201 let mut columns = Vec::new();
202 columns.extend(bid_price_arrays);
203 columns.extend(ask_price_arrays);
204 columns.extend(bid_size_arrays);
205 columns.extend(ask_size_arrays);
206 columns.extend(bid_count_arrays);
207 columns.extend(ask_count_arrays);
208 columns.push(flags_array);
209 columns.push(sequence_array);
210 columns.push(ts_event_array);
211 columns.push(ts_init_array);
212
213 RecordBatch::try_new(Self::get_schema(Some(metadata.clone())).into(), columns)
214 }
215
216 fn metadata(&self) -> HashMap<String, String> {
217 OrderBookDepth10::get_metadata(
218 &self.instrument_id,
219 self.bids[0].price.precision,
220 self.bids[0].size.precision,
221 )
222 }
223}
224
225impl DecodeFromRecordBatch for OrderBookDepth10 {
226 fn decode_batch(
227 metadata: &HashMap<String, String>,
228 record_batch: RecordBatch,
229 ) -> Result<Vec<Self>, EncodingError> {
230 let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
231 let cols = record_batch.columns();
232
233 let mut bid_prices = Vec::with_capacity(DEPTH10_LEN);
234 let mut ask_prices = Vec::with_capacity(DEPTH10_LEN);
235 let mut bid_sizes = Vec::with_capacity(DEPTH10_LEN);
236 let mut ask_sizes = Vec::with_capacity(DEPTH10_LEN);
237 let mut bid_counts = Vec::with_capacity(DEPTH10_LEN);
238 let mut ask_counts = Vec::with_capacity(DEPTH10_LEN);
239
240 macro_rules! extract_depth_column {
241 ($array:ty, $name:literal, $i:expr, $offset:expr, $type:expr) => {
242 extract_column::<$array>(cols, concat!($name, "_", stringify!($i)), $offset, $type)?
243 };
244 }
245
246 for i in 0..DEPTH10_LEN {
247 bid_prices.push(extract_depth_column!(
248 FixedSizeBinaryArray,
249 "bid_price",
250 i,
251 i,
252 DataType::FixedSizeBinary(PRECISION_BYTES)
253 ));
254 ask_prices.push(extract_depth_column!(
255 FixedSizeBinaryArray,
256 "ask_price",
257 i,
258 DEPTH10_LEN + i,
259 DataType::FixedSizeBinary(PRECISION_BYTES)
260 ));
261 bid_sizes.push(extract_depth_column!(
262 FixedSizeBinaryArray,
263 "bid_size",
264 i,
265 2 * DEPTH10_LEN + i,
266 DataType::FixedSizeBinary(PRECISION_BYTES)
267 ));
268 ask_sizes.push(extract_depth_column!(
269 FixedSizeBinaryArray,
270 "ask_size",
271 i,
272 3 * DEPTH10_LEN + i,
273 DataType::FixedSizeBinary(PRECISION_BYTES)
274 ));
275 bid_counts.push(extract_depth_column!(
276 UInt32Array,
277 "bid_count",
278 i,
279 4 * DEPTH10_LEN + i,
280 DataType::UInt32
281 ));
282 ask_counts.push(extract_depth_column!(
283 UInt32Array,
284 "ask_count",
285 i,
286 5 * DEPTH10_LEN + i,
287 DataType::UInt32
288 ));
289 }
290
291 for i in 0..DEPTH10_LEN {
292 assert_eq!(
293 bid_prices[i].value_length(),
294 PRECISION_BYTES,
295 "Price precision uses {PRECISION_BYTES} byte value"
296 );
297 assert_eq!(
298 ask_prices[i].value_length(),
299 PRECISION_BYTES,
300 "Price precision uses {PRECISION_BYTES} byte value"
301 );
302 assert_eq!(
303 bid_sizes[i].value_length(),
304 PRECISION_BYTES,
305 "Size precision uses {PRECISION_BYTES} byte value"
306 );
307 assert_eq!(
308 ask_sizes[i].value_length(),
309 PRECISION_BYTES,
310 "Size precision uses {PRECISION_BYTES} byte value"
311 );
312 }
313
314 let flags = extract_column::<UInt8Array>(cols, "flags", 6 * DEPTH10_LEN, DataType::UInt8)?;
315 let sequence =
316 extract_column::<UInt64Array>(cols, "sequence", 6 * DEPTH10_LEN + 1, DataType::UInt64)?;
317 let ts_event =
318 extract_column::<UInt64Array>(cols, "ts_event", 6 * DEPTH10_LEN + 2, DataType::UInt64)?;
319 let ts_init =
320 extract_column::<UInt64Array>(cols, "ts_init", 6 * DEPTH10_LEN + 3, DataType::UInt64)?;
321
322 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
324 .map(|row| {
325 let mut bids = [BookOrder::default(); DEPTH10_LEN];
326 let mut asks = [BookOrder::default(); DEPTH10_LEN];
327 let mut bid_count_arr = [0u32; DEPTH10_LEN];
328 let mut ask_count_arr = [0u32; DEPTH10_LEN];
329
330 for i in 0..DEPTH10_LEN {
331 bids[i] = BookOrder::new(
332 OrderSide::Buy,
333 Price::from_raw(get_raw_price(bid_prices[i].value(row)), price_precision),
334 Quantity::from_raw(
335 get_raw_quantity(bid_sizes[i].value(row)),
336 size_precision,
337 ),
338 0, );
340 asks[i] = BookOrder::new(
341 OrderSide::Sell,
342 Price::from_raw(get_raw_price(ask_prices[i].value(row)), price_precision),
343 Quantity::from_raw(
344 get_raw_quantity(ask_sizes[i].value(row)),
345 size_precision,
346 ),
347 0, );
349 bid_count_arr[i] = bid_counts[i].value(row);
350 ask_count_arr[i] = ask_counts[i].value(row);
351 }
352
353 Ok(Self {
354 instrument_id,
355 bids,
356 asks,
357 bid_counts: bid_count_arr,
358 ask_counts: ask_count_arr,
359 flags: flags.value(row),
360 sequence: sequence.value(row),
361 ts_event: ts_event.value(row).into(),
362 ts_init: ts_init.value(row).into(),
363 })
364 })
365 .collect();
366
367 result
368 }
369}
370
371impl DecodeDataFromRecordBatch for OrderBookDepth10 {
372 fn decode_data_batch(
373 metadata: &HashMap<String, String>,
374 record_batch: RecordBatch,
375 ) -> Result<Vec<Data>, EncodingError> {
376 let depths: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
377 Ok(depths.into_iter().map(Data::from).collect())
378 }
379}
380
381#[cfg(test)]
385mod tests {
386 use arrow::datatypes::{DataType, Field};
387 use nautilus_model::{
388 data::stubs::stub_depth10,
389 types::{fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw},
390 };
391 use pretty_assertions::assert_eq;
392 use rstest::rstest;
393
394 use super::*;
395
396 #[rstest]
397 fn test_get_schema() {
398 let instrument_id = InstrumentId::from("AAPL.XNAS");
399 let metadata = OrderBookDepth10::get_metadata(&instrument_id, 2, 0);
400 let schema = OrderBookDepth10::get_schema(Some(metadata.clone()));
401
402 let mut group_count = 0;
403 let field_data = get_field_data();
404 for (name, data_type) in field_data {
405 for i in 0..DEPTH10_LEN {
406 let field = schema.field(i + group_count * DEPTH10_LEN).clone();
407 assert_eq!(
408 field,
409 Field::new(format!("{}_{i}", name), data_type.clone(), false)
410 );
411 }
412
413 group_count += 1;
414 }
415
416 let flags_field = schema.field(group_count * DEPTH10_LEN).clone();
417 assert_eq!(flags_field, Field::new("flags", DataType::UInt8, false));
418 let sequence_field = schema.field(group_count * DEPTH10_LEN + 1).clone();
419 assert_eq!(
420 sequence_field,
421 Field::new("sequence", DataType::UInt64, false)
422 );
423 let ts_event_field = schema.field(group_count * DEPTH10_LEN + 2).clone();
424 assert_eq!(
425 ts_event_field,
426 Field::new("ts_event", DataType::UInt64, false)
427 );
428 let ts_init_field = schema.field(group_count * DEPTH10_LEN + 3).clone();
429 assert_eq!(
430 ts_init_field,
431 Field::new("ts_init", DataType::UInt64, false)
432 );
433
434 assert_eq!(schema.metadata()["instrument_id"], "AAPL.XNAS");
435 assert_eq!(schema.metadata()["price_precision"], "2");
436 assert_eq!(schema.metadata()["size_precision"], "0");
437 }
438
439 #[rstest]
440 fn test_get_schema_map() {
441 let schema_map = OrderBookDepth10::get_schema_map();
442
443 let field_data = get_field_data();
444 for (name, data_type) in field_data {
445 for i in 0..DEPTH10_LEN {
446 let field = schema_map.get(&format!("{}_{i}", name)).map(String::as_str);
447 assert_eq!(field, Some(format!("{:?}", data_type).as_str()));
448 }
449 }
450
451 assert_eq!(schema_map.get("flags").map(String::as_str), Some("UInt8"));
452 assert_eq!(
453 schema_map.get("sequence").map(String::as_str),
454 Some("UInt64")
455 );
456 assert_eq!(
457 schema_map.get("ts_event").map(String::as_str),
458 Some("UInt64")
459 );
460 assert_eq!(
461 schema_map.get("ts_init").map(String::as_str),
462 Some("UInt64")
463 );
464 }
465
466 #[rstest]
467 fn test_encode_batch(stub_depth10: OrderBookDepth10) {
468 let instrument_id = InstrumentId::from("AAPL.XNAS");
469 let price_precision = 2;
470 let metadata = OrderBookDepth10::get_metadata(&instrument_id, price_precision, 0);
471
472 let data = vec![stub_depth10];
473 let record_batch = OrderBookDepth10::encode_batch(&metadata, &data).unwrap();
474 let columns = record_batch.columns();
475
476 assert_eq!(columns.len(), DEPTH10_LEN * 6 + 4);
477
478 let bid_prices: Vec<_> = (0..DEPTH10_LEN)
480 .map(|i| {
481 columns[i]
482 .as_any()
483 .downcast_ref::<FixedSizeBinaryArray>()
484 .unwrap()
485 })
486 .collect();
487
488 let expected_bid_prices: Vec<f64> =
489 vec![99.0, 98.0, 97.0, 96.0, 95.0, 94.0, 93.0, 92.0, 91.0, 90.0];
490
491 for (i, bid_price) in bid_prices.iter().enumerate() {
492 assert_eq!(bid_price.len(), 1);
493 assert_eq!(
494 get_raw_price(bid_price.value(0)),
495 (expected_bid_prices[i] * FIXED_SCALAR) as PriceRaw
496 );
497 assert_eq!(
498 Price::from_raw(get_raw_price(bid_price.value(0)), price_precision).as_f64(),
499 expected_bid_prices[i]
500 );
501 }
502
503 let ask_prices: Vec<_> = (0..DEPTH10_LEN)
505 .map(|i| {
506 columns[DEPTH10_LEN + i]
507 .as_any()
508 .downcast_ref::<FixedSizeBinaryArray>()
509 .unwrap()
510 })
511 .collect();
512
513 let expected_ask_prices: Vec<f64> = vec![
514 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0,
515 ];
516
517 for (i, ask_price) in ask_prices.iter().enumerate() {
518 assert_eq!(ask_price.len(), 1);
519 assert_eq!(
520 get_raw_price(ask_price.value(0)),
521 (expected_ask_prices[i] * FIXED_SCALAR) as PriceRaw
522 );
523 assert_eq!(
524 Price::from_raw(get_raw_price(ask_price.value(0)), price_precision).as_f64(),
525 expected_ask_prices[i]
526 );
527 }
528
529 let bid_sizes: Vec<_> = (0..DEPTH10_LEN)
531 .map(|i| {
532 columns[2 * DEPTH10_LEN + i]
533 .as_any()
534 .downcast_ref::<FixedSizeBinaryArray>()
535 .unwrap()
536 })
537 .collect();
538
539 for (i, bid_size) in bid_sizes.iter().enumerate() {
540 assert_eq!(bid_size.len(), 1);
541 assert_eq!(
542 get_raw_quantity(bid_size.value(0)),
543 ((100.0 * FIXED_SCALAR * (i + 1) as f64) as QuantityRaw)
544 );
545 }
546
547 let ask_sizes: Vec<_> = (0..DEPTH10_LEN)
549 .map(|i| {
550 columns[3 * DEPTH10_LEN + i]
551 .as_any()
552 .downcast_ref::<FixedSizeBinaryArray>()
553 .unwrap()
554 })
555 .collect();
556
557 for (i, ask_size) in ask_sizes.iter().enumerate() {
558 assert_eq!(ask_size.len(), 1);
559 assert_eq!(
560 get_raw_quantity(ask_size.value(0)),
561 ((100.0 * FIXED_SCALAR * ((i + 1) as f64)) as QuantityRaw)
562 );
563 }
564
565 let bid_counts: Vec<_> = (0..DEPTH10_LEN)
567 .map(|i| {
568 columns[4 * DEPTH10_LEN + i]
569 .as_any()
570 .downcast_ref::<UInt32Array>()
571 .unwrap()
572 })
573 .collect();
574
575 for count_values in bid_counts {
576 assert_eq!(count_values.len(), 1);
577 assert_eq!(count_values.value(0), 1);
578 }
579
580 let ask_counts: Vec<_> = (0..DEPTH10_LEN)
582 .map(|i| {
583 columns[5 * DEPTH10_LEN + i]
584 .as_any()
585 .downcast_ref::<UInt32Array>()
586 .unwrap()
587 })
588 .collect();
589
590 for count_values in ask_counts {
591 assert_eq!(count_values.len(), 1);
592 assert_eq!(count_values.value(0), 1);
593 }
594
595 let flags_values = columns[6 * DEPTH10_LEN]
597 .as_any()
598 .downcast_ref::<UInt8Array>()
599 .unwrap();
600 let sequence_values = columns[6 * DEPTH10_LEN + 1]
601 .as_any()
602 .downcast_ref::<UInt64Array>()
603 .unwrap();
604 let ts_event_values = columns[6 * DEPTH10_LEN + 2]
605 .as_any()
606 .downcast_ref::<UInt64Array>()
607 .unwrap();
608 let ts_init_values = columns[6 * DEPTH10_LEN + 3]
609 .as_any()
610 .downcast_ref::<UInt64Array>()
611 .unwrap();
612
613 assert_eq!(flags_values.len(), 1);
614 assert_eq!(flags_values.value(0), 0);
615 assert_eq!(sequence_values.len(), 1);
616 assert_eq!(sequence_values.value(0), 0);
617 assert_eq!(ts_event_values.len(), 1);
618 assert_eq!(ts_event_values.value(0), 1);
619 assert_eq!(ts_init_values.len(), 1);
620 assert_eq!(ts_init_values.value(0), 2);
621 }
622
623 #[rstest]
624 fn test_decode_batch(stub_depth10: OrderBookDepth10) {
625 let instrument_id = InstrumentId::from("AAPL.XNAS");
626 let metadata = OrderBookDepth10::get_metadata(&instrument_id, 2, 0);
627
628 let data = vec![stub_depth10];
629 let record_batch = OrderBookDepth10::encode_batch(&metadata, &data).unwrap();
630 let decoded_data = OrderBookDepth10::decode_batch(&metadata, record_batch).unwrap();
631
632 assert_eq!(decoded_data.len(), 1);
633 }
634}