1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt8Array, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::close::InstrumentClose,
26 enums::{FromU8, InstrumentCloseType},
27 identifiers::InstrumentId,
28 types::{Price, fixed::PRECISION_BYTES},
29};
30
31use super::{
32 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
33 extract_column,
34};
35use crate::arrow::{
36 ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
37};
38
39impl ArrowSchemaProvider for InstrumentClose {
40 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
41 let fields = vec![
42 Field::new(
43 "close_price",
44 DataType::FixedSizeBinary(PRECISION_BYTES),
45 false,
46 ),
47 Field::new("close_type", DataType::UInt8, false),
48 Field::new("ts_event", DataType::UInt64, false),
49 Field::new("ts_init", DataType::UInt64, false),
50 ];
51
52 match metadata {
53 Some(metadata) => Schema::new_with_metadata(fields, metadata),
54 None => Schema::new(fields),
55 }
56 }
57}
58
59fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
60 let instrument_id_str = metadata
61 .get(KEY_INSTRUMENT_ID)
62 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
63 let instrument_id = InstrumentId::from_str(instrument_id_str)
64 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
65
66 let price_precision = metadata
67 .get(KEY_PRICE_PRECISION)
68 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
69 .parse::<u8>()
70 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
71
72 Ok((instrument_id, price_precision))
73}
74
75impl EncodeToRecordBatch for InstrumentClose {
76 fn encode_batch(
77 metadata: &HashMap<String, String>,
78 data: &[Self],
79 ) -> Result<RecordBatch, ArrowError> {
80 let mut close_price_builder =
81 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
82 let mut close_type_builder = UInt8Array::builder(data.len());
83 let mut ts_event_builder = UInt64Array::builder(data.len());
84 let mut ts_init_builder = UInt64Array::builder(data.len());
85
86 for item in data {
87 close_price_builder
88 .append_value(item.close_price.raw.to_le_bytes())
89 .unwrap();
90 close_type_builder.append_value(item.close_type as u8);
91 ts_event_builder.append_value(item.ts_event.as_u64());
92 ts_init_builder.append_value(item.ts_init.as_u64());
93 }
94
95 RecordBatch::try_new(
96 Self::get_schema(Some(metadata.clone())).into(),
97 vec![
98 Arc::new(close_price_builder.finish()),
99 Arc::new(close_type_builder.finish()),
100 Arc::new(ts_event_builder.finish()),
101 Arc::new(ts_init_builder.finish()),
102 ],
103 )
104 }
105
106 fn metadata(&self) -> HashMap<String, String> {
107 InstrumentClose::get_metadata(&self.instrument_id, self.close_price.precision)
108 }
109}
110
111impl DecodeFromRecordBatch for InstrumentClose {
112 fn decode_batch(
113 metadata: &HashMap<String, String>,
114 record_batch: RecordBatch,
115 ) -> Result<Vec<Self>, EncodingError> {
116 let (instrument_id, price_precision) = parse_metadata(metadata)?;
117 let cols = record_batch.columns();
118
119 let close_price_values = extract_column::<FixedSizeBinaryArray>(
120 cols,
121 "close_price",
122 0,
123 DataType::FixedSizeBinary(PRECISION_BYTES),
124 )?;
125 let close_type_values =
126 extract_column::<UInt8Array>(cols, "close_type", 1, DataType::UInt8)?;
127 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 2, DataType::UInt64)?;
128 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 3, DataType::UInt64)?;
129
130 assert_eq!(
131 close_price_values.value_length(),
132 PRECISION_BYTES,
133 "Price precision uses {PRECISION_BYTES} byte value"
134 );
135
136 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
137 .map(|row| {
138 let close_type_value = close_type_values.value(row);
139 let close_type =
140 InstrumentCloseType::from_u8(close_type_value).ok_or_else(|| {
141 EncodingError::ParseError(
142 stringify!(InstrumentCloseType),
143 format!("Invalid enum value, was {close_type_value}"),
144 )
145 })?;
146
147 Ok(Self {
148 instrument_id,
149 close_price: Price::from_raw(
150 get_raw_price(close_price_values.value(row)),
151 price_precision,
152 ),
153 close_type,
154 ts_event: ts_event_values.value(row).into(),
155 ts_init: ts_init_values.value(row).into(),
156 })
157 })
158 .collect();
159
160 result
161 }
162}
163
164impl DecodeDataFromRecordBatch for InstrumentClose {
165 fn decode_data_batch(
166 metadata: &HashMap<String, String>,
167 record_batch: RecordBatch,
168 ) -> Result<Vec<Data>, EncodingError> {
169 let items: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
170 Ok(items.into_iter().map(Data::from).collect())
171 }
172}
173
174#[cfg(test)]
178mod tests {
179 use std::sync::Arc;
180
181 use arrow::{array::Array, record_batch::RecordBatch};
182 use nautilus_model::{
183 enums::InstrumentCloseType,
184 types::{fixed::FIXED_SCALAR, price::PriceRaw},
185 };
186 use rstest::rstest;
187
188 use super::*;
189 use crate::arrow::get_raw_price;
190
191 #[rstest]
192 fn test_get_schema() {
193 let instrument_id = InstrumentId::from("AAPL.XNAS");
194 let metadata = HashMap::from([
195 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
196 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
197 ]);
198 let schema = InstrumentClose::get_schema(Some(metadata.clone()));
199
200 let expected_fields = vec![
201 Field::new(
202 "close_price",
203 DataType::FixedSizeBinary(PRECISION_BYTES),
204 false,
205 ),
206 Field::new("close_type", DataType::UInt8, false),
207 Field::new("ts_event", DataType::UInt64, false),
208 Field::new("ts_init", DataType::UInt64, false),
209 ];
210
211 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
212 assert_eq!(schema, expected_schema);
213 }
214
215 #[rstest]
216 fn test_get_schema_map() {
217 let schema_map = InstrumentClose::get_schema_map();
218 let mut expected_map = HashMap::new();
219
220 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
221 expected_map.insert("close_price".to_string(), fixed_size_binary);
222 expected_map.insert("close_type".to_string(), "UInt8".to_string());
223 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
224 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
225 assert_eq!(schema_map, expected_map);
226 }
227
228 #[rstest]
229 fn test_encode_batch() {
230 let instrument_id = InstrumentId::from("AAPL.XNAS");
231 let metadata = HashMap::from([
232 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
233 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
234 ]);
235
236 let close1 = InstrumentClose {
237 instrument_id,
238 close_price: Price::from("150.50"),
239 close_type: InstrumentCloseType::EndOfSession,
240 ts_event: 1.into(),
241 ts_init: 3.into(),
242 };
243
244 let close2 = InstrumentClose {
245 instrument_id,
246 close_price: Price::from("151.25"),
247 close_type: InstrumentCloseType::ContractExpired,
248 ts_event: 2.into(),
249 ts_init: 4.into(),
250 };
251
252 let data = vec![close1, close2];
253 let record_batch = InstrumentClose::encode_batch(&metadata, &data).unwrap();
254
255 let columns = record_batch.columns();
256 let close_price_values = columns[0]
257 .as_any()
258 .downcast_ref::<FixedSizeBinaryArray>()
259 .unwrap();
260 let close_type_values = columns[1].as_any().downcast_ref::<UInt8Array>().unwrap();
261 let ts_event_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
262 let ts_init_values = columns[3].as_any().downcast_ref::<UInt64Array>().unwrap();
263
264 assert_eq!(columns.len(), 4);
265 assert_eq!(close_price_values.len(), 2);
266 assert_eq!(
267 get_raw_price(close_price_values.value(0)),
268 (150.50 * FIXED_SCALAR) as PriceRaw
269 );
270 assert_eq!(
271 get_raw_price(close_price_values.value(1)),
272 (151.25 * FIXED_SCALAR) as PriceRaw
273 );
274 assert_eq!(close_type_values.len(), 2);
275 assert_eq!(
276 close_type_values.value(0),
277 InstrumentCloseType::EndOfSession as u8
278 );
279 assert_eq!(
280 close_type_values.value(1),
281 InstrumentCloseType::ContractExpired as u8
282 );
283 assert_eq!(ts_event_values.len(), 2);
284 assert_eq!(ts_event_values.value(0), 1);
285 assert_eq!(ts_event_values.value(1), 2);
286 assert_eq!(ts_init_values.len(), 2);
287 assert_eq!(ts_init_values.value(0), 3);
288 assert_eq!(ts_init_values.value(1), 4);
289 }
290
291 #[rstest]
292 fn test_decode_batch() {
293 let instrument_id = InstrumentId::from("AAPL.XNAS");
294 let metadata = HashMap::from([
295 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
296 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
297 ]);
298
299 let close_price = FixedSizeBinaryArray::from(vec![
300 &(15050 as PriceRaw).to_le_bytes(),
301 &(15125 as PriceRaw).to_le_bytes(),
302 ]);
303 let close_type = UInt8Array::from(vec![
304 InstrumentCloseType::EndOfSession as u8,
305 InstrumentCloseType::ContractExpired as u8,
306 ]);
307 let ts_event = UInt64Array::from(vec![1, 2]);
308 let ts_init = UInt64Array::from(vec![3, 4]);
309
310 let record_batch = RecordBatch::try_new(
311 InstrumentClose::get_schema(Some(metadata.clone())).into(),
312 vec![
313 Arc::new(close_price),
314 Arc::new(close_type),
315 Arc::new(ts_event),
316 Arc::new(ts_init),
317 ],
318 )
319 .unwrap();
320
321 let decoded_data = InstrumentClose::decode_batch(&metadata, record_batch).unwrap();
322
323 assert_eq!(decoded_data.len(), 2);
324 assert_eq!(decoded_data[0].instrument_id, instrument_id);
325 assert_eq!(decoded_data[0].close_price, Price::from_raw(15050, 2));
326 assert_eq!(
327 decoded_data[0].close_type,
328 InstrumentCloseType::EndOfSession
329 );
330 assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
331 assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
332
333 assert_eq!(decoded_data[1].instrument_id, instrument_id);
334 assert_eq!(decoded_data[1].close_price, Price::from_raw(15125, 2));
335 assert_eq!(
336 decoded_data[1].close_type,
337 InstrumentCloseType::ContractExpired
338 );
339 assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
340 assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
341 }
342}