1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::QuoteTick,
26 identifiers::InstrumentId,
27 types::{Price, Quantity, fixed::PRECISION_BYTES},
28};
29
30use super::{
31 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
32 KEY_SIZE_PRECISION, extract_column, get_raw_price, get_raw_quantity,
33};
34use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
35
36impl ArrowSchemaProvider for QuoteTick {
37 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
38 let fields = vec![
39 Field::new(
40 "bid_price",
41 DataType::FixedSizeBinary(PRECISION_BYTES),
42 false,
43 ),
44 Field::new(
45 "ask_price",
46 DataType::FixedSizeBinary(PRECISION_BYTES),
47 false,
48 ),
49 Field::new(
50 "bid_size",
51 DataType::FixedSizeBinary(PRECISION_BYTES),
52 false,
53 ),
54 Field::new(
55 "ask_size",
56 DataType::FixedSizeBinary(PRECISION_BYTES),
57 false,
58 ),
59 Field::new("ts_event", DataType::UInt64, false),
60 Field::new("ts_init", DataType::UInt64, false),
61 ];
62
63 match metadata {
64 Some(metadata) => Schema::new_with_metadata(fields, metadata),
65 None => Schema::new(fields),
66 }
67 }
68}
69
70fn parse_metadata(
71 metadata: &HashMap<String, String>,
72) -> Result<(InstrumentId, u8, u8), EncodingError> {
73 let instrument_id_str = metadata
74 .get(KEY_INSTRUMENT_ID)
75 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
76 let instrument_id = InstrumentId::from_str(instrument_id_str)
77 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
78
79 let price_precision = metadata
80 .get(KEY_PRICE_PRECISION)
81 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
82 .parse::<u8>()
83 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
84
85 let size_precision = metadata
86 .get(KEY_SIZE_PRECISION)
87 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
88 .parse::<u8>()
89 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
90
91 Ok((instrument_id, price_precision, size_precision))
92}
93
94impl EncodeToRecordBatch for QuoteTick {
95 fn encode_batch(
96 metadata: &HashMap<String, String>,
97 data: &[Self],
98 ) -> Result<RecordBatch, ArrowError> {
99 let mut bid_price_builder =
100 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
101 let mut ask_price_builder =
102 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
103 let mut bid_size_builder =
104 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
105 let mut ask_size_builder =
106 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
107 let mut ts_event_builder = UInt64Array::builder(data.len());
108 let mut ts_init_builder = UInt64Array::builder(data.len());
109
110 for quote in data {
111 bid_price_builder
112 .append_value(quote.bid_price.raw.to_le_bytes())
113 .unwrap();
114 ask_price_builder
115 .append_value(quote.ask_price.raw.to_le_bytes())
116 .unwrap();
117 bid_size_builder
118 .append_value(quote.bid_size.raw.to_le_bytes())
119 .unwrap();
120 ask_size_builder
121 .append_value(quote.ask_size.raw.to_le_bytes())
122 .unwrap();
123 ts_event_builder.append_value(quote.ts_event.as_u64());
124 ts_init_builder.append_value(quote.ts_init.as_u64());
125 }
126
127 RecordBatch::try_new(
128 Self::get_schema(Some(metadata.clone())).into(),
129 vec![
130 Arc::new(bid_price_builder.finish()),
131 Arc::new(ask_price_builder.finish()),
132 Arc::new(bid_size_builder.finish()),
133 Arc::new(ask_size_builder.finish()),
134 Arc::new(ts_event_builder.finish()),
135 Arc::new(ts_init_builder.finish()),
136 ],
137 )
138 }
139
140 fn metadata(&self) -> HashMap<String, String> {
141 QuoteTick::get_metadata(
142 &self.instrument_id,
143 self.bid_price.precision,
144 self.bid_size.precision,
145 )
146 }
147}
148
149impl DecodeFromRecordBatch for QuoteTick {
150 fn decode_batch(
151 metadata: &HashMap<String, String>,
152 record_batch: RecordBatch,
153 ) -> Result<Vec<Self>, EncodingError> {
154 let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
155 let cols = record_batch.columns();
156
157 let bid_price_values = extract_column::<FixedSizeBinaryArray>(
158 cols,
159 "bid_price",
160 0,
161 DataType::FixedSizeBinary(PRECISION_BYTES),
162 )?;
163 let ask_price_values = extract_column::<FixedSizeBinaryArray>(
164 cols,
165 "ask_price",
166 1,
167 DataType::FixedSizeBinary(PRECISION_BYTES),
168 )?;
169 let bid_size_values = extract_column::<FixedSizeBinaryArray>(
170 cols,
171 "bid_size",
172 2,
173 DataType::FixedSizeBinary(PRECISION_BYTES),
174 )?;
175 let ask_size_values = extract_column::<FixedSizeBinaryArray>(
176 cols,
177 "ask_size",
178 3,
179 DataType::FixedSizeBinary(PRECISION_BYTES),
180 )?;
181 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 4, DataType::UInt64)?;
182 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 5, DataType::UInt64)?;
183
184 if bid_price_values.value_length() != PRECISION_BYTES {
185 return Err(EncodingError::ParseError(
186 "bid_price",
187 format!(
188 "Invalid value length: expected {PRECISION_BYTES}, found {}",
189 bid_price_values.value_length()
190 ),
191 ));
192 }
193 if ask_price_values.value_length() != PRECISION_BYTES {
194 return Err(EncodingError::ParseError(
195 "ask_price",
196 format!(
197 "Invalid value length: expected {PRECISION_BYTES}, found {}",
198 ask_price_values.value_length()
199 ),
200 ));
201 }
202
203 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
204 .map(|row| {
205 Ok(Self {
206 instrument_id,
207 bid_price: Price::from_raw(
208 get_raw_price(bid_price_values.value(row)),
209 price_precision,
210 ),
211 ask_price: Price::from_raw(
212 get_raw_price(ask_price_values.value(row)),
213 price_precision,
214 ),
215 bid_size: Quantity::from_raw(
216 get_raw_quantity(bid_size_values.value(row)),
217 size_precision,
218 ),
219 ask_size: Quantity::from_raw(
220 get_raw_quantity(ask_size_values.value(row)),
221 size_precision,
222 ),
223 ts_event: ts_event_values.value(row).into(),
224 ts_init: ts_init_values.value(row).into(),
225 })
226 })
227 .collect();
228
229 result
230 }
231}
232
233impl DecodeDataFromRecordBatch for QuoteTick {
234 fn decode_data_batch(
235 metadata: &HashMap<String, String>,
236 record_batch: RecordBatch,
237 ) -> Result<Vec<Data>, EncodingError> {
238 let ticks: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
239 Ok(ticks.into_iter().map(Data::from).collect())
240 }
241}
242
243#[cfg(test)]
247mod tests {
248 use std::{collections::HashMap, sync::Arc};
249
250 use arrow::{array::Array, record_batch::RecordBatch};
251 use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw};
252 use rstest::rstest;
253
254 use super::*;
255 use crate::arrow::get_raw_price;
256
257 #[rstest]
258 fn test_get_schema() {
259 let instrument_id = InstrumentId::from("AAPL.XNAS");
260 let metadata = QuoteTick::get_metadata(&instrument_id, 2, 0);
261 let schema = QuoteTick::get_schema(Some(metadata.clone()));
262
263 let mut expected_fields = Vec::with_capacity(6);
264
265 expected_fields.push(Field::new(
266 "bid_price",
267 DataType::FixedSizeBinary(PRECISION_BYTES),
268 false,
269 ));
270 expected_fields.push(Field::new(
271 "ask_price",
272 DataType::FixedSizeBinary(PRECISION_BYTES),
273 false,
274 ));
275
276 expected_fields.extend(vec![
277 Field::new(
278 "bid_size",
279 DataType::FixedSizeBinary(PRECISION_BYTES),
280 false,
281 ),
282 Field::new(
283 "ask_size",
284 DataType::FixedSizeBinary(PRECISION_BYTES),
285 false,
286 ),
287 Field::new("ts_event", DataType::UInt64, false),
288 Field::new("ts_init", DataType::UInt64, false),
289 ]);
290
291 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
292 assert_eq!(schema, expected_schema);
293 }
294
295 #[rstest]
296 fn test_get_schema_map() {
297 let arrow_schema = QuoteTick::get_schema_map();
298 let mut expected_map = HashMap::new();
299
300 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
301 expected_map.insert("bid_price".to_string(), fixed_size_binary.clone());
302 expected_map.insert("ask_price".to_string(), fixed_size_binary.clone());
303 expected_map.insert("bid_size".to_string(), fixed_size_binary.clone());
304 expected_map.insert("ask_size".to_string(), fixed_size_binary);
305 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
306 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
307 assert_eq!(arrow_schema, expected_map);
308 }
309
310 #[rstest]
311 fn test_encode_quote_tick() {
312 let instrument_id = InstrumentId::from("AAPL.XNAS");
314 let tick1 = QuoteTick {
315 instrument_id,
316 bid_price: Price::from("100.10"),
317 ask_price: Price::from("101.50"),
318 bid_size: Quantity::from(1000),
319 ask_size: Quantity::from(500),
320 ts_event: 1.into(),
321 ts_init: 3.into(),
322 };
323
324 let tick2 = QuoteTick {
325 instrument_id,
326 bid_price: Price::from("100.75"),
327 ask_price: Price::from("100.20"),
328 bid_size: Quantity::from(750),
329 ask_size: Quantity::from(300),
330 ts_event: 2.into(),
331 ts_init: 4.into(),
332 };
333
334 let data = vec![tick1, tick2];
335 let metadata = QuoteTick::get_metadata(&instrument_id, 2, 0);
336 let record_batch = QuoteTick::encode_batch(&metadata, &data).unwrap();
337
338 let columns = record_batch.columns();
340
341 let bid_price_values = columns[0]
342 .as_any()
343 .downcast_ref::<FixedSizeBinaryArray>()
344 .unwrap();
345 let ask_price_values = columns[1]
346 .as_any()
347 .downcast_ref::<FixedSizeBinaryArray>()
348 .unwrap();
349 assert_eq!(
350 get_raw_price(bid_price_values.value(0)),
351 (100.10 * FIXED_SCALAR) as PriceRaw
352 );
353 assert_eq!(
354 get_raw_price(bid_price_values.value(1)),
355 (100.75 * FIXED_SCALAR) as PriceRaw
356 );
357 assert_eq!(
358 get_raw_price(ask_price_values.value(0)),
359 (101.50 * FIXED_SCALAR) as PriceRaw
360 );
361 assert_eq!(
362 get_raw_price(ask_price_values.value(1)),
363 (100.20 * FIXED_SCALAR) as PriceRaw
364 );
365
366 let bid_size_values = columns[2]
367 .as_any()
368 .downcast_ref::<FixedSizeBinaryArray>()
369 .unwrap();
370 let ask_size_values = columns[3]
371 .as_any()
372 .downcast_ref::<FixedSizeBinaryArray>()
373 .unwrap();
374 let ts_event_values = columns[4].as_any().downcast_ref::<UInt64Array>().unwrap();
375 let ts_init_values = columns[5].as_any().downcast_ref::<UInt64Array>().unwrap();
376
377 assert_eq!(columns.len(), 6);
378 assert_eq!(bid_size_values.len(), 2);
379 assert_eq!(
380 get_raw_quantity(bid_size_values.value(0)),
381 (1000.0 * FIXED_SCALAR) as QuantityRaw
382 );
383 assert_eq!(
384 get_raw_quantity(bid_size_values.value(1)),
385 (750.0 * FIXED_SCALAR) as QuantityRaw
386 );
387 assert_eq!(ask_size_values.len(), 2);
388 assert_eq!(
389 get_raw_quantity(ask_size_values.value(0)),
390 (500.0 * FIXED_SCALAR) as QuantityRaw
391 );
392 assert_eq!(
393 get_raw_quantity(ask_size_values.value(1)),
394 (300.0 * FIXED_SCALAR) as QuantityRaw
395 );
396 assert_eq!(ts_event_values.len(), 2);
397 assert_eq!(ts_event_values.value(0), 1);
398 assert_eq!(ts_event_values.value(1), 2);
399 assert_eq!(ts_init_values.len(), 2);
400 assert_eq!(ts_init_values.value(0), 3);
401 assert_eq!(ts_init_values.value(1), 4);
402 }
403
404 #[rstest]
405 fn test_decode_batch() {
406 let instrument_id = InstrumentId::from("AAPL.XNAS");
407 let metadata = QuoteTick::get_metadata(&instrument_id, 2, 0);
408
409 let (bid_price, ask_price) = (
410 FixedSizeBinaryArray::from(vec![
411 &(10000 as PriceRaw).to_le_bytes(),
412 &(9900 as PriceRaw).to_le_bytes(),
413 ]),
414 FixedSizeBinaryArray::from(vec![
415 &(10100 as PriceRaw).to_le_bytes(),
416 &(10000 as PriceRaw).to_le_bytes(),
417 ]),
418 );
419
420 let bid_size = FixedSizeBinaryArray::from(vec![
421 &((100.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
422 &((90.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
423 ]);
424 let ask_size = FixedSizeBinaryArray::from(vec![
425 &((110.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
426 &((100.0 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
427 ]);
428 let ts_event = UInt64Array::from(vec![1, 2]);
429 let ts_init = UInt64Array::from(vec![3, 4]);
430
431 let record_batch = RecordBatch::try_new(
432 QuoteTick::get_schema(Some(metadata.clone())).into(),
433 vec![
434 Arc::new(bid_price),
435 Arc::new(ask_price),
436 Arc::new(bid_size),
437 Arc::new(ask_size),
438 Arc::new(ts_event),
439 Arc::new(ts_init),
440 ],
441 )
442 .unwrap();
443
444 let decoded_data = QuoteTick::decode_batch(&metadata, record_batch).unwrap();
445 assert_eq!(decoded_data.len(), 2);
446
447 assert_eq!(decoded_data[0].bid_price, Price::from_raw(10000, 2));
449 assert_eq!(decoded_data[0].ask_price, Price::from_raw(10100, 2));
450 assert_eq!(decoded_data[1].bid_price, Price::from_raw(9900, 2));
451 assert_eq!(decoded_data[1].ask_price, Price::from_raw(10000, 2));
452 }
453}