1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::QuoteTick,
26 identifiers::InstrumentId,
27 types::{Price, Quantity, fixed::PRECISION_BYTES},
28};
29
30use super::{
31 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
32 KEY_SIZE_PRECISION, extract_column, get_corrected_raw_price, get_corrected_raw_quantity,
33};
34use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
35
36impl ArrowSchemaProvider for QuoteTick {
37 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
38 let fields = vec![
39 Field::new(
40 "bid_price",
41 DataType::FixedSizeBinary(PRECISION_BYTES),
42 false,
43 ),
44 Field::new(
45 "ask_price",
46 DataType::FixedSizeBinary(PRECISION_BYTES),
47 false,
48 ),
49 Field::new(
50 "bid_size",
51 DataType::FixedSizeBinary(PRECISION_BYTES),
52 false,
53 ),
54 Field::new(
55 "ask_size",
56 DataType::FixedSizeBinary(PRECISION_BYTES),
57 false,
58 ),
59 Field::new("ts_event", DataType::UInt64, false),
60 Field::new("ts_init", DataType::UInt64, false),
61 ];
62
63 match metadata {
64 Some(metadata) => Schema::new_with_metadata(fields, metadata),
65 None => Schema::new(fields),
66 }
67 }
68}
69
70fn parse_metadata(
71 metadata: &HashMap<String, String>,
72) -> Result<(InstrumentId, u8, u8), EncodingError> {
73 let instrument_id_str = metadata
74 .get(KEY_INSTRUMENT_ID)
75 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
76 let instrument_id = InstrumentId::from_str(instrument_id_str)
77 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
78
79 let price_precision = metadata
80 .get(KEY_PRICE_PRECISION)
81 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
82 .parse::<u8>()
83 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
84
85 let size_precision = metadata
86 .get(KEY_SIZE_PRECISION)
87 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
88 .parse::<u8>()
89 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
90
91 Ok((instrument_id, price_precision, size_precision))
92}
93
94impl EncodeToRecordBatch for QuoteTick {
95 fn encode_batch(
96 metadata: &HashMap<String, String>,
97 data: &[Self],
98 ) -> Result<RecordBatch, ArrowError> {
99 let mut bid_price_builder =
100 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
101 let mut ask_price_builder =
102 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
103 let mut bid_size_builder =
104 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
105 let mut ask_size_builder =
106 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
107 let mut ts_event_builder = UInt64Array::builder(data.len());
108 let mut ts_init_builder = UInt64Array::builder(data.len());
109
110 for quote in data {
111 bid_price_builder
112 .append_value(quote.bid_price.raw.to_le_bytes())
113 .unwrap();
114 ask_price_builder
115 .append_value(quote.ask_price.raw.to_le_bytes())
116 .unwrap();
117 bid_size_builder
118 .append_value(quote.bid_size.raw.to_le_bytes())
119 .unwrap();
120 ask_size_builder
121 .append_value(quote.ask_size.raw.to_le_bytes())
122 .unwrap();
123 ts_event_builder.append_value(quote.ts_event.as_u64());
124 ts_init_builder.append_value(quote.ts_init.as_u64());
125 }
126
127 RecordBatch::try_new(
128 Self::get_schema(Some(metadata.clone())).into(),
129 vec![
130 Arc::new(bid_price_builder.finish()),
131 Arc::new(ask_price_builder.finish()),
132 Arc::new(bid_size_builder.finish()),
133 Arc::new(ask_size_builder.finish()),
134 Arc::new(ts_event_builder.finish()),
135 Arc::new(ts_init_builder.finish()),
136 ],
137 )
138 }
139
140 fn metadata(&self) -> HashMap<String, String> {
141 Self::get_metadata(
142 &self.instrument_id,
143 self.bid_price.precision,
144 self.bid_size.precision,
145 )
146 }
147}
148
149impl DecodeFromRecordBatch for QuoteTick {
150 fn decode_batch(
151 metadata: &HashMap<String, String>,
152 record_batch: RecordBatch,
153 ) -> Result<Vec<Self>, EncodingError> {
154 let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
155 let cols = record_batch.columns();
156
157 let bid_price_values = extract_column::<FixedSizeBinaryArray>(
158 cols,
159 "bid_price",
160 0,
161 DataType::FixedSizeBinary(PRECISION_BYTES),
162 )?;
163 let ask_price_values = extract_column::<FixedSizeBinaryArray>(
164 cols,
165 "ask_price",
166 1,
167 DataType::FixedSizeBinary(PRECISION_BYTES),
168 )?;
169 let bid_size_values = extract_column::<FixedSizeBinaryArray>(
170 cols,
171 "bid_size",
172 2,
173 DataType::FixedSizeBinary(PRECISION_BYTES),
174 )?;
175 let ask_size_values = extract_column::<FixedSizeBinaryArray>(
176 cols,
177 "ask_size",
178 3,
179 DataType::FixedSizeBinary(PRECISION_BYTES),
180 )?;
181 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 4, DataType::UInt64)?;
182 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 5, DataType::UInt64)?;
183
184 if bid_price_values.value_length() != PRECISION_BYTES {
185 return Err(EncodingError::ParseError(
186 "bid_price",
187 format!(
188 "Invalid value length: expected {PRECISION_BYTES}, found {}",
189 bid_price_values.value_length()
190 ),
191 ));
192 }
193 if ask_price_values.value_length() != PRECISION_BYTES {
194 return Err(EncodingError::ParseError(
195 "ask_price",
196 format!(
197 "Invalid value length: expected {PRECISION_BYTES}, found {}",
198 ask_price_values.value_length()
199 ),
200 ));
201 }
202
203 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
205 .map(|row| {
206 Ok(Self {
207 instrument_id,
208 bid_price: Price::from_raw(
209 get_corrected_raw_price(bid_price_values.value(row), price_precision),
210 price_precision,
211 ),
212 ask_price: Price::from_raw(
213 get_corrected_raw_price(ask_price_values.value(row), price_precision),
214 price_precision,
215 ),
216 bid_size: Quantity::from_raw(
217 get_corrected_raw_quantity(bid_size_values.value(row), size_precision),
218 size_precision,
219 ),
220 ask_size: Quantity::from_raw(
221 get_corrected_raw_quantity(ask_size_values.value(row), size_precision),
222 size_precision,
223 ),
224 ts_event: ts_event_values.value(row).into(),
225 ts_init: ts_init_values.value(row).into(),
226 })
227 })
228 .collect();
229
230 result
231 }
232}
233
234impl DecodeDataFromRecordBatch for QuoteTick {
235 fn decode_data_batch(
236 metadata: &HashMap<String, String>,
237 record_batch: RecordBatch,
238 ) -> Result<Vec<Data>, EncodingError> {
239 let ticks: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
240 Ok(ticks.into_iter().map(Data::from).collect())
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use std::{collections::HashMap, sync::Arc};
247
248 use arrow::{array::Array, record_batch::RecordBatch};
249 use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw};
250 use rstest::rstest;
251
252 use super::*;
253 use crate::arrow::{get_raw_price, get_raw_quantity};
254
255 #[rstest]
256 fn test_get_schema() {
257 let instrument_id = InstrumentId::from("AAPL.XNAS");
258 let metadata = QuoteTick::get_metadata(&instrument_id, 2, 0);
259 let schema = QuoteTick::get_schema(Some(metadata.clone()));
260
261 let mut expected_fields = Vec::with_capacity(6);
262
263 expected_fields.push(Field::new(
264 "bid_price",
265 DataType::FixedSizeBinary(PRECISION_BYTES),
266 false,
267 ));
268 expected_fields.push(Field::new(
269 "ask_price",
270 DataType::FixedSizeBinary(PRECISION_BYTES),
271 false,
272 ));
273
274 expected_fields.extend(vec![
275 Field::new(
276 "bid_size",
277 DataType::FixedSizeBinary(PRECISION_BYTES),
278 false,
279 ),
280 Field::new(
281 "ask_size",
282 DataType::FixedSizeBinary(PRECISION_BYTES),
283 false,
284 ),
285 Field::new("ts_event", DataType::UInt64, false),
286 Field::new("ts_init", DataType::UInt64, false),
287 ]);
288
289 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
290 assert_eq!(schema, expected_schema);
291 }
292
293 #[rstest]
294 fn test_get_schema_map() {
295 let arrow_schema = QuoteTick::get_schema_map();
296 let mut expected_map = HashMap::new();
297
298 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
299 expected_map.insert("bid_price".to_string(), fixed_size_binary.clone());
300 expected_map.insert("ask_price".to_string(), fixed_size_binary.clone());
301 expected_map.insert("bid_size".to_string(), fixed_size_binary.clone());
302 expected_map.insert("ask_size".to_string(), fixed_size_binary);
303 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
304 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
305 assert_eq!(arrow_schema, expected_map);
306 }
307
308 #[rstest]
309 fn test_encode_quote_tick() {
310 let instrument_id = InstrumentId::from("AAPL.XNAS");
312 let tick1 = QuoteTick {
313 instrument_id,
314 bid_price: Price::from("100.10"),
315 ask_price: Price::from("101.50"),
316 bid_size: Quantity::from(1000),
317 ask_size: Quantity::from(500),
318 ts_event: 1.into(),
319 ts_init: 3.into(),
320 };
321
322 let tick2 = QuoteTick {
323 instrument_id,
324 bid_price: Price::from("100.75"),
325 ask_price: Price::from("100.20"),
326 bid_size: Quantity::from(750),
327 ask_size: Quantity::from(300),
328 ts_event: 2.into(),
329 ts_init: 4.into(),
330 };
331
332 let data = vec![tick1, tick2];
333 let metadata = QuoteTick::get_metadata(&instrument_id, 2, 0);
334 let record_batch = QuoteTick::encode_batch(&metadata, &data).unwrap();
335
336 let columns = record_batch.columns();
338
339 let bid_price_values = columns[0]
340 .as_any()
341 .downcast_ref::<FixedSizeBinaryArray>()
342 .unwrap();
343 let ask_price_values = columns[1]
344 .as_any()
345 .downcast_ref::<FixedSizeBinaryArray>()
346 .unwrap();
347 assert_eq!(
348 get_raw_price(bid_price_values.value(0)),
349 (100.10 * FIXED_SCALAR) as PriceRaw
350 );
351 assert_eq!(
352 get_raw_price(bid_price_values.value(1)),
353 (100.75 * FIXED_SCALAR) as PriceRaw
354 );
355 assert_eq!(
356 get_raw_price(ask_price_values.value(0)),
357 (101.50 * FIXED_SCALAR) as PriceRaw
358 );
359 assert_eq!(
360 get_raw_price(ask_price_values.value(1)),
361 (100.20 * FIXED_SCALAR) as PriceRaw
362 );
363
364 let bid_size_values = columns[2]
365 .as_any()
366 .downcast_ref::<FixedSizeBinaryArray>()
367 .unwrap();
368 let ask_size_values = columns[3]
369 .as_any()
370 .downcast_ref::<FixedSizeBinaryArray>()
371 .unwrap();
372 let ts_event_values = columns[4].as_any().downcast_ref::<UInt64Array>().unwrap();
373 let ts_init_values = columns[5].as_any().downcast_ref::<UInt64Array>().unwrap();
374
375 assert_eq!(columns.len(), 6);
376 assert_eq!(bid_size_values.len(), 2);
377 assert_eq!(
378 get_raw_quantity(bid_size_values.value(0)),
379 (1000.0 * FIXED_SCALAR) as QuantityRaw
380 );
381 assert_eq!(
382 get_raw_quantity(bid_size_values.value(1)),
383 (750.0 * FIXED_SCALAR) as QuantityRaw
384 );
385 assert_eq!(ask_size_values.len(), 2);
386 assert_eq!(
387 get_raw_quantity(ask_size_values.value(0)),
388 (500.0 * FIXED_SCALAR) as QuantityRaw
389 );
390 assert_eq!(
391 get_raw_quantity(ask_size_values.value(1)),
392 (300.0 * FIXED_SCALAR) as QuantityRaw
393 );
394 assert_eq!(ts_event_values.len(), 2);
395 assert_eq!(ts_event_values.value(0), 1);
396 assert_eq!(ts_event_values.value(1), 2);
397 assert_eq!(ts_init_values.len(), 2);
398 assert_eq!(ts_init_values.value(0), 3);
399 assert_eq!(ts_init_values.value(1), 4);
400 }
401
402 #[rstest]
403 fn test_decode_batch() {
404 let instrument_id = InstrumentId::from("AAPL.XNAS");
405 let metadata = QuoteTick::get_metadata(&instrument_id, 2, 0);
406
407 let raw_bid1 = (100.00 * FIXED_SCALAR) as PriceRaw;
408 let raw_bid2 = (99.00 * FIXED_SCALAR) as PriceRaw;
409 let raw_ask1 = (101.00 * FIXED_SCALAR) as PriceRaw;
410 let raw_ask2 = (100.00 * FIXED_SCALAR) as PriceRaw;
411
412 let (bid_price, ask_price) = (
413 FixedSizeBinaryArray::from(vec![&raw_bid1.to_le_bytes(), &raw_bid2.to_le_bytes()]),
414 FixedSizeBinaryArray::from(vec![&raw_ask1.to_le_bytes(), &raw_ask2.to_le_bytes()]),
415 );
416
417 let bid_size = FixedSizeBinaryArray::from(vec![
418 &((100.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
419 &((90.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
420 ]);
421 let ask_size = FixedSizeBinaryArray::from(vec![
422 &((110.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
423 &((100.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
424 ]);
425 let ts_event = UInt64Array::from(vec![1, 2]);
426 let ts_init = UInt64Array::from(vec![3, 4]);
427
428 let record_batch = RecordBatch::try_new(
429 QuoteTick::get_schema(Some(metadata.clone())).into(),
430 vec![
431 Arc::new(bid_price),
432 Arc::new(ask_price),
433 Arc::new(bid_size),
434 Arc::new(ask_size),
435 Arc::new(ts_event),
436 Arc::new(ts_init),
437 ],
438 )
439 .unwrap();
440
441 let decoded_data = QuoteTick::decode_batch(&metadata, record_batch).unwrap();
442 assert_eq!(decoded_data.len(), 2);
443
444 assert_eq!(decoded_data[0].bid_price, Price::from_raw(raw_bid1, 2));
446 assert_eq!(decoded_data[0].ask_price, Price::from_raw(raw_ask1, 2));
447 assert_eq!(decoded_data[1].bid_price, Price::from_raw(raw_bid2, 2));
448 assert_eq!(decoded_data[1].ask_price, Price::from_raw(raw_ask2, 2));
449 }
450}