1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt8Array, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::{BookOrder, OrderBookDelta},
26 enums::{BookAction, FromU8, OrderSide},
27 identifiers::InstrumentId,
28 types::{Price, Quantity, fixed::PRECISION_BYTES},
29};
30
31use super::{
32 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
33 KEY_SIZE_PRECISION, extract_column,
34};
35use crate::arrow::{
36 ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
37 get_raw_quantity,
38};
39
40impl ArrowSchemaProvider for OrderBookDelta {
41 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
42 let fields = vec![
43 Field::new("action", DataType::UInt8, false),
44 Field::new("side", DataType::UInt8, false),
45 Field::new("price", DataType::FixedSizeBinary(PRECISION_BYTES), false),
46 Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
47 Field::new("order_id", DataType::UInt64, false),
48 Field::new("flags", DataType::UInt8, false),
49 Field::new("sequence", DataType::UInt64, false),
50 Field::new("ts_event", DataType::UInt64, false),
51 Field::new("ts_init", DataType::UInt64, false),
52 ];
53
54 match metadata {
55 Some(metadata) => Schema::new_with_metadata(fields, metadata),
56 None => Schema::new(fields),
57 }
58 }
59}
60
61fn parse_metadata(
62 metadata: &HashMap<String, String>,
63) -> Result<(InstrumentId, u8, u8), EncodingError> {
64 let instrument_id_str = metadata
65 .get(KEY_INSTRUMENT_ID)
66 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
67 let instrument_id = InstrumentId::from_str(instrument_id_str)
68 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
69
70 let price_precision = metadata
71 .get(KEY_PRICE_PRECISION)
72 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
73 .parse::<u8>()
74 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
75
76 let size_precision = metadata
77 .get(KEY_SIZE_PRECISION)
78 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
79 .parse::<u8>()
80 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
81
82 Ok((instrument_id, price_precision, size_precision))
83}
84
85impl EncodeToRecordBatch for OrderBookDelta {
86 fn encode_batch(
87 metadata: &HashMap<String, String>,
88 data: &[Self],
89 ) -> Result<RecordBatch, ArrowError> {
90 let mut action_builder = UInt8Array::builder(data.len());
91 let mut side_builder = UInt8Array::builder(data.len());
92 let mut price_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
93 let mut size_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
94 let mut order_id_builder = UInt64Array::builder(data.len());
95 let mut flags_builder = UInt8Array::builder(data.len());
96 let mut sequence_builder = UInt64Array::builder(data.len());
97 let mut ts_event_builder = UInt64Array::builder(data.len());
98 let mut ts_init_builder = UInt64Array::builder(data.len());
99
100 for delta in data {
101 action_builder.append_value(delta.action as u8);
102 side_builder.append_value(delta.order.side as u8);
103 price_builder
104 .append_value(delta.order.price.raw.to_le_bytes())
105 .unwrap();
106 size_builder
107 .append_value(delta.order.size.raw.to_le_bytes())
108 .unwrap();
109 order_id_builder.append_value(delta.order.order_id);
110 flags_builder.append_value(delta.flags);
111 sequence_builder.append_value(delta.sequence);
112 ts_event_builder.append_value(delta.ts_event.as_u64());
113 ts_init_builder.append_value(delta.ts_init.as_u64());
114 }
115
116 let action_array = action_builder.finish();
117 let side_array = side_builder.finish();
118 let price_array = price_builder.finish();
119 let size_array = size_builder.finish();
120 let order_id_array = order_id_builder.finish();
121 let flags_array = flags_builder.finish();
122 let sequence_array = sequence_builder.finish();
123 let ts_event_array = ts_event_builder.finish();
124 let ts_init_array = ts_init_builder.finish();
125
126 RecordBatch::try_new(
127 Self::get_schema(Some(metadata.clone())).into(),
128 vec![
129 Arc::new(action_array),
130 Arc::new(side_array),
131 Arc::new(price_array),
132 Arc::new(size_array),
133 Arc::new(order_id_array),
134 Arc::new(flags_array),
135 Arc::new(sequence_array),
136 Arc::new(ts_event_array),
137 Arc::new(ts_init_array),
138 ],
139 )
140 }
141
142 fn metadata(&self) -> HashMap<String, String> {
143 OrderBookDelta::get_metadata(
144 &self.instrument_id,
145 self.order.price.precision,
146 self.order.size.precision,
147 )
148 }
149
150 fn chunk_metadata(chunk: &[Self]) -> HashMap<String, String> {
154 let delta = chunk
155 .first()
156 .expect("Chunk should have at least one element to encode");
157
158 if delta.order.price.precision == 0 && delta.order.size.precision == 0 {
159 if let Some(delta) = chunk.get(1) {
160 return EncodeToRecordBatch::metadata(delta);
161 }
162 }
163
164 EncodeToRecordBatch::metadata(delta)
165 }
166}
167
168impl DecodeFromRecordBatch for OrderBookDelta {
169 fn decode_batch(
170 metadata: &HashMap<String, String>,
171 record_batch: RecordBatch,
172 ) -> Result<Vec<Self>, EncodingError> {
173 let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
174 let cols = record_batch.columns();
175
176 let action_values = extract_column::<UInt8Array>(cols, "action", 0, DataType::UInt8)?;
177 let side_values = extract_column::<UInt8Array>(cols, "side", 1, DataType::UInt8)?;
178 let price_values = extract_column::<FixedSizeBinaryArray>(
179 cols,
180 "price",
181 2,
182 DataType::FixedSizeBinary(PRECISION_BYTES),
183 )?;
184 let size_values = extract_column::<FixedSizeBinaryArray>(
185 cols,
186 "size",
187 3,
188 DataType::FixedSizeBinary(PRECISION_BYTES),
189 )?;
190 let order_id_values = extract_column::<UInt64Array>(cols, "order_id", 4, DataType::UInt64)?;
191 let flags_values = extract_column::<UInt8Array>(cols, "flags", 5, DataType::UInt8)?;
192 let sequence_values = extract_column::<UInt64Array>(cols, "sequence", 6, DataType::UInt64)?;
193 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 7, DataType::UInt64)?;
194 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 8, DataType::UInt64)?;
195
196 assert_eq!(
197 price_values.value_length(),
198 PRECISION_BYTES,
199 "Price precision uses {PRECISION_BYTES} byte value"
200 );
201
202 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
203 .map(|i| {
204 let action_value = action_values.value(i);
205 let action = BookAction::from_u8(action_value).ok_or_else(|| {
206 EncodingError::ParseError(
207 stringify!(BookAction),
208 format!("Invalid enum value, was {action_value}"),
209 )
210 })?;
211 let side_value = side_values.value(i);
212 let side = OrderSide::from_u8(side_value).ok_or_else(|| {
213 EncodingError::ParseError(
214 stringify!(OrderSide),
215 format!("Invalid enum value, was {side_value}"),
216 )
217 })?;
218 let price = Price::from_raw(get_raw_price(price_values.value(i)), price_precision);
219 let size =
220 Quantity::from_raw(get_raw_quantity(size_values.value(i)), size_precision);
221 let order_id = order_id_values.value(i);
222 let flags = flags_values.value(i);
223 let sequence = sequence_values.value(i);
224 let ts_event = ts_event_values.value(i).into();
225 let ts_init = ts_init_values.value(i).into();
226
227 Ok(Self {
228 instrument_id,
229 action,
230 order: BookOrder {
231 side,
232 price,
233 size,
234 order_id,
235 },
236 flags,
237 sequence,
238 ts_event,
239 ts_init,
240 })
241 })
242 .collect();
243
244 result
245 }
246}
247
248impl DecodeDataFromRecordBatch for OrderBookDelta {
249 fn decode_data_batch(
250 metadata: &HashMap<String, String>,
251 record_batch: RecordBatch,
252 ) -> Result<Vec<Data>, EncodingError> {
253 let deltas: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
254 Ok(deltas.into_iter().map(Data::from).collect())
255 }
256}
257
258#[cfg(test)]
262mod tests {
263 use std::sync::Arc;
264
265 use arrow::{array::Array, record_batch::RecordBatch};
266 use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw};
267 use pretty_assertions::assert_eq;
268 use rstest::rstest;
269
270 use super::*;
271 use crate::arrow::get_raw_price;
272
273 #[rstest]
274 fn test_get_schema() {
275 let instrument_id = InstrumentId::from("AAPL.XNAS");
276 let metadata = OrderBookDelta::get_metadata(&instrument_id, 2, 0);
277 let schema = OrderBookDelta::get_schema(Some(metadata.clone()));
278
279 let expected_fields = vec![
280 Field::new("action", DataType::UInt8, false),
281 Field::new("side", DataType::UInt8, false),
282 Field::new("price", DataType::FixedSizeBinary(PRECISION_BYTES), false),
283 Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
284 Field::new("order_id", DataType::UInt64, false),
285 Field::new("flags", DataType::UInt8, false),
286 Field::new("sequence", DataType::UInt64, false),
287 Field::new("ts_event", DataType::UInt64, false),
288 Field::new("ts_init", DataType::UInt64, false),
289 ];
290
291 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
292 assert_eq!(schema, expected_schema);
293 }
294
295 #[rstest]
296 fn test_get_schema_map() {
297 let schema_map = OrderBookDelta::get_schema_map();
298 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
299
300 assert_eq!(schema_map.get("action").unwrap(), "UInt8");
301 assert_eq!(schema_map.get("side").unwrap(), "UInt8");
302 assert_eq!(*schema_map.get("price").unwrap(), fixed_size_binary.clone());
303 assert_eq!(*schema_map.get("size").unwrap(), fixed_size_binary);
304 assert_eq!(schema_map.get("order_id").unwrap(), "UInt64");
305 assert_eq!(schema_map.get("flags").unwrap(), "UInt8");
306 assert_eq!(schema_map.get("sequence").unwrap(), "UInt64");
307 assert_eq!(schema_map.get("ts_event").unwrap(), "UInt64");
308 assert_eq!(schema_map.get("ts_init").unwrap(), "UInt64");
309 }
310
311 #[rstest]
312 fn test_encode_batch() {
313 let instrument_id = InstrumentId::from("AAPL.XNAS");
314 let metadata = OrderBookDelta::get_metadata(&instrument_id, 2, 0);
315
316 let delta1 = OrderBookDelta {
317 instrument_id,
318 action: BookAction::Add,
319 order: BookOrder {
320 side: OrderSide::Buy,
321 price: Price::from("100.10"),
322 size: Quantity::from(100),
323 order_id: 1,
324 },
325 flags: 0,
326 sequence: 1,
327 ts_event: 1.into(),
328 ts_init: 3.into(),
329 };
330
331 let delta2 = OrderBookDelta {
332 instrument_id,
333 action: BookAction::Update,
334 order: BookOrder {
335 side: OrderSide::Sell,
336 price: Price::from("101.20"),
337 size: Quantity::from(200),
338 order_id: 2,
339 },
340 flags: 1,
341 sequence: 2,
342 ts_event: 2.into(),
343 ts_init: 4.into(),
344 };
345
346 let data = vec![delta1, delta2];
347 let record_batch = OrderBookDelta::encode_batch(&metadata, &data).unwrap();
348
349 let columns = record_batch.columns();
350 let action_values = columns[0].as_any().downcast_ref::<UInt8Array>().unwrap();
351 let side_values = columns[1].as_any().downcast_ref::<UInt8Array>().unwrap();
352 let price_values = columns[2]
353 .as_any()
354 .downcast_ref::<FixedSizeBinaryArray>()
355 .unwrap();
356 let size_values = columns[3]
357 .as_any()
358 .downcast_ref::<FixedSizeBinaryArray>()
359 .unwrap();
360 let order_id_values = columns[4].as_any().downcast_ref::<UInt64Array>().unwrap();
361 let flags_values = columns[5].as_any().downcast_ref::<UInt8Array>().unwrap();
362 let sequence_values = columns[6].as_any().downcast_ref::<UInt64Array>().unwrap();
363 let ts_event_values = columns[7].as_any().downcast_ref::<UInt64Array>().unwrap();
364 let ts_init_values = columns[8].as_any().downcast_ref::<UInt64Array>().unwrap();
365
366 assert_eq!(columns.len(), 9);
367 assert_eq!(action_values.len(), 2);
368 assert_eq!(action_values.value(0), 1);
369 assert_eq!(action_values.value(1), 2);
370 assert_eq!(side_values.len(), 2);
371 assert_eq!(side_values.value(0), 1);
372 assert_eq!(side_values.value(1), 2);
373
374 assert_eq!(price_values.len(), 2);
375 assert_eq!(
376 get_raw_price(price_values.value(0)),
377 (100.10 * FIXED_SCALAR) as PriceRaw
378 );
379 assert_eq!(
380 get_raw_price(price_values.value(1)),
381 (101.20 * FIXED_SCALAR) as PriceRaw
382 );
383
384 assert_eq!(size_values.len(), 2);
385 assert_eq!(
386 get_raw_price(size_values.value(0)),
387 (100.0 * FIXED_SCALAR) as PriceRaw
388 );
389 assert_eq!(
390 get_raw_price(size_values.value(1)),
391 (200.0 * FIXED_SCALAR) as PriceRaw
392 );
393 assert_eq!(order_id_values.len(), 2);
394 assert_eq!(order_id_values.value(0), 1);
395 assert_eq!(order_id_values.value(1), 2);
396 assert_eq!(flags_values.len(), 2);
397 assert_eq!(flags_values.value(0), 0);
398 assert_eq!(flags_values.value(1), 1);
399 assert_eq!(sequence_values.len(), 2);
400 assert_eq!(sequence_values.value(0), 1);
401 assert_eq!(sequence_values.value(1), 2);
402 assert_eq!(ts_event_values.len(), 2);
403 assert_eq!(ts_event_values.value(0), 1);
404 assert_eq!(ts_event_values.value(1), 2);
405 assert_eq!(ts_init_values.len(), 2);
406 assert_eq!(ts_init_values.value(0), 3);
407 assert_eq!(ts_init_values.value(1), 4);
408 }
409
410 #[rstest]
411 fn test_decode_batch() {
412 let instrument_id = InstrumentId::from("AAPL.XNAS");
413 let metadata = OrderBookDelta::get_metadata(&instrument_id, 2, 0);
414
415 let action = UInt8Array::from(vec![1, 2]);
416 let side = UInt8Array::from(vec![1, 1]);
417 let price = FixedSizeBinaryArray::from(vec![
418 &((101.10 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
419 &((101.20 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
420 ]);
421 let size = FixedSizeBinaryArray::from(vec![
422 &((10000.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
423 &((9000.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
424 ]);
425 let order_id = UInt64Array::from(vec![1, 2]);
426 let flags = UInt8Array::from(vec![0, 0]);
427 let sequence = UInt64Array::from(vec![1, 2]);
428 let ts_event = UInt64Array::from(vec![1, 2]);
429 let ts_init = UInt64Array::from(vec![3, 4]);
430
431 let record_batch = RecordBatch::try_new(
432 OrderBookDelta::get_schema(Some(metadata.clone())).into(),
433 vec![
434 Arc::new(action),
435 Arc::new(side),
436 Arc::new(price),
437 Arc::new(size),
438 Arc::new(order_id),
439 Arc::new(flags),
440 Arc::new(sequence),
441 Arc::new(ts_event),
442 Arc::new(ts_init),
443 ],
444 )
445 .unwrap();
446
447 let decoded_data = OrderBookDelta::decode_batch(&metadata, record_batch).unwrap();
448 assert_eq!(decoded_data.len(), 2);
449 }
450}