1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt8Array, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::{BookOrder, OrderBookDelta},
26 enums::{BookAction, FromU8, OrderSide},
27 identifiers::InstrumentId,
28 types::{Price, Quantity, fixed::PRECISION_BYTES},
29};
30
31use super::{
32 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
33 KEY_SIZE_PRECISION, extract_column,
34};
35use crate::arrow::{
36 ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
37 get_raw_quantity,
38};
39
40impl ArrowSchemaProvider for OrderBookDelta {
41 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
42 let fields = vec![
43 Field::new("action", DataType::UInt8, false),
44 Field::new("side", DataType::UInt8, false),
45 Field::new("price", DataType::FixedSizeBinary(PRECISION_BYTES), false),
46 Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
47 Field::new("order_id", DataType::UInt64, false),
48 Field::new("flags", DataType::UInt8, false),
49 Field::new("sequence", DataType::UInt64, false),
50 Field::new("ts_event", DataType::UInt64, false),
51 Field::new("ts_init", DataType::UInt64, false),
52 ];
53
54 match metadata {
55 Some(metadata) => Schema::new_with_metadata(fields, metadata),
56 None => Schema::new(fields),
57 }
58 }
59}
60
61fn parse_metadata(
62 metadata: &HashMap<String, String>,
63) -> Result<(InstrumentId, u8, u8), EncodingError> {
64 let instrument_id_str = metadata
65 .get(KEY_INSTRUMENT_ID)
66 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
67 let instrument_id = InstrumentId::from_str(instrument_id_str)
68 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
69
70 let price_precision = metadata
71 .get(KEY_PRICE_PRECISION)
72 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
73 .parse::<u8>()
74 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
75
76 let size_precision = metadata
77 .get(KEY_SIZE_PRECISION)
78 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
79 .parse::<u8>()
80 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
81
82 Ok((instrument_id, price_precision, size_precision))
83}
84
85impl EncodeToRecordBatch for OrderBookDelta {
86 fn encode_batch(
87 metadata: &HashMap<String, String>,
88 data: &[Self],
89 ) -> Result<RecordBatch, ArrowError> {
90 let mut action_builder = UInt8Array::builder(data.len());
91 let mut side_builder = UInt8Array::builder(data.len());
92 let mut price_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
93 let mut size_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
94 let mut order_id_builder = UInt64Array::builder(data.len());
95 let mut flags_builder = UInt8Array::builder(data.len());
96 let mut sequence_builder = UInt64Array::builder(data.len());
97 let mut ts_event_builder = UInt64Array::builder(data.len());
98 let mut ts_init_builder = UInt64Array::builder(data.len());
99
100 for delta in data {
101 action_builder.append_value(delta.action as u8);
102 side_builder.append_value(delta.order.side as u8);
103 price_builder
104 .append_value(delta.order.price.raw.to_le_bytes())
105 .unwrap();
106 size_builder
107 .append_value(delta.order.size.raw.to_le_bytes())
108 .unwrap();
109 order_id_builder.append_value(delta.order.order_id);
110 flags_builder.append_value(delta.flags);
111 sequence_builder.append_value(delta.sequence);
112 ts_event_builder.append_value(delta.ts_event.as_u64());
113 ts_init_builder.append_value(delta.ts_init.as_u64());
114 }
115
116 let action_array = action_builder.finish();
117 let side_array = side_builder.finish();
118 let price_array = price_builder.finish();
119 let size_array = size_builder.finish();
120 let order_id_array = order_id_builder.finish();
121 let flags_array = flags_builder.finish();
122 let sequence_array = sequence_builder.finish();
123 let ts_event_array = ts_event_builder.finish();
124 let ts_init_array = ts_init_builder.finish();
125
126 RecordBatch::try_new(
127 Self::get_schema(Some(metadata.clone())).into(),
128 vec![
129 Arc::new(action_array),
130 Arc::new(side_array),
131 Arc::new(price_array),
132 Arc::new(size_array),
133 Arc::new(order_id_array),
134 Arc::new(flags_array),
135 Arc::new(sequence_array),
136 Arc::new(ts_event_array),
137 Arc::new(ts_init_array),
138 ],
139 )
140 }
141
142 fn metadata(&self) -> HashMap<String, String> {
143 OrderBookDelta::get_metadata(
144 &self.instrument_id,
145 self.order.price.precision,
146 self.order.size.precision,
147 )
148 }
149
150 fn chunk_metadata(chunk: &[Self]) -> HashMap<String, String> {
154 let delta = chunk
155 .first()
156 .expect("Chunk should have at least one element to encode");
157
158 if delta.order.price.precision == 0
159 && delta.order.size.precision == 0
160 && let Some(delta) = chunk.get(1)
161 {
162 return EncodeToRecordBatch::metadata(delta);
163 }
164
165 EncodeToRecordBatch::metadata(delta)
166 }
167}
168
169impl DecodeFromRecordBatch for OrderBookDelta {
170 fn decode_batch(
171 metadata: &HashMap<String, String>,
172 record_batch: RecordBatch,
173 ) -> Result<Vec<Self>, EncodingError> {
174 let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
175 let cols = record_batch.columns();
176
177 let action_values = extract_column::<UInt8Array>(cols, "action", 0, DataType::UInt8)?;
178 let side_values = extract_column::<UInt8Array>(cols, "side", 1, DataType::UInt8)?;
179 let price_values = extract_column::<FixedSizeBinaryArray>(
180 cols,
181 "price",
182 2,
183 DataType::FixedSizeBinary(PRECISION_BYTES),
184 )?;
185 let size_values = extract_column::<FixedSizeBinaryArray>(
186 cols,
187 "size",
188 3,
189 DataType::FixedSizeBinary(PRECISION_BYTES),
190 )?;
191 let order_id_values = extract_column::<UInt64Array>(cols, "order_id", 4, DataType::UInt64)?;
192 let flags_values = extract_column::<UInt8Array>(cols, "flags", 5, DataType::UInt8)?;
193 let sequence_values = extract_column::<UInt64Array>(cols, "sequence", 6, DataType::UInt64)?;
194 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 7, DataType::UInt64)?;
195 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 8, DataType::UInt64)?;
196
197 if price_values.value_length() != PRECISION_BYTES {
198 return Err(EncodingError::ParseError(
199 "price",
200 format!(
201 "Invalid value length: expected {PRECISION_BYTES}, found {}",
202 price_values.value_length()
203 ),
204 ));
205 }
206
207 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
208 .map(|i| {
209 let action_value = action_values.value(i);
210 let action = BookAction::from_u8(action_value).ok_or_else(|| {
211 EncodingError::ParseError(
212 stringify!(BookAction),
213 format!("Invalid enum value, was {action_value}"),
214 )
215 })?;
216 let side_value = side_values.value(i);
217 let side = OrderSide::from_u8(side_value).ok_or_else(|| {
218 EncodingError::ParseError(
219 stringify!(OrderSide),
220 format!("Invalid enum value, was {side_value}"),
221 )
222 })?;
223 let price = Price::from_raw(get_raw_price(price_values.value(i)), price_precision);
224 let size =
225 Quantity::from_raw(get_raw_quantity(size_values.value(i)), size_precision);
226 let order_id = order_id_values.value(i);
227 let flags = flags_values.value(i);
228 let sequence = sequence_values.value(i);
229 let ts_event = ts_event_values.value(i).into();
230 let ts_init = ts_init_values.value(i).into();
231
232 Ok(Self {
233 instrument_id,
234 action,
235 order: BookOrder {
236 side,
237 price,
238 size,
239 order_id,
240 },
241 flags,
242 sequence,
243 ts_event,
244 ts_init,
245 })
246 })
247 .collect();
248
249 result
250 }
251}
252
253impl DecodeDataFromRecordBatch for OrderBookDelta {
254 fn decode_data_batch(
255 metadata: &HashMap<String, String>,
256 record_batch: RecordBatch,
257 ) -> Result<Vec<Data>, EncodingError> {
258 let deltas: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
259 Ok(deltas.into_iter().map(Data::from).collect())
260 }
261}
262
263#[cfg(test)]
267mod tests {
268 use std::sync::Arc;
269
270 use arrow::{array::Array, record_batch::RecordBatch};
271 use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw};
272 use pretty_assertions::assert_eq;
273 use rstest::rstest;
274
275 use super::*;
276 use crate::arrow::get_raw_price;
277
278 #[rstest]
279 fn test_get_schema() {
280 let instrument_id = InstrumentId::from("AAPL.XNAS");
281 let metadata = OrderBookDelta::get_metadata(&instrument_id, 2, 0);
282 let schema = OrderBookDelta::get_schema(Some(metadata.clone()));
283
284 let expected_fields = vec![
285 Field::new("action", DataType::UInt8, false),
286 Field::new("side", DataType::UInt8, false),
287 Field::new("price", DataType::FixedSizeBinary(PRECISION_BYTES), false),
288 Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
289 Field::new("order_id", DataType::UInt64, false),
290 Field::new("flags", DataType::UInt8, false),
291 Field::new("sequence", DataType::UInt64, false),
292 Field::new("ts_event", DataType::UInt64, false),
293 Field::new("ts_init", DataType::UInt64, false),
294 ];
295
296 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
297 assert_eq!(schema, expected_schema);
298 }
299
300 #[rstest]
301 fn test_get_schema_map() {
302 let schema_map = OrderBookDelta::get_schema_map();
303 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
304
305 assert_eq!(schema_map.get("action").unwrap(), "UInt8");
306 assert_eq!(schema_map.get("side").unwrap(), "UInt8");
307 assert_eq!(*schema_map.get("price").unwrap(), fixed_size_binary.clone());
308 assert_eq!(*schema_map.get("size").unwrap(), fixed_size_binary);
309 assert_eq!(schema_map.get("order_id").unwrap(), "UInt64");
310 assert_eq!(schema_map.get("flags").unwrap(), "UInt8");
311 assert_eq!(schema_map.get("sequence").unwrap(), "UInt64");
312 assert_eq!(schema_map.get("ts_event").unwrap(), "UInt64");
313 assert_eq!(schema_map.get("ts_init").unwrap(), "UInt64");
314 }
315
316 #[rstest]
317 fn test_encode_batch() {
318 let instrument_id = InstrumentId::from("AAPL.XNAS");
319 let metadata = OrderBookDelta::get_metadata(&instrument_id, 2, 0);
320
321 let delta1 = OrderBookDelta {
322 instrument_id,
323 action: BookAction::Add,
324 order: BookOrder {
325 side: OrderSide::Buy,
326 price: Price::from("100.10"),
327 size: Quantity::from(100),
328 order_id: 1,
329 },
330 flags: 0,
331 sequence: 1,
332 ts_event: 1.into(),
333 ts_init: 3.into(),
334 };
335
336 let delta2 = OrderBookDelta {
337 instrument_id,
338 action: BookAction::Update,
339 order: BookOrder {
340 side: OrderSide::Sell,
341 price: Price::from("101.20"),
342 size: Quantity::from(200),
343 order_id: 2,
344 },
345 flags: 1,
346 sequence: 2,
347 ts_event: 2.into(),
348 ts_init: 4.into(),
349 };
350
351 let data = vec![delta1, delta2];
352 let record_batch = OrderBookDelta::encode_batch(&metadata, &data).unwrap();
353
354 let columns = record_batch.columns();
355 let action_values = columns[0].as_any().downcast_ref::<UInt8Array>().unwrap();
356 let side_values = columns[1].as_any().downcast_ref::<UInt8Array>().unwrap();
357 let price_values = columns[2]
358 .as_any()
359 .downcast_ref::<FixedSizeBinaryArray>()
360 .unwrap();
361 let size_values = columns[3]
362 .as_any()
363 .downcast_ref::<FixedSizeBinaryArray>()
364 .unwrap();
365 let order_id_values = columns[4].as_any().downcast_ref::<UInt64Array>().unwrap();
366 let flags_values = columns[5].as_any().downcast_ref::<UInt8Array>().unwrap();
367 let sequence_values = columns[6].as_any().downcast_ref::<UInt64Array>().unwrap();
368 let ts_event_values = columns[7].as_any().downcast_ref::<UInt64Array>().unwrap();
369 let ts_init_values = columns[8].as_any().downcast_ref::<UInt64Array>().unwrap();
370
371 assert_eq!(columns.len(), 9);
372 assert_eq!(action_values.len(), 2);
373 assert_eq!(action_values.value(0), 1);
374 assert_eq!(action_values.value(1), 2);
375 assert_eq!(side_values.len(), 2);
376 assert_eq!(side_values.value(0), 1);
377 assert_eq!(side_values.value(1), 2);
378
379 assert_eq!(price_values.len(), 2);
380 assert_eq!(
381 get_raw_price(price_values.value(0)),
382 (100.10 * FIXED_SCALAR) as PriceRaw
383 );
384 assert_eq!(
385 get_raw_price(price_values.value(1)),
386 (101.20 * FIXED_SCALAR) as PriceRaw
387 );
388
389 assert_eq!(size_values.len(), 2);
390 assert_eq!(
391 get_raw_price(size_values.value(0)),
392 (100.0 * FIXED_SCALAR) as PriceRaw
393 );
394 assert_eq!(
395 get_raw_price(size_values.value(1)),
396 (200.0 * FIXED_SCALAR) as PriceRaw
397 );
398 assert_eq!(order_id_values.len(), 2);
399 assert_eq!(order_id_values.value(0), 1);
400 assert_eq!(order_id_values.value(1), 2);
401 assert_eq!(flags_values.len(), 2);
402 assert_eq!(flags_values.value(0), 0);
403 assert_eq!(flags_values.value(1), 1);
404 assert_eq!(sequence_values.len(), 2);
405 assert_eq!(sequence_values.value(0), 1);
406 assert_eq!(sequence_values.value(1), 2);
407 assert_eq!(ts_event_values.len(), 2);
408 assert_eq!(ts_event_values.value(0), 1);
409 assert_eq!(ts_event_values.value(1), 2);
410 assert_eq!(ts_init_values.len(), 2);
411 assert_eq!(ts_init_values.value(0), 3);
412 assert_eq!(ts_init_values.value(1), 4);
413 }
414
415 #[rstest]
416 fn test_decode_batch() {
417 let instrument_id = InstrumentId::from("AAPL.XNAS");
418 let metadata = OrderBookDelta::get_metadata(&instrument_id, 2, 0);
419
420 let action = UInt8Array::from(vec![1, 2]);
421 let side = UInt8Array::from(vec![1, 1]);
422 let price = FixedSizeBinaryArray::from(vec![
423 &((101.10 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
424 &((101.20 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
425 ]);
426 let size = FixedSizeBinaryArray::from(vec![
427 &((10000.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
428 &((9000.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
429 ]);
430 let order_id = UInt64Array::from(vec![1, 2]);
431 let flags = UInt8Array::from(vec![0, 0]);
432 let sequence = UInt64Array::from(vec![1, 2]);
433 let ts_event = UInt64Array::from(vec![1, 2]);
434 let ts_init = UInt64Array::from(vec![3, 4]);
435
436 let record_batch = RecordBatch::try_new(
437 OrderBookDelta::get_schema(Some(metadata.clone())).into(),
438 vec![
439 Arc::new(action),
440 Arc::new(side),
441 Arc::new(price),
442 Arc::new(size),
443 Arc::new(order_id),
444 Arc::new(flags),
445 Arc::new(sequence),
446 Arc::new(ts_event),
447 Arc::new(ts_init),
448 ],
449 )
450 .unwrap();
451
452 let decoded_data = OrderBookDelta::decode_batch(&metadata, record_batch).unwrap();
453 assert_eq!(decoded_data.len(), 2);
454 }
455}