1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::QuoteTick,
26 identifiers::InstrumentId,
27 types::{Price, Quantity, fixed::PRECISION_BYTES},
28};
29
30use super::{
31 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
32 KEY_SIZE_PRECISION, extract_column, get_raw_price, get_raw_quantity,
33};
34use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
35
36impl ArrowSchemaProvider for QuoteTick {
37 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
38 let fields = vec![
39 Field::new(
40 "bid_price",
41 DataType::FixedSizeBinary(PRECISION_BYTES),
42 false,
43 ),
44 Field::new(
45 "ask_price",
46 DataType::FixedSizeBinary(PRECISION_BYTES),
47 false,
48 ),
49 Field::new(
50 "bid_size",
51 DataType::FixedSizeBinary(PRECISION_BYTES),
52 false,
53 ),
54 Field::new(
55 "ask_size",
56 DataType::FixedSizeBinary(PRECISION_BYTES),
57 false,
58 ),
59 Field::new("ts_event", DataType::UInt64, false),
60 Field::new("ts_init", DataType::UInt64, false),
61 ];
62
63 match metadata {
64 Some(metadata) => Schema::new_with_metadata(fields, metadata),
65 None => Schema::new(fields),
66 }
67 }
68}
69
70fn parse_metadata(
71 metadata: &HashMap<String, String>,
72) -> Result<(InstrumentId, u8, u8), EncodingError> {
73 let instrument_id_str = metadata
74 .get(KEY_INSTRUMENT_ID)
75 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
76 let instrument_id = InstrumentId::from_str(instrument_id_str)
77 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
78
79 let price_precision = metadata
80 .get(KEY_PRICE_PRECISION)
81 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
82 .parse::<u8>()
83 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
84
85 let size_precision = metadata
86 .get(KEY_SIZE_PRECISION)
87 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
88 .parse::<u8>()
89 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
90
91 Ok((instrument_id, price_precision, size_precision))
92}
93
94impl EncodeToRecordBatch for QuoteTick {
95 fn encode_batch(
96 metadata: &HashMap<String, String>,
97 data: &[Self],
98 ) -> Result<RecordBatch, ArrowError> {
99 let mut bid_price_builder =
100 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
101 let mut ask_price_builder =
102 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
103 let mut bid_size_builder =
104 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
105 let mut ask_size_builder =
106 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
107 let mut ts_event_builder = UInt64Array::builder(data.len());
108 let mut ts_init_builder = UInt64Array::builder(data.len());
109
110 for quote in data {
111 bid_price_builder
112 .append_value(quote.bid_price.raw.to_le_bytes())
113 .unwrap();
114 ask_price_builder
115 .append_value(quote.ask_price.raw.to_le_bytes())
116 .unwrap();
117 bid_size_builder
118 .append_value(quote.bid_size.raw.to_le_bytes())
119 .unwrap();
120 ask_size_builder
121 .append_value(quote.ask_size.raw.to_le_bytes())
122 .unwrap();
123 ts_event_builder.append_value(quote.ts_event.as_u64());
124 ts_init_builder.append_value(quote.ts_init.as_u64());
125 }
126
127 RecordBatch::try_new(
128 Self::get_schema(Some(metadata.clone())).into(),
129 vec![
130 Arc::new(bid_price_builder.finish()),
131 Arc::new(ask_price_builder.finish()),
132 Arc::new(bid_size_builder.finish()),
133 Arc::new(ask_size_builder.finish()),
134 Arc::new(ts_event_builder.finish()),
135 Arc::new(ts_init_builder.finish()),
136 ],
137 )
138 }
139
140 fn metadata(&self) -> HashMap<String, String> {
141 QuoteTick::get_metadata(
142 &self.instrument_id,
143 self.bid_price.precision,
144 self.bid_size.precision,
145 )
146 }
147}
148
149impl DecodeFromRecordBatch for QuoteTick {
150 fn decode_batch(
151 metadata: &HashMap<String, String>,
152 record_batch: RecordBatch,
153 ) -> Result<Vec<Self>, EncodingError> {
154 let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
155 let cols = record_batch.columns();
156
157 let bid_price_values = extract_column::<FixedSizeBinaryArray>(
158 cols,
159 "bid_price",
160 0,
161 DataType::FixedSizeBinary(PRECISION_BYTES),
162 )?;
163 let ask_price_values = extract_column::<FixedSizeBinaryArray>(
164 cols,
165 "ask_price",
166 1,
167 DataType::FixedSizeBinary(PRECISION_BYTES),
168 )?;
169 let bid_size_values = extract_column::<FixedSizeBinaryArray>(
170 cols,
171 "bid_size",
172 2,
173 DataType::FixedSizeBinary(PRECISION_BYTES),
174 )?;
175 let ask_size_values = extract_column::<FixedSizeBinaryArray>(
176 cols,
177 "ask_size",
178 3,
179 DataType::FixedSizeBinary(PRECISION_BYTES),
180 )?;
181 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 4, DataType::UInt64)?;
182 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 5, DataType::UInt64)?;
183
184 assert_eq!(
185 bid_price_values.value_length(),
186 PRECISION_BYTES,
187 "Price precision uses {PRECISION_BYTES} byte value"
188 );
189 assert_eq!(
190 ask_price_values.value_length(),
191 PRECISION_BYTES,
192 "Price precision uses {PRECISION_BYTES} byte value"
193 );
194
195 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
196 .map(|row| {
197 Ok(Self {
198 instrument_id,
199 bid_price: Price::from_raw(
200 get_raw_price(bid_price_values.value(row)),
201 price_precision,
202 ),
203 ask_price: Price::from_raw(
204 get_raw_price(ask_price_values.value(row)),
205 price_precision,
206 ),
207 bid_size: Quantity::from_raw(
208 get_raw_quantity(bid_size_values.value(row)),
209 size_precision,
210 ),
211 ask_size: Quantity::from_raw(
212 get_raw_quantity(ask_size_values.value(row)),
213 size_precision,
214 ),
215 ts_event: ts_event_values.value(row).into(),
216 ts_init: ts_init_values.value(row).into(),
217 })
218 })
219 .collect();
220
221 result
222 }
223}
224
225impl DecodeDataFromRecordBatch for QuoteTick {
226 fn decode_data_batch(
227 metadata: &HashMap<String, String>,
228 record_batch: RecordBatch,
229 ) -> Result<Vec<Data>, EncodingError> {
230 let ticks: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
231 Ok(ticks.into_iter().map(Data::from).collect())
232 }
233}
234
235#[cfg(test)]
239mod tests {
240 use std::{collections::HashMap, sync::Arc};
241
242 use arrow::{array::Array, record_batch::RecordBatch};
243 use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw};
244 use rstest::rstest;
245
246 use super::*;
247 use crate::arrow::get_raw_price;
248
249 #[rstest]
250 fn test_get_schema() {
251 let instrument_id = InstrumentId::from("AAPL.XNAS");
252 let metadata = QuoteTick::get_metadata(&instrument_id, 2, 0);
253 let schema = QuoteTick::get_schema(Some(metadata.clone()));
254
255 let mut expected_fields = Vec::with_capacity(6);
256
257 expected_fields.push(Field::new(
258 "bid_price",
259 DataType::FixedSizeBinary(PRECISION_BYTES),
260 false,
261 ));
262 expected_fields.push(Field::new(
263 "ask_price",
264 DataType::FixedSizeBinary(PRECISION_BYTES),
265 false,
266 ));
267
268 expected_fields.extend(vec![
269 Field::new(
270 "bid_size",
271 DataType::FixedSizeBinary(PRECISION_BYTES),
272 false,
273 ),
274 Field::new(
275 "ask_size",
276 DataType::FixedSizeBinary(PRECISION_BYTES),
277 false,
278 ),
279 Field::new("ts_event", DataType::UInt64, false),
280 Field::new("ts_init", DataType::UInt64, false),
281 ]);
282
283 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
284 assert_eq!(schema, expected_schema);
285 }
286
287 #[rstest]
288 fn test_get_schema_map() {
289 let arrow_schema = QuoteTick::get_schema_map();
290 let mut expected_map = HashMap::new();
291
292 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
293 expected_map.insert("bid_price".to_string(), fixed_size_binary.clone());
294 expected_map.insert("ask_price".to_string(), fixed_size_binary.clone());
295 expected_map.insert("bid_size".to_string(), fixed_size_binary.clone());
296 expected_map.insert("ask_size".to_string(), fixed_size_binary);
297 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
298 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
299 assert_eq!(arrow_schema, expected_map);
300 }
301
302 #[rstest]
303 fn test_encode_quote_tick() {
304 let instrument_id = InstrumentId::from("AAPL.XNAS");
306 let tick1 = QuoteTick {
307 instrument_id,
308 bid_price: Price::from("100.10"),
309 ask_price: Price::from("101.50"),
310 bid_size: Quantity::from(1000),
311 ask_size: Quantity::from(500),
312 ts_event: 1.into(),
313 ts_init: 3.into(),
314 };
315
316 let tick2 = QuoteTick {
317 instrument_id,
318 bid_price: Price::from("100.75"),
319 ask_price: Price::from("100.20"),
320 bid_size: Quantity::from(750),
321 ask_size: Quantity::from(300),
322 ts_event: 2.into(),
323 ts_init: 4.into(),
324 };
325
326 let data = vec![tick1, tick2];
327 let metadata = QuoteTick::get_metadata(&instrument_id, 2, 0);
328 let record_batch = QuoteTick::encode_batch(&metadata, &data).unwrap();
329
330 let columns = record_batch.columns();
332
333 let bid_price_values = columns[0]
334 .as_any()
335 .downcast_ref::<FixedSizeBinaryArray>()
336 .unwrap();
337 let ask_price_values = columns[1]
338 .as_any()
339 .downcast_ref::<FixedSizeBinaryArray>()
340 .unwrap();
341 assert_eq!(
342 get_raw_price(bid_price_values.value(0)),
343 (100.10 * FIXED_SCALAR) as PriceRaw
344 );
345 assert_eq!(
346 get_raw_price(bid_price_values.value(1)),
347 (100.75 * FIXED_SCALAR) as PriceRaw
348 );
349 assert_eq!(
350 get_raw_price(ask_price_values.value(0)),
351 (101.50 * FIXED_SCALAR) as PriceRaw
352 );
353 assert_eq!(
354 get_raw_price(ask_price_values.value(1)),
355 (100.20 * FIXED_SCALAR) as PriceRaw
356 );
357
358 let bid_size_values = columns[2]
359 .as_any()
360 .downcast_ref::<FixedSizeBinaryArray>()
361 .unwrap();
362 let ask_size_values = columns[3]
363 .as_any()
364 .downcast_ref::<FixedSizeBinaryArray>()
365 .unwrap();
366 let ts_event_values = columns[4].as_any().downcast_ref::<UInt64Array>().unwrap();
367 let ts_init_values = columns[5].as_any().downcast_ref::<UInt64Array>().unwrap();
368
369 assert_eq!(columns.len(), 6);
370 assert_eq!(bid_size_values.len(), 2);
371 assert_eq!(
372 get_raw_quantity(bid_size_values.value(0)),
373 (1000.0 * FIXED_SCALAR) as QuantityRaw
374 );
375 assert_eq!(
376 get_raw_quantity(bid_size_values.value(1)),
377 (750.0 * FIXED_SCALAR) as QuantityRaw
378 );
379 assert_eq!(ask_size_values.len(), 2);
380 assert_eq!(
381 get_raw_quantity(ask_size_values.value(0)),
382 (500.0 * FIXED_SCALAR) as QuantityRaw
383 );
384 assert_eq!(
385 get_raw_quantity(ask_size_values.value(1)),
386 (300.0 * FIXED_SCALAR) as QuantityRaw
387 );
388 assert_eq!(ts_event_values.len(), 2);
389 assert_eq!(ts_event_values.value(0), 1);
390 assert_eq!(ts_event_values.value(1), 2);
391 assert_eq!(ts_init_values.len(), 2);
392 assert_eq!(ts_init_values.value(0), 3);
393 assert_eq!(ts_init_values.value(1), 4);
394 }
395
396 #[rstest]
397 fn test_decode_batch() {
398 let instrument_id = InstrumentId::from("AAPL.XNAS");
399 let metadata = QuoteTick::get_metadata(&instrument_id, 2, 0);
400
401 let (bid_price, ask_price) = (
402 FixedSizeBinaryArray::from(vec![
403 &(10000 as PriceRaw).to_le_bytes(),
404 &(9900 as PriceRaw).to_le_bytes(),
405 ]),
406 FixedSizeBinaryArray::from(vec![
407 &(10100 as PriceRaw).to_le_bytes(),
408 &(10000 as PriceRaw).to_le_bytes(),
409 ]),
410 );
411
412 let bid_size = FixedSizeBinaryArray::from(vec![
413 &((100.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
414 &((90.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
415 ]);
416 let ask_size = FixedSizeBinaryArray::from(vec![
417 &((110.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
418 &((100.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
419 ]);
420 let ts_event = UInt64Array::from(vec![1, 2]);
421 let ts_init = UInt64Array::from(vec![3, 4]);
422
423 let record_batch = RecordBatch::try_new(
424 QuoteTick::get_schema(Some(metadata.clone())).into(),
425 vec![
426 Arc::new(bid_price),
427 Arc::new(ask_price),
428 Arc::new(bid_size),
429 Arc::new(ask_size),
430 Arc::new(ts_event),
431 Arc::new(ts_init),
432 ],
433 )
434 .unwrap();
435
436 let decoded_data = QuoteTick::decode_batch(&metadata, record_batch).unwrap();
437 assert_eq!(decoded_data.len(), 2);
438
439 assert_eq!(decoded_data[0].bid_price, Price::from_raw(10000, 2));
441 assert_eq!(decoded_data[0].ask_price, Price::from_raw(10100, 2));
442 assert_eq!(decoded_data[1].bid_price, Price::from_raw(9900, 2));
443 assert_eq!(decoded_data[1].ask_price, Price::from_raw(10000, 2));
444 }
445}