1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt8Array, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::close::InstrumentClose,
26 enums::{FromU8, InstrumentCloseType},
27 identifiers::InstrumentId,
28 types::fixed::PRECISION_BYTES,
29};
30
31use super::{
32 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION, decode_price,
33 extract_column,
34};
35use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
36
37impl ArrowSchemaProvider for InstrumentClose {
38 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
39 let fields = vec![
40 Field::new(
41 "close_price",
42 DataType::FixedSizeBinary(PRECISION_BYTES),
43 false,
44 ),
45 Field::new("close_type", DataType::UInt8, false),
46 Field::new("ts_event", DataType::UInt64, false),
47 Field::new("ts_init", DataType::UInt64, false),
48 ];
49
50 match metadata {
51 Some(metadata) => Schema::new_with_metadata(fields, metadata),
52 None => Schema::new(fields),
53 }
54 }
55}
56
57fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
58 let instrument_id_str = metadata
59 .get(KEY_INSTRUMENT_ID)
60 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
61 let instrument_id = InstrumentId::from_str(instrument_id_str)
62 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
63
64 let price_precision = metadata
65 .get(KEY_PRICE_PRECISION)
66 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
67 .parse::<u8>()
68 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
69
70 Ok((instrument_id, price_precision))
71}
72
73impl EncodeToRecordBatch for InstrumentClose {
74 fn encode_batch(
75 metadata: &HashMap<String, String>,
76 data: &[Self],
77 ) -> Result<RecordBatch, ArrowError> {
78 let mut close_price_builder =
79 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
80 let mut close_type_builder = UInt8Array::builder(data.len());
81 let mut ts_event_builder = UInt64Array::builder(data.len());
82 let mut ts_init_builder = UInt64Array::builder(data.len());
83
84 for item in data {
85 close_price_builder
86 .append_value(item.close_price.raw.to_le_bytes())
87 .unwrap();
88 close_type_builder.append_value(item.close_type as u8);
89 ts_event_builder.append_value(item.ts_event.as_u64());
90 ts_init_builder.append_value(item.ts_init.as_u64());
91 }
92
93 RecordBatch::try_new(
94 Self::get_schema(Some(metadata.clone())).into(),
95 vec![
96 Arc::new(close_price_builder.finish()),
97 Arc::new(close_type_builder.finish()),
98 Arc::new(ts_event_builder.finish()),
99 Arc::new(ts_init_builder.finish()),
100 ],
101 )
102 }
103
104 fn metadata(&self) -> HashMap<String, String> {
105 Self::get_metadata(&self.instrument_id, self.close_price.precision)
106 }
107}
108
109impl DecodeFromRecordBatch for InstrumentClose {
110 fn decode_batch(
111 metadata: &HashMap<String, String>,
112 record_batch: RecordBatch,
113 ) -> Result<Vec<Self>, EncodingError> {
114 let (instrument_id, price_precision) = parse_metadata(metadata)?;
115 let cols = record_batch.columns();
116
117 let close_price_values = extract_column::<FixedSizeBinaryArray>(
118 cols,
119 "close_price",
120 0,
121 DataType::FixedSizeBinary(PRECISION_BYTES),
122 )?;
123 let close_type_values =
124 extract_column::<UInt8Array>(cols, "close_type", 1, DataType::UInt8)?;
125 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 2, DataType::UInt64)?;
126 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 3, DataType::UInt64)?;
127
128 if close_price_values.value_length() != PRECISION_BYTES {
130 return Err(EncodingError::ParseError(
131 "close_price",
132 format!(
133 "Invalid value length: expected {PRECISION_BYTES}, found {}",
134 close_price_values.value_length()
135 ),
136 ));
137 }
138
139 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
140 .map(|row| {
141 let close_price = decode_price(
142 close_price_values.value(row),
143 price_precision,
144 "close_price",
145 row,
146 )?;
147 let close_type_value = close_type_values.value(row);
148 let close_type =
149 InstrumentCloseType::from_u8(close_type_value).ok_or_else(|| {
150 EncodingError::ParseError(
151 stringify!(InstrumentCloseType),
152 format!("Invalid enum value, was {close_type_value}"),
153 )
154 })?;
155 Ok(Self {
156 instrument_id,
157 close_price,
158 close_type,
159 ts_event: ts_event_values.value(row).into(),
160 ts_init: ts_init_values.value(row).into(),
161 })
162 })
163 .collect();
164
165 result
166 }
167}
168
169impl DecodeDataFromRecordBatch for InstrumentClose {
170 fn decode_data_batch(
171 metadata: &HashMap<String, String>,
172 record_batch: RecordBatch,
173 ) -> Result<Vec<Data>, EncodingError> {
174 let items: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
175 Ok(items.into_iter().map(Data::from).collect())
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use std::sync::Arc;
182
183 use arrow::{array::Array, record_batch::RecordBatch};
184 use nautilus_model::types::{Price, fixed::FIXED_SCALAR, price::PriceRaw};
185 use rstest::rstest;
186
187 use super::*;
188 use crate::arrow::get_raw_price;
189
190 #[rstest]
191 fn test_get_schema() {
192 let instrument_id = InstrumentId::from("AAPL.XNAS");
193 let metadata = HashMap::from([
194 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
195 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
196 ]);
197 let schema = InstrumentClose::get_schema(Some(metadata.clone()));
198
199 let expected_fields = vec![
200 Field::new(
201 "close_price",
202 DataType::FixedSizeBinary(PRECISION_BYTES),
203 false,
204 ),
205 Field::new("close_type", DataType::UInt8, false),
206 Field::new("ts_event", DataType::UInt64, false),
207 Field::new("ts_init", DataType::UInt64, false),
208 ];
209
210 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
211 assert_eq!(schema, expected_schema);
212 }
213
214 #[rstest]
215 fn test_get_schema_map() {
216 let schema_map = InstrumentClose::get_schema_map();
217 let mut expected_map = HashMap::new();
218
219 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
220 expected_map.insert("close_price".to_string(), fixed_size_binary);
221 expected_map.insert("close_type".to_string(), "UInt8".to_string());
222 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
223 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
224 assert_eq!(schema_map, expected_map);
225 }
226
227 #[rstest]
228 fn test_encode_batch() {
229 let instrument_id = InstrumentId::from("AAPL.XNAS");
230 let metadata = HashMap::from([
231 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
232 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
233 ]);
234
235 let close1 = InstrumentClose {
236 instrument_id,
237 close_price: Price::from("150.50"),
238 close_type: InstrumentCloseType::EndOfSession,
239 ts_event: 1.into(),
240 ts_init: 3.into(),
241 };
242
243 let close2 = InstrumentClose {
244 instrument_id,
245 close_price: Price::from("151.25"),
246 close_type: InstrumentCloseType::ContractExpired,
247 ts_event: 2.into(),
248 ts_init: 4.into(),
249 };
250
251 let data = vec![close1, close2];
252 let record_batch = InstrumentClose::encode_batch(&metadata, &data).unwrap();
253
254 let columns = record_batch.columns();
255 let close_price_values = columns[0]
256 .as_any()
257 .downcast_ref::<FixedSizeBinaryArray>()
258 .unwrap();
259 let close_type_values = columns[1].as_any().downcast_ref::<UInt8Array>().unwrap();
260 let ts_event_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
261 let ts_init_values = columns[3].as_any().downcast_ref::<UInt64Array>().unwrap();
262
263 assert_eq!(columns.len(), 4);
264 assert_eq!(close_price_values.len(), 2);
265 assert_eq!(
266 get_raw_price(close_price_values.value(0)),
267 (150.50 * FIXED_SCALAR) as PriceRaw
268 );
269 assert_eq!(
270 get_raw_price(close_price_values.value(1)),
271 (151.25 * FIXED_SCALAR) as PriceRaw
272 );
273 assert_eq!(close_type_values.len(), 2);
274 assert_eq!(
275 close_type_values.value(0),
276 InstrumentCloseType::EndOfSession as u8
277 );
278 assert_eq!(
279 close_type_values.value(1),
280 InstrumentCloseType::ContractExpired as u8
281 );
282 assert_eq!(ts_event_values.len(), 2);
283 assert_eq!(ts_event_values.value(0), 1);
284 assert_eq!(ts_event_values.value(1), 2);
285 assert_eq!(ts_init_values.len(), 2);
286 assert_eq!(ts_init_values.value(0), 3);
287 assert_eq!(ts_init_values.value(1), 4);
288 }
289
290 #[rstest]
291 fn test_decode_batch() {
292 let instrument_id = InstrumentId::from("AAPL.XNAS");
293 let metadata = HashMap::from([
294 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
295 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
296 ]);
297
298 let raw_price1 = (150.50 * FIXED_SCALAR) as PriceRaw;
299 let raw_price2 = (151.25 * FIXED_SCALAR) as PriceRaw;
300 let close_price =
301 FixedSizeBinaryArray::from(vec![&raw_price1.to_le_bytes(), &raw_price2.to_le_bytes()]);
302 let close_type = UInt8Array::from(vec![
303 InstrumentCloseType::EndOfSession as u8,
304 InstrumentCloseType::ContractExpired as u8,
305 ]);
306 let ts_event = UInt64Array::from(vec![1, 2]);
307 let ts_init = UInt64Array::from(vec![3, 4]);
308
309 let record_batch = RecordBatch::try_new(
310 InstrumentClose::get_schema(Some(metadata.clone())).into(),
311 vec![
312 Arc::new(close_price),
313 Arc::new(close_type),
314 Arc::new(ts_event),
315 Arc::new(ts_init),
316 ],
317 )
318 .unwrap();
319
320 let decoded_data = InstrumentClose::decode_batch(&metadata, record_batch).unwrap();
321
322 assert_eq!(decoded_data.len(), 2);
323 assert_eq!(decoded_data[0].instrument_id, instrument_id);
324 assert_eq!(decoded_data[0].close_price, Price::from_raw(raw_price1, 2));
325 assert_eq!(
326 decoded_data[0].close_type,
327 InstrumentCloseType::EndOfSession
328 );
329 assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
330 assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
331
332 assert_eq!(decoded_data[1].instrument_id, instrument_id);
333 assert_eq!(decoded_data[1].close_price, Price::from_raw(raw_price2, 2));
334 assert_eq!(
335 decoded_data[1].close_type,
336 InstrumentCloseType::ContractExpired
337 );
338 assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
339 assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
340 }
341
342 #[rstest]
343 fn test_decode_batch_invalid_close_price_returns_error() {
344 let instrument_id = InstrumentId::from("AAPL.XNAS");
345 let metadata = HashMap::from([
346 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
347 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
348 ]);
349
350 let invalid_price: PriceRaw = PriceRaw::MAX - 1000;
351 let close_price = FixedSizeBinaryArray::from(vec![&invalid_price.to_le_bytes()]);
352 let close_type = UInt8Array::from(vec![InstrumentCloseType::EndOfSession as u8]);
353 let ts_event = UInt64Array::from(vec![1]);
354 let ts_init = UInt64Array::from(vec![2]);
355
356 let record_batch = RecordBatch::try_new(
357 InstrumentClose::get_schema(Some(metadata.clone())).into(),
358 vec![
359 Arc::new(close_price),
360 Arc::new(close_type),
361 Arc::new(ts_event),
362 Arc::new(ts_init),
363 ],
364 )
365 .unwrap();
366
367 let result = InstrumentClose::decode_batch(&metadata, record_batch);
368 assert!(result.is_err());
369 let err = result.unwrap_err();
370 assert!(
371 err.to_string().contains("close_price") && err.to_string().contains("row 0"),
372 "Expected close_price error at row 0, got: {err}"
373 );
374 }
375
376 #[rstest]
377 fn test_decode_batch_invalid_close_type_returns_error() {
378 let instrument_id = InstrumentId::from("AAPL.XNAS");
379 let metadata = HashMap::from([
380 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
381 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
382 ]);
383
384 let raw_price = (150.50 * FIXED_SCALAR) as PriceRaw;
385 let close_price = FixedSizeBinaryArray::from(vec![&raw_price.to_le_bytes()]);
386 let close_type = UInt8Array::from(vec![99]);
387 let ts_event = UInt64Array::from(vec![1]);
388 let ts_init = UInt64Array::from(vec![2]);
389
390 let record_batch = RecordBatch::try_new(
391 InstrumentClose::get_schema(Some(metadata.clone())).into(),
392 vec![
393 Arc::new(close_price),
394 Arc::new(close_type),
395 Arc::new(ts_event),
396 Arc::new(ts_init),
397 ],
398 )
399 .unwrap();
400
401 let result = InstrumentClose::decode_batch(&metadata, record_batch);
402 assert!(result.is_err());
403 let err = result.unwrap_err();
404 assert!(
405 err.to_string().contains("InstrumentCloseType"),
406 "Expected InstrumentCloseType error, got: {err}"
407 );
408 }
409
410 #[rstest]
411 fn test_decode_batch_missing_instrument_id_returns_error() {
412 let instrument_id = InstrumentId::from("AAPL.XNAS");
413 let mut metadata = HashMap::from([
414 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
415 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
416 ]);
417
418 let raw_price = (150.50 * FIXED_SCALAR) as PriceRaw;
419 let close_price = FixedSizeBinaryArray::from(vec![&raw_price.to_le_bytes()]);
420 let close_type = UInt8Array::from(vec![InstrumentCloseType::EndOfSession as u8]);
421 let ts_event = UInt64Array::from(vec![1]);
422 let ts_init = UInt64Array::from(vec![2]);
423
424 let record_batch = RecordBatch::try_new(
425 InstrumentClose::get_schema(Some(metadata.clone())).into(),
426 vec![
427 Arc::new(close_price),
428 Arc::new(close_type),
429 Arc::new(ts_event),
430 Arc::new(ts_init),
431 ],
432 )
433 .unwrap();
434
435 metadata.remove(KEY_INSTRUMENT_ID);
436
437 let result = InstrumentClose::decode_batch(&metadata, record_batch);
438 assert!(result.is_err());
439 let err = result.unwrap_err();
440 assert!(
441 err.to_string().contains("instrument_id"),
442 "Expected missing instrument_id error, got: {err}"
443 );
444 }
445
446 #[rstest]
447 fn test_encode_decode_round_trip() {
448 let instrument_id = InstrumentId::from("AAPL.XNAS");
449 let metadata = HashMap::from([
450 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
451 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
452 ]);
453
454 let close1 = InstrumentClose {
455 instrument_id,
456 close_price: Price::from("150.50"),
457 close_type: InstrumentCloseType::EndOfSession,
458 ts_event: 1_000_000_000.into(),
459 ts_init: 1_000_000_001.into(),
460 };
461
462 let close2 = InstrumentClose {
463 instrument_id,
464 close_price: Price::from("151.25"),
465 close_type: InstrumentCloseType::ContractExpired,
466 ts_event: 2_000_000_000.into(),
467 ts_init: 2_000_000_001.into(),
468 };
469
470 let original = vec![close1, close2];
471 let record_batch = InstrumentClose::encode_batch(&metadata, &original).unwrap();
472 let decoded = InstrumentClose::decode_batch(&metadata, record_batch).unwrap();
473
474 assert_eq!(decoded.len(), original.len());
475 for (orig, dec) in original.iter().zip(decoded.iter()) {
476 assert_eq!(dec.instrument_id, orig.instrument_id);
477 assert_eq!(dec.close_price, orig.close_price);
478 assert_eq!(dec.close_type, orig.close_type);
479 assert_eq!(dec.ts_event, orig.ts_event);
480 assert_eq!(dec.ts_init, orig.ts_init);
481 }
482 }
483}