1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt8Array, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::close::InstrumentClose,
26 enums::{FromU8, InstrumentCloseType},
27 identifiers::InstrumentId,
28 types::{Price, fixed::PRECISION_BYTES},
29};
30
31use super::{
32 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
33 extract_column, get_corrected_raw_price,
34};
35use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
36
37impl ArrowSchemaProvider for InstrumentClose {
38 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
39 let fields = vec![
40 Field::new(
41 "close_price",
42 DataType::FixedSizeBinary(PRECISION_BYTES),
43 false,
44 ),
45 Field::new("close_type", DataType::UInt8, false),
46 Field::new("ts_event", DataType::UInt64, false),
47 Field::new("ts_init", DataType::UInt64, false),
48 ];
49
50 match metadata {
51 Some(metadata) => Schema::new_with_metadata(fields, metadata),
52 None => Schema::new(fields),
53 }
54 }
55}
56
57fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
58 let instrument_id_str = metadata
59 .get(KEY_INSTRUMENT_ID)
60 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
61 let instrument_id = InstrumentId::from_str(instrument_id_str)
62 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
63
64 let price_precision = metadata
65 .get(KEY_PRICE_PRECISION)
66 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
67 .parse::<u8>()
68 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
69
70 Ok((instrument_id, price_precision))
71}
72
73impl EncodeToRecordBatch for InstrumentClose {
74 fn encode_batch(
75 metadata: &HashMap<String, String>,
76 data: &[Self],
77 ) -> Result<RecordBatch, ArrowError> {
78 let mut close_price_builder =
79 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
80 let mut close_type_builder = UInt8Array::builder(data.len());
81 let mut ts_event_builder = UInt64Array::builder(data.len());
82 let mut ts_init_builder = UInt64Array::builder(data.len());
83
84 for item in data {
85 close_price_builder
86 .append_value(item.close_price.raw.to_le_bytes())
87 .unwrap();
88 close_type_builder.append_value(item.close_type as u8);
89 ts_event_builder.append_value(item.ts_event.as_u64());
90 ts_init_builder.append_value(item.ts_init.as_u64());
91 }
92
93 RecordBatch::try_new(
94 Self::get_schema(Some(metadata.clone())).into(),
95 vec![
96 Arc::new(close_price_builder.finish()),
97 Arc::new(close_type_builder.finish()),
98 Arc::new(ts_event_builder.finish()),
99 Arc::new(ts_init_builder.finish()),
100 ],
101 )
102 }
103
104 fn metadata(&self) -> HashMap<String, String> {
105 Self::get_metadata(&self.instrument_id, self.close_price.precision)
106 }
107}
108
109impl DecodeFromRecordBatch for InstrumentClose {
110 fn decode_batch(
111 metadata: &HashMap<String, String>,
112 record_batch: RecordBatch,
113 ) -> Result<Vec<Self>, EncodingError> {
114 let (instrument_id, price_precision) = parse_metadata(metadata)?;
115 let cols = record_batch.columns();
116
117 let close_price_values = extract_column::<FixedSizeBinaryArray>(
118 cols,
119 "close_price",
120 0,
121 DataType::FixedSizeBinary(PRECISION_BYTES),
122 )?;
123 let close_type_values =
124 extract_column::<UInt8Array>(cols, "close_type", 1, DataType::UInt8)?;
125 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 2, DataType::UInt64)?;
126 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 3, DataType::UInt64)?;
127
128 if close_price_values.value_length() != PRECISION_BYTES {
130 return Err(EncodingError::ParseError(
131 "close_price",
132 format!(
133 "Invalid value length: expected {PRECISION_BYTES}, found {}",
134 close_price_values.value_length()
135 ),
136 ));
137 }
138
139 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
140 .map(|row| {
141 let close_type_value = close_type_values.value(row);
142 let close_type =
143 InstrumentCloseType::from_u8(close_type_value).ok_or_else(|| {
144 EncodingError::ParseError(
145 stringify!(InstrumentCloseType),
146 format!("Invalid enum value, was {close_type_value}"),
147 )
148 })?;
149
150 Ok(Self {
152 instrument_id,
153 close_price: Price::from_raw(
154 get_corrected_raw_price(close_price_values.value(row), price_precision),
155 price_precision,
156 ),
157 close_type,
158 ts_event: ts_event_values.value(row).into(),
159 ts_init: ts_init_values.value(row).into(),
160 })
161 })
162 .collect();
163
164 result
165 }
166}
167
168impl DecodeDataFromRecordBatch for InstrumentClose {
169 fn decode_data_batch(
170 metadata: &HashMap<String, String>,
171 record_batch: RecordBatch,
172 ) -> Result<Vec<Data>, EncodingError> {
173 let items: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
174 Ok(items.into_iter().map(Data::from).collect())
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use std::sync::Arc;
181
182 use arrow::{array::Array, record_batch::RecordBatch};
183 use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw};
184 use rstest::rstest;
185
186 use super::*;
187 use crate::arrow::get_raw_price;
188
189 #[rstest]
190 fn test_get_schema() {
191 let instrument_id = InstrumentId::from("AAPL.XNAS");
192 let metadata = HashMap::from([
193 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
194 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
195 ]);
196 let schema = InstrumentClose::get_schema(Some(metadata.clone()));
197
198 let expected_fields = vec![
199 Field::new(
200 "close_price",
201 DataType::FixedSizeBinary(PRECISION_BYTES),
202 false,
203 ),
204 Field::new("close_type", DataType::UInt8, false),
205 Field::new("ts_event", DataType::UInt64, false),
206 Field::new("ts_init", DataType::UInt64, false),
207 ];
208
209 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
210 assert_eq!(schema, expected_schema);
211 }
212
213 #[rstest]
214 fn test_get_schema_map() {
215 let schema_map = InstrumentClose::get_schema_map();
216 let mut expected_map = HashMap::new();
217
218 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
219 expected_map.insert("close_price".to_string(), fixed_size_binary);
220 expected_map.insert("close_type".to_string(), "UInt8".to_string());
221 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
222 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
223 assert_eq!(schema_map, expected_map);
224 }
225
226 #[rstest]
227 fn test_encode_batch() {
228 let instrument_id = InstrumentId::from("AAPL.XNAS");
229 let metadata = HashMap::from([
230 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
231 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
232 ]);
233
234 let close1 = InstrumentClose {
235 instrument_id,
236 close_price: Price::from("150.50"),
237 close_type: InstrumentCloseType::EndOfSession,
238 ts_event: 1.into(),
239 ts_init: 3.into(),
240 };
241
242 let close2 = InstrumentClose {
243 instrument_id,
244 close_price: Price::from("151.25"),
245 close_type: InstrumentCloseType::ContractExpired,
246 ts_event: 2.into(),
247 ts_init: 4.into(),
248 };
249
250 let data = vec![close1, close2];
251 let record_batch = InstrumentClose::encode_batch(&metadata, &data).unwrap();
252
253 let columns = record_batch.columns();
254 let close_price_values = columns[0]
255 .as_any()
256 .downcast_ref::<FixedSizeBinaryArray>()
257 .unwrap();
258 let close_type_values = columns[1].as_any().downcast_ref::<UInt8Array>().unwrap();
259 let ts_event_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
260 let ts_init_values = columns[3].as_any().downcast_ref::<UInt64Array>().unwrap();
261
262 assert_eq!(columns.len(), 4);
263 assert_eq!(close_price_values.len(), 2);
264 assert_eq!(
265 get_raw_price(close_price_values.value(0)),
266 (150.50 * FIXED_SCALAR) as PriceRaw
267 );
268 assert_eq!(
269 get_raw_price(close_price_values.value(1)),
270 (151.25 * FIXED_SCALAR) as PriceRaw
271 );
272 assert_eq!(close_type_values.len(), 2);
273 assert_eq!(
274 close_type_values.value(0),
275 InstrumentCloseType::EndOfSession as u8
276 );
277 assert_eq!(
278 close_type_values.value(1),
279 InstrumentCloseType::ContractExpired as u8
280 );
281 assert_eq!(ts_event_values.len(), 2);
282 assert_eq!(ts_event_values.value(0), 1);
283 assert_eq!(ts_event_values.value(1), 2);
284 assert_eq!(ts_init_values.len(), 2);
285 assert_eq!(ts_init_values.value(0), 3);
286 assert_eq!(ts_init_values.value(1), 4);
287 }
288
289 #[rstest]
290 fn test_decode_batch() {
291 let instrument_id = InstrumentId::from("AAPL.XNAS");
292 let metadata = HashMap::from([
293 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
294 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
295 ]);
296
297 let raw_price1 = (150.50 * FIXED_SCALAR) as PriceRaw;
298 let raw_price2 = (151.25 * FIXED_SCALAR) as PriceRaw;
299 let close_price =
300 FixedSizeBinaryArray::from(vec![&raw_price1.to_le_bytes(), &raw_price2.to_le_bytes()]);
301 let close_type = UInt8Array::from(vec![
302 InstrumentCloseType::EndOfSession as u8,
303 InstrumentCloseType::ContractExpired as u8,
304 ]);
305 let ts_event = UInt64Array::from(vec![1, 2]);
306 let ts_init = UInt64Array::from(vec![3, 4]);
307
308 let record_batch = RecordBatch::try_new(
309 InstrumentClose::get_schema(Some(metadata.clone())).into(),
310 vec![
311 Arc::new(close_price),
312 Arc::new(close_type),
313 Arc::new(ts_event),
314 Arc::new(ts_init),
315 ],
316 )
317 .unwrap();
318
319 let decoded_data = InstrumentClose::decode_batch(&metadata, record_batch).unwrap();
320
321 assert_eq!(decoded_data.len(), 2);
322 assert_eq!(decoded_data[0].instrument_id, instrument_id);
323 assert_eq!(decoded_data[0].close_price, Price::from_raw(raw_price1, 2));
324 assert_eq!(
325 decoded_data[0].close_type,
326 InstrumentCloseType::EndOfSession
327 );
328 assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
329 assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
330
331 assert_eq!(decoded_data[1].instrument_id, instrument_id);
332 assert_eq!(decoded_data[1].close_price, Price::from_raw(raw_price2, 2));
333 assert_eq!(
334 decoded_data[1].close_type,
335 InstrumentCloseType::ContractExpired
336 );
337 assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
338 assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
339 }
340}