1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt8Array, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::{BookOrder, OrderBookDelta},
26 enums::{BookAction, FromU8, OrderSide},
27 identifiers::InstrumentId,
28 types::{Price, Quantity, fixed::PRECISION_BYTES},
29};
30
31use super::{
32 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
33 KEY_SIZE_PRECISION, extract_column,
34};
35use crate::arrow::{
36 ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
37 get_raw_quantity,
38};
39
40impl ArrowSchemaProvider for OrderBookDelta {
41 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
42 let fields = vec![
43 Field::new("action", DataType::UInt8, false),
44 Field::new("side", DataType::UInt8, false),
45 Field::new("price", DataType::FixedSizeBinary(PRECISION_BYTES), false),
46 Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
47 Field::new("order_id", DataType::UInt64, false),
48 Field::new("flags", DataType::UInt8, false),
49 Field::new("sequence", DataType::UInt64, false),
50 Field::new("ts_event", DataType::UInt64, false),
51 Field::new("ts_init", DataType::UInt64, false),
52 ];
53
54 match metadata {
55 Some(metadata) => Schema::new_with_metadata(fields, metadata),
56 None => Schema::new(fields),
57 }
58 }
59}
60
61fn parse_metadata(
62 metadata: &HashMap<String, String>,
63) -> Result<(InstrumentId, u8, u8), EncodingError> {
64 let instrument_id_str = metadata
65 .get(KEY_INSTRUMENT_ID)
66 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
67 let instrument_id = InstrumentId::from_str(instrument_id_str)
68 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
69
70 let price_precision = metadata
71 .get(KEY_PRICE_PRECISION)
72 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
73 .parse::<u8>()
74 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
75
76 let size_precision = metadata
77 .get(KEY_SIZE_PRECISION)
78 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
79 .parse::<u8>()
80 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
81
82 Ok((instrument_id, price_precision, size_precision))
83}
84
85impl EncodeToRecordBatch for OrderBookDelta {
86 fn encode_batch(
87 metadata: &HashMap<String, String>,
88 data: &[Self],
89 ) -> Result<RecordBatch, ArrowError> {
90 let mut action_builder = UInt8Array::builder(data.len());
91 let mut side_builder = UInt8Array::builder(data.len());
92 let mut price_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
93 let mut size_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
94 let mut order_id_builder = UInt64Array::builder(data.len());
95 let mut flags_builder = UInt8Array::builder(data.len());
96 let mut sequence_builder = UInt64Array::builder(data.len());
97 let mut ts_event_builder = UInt64Array::builder(data.len());
98 let mut ts_init_builder = UInt64Array::builder(data.len());
99
100 for delta in data {
101 action_builder.append_value(delta.action as u8);
102 side_builder.append_value(delta.order.side as u8);
103 price_builder
104 .append_value(delta.order.price.raw.to_le_bytes())
105 .unwrap();
106 size_builder
107 .append_value(delta.order.size.raw.to_le_bytes())
108 .unwrap();
109 order_id_builder.append_value(delta.order.order_id);
110 flags_builder.append_value(delta.flags);
111 sequence_builder.append_value(delta.sequence);
112 ts_event_builder.append_value(delta.ts_event.as_u64());
113 ts_init_builder.append_value(delta.ts_init.as_u64());
114 }
115
116 let action_array = action_builder.finish();
117 let side_array = side_builder.finish();
118 let price_array = price_builder.finish();
119 let size_array = size_builder.finish();
120 let order_id_array = order_id_builder.finish();
121 let flags_array = flags_builder.finish();
122 let sequence_array = sequence_builder.finish();
123 let ts_event_array = ts_event_builder.finish();
124 let ts_init_array = ts_init_builder.finish();
125
126 RecordBatch::try_new(
127 Self::get_schema(Some(metadata.clone())).into(),
128 vec![
129 Arc::new(action_array),
130 Arc::new(side_array),
131 Arc::new(price_array),
132 Arc::new(size_array),
133 Arc::new(order_id_array),
134 Arc::new(flags_array),
135 Arc::new(sequence_array),
136 Arc::new(ts_event_array),
137 Arc::new(ts_init_array),
138 ],
139 )
140 }
141
142 fn metadata(&self) -> HashMap<String, String> {
143 OrderBookDelta::get_metadata(
144 &self.instrument_id,
145 self.order.price.precision,
146 self.order.size.precision,
147 )
148 }
149
150 fn chunk_metadata(chunk: &[Self]) -> HashMap<String, String> {
154 let delta = chunk
155 .first()
156 .expect("Chunk should have at least one element to encode");
157
158 if delta.order.price.precision == 0
159 && delta.order.size.precision == 0
160 && let Some(delta) = chunk.get(1)
161 {
162 return EncodeToRecordBatch::metadata(delta);
163 }
164
165 EncodeToRecordBatch::metadata(delta)
166 }
167}
168
169impl DecodeFromRecordBatch for OrderBookDelta {
170 fn decode_batch(
171 metadata: &HashMap<String, String>,
172 record_batch: RecordBatch,
173 ) -> Result<Vec<Self>, EncodingError> {
174 let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
175 let cols = record_batch.columns();
176
177 let action_values = extract_column::<UInt8Array>(cols, "action", 0, DataType::UInt8)?;
178 let side_values = extract_column::<UInt8Array>(cols, "side", 1, DataType::UInt8)?;
179 let price_values = extract_column::<FixedSizeBinaryArray>(
180 cols,
181 "price",
182 2,
183 DataType::FixedSizeBinary(PRECISION_BYTES),
184 )?;
185 let size_values = extract_column::<FixedSizeBinaryArray>(
186 cols,
187 "size",
188 3,
189 DataType::FixedSizeBinary(PRECISION_BYTES),
190 )?;
191 let order_id_values = extract_column::<UInt64Array>(cols, "order_id", 4, DataType::UInt64)?;
192 let flags_values = extract_column::<UInt8Array>(cols, "flags", 5, DataType::UInt8)?;
193 let sequence_values = extract_column::<UInt64Array>(cols, "sequence", 6, DataType::UInt64)?;
194 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 7, DataType::UInt64)?;
195 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 8, DataType::UInt64)?;
196
197 assert_eq!(
198 price_values.value_length(),
199 PRECISION_BYTES,
200 "Price precision uses {PRECISION_BYTES} byte value"
201 );
202
203 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
204 .map(|i| {
205 let action_value = action_values.value(i);
206 let action = BookAction::from_u8(action_value).ok_or_else(|| {
207 EncodingError::ParseError(
208 stringify!(BookAction),
209 format!("Invalid enum value, was {action_value}"),
210 )
211 })?;
212 let side_value = side_values.value(i);
213 let side = OrderSide::from_u8(side_value).ok_or_else(|| {
214 EncodingError::ParseError(
215 stringify!(OrderSide),
216 format!("Invalid enum value, was {side_value}"),
217 )
218 })?;
219 let price = Price::from_raw(get_raw_price(price_values.value(i)), price_precision);
220 let size =
221 Quantity::from_raw(get_raw_quantity(size_values.value(i)), size_precision);
222 let order_id = order_id_values.value(i);
223 let flags = flags_values.value(i);
224 let sequence = sequence_values.value(i);
225 let ts_event = ts_event_values.value(i).into();
226 let ts_init = ts_init_values.value(i).into();
227
228 Ok(Self {
229 instrument_id,
230 action,
231 order: BookOrder {
232 side,
233 price,
234 size,
235 order_id,
236 },
237 flags,
238 sequence,
239 ts_event,
240 ts_init,
241 })
242 })
243 .collect();
244
245 result
246 }
247}
248
249impl DecodeDataFromRecordBatch for OrderBookDelta {
250 fn decode_data_batch(
251 metadata: &HashMap<String, String>,
252 record_batch: RecordBatch,
253 ) -> Result<Vec<Data>, EncodingError> {
254 let deltas: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
255 Ok(deltas.into_iter().map(Data::from).collect())
256 }
257}
258
259#[cfg(test)]
263mod tests {
264 use std::sync::Arc;
265
266 use arrow::{array::Array, record_batch::RecordBatch};
267 use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw};
268 use pretty_assertions::assert_eq;
269 use rstest::rstest;
270
271 use super::*;
272 use crate::arrow::get_raw_price;
273
274 #[rstest]
275 fn test_get_schema() {
276 let instrument_id = InstrumentId::from("AAPL.XNAS");
277 let metadata = OrderBookDelta::get_metadata(&instrument_id, 2, 0);
278 let schema = OrderBookDelta::get_schema(Some(metadata.clone()));
279
280 let expected_fields = vec![
281 Field::new("action", DataType::UInt8, false),
282 Field::new("side", DataType::UInt8, false),
283 Field::new("price", DataType::FixedSizeBinary(PRECISION_BYTES), false),
284 Field::new("size", DataType::FixedSizeBinary(PRECISION_BYTES), false),
285 Field::new("order_id", DataType::UInt64, false),
286 Field::new("flags", DataType::UInt8, false),
287 Field::new("sequence", DataType::UInt64, false),
288 Field::new("ts_event", DataType::UInt64, false),
289 Field::new("ts_init", DataType::UInt64, false),
290 ];
291
292 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
293 assert_eq!(schema, expected_schema);
294 }
295
296 #[rstest]
297 fn test_get_schema_map() {
298 let schema_map = OrderBookDelta::get_schema_map();
299 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
300
301 assert_eq!(schema_map.get("action").unwrap(), "UInt8");
302 assert_eq!(schema_map.get("side").unwrap(), "UInt8");
303 assert_eq!(*schema_map.get("price").unwrap(), fixed_size_binary.clone());
304 assert_eq!(*schema_map.get("size").unwrap(), fixed_size_binary);
305 assert_eq!(schema_map.get("order_id").unwrap(), "UInt64");
306 assert_eq!(schema_map.get("flags").unwrap(), "UInt8");
307 assert_eq!(schema_map.get("sequence").unwrap(), "UInt64");
308 assert_eq!(schema_map.get("ts_event").unwrap(), "UInt64");
309 assert_eq!(schema_map.get("ts_init").unwrap(), "UInt64");
310 }
311
312 #[rstest]
313 fn test_encode_batch() {
314 let instrument_id = InstrumentId::from("AAPL.XNAS");
315 let metadata = OrderBookDelta::get_metadata(&instrument_id, 2, 0);
316
317 let delta1 = OrderBookDelta {
318 instrument_id,
319 action: BookAction::Add,
320 order: BookOrder {
321 side: OrderSide::Buy,
322 price: Price::from("100.10"),
323 size: Quantity::from(100),
324 order_id: 1,
325 },
326 flags: 0,
327 sequence: 1,
328 ts_event: 1.into(),
329 ts_init: 3.into(),
330 };
331
332 let delta2 = OrderBookDelta {
333 instrument_id,
334 action: BookAction::Update,
335 order: BookOrder {
336 side: OrderSide::Sell,
337 price: Price::from("101.20"),
338 size: Quantity::from(200),
339 order_id: 2,
340 },
341 flags: 1,
342 sequence: 2,
343 ts_event: 2.into(),
344 ts_init: 4.into(),
345 };
346
347 let data = vec![delta1, delta2];
348 let record_batch = OrderBookDelta::encode_batch(&metadata, &data).unwrap();
349
350 let columns = record_batch.columns();
351 let action_values = columns[0].as_any().downcast_ref::<UInt8Array>().unwrap();
352 let side_values = columns[1].as_any().downcast_ref::<UInt8Array>().unwrap();
353 let price_values = columns[2]
354 .as_any()
355 .downcast_ref::<FixedSizeBinaryArray>()
356 .unwrap();
357 let size_values = columns[3]
358 .as_any()
359 .downcast_ref::<FixedSizeBinaryArray>()
360 .unwrap();
361 let order_id_values = columns[4].as_any().downcast_ref::<UInt64Array>().unwrap();
362 let flags_values = columns[5].as_any().downcast_ref::<UInt8Array>().unwrap();
363 let sequence_values = columns[6].as_any().downcast_ref::<UInt64Array>().unwrap();
364 let ts_event_values = columns[7].as_any().downcast_ref::<UInt64Array>().unwrap();
365 let ts_init_values = columns[8].as_any().downcast_ref::<UInt64Array>().unwrap();
366
367 assert_eq!(columns.len(), 9);
368 assert_eq!(action_values.len(), 2);
369 assert_eq!(action_values.value(0), 1);
370 assert_eq!(action_values.value(1), 2);
371 assert_eq!(side_values.len(), 2);
372 assert_eq!(side_values.value(0), 1);
373 assert_eq!(side_values.value(1), 2);
374
375 assert_eq!(price_values.len(), 2);
376 assert_eq!(
377 get_raw_price(price_values.value(0)),
378 (100.10 * FIXED_SCALAR) as PriceRaw
379 );
380 assert_eq!(
381 get_raw_price(price_values.value(1)),
382 (101.20 * FIXED_SCALAR) as PriceRaw
383 );
384
385 assert_eq!(size_values.len(), 2);
386 assert_eq!(
387 get_raw_price(size_values.value(0)),
388 (100.0 * FIXED_SCALAR) as PriceRaw
389 );
390 assert_eq!(
391 get_raw_price(size_values.value(1)),
392 (200.0 * FIXED_SCALAR) as PriceRaw
393 );
394 assert_eq!(order_id_values.len(), 2);
395 assert_eq!(order_id_values.value(0), 1);
396 assert_eq!(order_id_values.value(1), 2);
397 assert_eq!(flags_values.len(), 2);
398 assert_eq!(flags_values.value(0), 0);
399 assert_eq!(flags_values.value(1), 1);
400 assert_eq!(sequence_values.len(), 2);
401 assert_eq!(sequence_values.value(0), 1);
402 assert_eq!(sequence_values.value(1), 2);
403 assert_eq!(ts_event_values.len(), 2);
404 assert_eq!(ts_event_values.value(0), 1);
405 assert_eq!(ts_event_values.value(1), 2);
406 assert_eq!(ts_init_values.len(), 2);
407 assert_eq!(ts_init_values.value(0), 3);
408 assert_eq!(ts_init_values.value(1), 4);
409 }
410
411 #[rstest]
412 fn test_decode_batch() {
413 let instrument_id = InstrumentId::from("AAPL.XNAS");
414 let metadata = OrderBookDelta::get_metadata(&instrument_id, 2, 0);
415
416 let action = UInt8Array::from(vec![1, 2]);
417 let side = UInt8Array::from(vec![1, 1]);
418 let price = FixedSizeBinaryArray::from(vec![
419 &((101.10 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
420 &((101.20 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
421 ]);
422 let size = FixedSizeBinaryArray::from(vec![
423 &((10000.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
424 &((9000.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
425 ]);
426 let order_id = UInt64Array::from(vec![1, 2]);
427 let flags = UInt8Array::from(vec![0, 0]);
428 let sequence = UInt64Array::from(vec![1, 2]);
429 let ts_event = UInt64Array::from(vec![1, 2]);
430 let ts_init = UInt64Array::from(vec![3, 4]);
431
432 let record_batch = RecordBatch::try_new(
433 OrderBookDelta::get_schema(Some(metadata.clone())).into(),
434 vec![
435 Arc::new(action),
436 Arc::new(side),
437 Arc::new(price),
438 Arc::new(size),
439 Arc::new(order_id),
440 Arc::new(flags),
441 Arc::new(sequence),
442 Arc::new(ts_event),
443 Arc::new(ts_init),
444 ],
445 )
446 .unwrap();
447
448 let decoded_data = OrderBookDelta::decode_batch(&metadata, record_batch).unwrap();
449 assert_eq!(decoded_data.len(), 2);
450 }
451}