1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::{Bar, BarType},
26 types::{Price, Quantity, fixed::PRECISION_BYTES},
27};
28
29use super::{
30 DecodeDataFromRecordBatch, EncodingError, KEY_BAR_TYPE, KEY_PRICE_PRECISION,
31 KEY_SIZE_PRECISION, extract_column, get_raw_quantity,
32};
33use crate::arrow::{
34 ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
35};
36
37impl ArrowSchemaProvider for Bar {
38 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
39 let fields = vec![
40 Field::new("open", DataType::FixedSizeBinary(PRECISION_BYTES), false),
41 Field::new("high", DataType::FixedSizeBinary(PRECISION_BYTES), false),
42 Field::new("low", DataType::FixedSizeBinary(PRECISION_BYTES), false),
43 Field::new("close", DataType::FixedSizeBinary(PRECISION_BYTES), false),
44 Field::new("volume", DataType::FixedSizeBinary(PRECISION_BYTES), false),
45 Field::new("ts_event", DataType::UInt64, false),
46 Field::new("ts_init", DataType::UInt64, false),
47 ];
48
49 match metadata {
50 Some(metadata) => Schema::new_with_metadata(fields, metadata),
51 None => Schema::new(fields),
52 }
53 }
54}
55
56fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(BarType, u8, u8), EncodingError> {
57 let bar_type_str = metadata
58 .get(KEY_BAR_TYPE)
59 .ok_or_else(|| EncodingError::MissingMetadata(KEY_BAR_TYPE))?;
60 let bar_type = BarType::from_str(bar_type_str)
61 .map_err(|e| EncodingError::ParseError(KEY_BAR_TYPE, e.to_string()))?;
62
63 let price_precision = metadata
64 .get(KEY_PRICE_PRECISION)
65 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
66 .parse::<u8>()
67 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
68
69 let size_precision = metadata
70 .get(KEY_SIZE_PRECISION)
71 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
72 .parse::<u8>()
73 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
74
75 Ok((bar_type, price_precision, size_precision))
76}
77
78impl EncodeToRecordBatch for Bar {
79 fn encode_batch(
80 metadata: &HashMap<String, String>,
81 data: &[Self],
82 ) -> Result<RecordBatch, ArrowError> {
83 let mut open_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
84 let mut high_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
85 let mut low_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
86 let mut close_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
87 let mut volume_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
88 let mut ts_event_builder = UInt64Array::builder(data.len());
89 let mut ts_init_builder = UInt64Array::builder(data.len());
90
91 for bar in data {
92 open_builder
93 .append_value(bar.open.raw.to_le_bytes())
94 .unwrap();
95 high_builder
96 .append_value(bar.high.raw.to_le_bytes())
97 .unwrap();
98 low_builder.append_value(bar.low.raw.to_le_bytes()).unwrap();
99 close_builder
100 .append_value(bar.close.raw.to_le_bytes())
101 .unwrap();
102 volume_builder
103 .append_value(bar.volume.raw.to_le_bytes())
104 .unwrap();
105 ts_event_builder.append_value(bar.ts_event.as_u64());
106 ts_init_builder.append_value(bar.ts_init.as_u64());
107 }
108
109 let open_array = open_builder.finish();
110 let high_array = high_builder.finish();
111 let low_array = low_builder.finish();
112 let close_array = close_builder.finish();
113 let volume_array = volume_builder.finish();
114 let ts_event_array = ts_event_builder.finish();
115 let ts_init_array = ts_init_builder.finish();
116
117 RecordBatch::try_new(
118 Self::get_schema(Some(metadata.clone())).into(),
119 vec![
120 Arc::new(open_array),
121 Arc::new(high_array),
122 Arc::new(low_array),
123 Arc::new(close_array),
124 Arc::new(volume_array),
125 Arc::new(ts_event_array),
126 Arc::new(ts_init_array),
127 ],
128 )
129 }
130
131 fn metadata(&self) -> HashMap<String, String> {
132 Bar::get_metadata(&self.bar_type, self.open.precision, self.volume.precision)
133 }
134}
135
136impl DecodeFromRecordBatch for Bar {
137 fn decode_batch(
138 metadata: &HashMap<String, String>,
139 record_batch: RecordBatch,
140 ) -> Result<Vec<Self>, EncodingError> {
141 let (bar_type, price_precision, size_precision) = parse_metadata(metadata)?;
142 let cols = record_batch.columns();
143
144 let open_values = extract_column::<FixedSizeBinaryArray>(
145 cols,
146 "open",
147 0,
148 DataType::FixedSizeBinary(PRECISION_BYTES),
149 )?;
150 let high_values = extract_column::<FixedSizeBinaryArray>(
151 cols,
152 "high",
153 1,
154 DataType::FixedSizeBinary(PRECISION_BYTES),
155 )?;
156 let low_values = extract_column::<FixedSizeBinaryArray>(
157 cols,
158 "low",
159 2,
160 DataType::FixedSizeBinary(PRECISION_BYTES),
161 )?;
162 let close_values = extract_column::<FixedSizeBinaryArray>(
163 cols,
164 "close",
165 3,
166 DataType::FixedSizeBinary(PRECISION_BYTES),
167 )?;
168 let volume_values = extract_column::<FixedSizeBinaryArray>(
169 cols,
170 "volume",
171 4,
172 DataType::FixedSizeBinary(PRECISION_BYTES),
173 )?;
174 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 5, DataType::UInt64)?;
175 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 6, DataType::UInt64)?;
176
177 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
178 .map(|i| {
179 let open = Price::from_raw(get_raw_price(open_values.value(i)), price_precision);
180 let high = Price::from_raw(get_raw_price(high_values.value(i)), price_precision);
181 let low = Price::from_raw(get_raw_price(low_values.value(i)), price_precision);
182 let close = Price::from_raw(get_raw_price(close_values.value(i)), price_precision);
183 let volume =
184 Quantity::from_raw(get_raw_quantity(volume_values.value(i)), size_precision);
185 let ts_event = ts_event_values.value(i).into();
186 let ts_init = ts_init_values.value(i).into();
187
188 Ok(Self {
189 bar_type,
190 open,
191 high,
192 low,
193 close,
194 volume,
195 ts_event,
196 ts_init,
197 })
198 })
199 .collect();
200
201 result
202 }
203}
204
205impl DecodeDataFromRecordBatch for Bar {
206 fn decode_data_batch(
207 metadata: &HashMap<String, String>,
208 record_batch: RecordBatch,
209 ) -> Result<Vec<Data>, EncodingError> {
210 let bars: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
211 Ok(bars.into_iter().map(Data::from).collect())
212 }
213}
214
215#[cfg(test)]
219mod tests {
220 use std::sync::Arc;
221
222 use arrow::{array::Array, record_batch::RecordBatch};
223 use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw};
224 use rstest::rstest;
225
226 use super::*;
227 use crate::arrow::get_raw_price;
228
229 #[rstest]
230 fn test_get_schema() {
231 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
232 let metadata = Bar::get_metadata(&bar_type, 2, 0);
233 let schema = Bar::get_schema(Some(metadata.clone()));
234 let expected_fields = vec![
235 Field::new("open", DataType::FixedSizeBinary(PRECISION_BYTES), false),
236 Field::new("high", DataType::FixedSizeBinary(PRECISION_BYTES), false),
237 Field::new("low", DataType::FixedSizeBinary(PRECISION_BYTES), false),
238 Field::new("close", DataType::FixedSizeBinary(PRECISION_BYTES), false),
239 Field::new("volume", DataType::FixedSizeBinary(PRECISION_BYTES), false),
240 Field::new("ts_event", DataType::UInt64, false),
241 Field::new("ts_init", DataType::UInt64, false),
242 ];
243 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
244 assert_eq!(schema, expected_schema);
245 }
246
247 #[rstest]
248 fn test_get_schema_map() {
249 let schema_map = Bar::get_schema_map();
250 let mut expected_map = HashMap::new();
251 let fixed_size_binary = format!("FixedSizeBinary({})", PRECISION_BYTES);
252 expected_map.insert("open".to_string(), fixed_size_binary.clone());
253 expected_map.insert("high".to_string(), fixed_size_binary.clone());
254 expected_map.insert("low".to_string(), fixed_size_binary.clone());
255 expected_map.insert("close".to_string(), fixed_size_binary.clone());
256 expected_map.insert("volume".to_string(), fixed_size_binary.clone());
257 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
258 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
259 assert_eq!(schema_map, expected_map);
260 }
261
262 #[rstest]
263 fn test_encode_batch() {
264 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
265 let metadata = Bar::get_metadata(&bar_type, 2, 0);
266
267 let bar1 = Bar::new(
268 bar_type,
269 Price::from("100.10"),
270 Price::from("102.00"),
271 Price::from("100.00"),
272 Price::from("101.00"),
273 Quantity::from(1100),
274 1.into(),
275 3.into(),
276 );
277 let bar2 = Bar::new(
278 bar_type,
279 Price::from("100.00"),
280 Price::from("100.10"),
281 Price::from("100.00"),
282 Price::from("100.10"),
283 Quantity::from(1110),
284 2.into(),
285 4.into(),
286 );
287
288 let data = vec![bar1, bar2];
289 let record_batch = Bar::encode_batch(&metadata, &data).unwrap();
290
291 let columns = record_batch.columns();
292 let open_values = columns[0]
293 .as_any()
294 .downcast_ref::<FixedSizeBinaryArray>()
295 .unwrap();
296 let high_values = columns[1]
297 .as_any()
298 .downcast_ref::<FixedSizeBinaryArray>()
299 .unwrap();
300 let low_values = columns[2]
301 .as_any()
302 .downcast_ref::<FixedSizeBinaryArray>()
303 .unwrap();
304 let close_values = columns[3]
305 .as_any()
306 .downcast_ref::<FixedSizeBinaryArray>()
307 .unwrap();
308 let volume_values = columns[4]
309 .as_any()
310 .downcast_ref::<FixedSizeBinaryArray>()
311 .unwrap();
312 let ts_event_values = columns[5].as_any().downcast_ref::<UInt64Array>().unwrap();
313 let ts_init_values = columns[6].as_any().downcast_ref::<UInt64Array>().unwrap();
314
315 assert_eq!(columns.len(), 7);
316 assert_eq!(open_values.len(), 2);
317 assert_eq!(
318 get_raw_price(open_values.value(0)),
319 (100.10 * FIXED_SCALAR) as PriceRaw
320 );
321 assert_eq!(
322 get_raw_price(open_values.value(1)),
323 (100.00 * FIXED_SCALAR) as PriceRaw
324 );
325 assert_eq!(high_values.len(), 2);
326 assert_eq!(
327 get_raw_price(high_values.value(0)),
328 (102.00 * FIXED_SCALAR) as PriceRaw
329 );
330 assert_eq!(
331 get_raw_price(high_values.value(1)),
332 (100.10 * FIXED_SCALAR) as PriceRaw
333 );
334 assert_eq!(low_values.len(), 2);
335 assert_eq!(
336 get_raw_price(low_values.value(0)),
337 (100.00 * FIXED_SCALAR) as PriceRaw
338 );
339 assert_eq!(
340 get_raw_price(low_values.value(1)),
341 (100.00 * FIXED_SCALAR) as PriceRaw
342 );
343 assert_eq!(close_values.len(), 2);
344 assert_eq!(
345 get_raw_price(close_values.value(0)),
346 (101.00 * FIXED_SCALAR) as PriceRaw
347 );
348 assert_eq!(
349 get_raw_price(close_values.value(1)),
350 (100.10 * FIXED_SCALAR) as PriceRaw
351 );
352 assert_eq!(volume_values.len(), 2);
353 assert_eq!(
354 get_raw_quantity(volume_values.value(0)),
355 (1100.0 * FIXED_SCALAR) as QuantityRaw
356 );
357 assert_eq!(
358 get_raw_quantity(volume_values.value(1)),
359 (1110.0 * FIXED_SCALAR) as QuantityRaw
360 );
361 assert_eq!(ts_event_values.len(), 2);
362 assert_eq!(ts_event_values.value(0), 1);
363 assert_eq!(ts_event_values.value(1), 2);
364 assert_eq!(ts_init_values.len(), 2);
365 assert_eq!(ts_init_values.value(0), 3);
366 assert_eq!(ts_init_values.value(1), 4);
367 }
368
369 #[rstest]
370 fn test_decode_batch() {
371 use nautilus_model::types::{price::PriceRaw, quantity::QuantityRaw};
372
373 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
374 let metadata = Bar::get_metadata(&bar_type, 2, 0);
375
376 let open = FixedSizeBinaryArray::from(vec![
377 &(100_100_000_000 as PriceRaw).to_le_bytes(),
378 &(10_000_000_000 as PriceRaw).to_le_bytes(),
379 ]);
380 let high = FixedSizeBinaryArray::from(vec![
381 &(102_000_000_000 as PriceRaw).to_le_bytes(),
382 &(10_000_000_000 as PriceRaw).to_le_bytes(),
383 ]);
384 let low = FixedSizeBinaryArray::from(vec![
385 &(100_000_000_000 as PriceRaw).to_le_bytes(),
386 &(10_000_000_000 as PriceRaw).to_le_bytes(),
387 ]);
388 let close = FixedSizeBinaryArray::from(vec![
389 &(101_000_000_000 as PriceRaw).to_le_bytes(),
390 &(10_010_000_000 as PriceRaw).to_le_bytes(),
391 ]);
392 let volume = FixedSizeBinaryArray::from(vec![
393 &(11_000_000_000 as QuantityRaw).to_le_bytes(),
394 &(10_000_000_000 as QuantityRaw).to_le_bytes(),
395 ]);
396 let ts_event = UInt64Array::from(vec![1, 2]);
397 let ts_init = UInt64Array::from(vec![3, 4]);
398
399 let record_batch = RecordBatch::try_new(
400 Bar::get_schema(Some(metadata.clone())).into(),
401 vec![
402 Arc::new(open),
403 Arc::new(high),
404 Arc::new(low),
405 Arc::new(close),
406 Arc::new(volume),
407 Arc::new(ts_event),
408 Arc::new(ts_init),
409 ],
410 )
411 .unwrap();
412
413 let decoded_data = Bar::decode_batch(&metadata, record_batch).unwrap();
414 assert_eq!(decoded_data.len(), 2);
415 }
416}