1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::{Bar, BarType},
26 types::fixed::PRECISION_BYTES,
27};
28
29use super::{
30 DecodeDataFromRecordBatch, EncodingError, KEY_BAR_TYPE, KEY_PRICE_PRECISION,
31 KEY_SIZE_PRECISION, decode_price, decode_quantity, extract_column,
32};
33use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
34
35impl ArrowSchemaProvider for Bar {
36 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
37 let fields = vec![
38 Field::new("open", DataType::FixedSizeBinary(PRECISION_BYTES), false),
39 Field::new("high", DataType::FixedSizeBinary(PRECISION_BYTES), false),
40 Field::new("low", DataType::FixedSizeBinary(PRECISION_BYTES), false),
41 Field::new("close", DataType::FixedSizeBinary(PRECISION_BYTES), false),
42 Field::new("volume", DataType::FixedSizeBinary(PRECISION_BYTES), false),
43 Field::new("ts_event", DataType::UInt64, false),
44 Field::new("ts_init", DataType::UInt64, false),
45 ];
46
47 match metadata {
48 Some(metadata) => Schema::new_with_metadata(fields, metadata),
49 None => Schema::new(fields),
50 }
51 }
52}
53
54fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(BarType, u8, u8), EncodingError> {
55 let bar_type_str = metadata
56 .get(KEY_BAR_TYPE)
57 .ok_or_else(|| EncodingError::MissingMetadata(KEY_BAR_TYPE))?;
58 let bar_type = BarType::from_str(bar_type_str)
59 .map_err(|e| EncodingError::ParseError(KEY_BAR_TYPE, e.to_string()))?;
60
61 let price_precision = metadata
62 .get(KEY_PRICE_PRECISION)
63 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
64 .parse::<u8>()
65 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
66
67 let size_precision = metadata
68 .get(KEY_SIZE_PRECISION)
69 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
70 .parse::<u8>()
71 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
72
73 Ok((bar_type, price_precision, size_precision))
74}
75
76impl EncodeToRecordBatch for Bar {
77 fn encode_batch(
78 metadata: &HashMap<String, String>,
79 data: &[Self],
80 ) -> Result<RecordBatch, ArrowError> {
81 let mut open_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
82 let mut high_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
83 let mut low_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
84 let mut close_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
85 let mut volume_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
86 let mut ts_event_builder = UInt64Array::builder(data.len());
87 let mut ts_init_builder = UInt64Array::builder(data.len());
88
89 for bar in data {
90 open_builder
91 .append_value(bar.open.raw.to_le_bytes())
92 .unwrap();
93 high_builder
94 .append_value(bar.high.raw.to_le_bytes())
95 .unwrap();
96 low_builder.append_value(bar.low.raw.to_le_bytes()).unwrap();
97 close_builder
98 .append_value(bar.close.raw.to_le_bytes())
99 .unwrap();
100 volume_builder
101 .append_value(bar.volume.raw.to_le_bytes())
102 .unwrap();
103 ts_event_builder.append_value(bar.ts_event.as_u64());
104 ts_init_builder.append_value(bar.ts_init.as_u64());
105 }
106
107 let open_array = open_builder.finish();
108 let high_array = high_builder.finish();
109 let low_array = low_builder.finish();
110 let close_array = close_builder.finish();
111 let volume_array = volume_builder.finish();
112 let ts_event_array = ts_event_builder.finish();
113 let ts_init_array = ts_init_builder.finish();
114
115 RecordBatch::try_new(
116 Self::get_schema(Some(metadata.clone())).into(),
117 vec![
118 Arc::new(open_array),
119 Arc::new(high_array),
120 Arc::new(low_array),
121 Arc::new(close_array),
122 Arc::new(volume_array),
123 Arc::new(ts_event_array),
124 Arc::new(ts_init_array),
125 ],
126 )
127 }
128
129 fn metadata(&self) -> HashMap<String, String> {
130 Self::get_metadata(&self.bar_type, self.open.precision, self.volume.precision)
131 }
132}
133
134impl DecodeFromRecordBatch for Bar {
135 fn decode_batch(
136 metadata: &HashMap<String, String>,
137 record_batch: RecordBatch,
138 ) -> Result<Vec<Self>, EncodingError> {
139 let (bar_type, price_precision, size_precision) = parse_metadata(metadata)?;
140 let cols = record_batch.columns();
141
142 let open_values = extract_column::<FixedSizeBinaryArray>(
143 cols,
144 "open",
145 0,
146 DataType::FixedSizeBinary(PRECISION_BYTES),
147 )?;
148 let high_values = extract_column::<FixedSizeBinaryArray>(
149 cols,
150 "high",
151 1,
152 DataType::FixedSizeBinary(PRECISION_BYTES),
153 )?;
154 let low_values = extract_column::<FixedSizeBinaryArray>(
155 cols,
156 "low",
157 2,
158 DataType::FixedSizeBinary(PRECISION_BYTES),
159 )?;
160 let close_values = extract_column::<FixedSizeBinaryArray>(
161 cols,
162 "close",
163 3,
164 DataType::FixedSizeBinary(PRECISION_BYTES),
165 )?;
166 let volume_values = extract_column::<FixedSizeBinaryArray>(
167 cols,
168 "volume",
169 4,
170 DataType::FixedSizeBinary(PRECISION_BYTES),
171 )?;
172 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 5, DataType::UInt64)?;
173 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 6, DataType::UInt64)?;
174
175 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
176 .map(|i| {
177 let open = decode_price(open_values.value(i), price_precision, "open", i)?;
178 let high = decode_price(high_values.value(i), price_precision, "high", i)?;
179 let low = decode_price(low_values.value(i), price_precision, "low", i)?;
180 let close = decode_price(close_values.value(i), price_precision, "close", i)?;
181 let volume = decode_quantity(volume_values.value(i), size_precision, "volume", i)?;
182 let ts_event = ts_event_values.value(i).into();
183 let ts_init = ts_init_values.value(i).into();
184
185 Ok(Self {
186 bar_type,
187 open,
188 high,
189 low,
190 close,
191 volume,
192 ts_event,
193 ts_init,
194 })
195 })
196 .collect();
197
198 result
199 }
200}
201
202impl DecodeDataFromRecordBatch for Bar {
203 fn decode_data_batch(
204 metadata: &HashMap<String, String>,
205 record_batch: RecordBatch,
206 ) -> Result<Vec<Data>, EncodingError> {
207 let bars: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
208 Ok(bars.into_iter().map(Data::from).collect())
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use std::sync::Arc;
215
216 use arrow::{array::Array, record_batch::RecordBatch};
217 use nautilus_model::types::{
218 Price, Quantity, fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw,
219 };
220 use rstest::rstest;
221
222 use super::*;
223 use crate::arrow::{get_raw_price, get_raw_quantity};
224
225 #[rstest]
226 fn test_get_schema() {
227 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
228 let metadata = Bar::get_metadata(&bar_type, 2, 0);
229 let schema = Bar::get_schema(Some(metadata.clone()));
230 let expected_fields = vec![
231 Field::new("open", DataType::FixedSizeBinary(PRECISION_BYTES), false),
232 Field::new("high", DataType::FixedSizeBinary(PRECISION_BYTES), false),
233 Field::new("low", DataType::FixedSizeBinary(PRECISION_BYTES), false),
234 Field::new("close", DataType::FixedSizeBinary(PRECISION_BYTES), false),
235 Field::new("volume", DataType::FixedSizeBinary(PRECISION_BYTES), false),
236 Field::new("ts_event", DataType::UInt64, false),
237 Field::new("ts_init", DataType::UInt64, false),
238 ];
239 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
240 assert_eq!(schema, expected_schema);
241 }
242
243 #[rstest]
244 fn test_get_schema_map() {
245 let schema_map = Bar::get_schema_map();
246 let mut expected_map = HashMap::new();
247 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
248 expected_map.insert("open".to_string(), fixed_size_binary.clone());
249 expected_map.insert("high".to_string(), fixed_size_binary.clone());
250 expected_map.insert("low".to_string(), fixed_size_binary.clone());
251 expected_map.insert("close".to_string(), fixed_size_binary.clone());
252 expected_map.insert("volume".to_string(), fixed_size_binary);
253 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
254 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
255 assert_eq!(schema_map, expected_map);
256 }
257
258 #[rstest]
259 fn test_encode_batch() {
260 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
261 let metadata = Bar::get_metadata(&bar_type, 2, 0);
262
263 let bar1 = Bar::new(
264 bar_type,
265 Price::from("100.10"),
266 Price::from("102.00"),
267 Price::from("100.00"),
268 Price::from("101.00"),
269 Quantity::from(1100),
270 1.into(),
271 3.into(),
272 );
273 let bar2 = Bar::new(
274 bar_type,
275 Price::from("100.00"),
276 Price::from("100.10"),
277 Price::from("100.00"),
278 Price::from("100.10"),
279 Quantity::from(1110),
280 2.into(),
281 4.into(),
282 );
283
284 let data = vec![bar1, bar2];
285 let record_batch = Bar::encode_batch(&metadata, &data).unwrap();
286
287 let columns = record_batch.columns();
288 let open_values = columns[0]
289 .as_any()
290 .downcast_ref::<FixedSizeBinaryArray>()
291 .unwrap();
292 let high_values = columns[1]
293 .as_any()
294 .downcast_ref::<FixedSizeBinaryArray>()
295 .unwrap();
296 let low_values = columns[2]
297 .as_any()
298 .downcast_ref::<FixedSizeBinaryArray>()
299 .unwrap();
300 let close_values = columns[3]
301 .as_any()
302 .downcast_ref::<FixedSizeBinaryArray>()
303 .unwrap();
304 let volume_values = columns[4]
305 .as_any()
306 .downcast_ref::<FixedSizeBinaryArray>()
307 .unwrap();
308 let ts_event_values = columns[5].as_any().downcast_ref::<UInt64Array>().unwrap();
309 let ts_init_values = columns[6].as_any().downcast_ref::<UInt64Array>().unwrap();
310
311 assert_eq!(columns.len(), 7);
312 assert_eq!(open_values.len(), 2);
313 assert_eq!(
314 get_raw_price(open_values.value(0)),
315 (100.10 * FIXED_SCALAR) as PriceRaw
316 );
317 assert_eq!(
318 get_raw_price(open_values.value(1)),
319 (100.00 * FIXED_SCALAR) as PriceRaw
320 );
321 assert_eq!(high_values.len(), 2);
322 assert_eq!(
323 get_raw_price(high_values.value(0)),
324 (102.00 * FIXED_SCALAR) as PriceRaw
325 );
326 assert_eq!(
327 get_raw_price(high_values.value(1)),
328 (100.10 * FIXED_SCALAR) as PriceRaw
329 );
330 assert_eq!(low_values.len(), 2);
331 assert_eq!(
332 get_raw_price(low_values.value(0)),
333 (100.00 * FIXED_SCALAR) as PriceRaw
334 );
335 assert_eq!(
336 get_raw_price(low_values.value(1)),
337 (100.00 * FIXED_SCALAR) as PriceRaw
338 );
339 assert_eq!(close_values.len(), 2);
340 assert_eq!(
341 get_raw_price(close_values.value(0)),
342 (101.00 * FIXED_SCALAR) as PriceRaw
343 );
344 assert_eq!(
345 get_raw_price(close_values.value(1)),
346 (100.10 * FIXED_SCALAR) as PriceRaw
347 );
348 assert_eq!(volume_values.len(), 2);
349 assert_eq!(
350 get_raw_quantity(volume_values.value(0)),
351 (1100.0 * FIXED_SCALAR) as QuantityRaw
352 );
353 assert_eq!(
354 get_raw_quantity(volume_values.value(1)),
355 (1110.0 * FIXED_SCALAR) as QuantityRaw
356 );
357 assert_eq!(ts_event_values.len(), 2);
358 assert_eq!(ts_event_values.value(0), 1);
359 assert_eq!(ts_event_values.value(1), 2);
360 assert_eq!(ts_init_values.len(), 2);
361 assert_eq!(ts_init_values.value(0), 3);
362 assert_eq!(ts_init_values.value(1), 4);
363 }
364
365 #[rstest]
366 fn test_decode_batch() {
367 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
368 let metadata = Bar::get_metadata(&bar_type, 2, 0);
369
370 let open = FixedSizeBinaryArray::from(vec![
371 &((100.10 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
372 &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
373 ]);
374 let high = FixedSizeBinaryArray::from(vec![
375 &((102.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
376 &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
377 ]);
378 let low = FixedSizeBinaryArray::from(vec![
379 &((100.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
380 &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
381 ]);
382 let close = FixedSizeBinaryArray::from(vec![
383 &((101.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
384 &((10.01 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
385 ]);
386 let volume = FixedSizeBinaryArray::from(vec![
387 &((11.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
388 &((10.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
389 ]);
390 let ts_event = UInt64Array::from(vec![1, 2]);
391 let ts_init = UInt64Array::from(vec![3, 4]);
392
393 let record_batch = RecordBatch::try_new(
394 Bar::get_schema(Some(metadata.clone())).into(),
395 vec![
396 Arc::new(open),
397 Arc::new(high),
398 Arc::new(low),
399 Arc::new(close),
400 Arc::new(volume),
401 Arc::new(ts_event),
402 Arc::new(ts_init),
403 ],
404 )
405 .unwrap();
406
407 let decoded_data = Bar::decode_batch(&metadata, record_batch).unwrap();
408 assert_eq!(decoded_data.len(), 2);
409 }
410
411 #[rstest]
412 fn test_decode_batch_invalid_price_returns_error() {
413 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
414 let metadata = Bar::get_metadata(&bar_type, 2, 0);
415
416 let invalid_price: PriceRaw = PriceRaw::MAX - 1000;
417 let valid_price = (100.00 * FIXED_SCALAR) as PriceRaw;
418
419 let open = FixedSizeBinaryArray::from(vec![&invalid_price.to_le_bytes()]);
420 let high = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
421 let low = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
422 let close = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
423 let volume = FixedSizeBinaryArray::from(vec![
424 &((100.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
425 ]);
426 let ts_event = UInt64Array::from(vec![1]);
427 let ts_init = UInt64Array::from(vec![2]);
428
429 let record_batch = RecordBatch::try_new(
430 Bar::get_schema(Some(metadata.clone())).into(),
431 vec![
432 Arc::new(open),
433 Arc::new(high),
434 Arc::new(low),
435 Arc::new(close),
436 Arc::new(volume),
437 Arc::new(ts_event),
438 Arc::new(ts_init),
439 ],
440 )
441 .unwrap();
442
443 let result = Bar::decode_batch(&metadata, record_batch);
444 assert!(result.is_err());
445 let err = result.unwrap_err();
446 assert!(
447 err.to_string().contains("open") && err.to_string().contains("row 0"),
448 "Expected open error at row 0, got: {err}"
449 );
450 }
451
452 #[rstest]
453 fn test_decode_batch_missing_bar_type_returns_error() {
454 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
455 let mut metadata = Bar::get_metadata(&bar_type, 2, 0);
456
457 let valid_price = (100.00 * FIXED_SCALAR) as PriceRaw;
458 let open = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
459 let high = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
460 let low = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
461 let close = FixedSizeBinaryArray::from(vec![&valid_price.to_le_bytes()]);
462 let volume = FixedSizeBinaryArray::from(vec![
463 &((100.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
464 ]);
465 let ts_event = UInt64Array::from(vec![1]);
466 let ts_init = UInt64Array::from(vec![2]);
467
468 let record_batch = RecordBatch::try_new(
469 Bar::get_schema(Some(metadata.clone())).into(),
470 vec![
471 Arc::new(open),
472 Arc::new(high),
473 Arc::new(low),
474 Arc::new(close),
475 Arc::new(volume),
476 Arc::new(ts_event),
477 Arc::new(ts_init),
478 ],
479 )
480 .unwrap();
481
482 metadata.remove(KEY_BAR_TYPE);
483
484 let result = Bar::decode_batch(&metadata, record_batch);
485 assert!(result.is_err());
486 let err = result.unwrap_err();
487 assert!(
488 err.to_string().contains("bar_type"),
489 "Expected missing bar_type error, got: {err}"
490 );
491 }
492
493 #[rstest]
494 fn test_encode_decode_round_trip() {
495 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
496 let metadata = Bar::get_metadata(&bar_type, 2, 0);
497
498 let bar1 = Bar::new(
499 bar_type,
500 Price::from("100.10"),
501 Price::from("102.00"),
502 Price::from("100.00"),
503 Price::from("101.00"),
504 Quantity::from(1100),
505 1_000_000_000.into(),
506 1_000_000_001.into(),
507 );
508
509 let bar2 = Bar::new(
510 bar_type,
511 Price::from("101.00"),
512 Price::from("103.00"),
513 Price::from("100.50"),
514 Price::from("102.50"),
515 Quantity::from(2200),
516 2_000_000_000.into(),
517 2_000_000_001.into(),
518 );
519
520 let original = vec![bar1, bar2];
521 let record_batch = Bar::encode_batch(&metadata, &original).unwrap();
522 let decoded = Bar::decode_batch(&metadata, record_batch).unwrap();
523
524 assert_eq!(decoded.len(), original.len());
525 for (orig, dec) in original.iter().zip(decoded.iter()) {
526 assert_eq!(dec.bar_type, orig.bar_type);
527 assert_eq!(dec.open, orig.open);
528 assert_eq!(dec.high, orig.high);
529 assert_eq!(dec.low, orig.low);
530 assert_eq!(dec.close, orig.close);
531 assert_eq!(dec.volume, orig.volume);
532 assert_eq!(dec.ts_event, orig.ts_event);
533 assert_eq!(dec.ts_init, orig.ts_init);
534 }
535 }
536}