1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::{Bar, BarType},
26 types::{Price, Quantity, fixed::PRECISION_BYTES},
27};
28
29use super::{
30 DecodeDataFromRecordBatch, EncodingError, KEY_BAR_TYPE, KEY_PRICE_PRECISION,
31 KEY_SIZE_PRECISION, extract_column, get_corrected_raw_price, get_corrected_raw_quantity,
32};
33use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
34
35impl ArrowSchemaProvider for Bar {
36 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
37 let fields = vec![
38 Field::new("open", DataType::FixedSizeBinary(PRECISION_BYTES), false),
39 Field::new("high", DataType::FixedSizeBinary(PRECISION_BYTES), false),
40 Field::new("low", DataType::FixedSizeBinary(PRECISION_BYTES), false),
41 Field::new("close", DataType::FixedSizeBinary(PRECISION_BYTES), false),
42 Field::new("volume", DataType::FixedSizeBinary(PRECISION_BYTES), false),
43 Field::new("ts_event", DataType::UInt64, false),
44 Field::new("ts_init", DataType::UInt64, false),
45 ];
46
47 match metadata {
48 Some(metadata) => Schema::new_with_metadata(fields, metadata),
49 None => Schema::new(fields),
50 }
51 }
52}
53
54fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(BarType, u8, u8), EncodingError> {
55 let bar_type_str = metadata
56 .get(KEY_BAR_TYPE)
57 .ok_or_else(|| EncodingError::MissingMetadata(KEY_BAR_TYPE))?;
58 let bar_type = BarType::from_str(bar_type_str)
59 .map_err(|e| EncodingError::ParseError(KEY_BAR_TYPE, e.to_string()))?;
60
61 let price_precision = metadata
62 .get(KEY_PRICE_PRECISION)
63 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
64 .parse::<u8>()
65 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
66
67 let size_precision = metadata
68 .get(KEY_SIZE_PRECISION)
69 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
70 .parse::<u8>()
71 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
72
73 Ok((bar_type, price_precision, size_precision))
74}
75
76impl EncodeToRecordBatch for Bar {
77 fn encode_batch(
78 metadata: &HashMap<String, String>,
79 data: &[Self],
80 ) -> Result<RecordBatch, ArrowError> {
81 let mut open_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
82 let mut high_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
83 let mut low_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
84 let mut close_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
85 let mut volume_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
86 let mut ts_event_builder = UInt64Array::builder(data.len());
87 let mut ts_init_builder = UInt64Array::builder(data.len());
88
89 for bar in data {
90 open_builder
91 .append_value(bar.open.raw.to_le_bytes())
92 .unwrap();
93 high_builder
94 .append_value(bar.high.raw.to_le_bytes())
95 .unwrap();
96 low_builder.append_value(bar.low.raw.to_le_bytes()).unwrap();
97 close_builder
98 .append_value(bar.close.raw.to_le_bytes())
99 .unwrap();
100 volume_builder
101 .append_value(bar.volume.raw.to_le_bytes())
102 .unwrap();
103 ts_event_builder.append_value(bar.ts_event.as_u64());
104 ts_init_builder.append_value(bar.ts_init.as_u64());
105 }
106
107 let open_array = open_builder.finish();
108 let high_array = high_builder.finish();
109 let low_array = low_builder.finish();
110 let close_array = close_builder.finish();
111 let volume_array = volume_builder.finish();
112 let ts_event_array = ts_event_builder.finish();
113 let ts_init_array = ts_init_builder.finish();
114
115 RecordBatch::try_new(
116 Self::get_schema(Some(metadata.clone())).into(),
117 vec![
118 Arc::new(open_array),
119 Arc::new(high_array),
120 Arc::new(low_array),
121 Arc::new(close_array),
122 Arc::new(volume_array),
123 Arc::new(ts_event_array),
124 Arc::new(ts_init_array),
125 ],
126 )
127 }
128
129 fn metadata(&self) -> HashMap<String, String> {
130 Self::get_metadata(&self.bar_type, self.open.precision, self.volume.precision)
131 }
132}
133
134impl DecodeFromRecordBatch for Bar {
135 fn decode_batch(
136 metadata: &HashMap<String, String>,
137 record_batch: RecordBatch,
138 ) -> Result<Vec<Self>, EncodingError> {
139 let (bar_type, price_precision, size_precision) = parse_metadata(metadata)?;
140 let cols = record_batch.columns();
141
142 let open_values = extract_column::<FixedSizeBinaryArray>(
143 cols,
144 "open",
145 0,
146 DataType::FixedSizeBinary(PRECISION_BYTES),
147 )?;
148 let high_values = extract_column::<FixedSizeBinaryArray>(
149 cols,
150 "high",
151 1,
152 DataType::FixedSizeBinary(PRECISION_BYTES),
153 )?;
154 let low_values = extract_column::<FixedSizeBinaryArray>(
155 cols,
156 "low",
157 2,
158 DataType::FixedSizeBinary(PRECISION_BYTES),
159 )?;
160 let close_values = extract_column::<FixedSizeBinaryArray>(
161 cols,
162 "close",
163 3,
164 DataType::FixedSizeBinary(PRECISION_BYTES),
165 )?;
166 let volume_values = extract_column::<FixedSizeBinaryArray>(
167 cols,
168 "volume",
169 4,
170 DataType::FixedSizeBinary(PRECISION_BYTES),
171 )?;
172 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 5, DataType::UInt64)?;
173 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 6, DataType::UInt64)?;
174
175 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
176 .map(|i| {
177 let open = Price::from_raw(
179 get_corrected_raw_price(open_values.value(i), price_precision),
180 price_precision,
181 );
182 let high = Price::from_raw(
183 get_corrected_raw_price(high_values.value(i), price_precision),
184 price_precision,
185 );
186 let low = Price::from_raw(
187 get_corrected_raw_price(low_values.value(i), price_precision),
188 price_precision,
189 );
190 let close = Price::from_raw(
191 get_corrected_raw_price(close_values.value(i), price_precision),
192 price_precision,
193 );
194 let volume = Quantity::from_raw(
195 get_corrected_raw_quantity(volume_values.value(i), size_precision),
196 size_precision,
197 );
198 let ts_event = ts_event_values.value(i).into();
199 let ts_init = ts_init_values.value(i).into();
200
201 Ok(Self {
202 bar_type,
203 open,
204 high,
205 low,
206 close,
207 volume,
208 ts_event,
209 ts_init,
210 })
211 })
212 .collect();
213
214 result
215 }
216}
217
218impl DecodeDataFromRecordBatch for Bar {
219 fn decode_data_batch(
220 metadata: &HashMap<String, String>,
221 record_batch: RecordBatch,
222 ) -> Result<Vec<Data>, EncodingError> {
223 let bars: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
224 Ok(bars.into_iter().map(Data::from).collect())
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use std::sync::Arc;
231
232 use arrow::{array::Array, record_batch::RecordBatch};
233 use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw, quantity::QuantityRaw};
234 use rstest::rstest;
235
236 use super::*;
237 use crate::arrow::{get_raw_price, get_raw_quantity};
238
239 #[rstest]
240 fn test_get_schema() {
241 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
242 let metadata = Bar::get_metadata(&bar_type, 2, 0);
243 let schema = Bar::get_schema(Some(metadata.clone()));
244 let expected_fields = vec![
245 Field::new("open", DataType::FixedSizeBinary(PRECISION_BYTES), false),
246 Field::new("high", DataType::FixedSizeBinary(PRECISION_BYTES), false),
247 Field::new("low", DataType::FixedSizeBinary(PRECISION_BYTES), false),
248 Field::new("close", DataType::FixedSizeBinary(PRECISION_BYTES), false),
249 Field::new("volume", DataType::FixedSizeBinary(PRECISION_BYTES), false),
250 Field::new("ts_event", DataType::UInt64, false),
251 Field::new("ts_init", DataType::UInt64, false),
252 ];
253 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
254 assert_eq!(schema, expected_schema);
255 }
256
257 #[rstest]
258 fn test_get_schema_map() {
259 let schema_map = Bar::get_schema_map();
260 let mut expected_map = HashMap::new();
261 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
262 expected_map.insert("open".to_string(), fixed_size_binary.clone());
263 expected_map.insert("high".to_string(), fixed_size_binary.clone());
264 expected_map.insert("low".to_string(), fixed_size_binary.clone());
265 expected_map.insert("close".to_string(), fixed_size_binary.clone());
266 expected_map.insert("volume".to_string(), fixed_size_binary);
267 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
268 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
269 assert_eq!(schema_map, expected_map);
270 }
271
272 #[rstest]
273 fn test_encode_batch() {
274 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
275 let metadata = Bar::get_metadata(&bar_type, 2, 0);
276
277 let bar1 = Bar::new(
278 bar_type,
279 Price::from("100.10"),
280 Price::from("102.00"),
281 Price::from("100.00"),
282 Price::from("101.00"),
283 Quantity::from(1100),
284 1.into(),
285 3.into(),
286 );
287 let bar2 = Bar::new(
288 bar_type,
289 Price::from("100.00"),
290 Price::from("100.10"),
291 Price::from("100.00"),
292 Price::from("100.10"),
293 Quantity::from(1110),
294 2.into(),
295 4.into(),
296 );
297
298 let data = vec![bar1, bar2];
299 let record_batch = Bar::encode_batch(&metadata, &data).unwrap();
300
301 let columns = record_batch.columns();
302 let open_values = columns[0]
303 .as_any()
304 .downcast_ref::<FixedSizeBinaryArray>()
305 .unwrap();
306 let high_values = columns[1]
307 .as_any()
308 .downcast_ref::<FixedSizeBinaryArray>()
309 .unwrap();
310 let low_values = columns[2]
311 .as_any()
312 .downcast_ref::<FixedSizeBinaryArray>()
313 .unwrap();
314 let close_values = columns[3]
315 .as_any()
316 .downcast_ref::<FixedSizeBinaryArray>()
317 .unwrap();
318 let volume_values = columns[4]
319 .as_any()
320 .downcast_ref::<FixedSizeBinaryArray>()
321 .unwrap();
322 let ts_event_values = columns[5].as_any().downcast_ref::<UInt64Array>().unwrap();
323 let ts_init_values = columns[6].as_any().downcast_ref::<UInt64Array>().unwrap();
324
325 assert_eq!(columns.len(), 7);
326 assert_eq!(open_values.len(), 2);
327 assert_eq!(
328 get_raw_price(open_values.value(0)),
329 (100.10 * FIXED_SCALAR) as PriceRaw
330 );
331 assert_eq!(
332 get_raw_price(open_values.value(1)),
333 (100.00 * FIXED_SCALAR) as PriceRaw
334 );
335 assert_eq!(high_values.len(), 2);
336 assert_eq!(
337 get_raw_price(high_values.value(0)),
338 (102.00 * FIXED_SCALAR) as PriceRaw
339 );
340 assert_eq!(
341 get_raw_price(high_values.value(1)),
342 (100.10 * FIXED_SCALAR) as PriceRaw
343 );
344 assert_eq!(low_values.len(), 2);
345 assert_eq!(
346 get_raw_price(low_values.value(0)),
347 (100.00 * FIXED_SCALAR) as PriceRaw
348 );
349 assert_eq!(
350 get_raw_price(low_values.value(1)),
351 (100.00 * FIXED_SCALAR) as PriceRaw
352 );
353 assert_eq!(close_values.len(), 2);
354 assert_eq!(
355 get_raw_price(close_values.value(0)),
356 (101.00 * FIXED_SCALAR) as PriceRaw
357 );
358 assert_eq!(
359 get_raw_price(close_values.value(1)),
360 (100.10 * FIXED_SCALAR) as PriceRaw
361 );
362 assert_eq!(volume_values.len(), 2);
363 assert_eq!(
364 get_raw_quantity(volume_values.value(0)),
365 (1100.0 * FIXED_SCALAR) as QuantityRaw
366 );
367 assert_eq!(
368 get_raw_quantity(volume_values.value(1)),
369 (1110.0 * FIXED_SCALAR) as QuantityRaw
370 );
371 assert_eq!(ts_event_values.len(), 2);
372 assert_eq!(ts_event_values.value(0), 1);
373 assert_eq!(ts_event_values.value(1), 2);
374 assert_eq!(ts_init_values.len(), 2);
375 assert_eq!(ts_init_values.value(0), 3);
376 assert_eq!(ts_init_values.value(1), 4);
377 }
378
379 #[rstest]
380 fn test_decode_batch() {
381 let bar_type = BarType::from_str("AAPL.XNAS-1-MINUTE-LAST-INTERNAL").unwrap();
382 let metadata = Bar::get_metadata(&bar_type, 2, 0);
383
384 let open = FixedSizeBinaryArray::from(vec![
385 &((100.10 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
386 &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
387 ]);
388 let high = FixedSizeBinaryArray::from(vec![
389 &((102.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
390 &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
391 ]);
392 let low = FixedSizeBinaryArray::from(vec![
393 &((100.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
394 &((10.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
395 ]);
396 let close = FixedSizeBinaryArray::from(vec![
397 &((101.00 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
398 &((10.01 * FIXED_SCALAR) as PriceRaw).to_le_bytes(),
399 ]);
400 let volume = FixedSizeBinaryArray::from(vec![
401 &((11.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
402 &((10.0 * FIXED_SCALAR) as QuantityRaw).to_le_bytes(),
403 ]);
404 let ts_event = UInt64Array::from(vec![1, 2]);
405 let ts_init = UInt64Array::from(vec![3, 4]);
406
407 let record_batch = RecordBatch::try_new(
408 Bar::get_schema(Some(metadata.clone())).into(),
409 vec![
410 Arc::new(open),
411 Arc::new(high),
412 Arc::new(low),
413 Arc::new(close),
414 Arc::new(volume),
415 Arc::new(ts_event),
416 Arc::new(ts_init),
417 ],
418 )
419 .unwrap();
420
421 let decoded_data = Bar::decode_batch(&metadata, record_batch).unwrap();
422 assert_eq!(decoded_data.len(), 2);
423 }
424}