1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt8Array, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::close::InstrumentClose,
26 enums::{FromU8, InstrumentCloseType},
27 identifiers::InstrumentId,
28 types::fixed::PRECISION_BYTES,
29};
30
31use super::{
32 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION, decode_price,
33 extract_column, validate_precision_bytes,
34};
35use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
36
37impl ArrowSchemaProvider for InstrumentClose {
38 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
39 let fields = vec![
40 Field::new(
41 "close_price",
42 DataType::FixedSizeBinary(PRECISION_BYTES),
43 false,
44 ),
45 Field::new("close_type", DataType::UInt8, false),
46 Field::new("ts_event", DataType::UInt64, false),
47 Field::new("ts_init", DataType::UInt64, false),
48 ];
49
50 match metadata {
51 Some(metadata) => Schema::new_with_metadata(fields, metadata),
52 None => Schema::new(fields),
53 }
54 }
55}
56
57fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
58 let instrument_id_str = metadata
59 .get(KEY_INSTRUMENT_ID)
60 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
61 let instrument_id = InstrumentId::from_str(instrument_id_str)
62 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
63
64 let price_precision = metadata
65 .get(KEY_PRICE_PRECISION)
66 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
67 .parse::<u8>()
68 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
69
70 Ok((instrument_id, price_precision))
71}
72
73impl EncodeToRecordBatch for InstrumentClose {
74 fn encode_batch(
75 metadata: &HashMap<String, String>,
76 data: &[Self],
77 ) -> Result<RecordBatch, ArrowError> {
78 let mut close_price_builder =
79 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
80 let mut close_type_builder = UInt8Array::builder(data.len());
81 let mut ts_event_builder = UInt64Array::builder(data.len());
82 let mut ts_init_builder = UInt64Array::builder(data.len());
83
84 for item in data {
85 close_price_builder
86 .append_value(item.close_price.raw.to_le_bytes())
87 .unwrap();
88 close_type_builder.append_value(item.close_type as u8);
89 ts_event_builder.append_value(item.ts_event.as_u64());
90 ts_init_builder.append_value(item.ts_init.as_u64());
91 }
92
93 RecordBatch::try_new(
94 Self::get_schema(Some(metadata.clone())).into(),
95 vec![
96 Arc::new(close_price_builder.finish()),
97 Arc::new(close_type_builder.finish()),
98 Arc::new(ts_event_builder.finish()),
99 Arc::new(ts_init_builder.finish()),
100 ],
101 )
102 }
103
104 fn metadata(&self) -> HashMap<String, String> {
105 Self::get_metadata(&self.instrument_id, self.close_price.precision)
106 }
107}
108
109impl DecodeFromRecordBatch for InstrumentClose {
110 fn decode_batch(
111 metadata: &HashMap<String, String>,
112 record_batch: RecordBatch,
113 ) -> Result<Vec<Self>, EncodingError> {
114 let (instrument_id, price_precision) = parse_metadata(metadata)?;
115 let cols = record_batch.columns();
116
117 let close_price_values = extract_column::<FixedSizeBinaryArray>(
118 cols,
119 "close_price",
120 0,
121 DataType::FixedSizeBinary(PRECISION_BYTES),
122 )?;
123 let close_type_values =
124 extract_column::<UInt8Array>(cols, "close_type", 1, DataType::UInt8)?;
125 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 2, DataType::UInt64)?;
126 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 3, DataType::UInt64)?;
127
128 validate_precision_bytes(close_price_values, "close_price")?;
129
130 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
131 .map(|row| {
132 let close_price = decode_price(
133 close_price_values.value(row),
134 price_precision,
135 "close_price",
136 row,
137 )?;
138 let close_type_value = close_type_values.value(row);
139 let close_type =
140 InstrumentCloseType::from_u8(close_type_value).ok_or_else(|| {
141 EncodingError::ParseError(
142 stringify!(InstrumentCloseType),
143 format!("Invalid enum value, was {close_type_value}"),
144 )
145 })?;
146 Ok(Self {
147 instrument_id,
148 close_price,
149 close_type,
150 ts_event: ts_event_values.value(row).into(),
151 ts_init: ts_init_values.value(row).into(),
152 })
153 })
154 .collect();
155
156 result
157 }
158}
159
160impl DecodeDataFromRecordBatch for InstrumentClose {
161 fn decode_data_batch(
162 metadata: &HashMap<String, String>,
163 record_batch: RecordBatch,
164 ) -> Result<Vec<Data>, EncodingError> {
165 let items: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
166 Ok(items.into_iter().map(Data::from).collect())
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use std::sync::Arc;
173
174 use arrow::{array::Array, record_batch::RecordBatch};
175 use nautilus_model::types::{Price, fixed::FIXED_SCALAR, price::PriceRaw};
176 use rstest::rstest;
177
178 use super::*;
179 use crate::arrow::get_raw_price;
180
181 #[rstest]
182 fn test_get_schema() {
183 let instrument_id = InstrumentId::from("AAPL.XNAS");
184 let metadata = HashMap::from([
185 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
186 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
187 ]);
188 let schema = InstrumentClose::get_schema(Some(metadata.clone()));
189
190 let expected_fields = vec![
191 Field::new(
192 "close_price",
193 DataType::FixedSizeBinary(PRECISION_BYTES),
194 false,
195 ),
196 Field::new("close_type", DataType::UInt8, false),
197 Field::new("ts_event", DataType::UInt64, false),
198 Field::new("ts_init", DataType::UInt64, false),
199 ];
200
201 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
202 assert_eq!(schema, expected_schema);
203 }
204
205 #[rstest]
206 fn test_get_schema_map() {
207 let schema_map = InstrumentClose::get_schema_map();
208 let mut expected_map = HashMap::new();
209
210 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
211 expected_map.insert("close_price".to_string(), fixed_size_binary);
212 expected_map.insert("close_type".to_string(), "UInt8".to_string());
213 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
214 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
215 assert_eq!(schema_map, expected_map);
216 }
217
218 #[rstest]
219 fn test_encode_batch() {
220 let instrument_id = InstrumentId::from("AAPL.XNAS");
221 let metadata = HashMap::from([
222 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
223 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
224 ]);
225
226 let close1 = InstrumentClose {
227 instrument_id,
228 close_price: Price::from("150.50"),
229 close_type: InstrumentCloseType::EndOfSession,
230 ts_event: 1.into(),
231 ts_init: 3.into(),
232 };
233
234 let close2 = InstrumentClose {
235 instrument_id,
236 close_price: Price::from("151.25"),
237 close_type: InstrumentCloseType::ContractExpired,
238 ts_event: 2.into(),
239 ts_init: 4.into(),
240 };
241
242 let data = vec![close1, close2];
243 let record_batch = InstrumentClose::encode_batch(&metadata, &data).unwrap();
244
245 let columns = record_batch.columns();
246 let close_price_values = columns[0]
247 .as_any()
248 .downcast_ref::<FixedSizeBinaryArray>()
249 .unwrap();
250 let close_type_values = columns[1].as_any().downcast_ref::<UInt8Array>().unwrap();
251 let ts_event_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
252 let ts_init_values = columns[3].as_any().downcast_ref::<UInt64Array>().unwrap();
253
254 assert_eq!(columns.len(), 4);
255 assert_eq!(close_price_values.len(), 2);
256 assert_eq!(
257 get_raw_price(close_price_values.value(0)),
258 (150.50 * FIXED_SCALAR) as PriceRaw
259 );
260 assert_eq!(
261 get_raw_price(close_price_values.value(1)),
262 (151.25 * FIXED_SCALAR) as PriceRaw
263 );
264 assert_eq!(close_type_values.len(), 2);
265 assert_eq!(
266 close_type_values.value(0),
267 InstrumentCloseType::EndOfSession as u8
268 );
269 assert_eq!(
270 close_type_values.value(1),
271 InstrumentCloseType::ContractExpired as u8
272 );
273 assert_eq!(ts_event_values.len(), 2);
274 assert_eq!(ts_event_values.value(0), 1);
275 assert_eq!(ts_event_values.value(1), 2);
276 assert_eq!(ts_init_values.len(), 2);
277 assert_eq!(ts_init_values.value(0), 3);
278 assert_eq!(ts_init_values.value(1), 4);
279 }
280
281 #[rstest]
282 fn test_decode_batch() {
283 let instrument_id = InstrumentId::from("AAPL.XNAS");
284 let metadata = HashMap::from([
285 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
286 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
287 ]);
288
289 let raw_price1 = (150.50 * FIXED_SCALAR) as PriceRaw;
290 let raw_price2 = (151.25 * FIXED_SCALAR) as PriceRaw;
291 let close_price =
292 FixedSizeBinaryArray::from(vec![&raw_price1.to_le_bytes(), &raw_price2.to_le_bytes()]);
293 let close_type = UInt8Array::from(vec![
294 InstrumentCloseType::EndOfSession as u8,
295 InstrumentCloseType::ContractExpired as u8,
296 ]);
297 let ts_event = UInt64Array::from(vec![1, 2]);
298 let ts_init = UInt64Array::from(vec![3, 4]);
299
300 let record_batch = RecordBatch::try_new(
301 InstrumentClose::get_schema(Some(metadata.clone())).into(),
302 vec![
303 Arc::new(close_price),
304 Arc::new(close_type),
305 Arc::new(ts_event),
306 Arc::new(ts_init),
307 ],
308 )
309 .unwrap();
310
311 let decoded_data = InstrumentClose::decode_batch(&metadata, record_batch).unwrap();
312
313 assert_eq!(decoded_data.len(), 2);
314 assert_eq!(decoded_data[0].instrument_id, instrument_id);
315 assert_eq!(decoded_data[0].close_price, Price::from_raw(raw_price1, 2));
316 assert_eq!(
317 decoded_data[0].close_type,
318 InstrumentCloseType::EndOfSession
319 );
320 assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
321 assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
322
323 assert_eq!(decoded_data[1].instrument_id, instrument_id);
324 assert_eq!(decoded_data[1].close_price, Price::from_raw(raw_price2, 2));
325 assert_eq!(
326 decoded_data[1].close_type,
327 InstrumentCloseType::ContractExpired
328 );
329 assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
330 assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
331 }
332
333 #[rstest]
334 fn test_decode_batch_invalid_close_price_returns_error() {
335 let instrument_id = InstrumentId::from("AAPL.XNAS");
336 let metadata = HashMap::from([
337 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
338 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
339 ]);
340
341 let invalid_price: PriceRaw = PriceRaw::MAX - 1000;
342 let close_price = FixedSizeBinaryArray::from(vec![&invalid_price.to_le_bytes()]);
343 let close_type = UInt8Array::from(vec![InstrumentCloseType::EndOfSession as u8]);
344 let ts_event = UInt64Array::from(vec![1]);
345 let ts_init = UInt64Array::from(vec![2]);
346
347 let record_batch = RecordBatch::try_new(
348 InstrumentClose::get_schema(Some(metadata.clone())).into(),
349 vec![
350 Arc::new(close_price),
351 Arc::new(close_type),
352 Arc::new(ts_event),
353 Arc::new(ts_init),
354 ],
355 )
356 .unwrap();
357
358 let result = InstrumentClose::decode_batch(&metadata, record_batch);
359 assert!(result.is_err());
360 let err = result.unwrap_err();
361 assert!(
362 err.to_string().contains("close_price") && err.to_string().contains("row 0"),
363 "Expected close_price error at row 0, was: {err}"
364 );
365 }
366
367 #[rstest]
368 fn test_decode_batch_invalid_close_type_returns_error() {
369 let instrument_id = InstrumentId::from("AAPL.XNAS");
370 let metadata = HashMap::from([
371 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
372 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
373 ]);
374
375 let raw_price = (150.50 * FIXED_SCALAR) as PriceRaw;
376 let close_price = FixedSizeBinaryArray::from(vec![&raw_price.to_le_bytes()]);
377 let close_type = UInt8Array::from(vec![99]);
378 let ts_event = UInt64Array::from(vec![1]);
379 let ts_init = UInt64Array::from(vec![2]);
380
381 let record_batch = RecordBatch::try_new(
382 InstrumentClose::get_schema(Some(metadata.clone())).into(),
383 vec![
384 Arc::new(close_price),
385 Arc::new(close_type),
386 Arc::new(ts_event),
387 Arc::new(ts_init),
388 ],
389 )
390 .unwrap();
391
392 let result = InstrumentClose::decode_batch(&metadata, record_batch);
393 assert!(result.is_err());
394 let err = result.unwrap_err();
395 assert!(
396 err.to_string().contains("InstrumentCloseType"),
397 "Expected InstrumentCloseType error, was: {err}"
398 );
399 }
400
401 #[rstest]
402 fn test_decode_batch_missing_instrument_id_returns_error() {
403 let instrument_id = InstrumentId::from("AAPL.XNAS");
404 let mut metadata = HashMap::from([
405 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
406 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
407 ]);
408
409 let raw_price = (150.50 * FIXED_SCALAR) as PriceRaw;
410 let close_price = FixedSizeBinaryArray::from(vec![&raw_price.to_le_bytes()]);
411 let close_type = UInt8Array::from(vec![InstrumentCloseType::EndOfSession as u8]);
412 let ts_event = UInt64Array::from(vec![1]);
413 let ts_init = UInt64Array::from(vec![2]);
414
415 let record_batch = RecordBatch::try_new(
416 InstrumentClose::get_schema(Some(metadata.clone())).into(),
417 vec![
418 Arc::new(close_price),
419 Arc::new(close_type),
420 Arc::new(ts_event),
421 Arc::new(ts_init),
422 ],
423 )
424 .unwrap();
425
426 metadata.remove(KEY_INSTRUMENT_ID);
427
428 let result = InstrumentClose::decode_batch(&metadata, record_batch);
429 assert!(result.is_err());
430 let err = result.unwrap_err();
431 assert!(
432 err.to_string().contains("instrument_id"),
433 "Expected missing instrument_id error, was: {err}"
434 );
435 }
436
437 #[rstest]
438 fn test_encode_decode_round_trip() {
439 let instrument_id = InstrumentId::from("AAPL.XNAS");
440 let metadata = HashMap::from([
441 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
442 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
443 ]);
444
445 let close1 = InstrumentClose {
446 instrument_id,
447 close_price: Price::from("150.50"),
448 close_type: InstrumentCloseType::EndOfSession,
449 ts_event: 1_000_000_000.into(),
450 ts_init: 1_000_000_001.into(),
451 };
452
453 let close2 = InstrumentClose {
454 instrument_id,
455 close_price: Price::from("151.25"),
456 close_type: InstrumentCloseType::ContractExpired,
457 ts_event: 2_000_000_000.into(),
458 ts_init: 2_000_000_001.into(),
459 };
460
461 let original = vec![close1, close2];
462 let record_batch = InstrumentClose::encode_batch(&metadata, &original).unwrap();
463 let decoded = InstrumentClose::decode_batch(&metadata, record_batch).unwrap();
464
465 assert_eq!(decoded.len(), original.len());
466 for (orig, dec) in original.iter().zip(decoded.iter()) {
467 assert_eq!(dec.instrument_id, orig.instrument_id);
468 assert_eq!(dec.close_price, orig.close_price);
469 assert_eq!(dec.close_type, orig.close_type);
470 assert_eq!(dec.ts_event, orig.ts_event);
471 assert_eq!(dec.ts_init, orig.ts_init);
472 }
473 }
474}