1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::{Bar, BarType},
26 types::fixed::PRECISION_BYTES,
27};
28
29use super::{
30 DecodeDataFromRecordBatch, EncodingError, KEY_BAR_TYPE, KEY_PRICE_PRECISION,
31 KEY_SIZE_PRECISION, decode_price, decode_quantity, extract_column, validate_precision_bytes,
32};
33use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
34
35impl ArrowSchemaProvider for Bar {
36 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
37 let fields = vec![
38 Field::new("open", DataType::FixedSizeBinary(PRECISION_BYTES), false),
39 Field::new("high", DataType::FixedSizeBinary(PRECISION_BYTES), false),
40 Field::new("low", DataType::FixedSizeBinary(PRECISION_BYTES), false),
41 Field::new("close", DataType::FixedSizeBinary(PRECISION_BYTES), false),
42 Field::new("volume", DataType::FixedSizeBinary(PRECISION_BYTES), false),
43 Field::new("ts_event", DataType::UInt64, false),
44 Field::new("ts_init", DataType::UInt64, false),
45 ];
46
47 match metadata {
48 Some(metadata) => Schema::new_with_metadata(fields, metadata),
49 None => Schema::new(fields),
50 }
51 }
52}
53
54fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(BarType, u8, u8), EncodingError> {
55 let bar_type_str = metadata
56 .get(KEY_BAR_TYPE)
57 .ok_or_else(|| EncodingError::MissingMetadata(KEY_BAR_TYPE))?;
58 let bar_type = BarType::from_str(bar_type_str)
59 .map_err(|e| EncodingError::ParseError(KEY_BAR_TYPE, e.to_string()))?;
60
61 let price_precision = metadata
62 .get(KEY_PRICE_PRECISION)
63 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
64 .parse::<u8>()
65 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
66
67 let size_precision = metadata
68 .get(KEY_SIZE_PRECISION)
69 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
70 .parse::<u8>()
71 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
72
73 Ok((bar_type, price_precision, size_precision))
74}
75
76impl EncodeToRecordBatch for Bar {
77 fn encode_batch(
78 metadata: &HashMap<String, String>,
79 data: &[Self],
80 ) -> Result<RecordBatch, ArrowError> {
81 let mut open_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
82 let mut high_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
83 let mut low_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
84 let mut close_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
85 let mut volume_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
86 let mut ts_event_builder = UInt64Array::builder(data.len());
87 let mut ts_init_builder = UInt64Array::builder(data.len());
88
89 for bar in data {
90 open_builder
91 .append_value(bar.open.raw.to_le_bytes())
92 .unwrap();
93 high_builder
94 .append_value(bar.high.raw.to_le_bytes())
95 .unwrap();
96 low_builder.append_value(bar.low.raw.to_le_bytes()).unwrap();
97 close_builder
98 .append_value(bar.close.raw.to_le_bytes())
99 .unwrap();
100 volume_builder
101 .append_value(bar.volume.raw.to_le_bytes())
102 .unwrap();
103 ts_event_builder.append_value(bar.ts_event.as_u64());
104 ts_init_builder.append_value(bar.ts_init.as_u64());
105 }
106
107 let open_array = open_builder.finish();
108 let high_array = high_builder.finish();
109 let low_array = low_builder.finish();
110 let close_array = close_builder.finish();
111 let volume_array = volume_builder.finish();
112 let ts_event_array = ts_event_builder.finish();
113 let ts_init_array = ts_init_builder.finish();
114
115 RecordBatch::try_new(
116 Self::get_schema(Some(metadata.clone())).into(),
117 vec![
118 Arc::new(open_array),
119 Arc::new(high_array),
120 Arc::new(low_array),
121 Arc::new(close_array),
122 Arc::new(volume_array),
123 Arc::new(ts_event_array),
124 Arc::new(ts_init_array),
125 ],
126 )
127 }
128
129 fn metadata(&self) -> HashMap<String, String> {
130 Self::get_metadata(&self.bar_type, self.open.precision, self.volume.precision)
131 }
132}
133
134impl DecodeFromRecordBatch for Bar {
135 fn decode_batch(
136 metadata: &HashMap<String, String>,
137 record_batch: RecordBatch,
138 ) -> Result<Vec<Self>, EncodingError> {
139 let (bar_type, price_precision, size_precision) = parse_metadata(metadata)?;
140 let cols = record_batch.columns();
141
142 let open_values = extract_column::<FixedSizeBinaryArray>(
143 cols,
144 "open",
145 0,
146 DataType::FixedSizeBinary(PRECISION_BYTES),
147 )?;
148 let high_values = extract_column::<FixedSizeBinaryArray>(
149 cols,
150 "high",
151 1,
152 DataType::FixedSizeBinary(PRECISION_BYTES),
153 )?;
154 let low_values = extract_column::<FixedSizeBinaryArray>(
155 cols,
156 "low",
157 2,
158 DataType::FixedSizeBinary(PRECISION_BYTES),
159 )?;
160 let close_values = extract_column::<FixedSizeBinaryArray>(
161 cols,
162 "close",
163 3,
164 DataType::FixedSizeBinary(PRECISION_BYTES),
165 )?;
166 let volume_values = extract_column::<FixedSizeBinaryArray>(
167 cols,
168 "volume",
169 4,
170 DataType::FixedSizeBinary(PRECISION_BYTES),
171 )?;
172 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 5, DataType::UInt64)?;
173 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 6, DataType::UInt64)?;
174
175 validate_precision_bytes(open_values, "open")?;
176 validate_precision_bytes(high_values, "high")?;
177 validate_precision_bytes(low_values, "low")?;
178 validate_precision_bytes(close_values, "close")?;
179 validate_precision_bytes(volume_values, "volume")?;
180
181 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
182 .map(|i| {
183 let open = decode_price(open_values.value(i), price_precision, "open", i)?;
184 let high = decode_price(high_values.value(i), price_precision, "high", i)?;
185 let low = decode_price(low_values.value(i), price_precision, "low", i)?;
186 let close = decode_price(close_values.value(i), price_precision, "close", i)?;
187 let volume = decode_quantity(volume_values.value(i), size_precision, "volume", i)?;
188 let ts_event = ts_event_values.value(i).into();
189 let ts_init = ts_init_values.value(i).into();
190
191 Ok(Self {
192 bar_type,
193 open,
194 high,
195 low,
196 close,
197 volume,
198 ts_event,
199 ts_init,
200 })
201 })
202 .collect();
203
204 result
205 }
206}
207
208impl DecodeDataFromRecordBatch for Bar {
209 fn decode_data_batch(
210 metadata: &HashMap<String, String>,
211 record_batch: RecordBatch,
212 ) -> Result<Vec<Data>, EncodingError> {
213 let bars: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
214 Ok(bars.into_iter().map(Data::from).collect())
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use std::sync::Arc;
221
222 use arrow::{array::Array, record_batch::RecordBatch};
223 use nautilus_model::types::{
224 Price, Quantity, fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw,
225 };
226 use rstest::rstest;
227
228 use super::*;
229 use crate::arrow::{get_raw_price, get_raw_quantity};
230
231 #[rstest]
232 fn test_get_schema() {
233 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
234 let metadata = Bar::get_metadata(&bar_type, 2, 0);
235 let schema = Bar::get_schema(Some(metadata.clone()));
236 let expected_fields = vec![
237 Field::new("open", DataType::FixedSizeBinary(PRECISION_BYTES), false),
238 Field::new("high", DataType::FixedSizeBinary(PRECISION_BYTES), false),
239 Field::new("low", DataType::FixedSizeBinary(PRECISION_BYTES), false),
240 Field::new("close", DataType::FixedSizeBinary(PRECISION_BYTES), false),
241 Field::new("volume", DataType::FixedSizeBinary(PRECISION_BYTES), false),
242 Field::new("ts_event", DataType::UInt64, false),
243 Field::new("ts_init", DataType::UInt64, false),
244 ];
245 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
246 assert_eq!(schema, expected_schema);
247 }
248
249 #[rstest]
250 fn test_get_schema_map() {
251 let schema_map = Bar::get_schema_map();
252 let mut expected_map = HashMap::new();
253 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
254 expected_map.insert("open".to_string(), fixed_size_binary.clone());
255 expected_map.insert("high".to_string(), fixed_size_binary.clone());
256 expected_map.insert("low".to_string(), fixed_size_binary.clone());
257 expected_map.insert("close".to_string(), fixed_size_binary.clone());
258 expected_map.insert("volume".to_string(), fixed_size_binary);
259 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
260 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
261 assert_eq!(schema_map, expected_map);
262 }
263
264 #[rstest]
265 fn test_encode_batch() {
266 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
267 let metadata = Bar::get_metadata(&bar_type, 2, 0);
268
269 let bar1 = Bar::new(
270 bar_type,
271 Price::from("100.10"),
272 Price::from("102.00"),
273 Price::from("100.00"),
274 Price::from("101.00"),
275 Quantity::from(1100),
276 1.into(),
277 3.into(),
278 );
279 let bar2 = Bar::new(
280 bar_type,
281 Price::from("100.00"),
282 Price::from("100.10"),
283 Price::from("100.00"),
284 Price::from("100.10"),
285 Quantity::from(1110),
286 2.into(),
287 4.into(),
288 );
289
290 let data = vec![bar1, bar2];
291 let record_batch = Bar::encode_batch(&metadata, &data).unwrap();
292
293 let columns = record_batch.columns();
294 let open_values = columns[0]
295 .as_any()
296 .downcast_ref::<FixedSizeBinaryArray>()
297 .unwrap();
298 let high_values = columns[1]
299 .as_any()
300 .downcast_ref::<FixedSizeBinaryArray>()
301 .unwrap();
302 let low_values = columns[2]
303 .as_any()
304 .downcast_ref::<FixedSizeBinaryArray>()
305 .unwrap();
306 let close_values = columns[3]
307 .as_any()
308 .downcast_ref::<FixedSizeBinaryArray>()
309 .unwrap();
310 let volume_values = columns[4]
311 .as_any()
312 .downcast_ref::<FixedSizeBinaryArray>()
313 .unwrap();
314 let ts_event_values = columns[5].as_any().downcast_ref::<UInt64Array>().unwrap();
315 let ts_init_values = columns[6].as_any().downcast_ref::<UInt64Array>().unwrap();
316
317 assert_eq!(columns.len(), 7);
318 assert_eq!(open_values.len(), 2);
319 assert_eq!(
320 get_raw_price(open_values.value(0)),
321 (100.10 * FIXED_SCALAR) as PriceRaw
322 );
323 assert_eq!(
324 get_raw_price(open_values.value(1)),
325 (100.00 * FIXED_SCALAR) as PriceRaw
326 );
327 assert_eq!(high_values.len(), 2);
328 assert_eq!(
329 get_raw_price(high_values.value(0)),
330 (102.00 * FIXED_SCALAR) as PriceRaw
331 );
332 assert_eq!(
333 get_raw_price(high_values.value(1)),
334 (100.10 * FIXED_SCALAR) as PriceRaw
335 );
336 assert_eq!(low_values.len(), 2);
337 assert_eq!(
338 get_raw_price(low_values.value(0)),
339 (100.00 * FIXED_SCALAR) as PriceRaw
340 );
341 assert_eq!(
342 get_raw_price(low_values.value(1)),
343 (100.00 * FIXED_SCALAR) as PriceRaw
344 );
345 assert_eq!(close_values.len(), 2);
346 assert_eq!(
347 get_raw_price(close_values.value(0)),
348 (101.00 * FIXED_SCALAR) as PriceRaw
349 );
350 assert_eq!(
351 get_raw_price(close_values.value(1)),
352 (100.10 * FIXED_SCALAR) as PriceRaw
353 );
354 assert_eq!(volume_values.len(), 2);
355 assert_eq!(
356 get_raw_quantity(volume_values.value(0)),
357 (1100.0 * FIXED_SCALAR) as QuantityRaw
358 );
359 assert_eq!(
360 get_raw_quantity(volume_values.value(1)),
361 (1110.0 * FIXED_SCALAR) as QuantityRaw
362 );
363 assert_eq!(ts_event_values.len(), 2);
364 assert_eq!(ts_event_values.value(0), 1);
365 assert_eq!(ts_event_values.value(1), 2);
366 assert_eq!(ts_init_values.len(), 2);
367 assert_eq!(ts_init_values.value(0), 3);
368 assert_eq!(ts_init_values.value(1), 4);
369 }
370
371 #[rstest]
372 fn test_decode_batch() {
373 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
374 let metadata = Bar::get_metadata(&bar_type, 2, 0);
375
376 let open = FixedSizeBinaryArray::from(vec![
377 &((100.10 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
378 &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
379 ]);
380 let high = FixedSizeBinaryArray::from(vec![
381 &((102.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
382 &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
383 ]);
384 let low = FixedSizeBinaryArray::from(vec![
385 &((100.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
386 &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
387 ]);
388 let close = FixedSizeBinaryArray::from(vec![
389 &((101.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
390 &((10.01 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
391 ]);
392 let volume = FixedSizeBinaryArray::from(vec![
393 &((11.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
394 &((10.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
395 ]);
396 let ts_event = UInt64Array::from(vec![1, 2]);
397 let ts_init = UInt64Array::from(vec![3, 4]);
398
399 let record_batch = RecordBatch::try_new(
400 Bar::get_schema(Some(metadata.clone())).into(),
401 vec![
402 Arc::new(open),
403 Arc::new(high),
404 Arc::new(low),
405 Arc::new(close),
406 Arc::new(volume),
407 Arc::new(ts_event),
408 Arc::new(ts_init),
409 ],
410 )
411 .unwrap();
412
413 let decoded_data = Bar::decode_batch(&metadata, record_batch).unwrap();
414 assert_eq!(decoded_data.len(), 2);
415 }
416
417 #[rstest]
418 fn test_decode_batch_invalid_price_returns_error() {
419 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
420 let metadata = Bar::get_metadata(&bar_type, 2, 0);
421
422 let invalid_price: PriceRaw = PriceRaw::MAX - 1000;
423 let valid_price = (100.00 * FIXED_SCALAR) as PriceRaw;
424
425 let open = FixedSizeBinaryArray::from(vec![&invalid_price.to_le_bytes()]);
426 let high = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
427 let low = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
428 let close = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
429 let volume = FixedSizeBinaryArray::from(vec![
430 &((100.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
431 ]);
432 let ts_event = UInt64Array::from(vec![1]);
433 let ts_init = UInt64Array::from(vec![2]);
434
435 let record_batch = RecordBatch::try_new(
436 Bar::get_schema(Some(metadata.clone())).into(),
437 vec![
438 Arc::new(open),
439 Arc::new(high),
440 Arc::new(low),
441 Arc::new(close),
442 Arc::new(volume),
443 Arc::new(ts_event),
444 Arc::new(ts_init),
445 ],
446 )
447 .unwrap();
448
449 let result = Bar::decode_batch(&metadata, record_batch);
450 assert!(result.is_err());
451 let err = result.unwrap_err();
452 assert!(
453 err.to_string().contains("open") && err.to_string().contains("row 0"),
454 "Expected open error at row 0, was: {err}"
455 );
456 }
457
458 #[rstest]
459 fn test_decode_batch_missing_bar_type_returns_error() {
460 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
461 let mut metadata = Bar::get_metadata(&bar_type, 2, 0);
462
463 let valid_price = (100.00 * FIXED_SCALAR) as PriceRaw;
464 let open = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
465 let high = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
466 let low = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
467 let close = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
468 let volume = FixedSizeBinaryArray::from(vec![
469 &((100.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
470 ]);
471 let ts_event = UInt64Array::from(vec![1]);
472 let ts_init = UInt64Array::from(vec![2]);
473
474 let record_batch = RecordBatch::try_new(
475 Bar::get_schema(Some(metadata.clone())).into(),
476 vec![
477 Arc::new(open),
478 Arc::new(high),
479 Arc::new(low),
480 Arc::new(close),
481 Arc::new(volume),
482 Arc::new(ts_event),
483 Arc::new(ts_init),
484 ],
485 )
486 .unwrap();
487
488 metadata.remove(KEY_BAR_TYPE);
489
490 let result = Bar::decode_batch(&metadata, record_batch);
491 assert!(result.is_err());
492 let err = result.unwrap_err();
493 assert!(
494 err.to_string().contains("bar_type"),
495 "Expected missing bar_type error, was: {err}"
496 );
497 }
498
499 #[rstest]
500 fn test_encode_decode_round_trip() {
501 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
502 let metadata = Bar::get_metadata(&bar_type, 2, 0);
503
504 let bar1 = Bar::new(
505 bar_type,
506 Price::from("100.10"),
507 Price::from("102.00"),
508 Price::from("100.00"),
509 Price::from("101.00"),
510 Quantity::from(1100),
511 1_000_000_000.into(),
512 1_000_000_001.into(),
513 );
514
515 let bar2 = Bar::new(
516 bar_type,
517 Price::from("101.00"),
518 Price::from("103.00"),
519 Price::from("100.50"),
520 Price::from("102.50"),
521 Quantity::from(2200),
522 2_000_000_000.into(),
523 2_000_000_001.into(),
524 );
525
526 let original = vec![bar1, bar2];
527 let record_batch = Bar::encode_batch(&metadata, &original).unwrap();
528 let decoded = Bar::decode_batch(&metadata, record_batch).unwrap();
529
530 assert_eq!(decoded.len(), original.len());
531 for (orig, dec) in original.iter().zip(decoded.iter()) {
532 assert_eq!(dec.bar_type, orig.bar_type);
533 assert_eq!(dec.open, orig.open);
534 assert_eq!(dec.high, orig.high);
535 assert_eq!(dec.low, orig.low);
536 assert_eq!(dec.close, orig.close);
537 assert_eq!(dec.volume, orig.volume);
538 assert_eq!(dec.ts_event, orig.ts_event);
539 assert_eq!(dec.ts_init, orig.ts_init);
540 }
541 }
542}