1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt8Array, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::close::InstrumentClose,
26 enums::{FromU8, InstrumentCloseType},
27 identifiers::InstrumentId,
28 types::{Price, fixed::PRECISION_BYTES},
29};
30
31use super::{
32 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
33 extract_column,
34};
35use crate::arrow::{
36 ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
37};
38
39impl ArrowSchemaProvider for InstrumentClose {
40 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
41 let fields = vec![
42 Field::new(
43 "close_price",
44 DataType::FixedSizeBinary(PRECISION_BYTES),
45 false,
46 ),
47 Field::new("close_type", DataType::UInt8, false),
48 Field::new("ts_event", DataType::UInt64, false),
49 Field::new("ts_init", DataType::UInt64, false),
50 ];
51
52 match metadata {
53 Some(metadata) => Schema::new_with_metadata(fields, metadata),
54 None => Schema::new(fields),
55 }
56 }
57}
58
59fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
60 let instrument_id_str = metadata
61 .get(KEY_INSTRUMENT_ID)
62 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
63 let instrument_id = InstrumentId::from_str(instrument_id_str)
64 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
65
66 let price_precision = metadata
67 .get(KEY_PRICE_PRECISION)
68 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
69 .parse::<u8>()
70 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
71
72 Ok((instrument_id, price_precision))
73}
74
75impl EncodeToRecordBatch for InstrumentClose {
76 fn encode_batch(
77 metadata: &HashMap<String, String>,
78 data: &[Self],
79 ) -> Result<RecordBatch, ArrowError> {
80 let mut close_price_builder =
81 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
82 let mut close_type_builder = UInt8Array::builder(data.len());
83 let mut ts_event_builder = UInt64Array::builder(data.len());
84 let mut ts_init_builder = UInt64Array::builder(data.len());
85
86 for item in data {
87 close_price_builder
88 .append_value(item.close_price.raw.to_le_bytes())
89 .unwrap();
90 close_type_builder.append_value(item.close_type as u8);
91 ts_event_builder.append_value(item.ts_event.as_u64());
92 ts_init_builder.append_value(item.ts_init.as_u64());
93 }
94
95 RecordBatch::try_new(
96 Self::get_schema(Some(metadata.clone())).into(),
97 vec![
98 Arc::new(close_price_builder.finish()),
99 Arc::new(close_type_builder.finish()),
100 Arc::new(ts_event_builder.finish()),
101 Arc::new(ts_init_builder.finish()),
102 ],
103 )
104 }
105
106 fn metadata(&self) -> HashMap<String, String> {
107 InstrumentClose::get_metadata(&self.instrument_id, self.close_price.precision)
108 }
109}
110
111impl DecodeFromRecordBatch for InstrumentClose {
112 fn decode_batch(
113 metadata: &HashMap<String, String>,
114 record_batch: RecordBatch,
115 ) -> Result<Vec<Self>, EncodingError> {
116 let (instrument_id, price_precision) = parse_metadata(metadata)?;
117 let cols = record_batch.columns();
118
119 let close_price_values = extract_column::<FixedSizeBinaryArray>(
120 cols,
121 "close_price",
122 0,
123 DataType::FixedSizeBinary(PRECISION_BYTES),
124 )?;
125 let close_type_values =
126 extract_column::<UInt8Array>(cols, "close_type", 1, DataType::UInt8)?;
127 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 2, DataType::UInt64)?;
128 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 3, DataType::UInt64)?;
129
130 if close_price_values.value_length() != PRECISION_BYTES {
132 return Err(EncodingError::ParseError(
133 "close_price",
134 format!(
135 "Invalid value length: expected {PRECISION_BYTES}, found {}",
136 close_price_values.value_length()
137 ),
138 ));
139 }
140
141 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
142 .map(|row| {
143 let close_type_value = close_type_values.value(row);
144 let close_type =
145 InstrumentCloseType::from_u8(close_type_value).ok_or_else(|| {
146 EncodingError::ParseError(
147 stringify!(InstrumentCloseType),
148 format!("Invalid enum value, was {close_type_value}"),
149 )
150 })?;
151
152 Ok(Self {
153 instrument_id,
154 close_price: Price::from_raw(
155 get_raw_price(close_price_values.value(row)),
156 price_precision,
157 ),
158 close_type,
159 ts_event: ts_event_values.value(row).into(),
160 ts_init: ts_init_values.value(row).into(),
161 })
162 })
163 .collect();
164
165 result
166 }
167}
168
169impl DecodeDataFromRecordBatch for InstrumentClose {
170 fn decode_data_batch(
171 metadata: &HashMap<String, String>,
172 record_batch: RecordBatch,
173 ) -> Result<Vec<Data>, EncodingError> {
174 let items: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
175 Ok(items.into_iter().map(Data::from).collect())
176 }
177}
178
179#[cfg(test)]
183mod tests {
184 use std::sync::Arc;
185
186 use arrow::{array::Array, record_batch::RecordBatch};
187 use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw};
188 use rstest::rstest;
189
190 use super::*;
191 use crate::arrow::get_raw_price;
192
193 #[rstest]
194 fn test_get_schema() {
195 let instrument_id = InstrumentId::from("AAPL.XNAS");
196 let metadata = HashMap::from([
197 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
198 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
199 ]);
200 let schema = InstrumentClose::get_schema(Some(metadata.clone()));
201
202 let expected_fields = vec![
203 Field::new(
204 "close_price",
205 DataType::FixedSizeBinary(PRECISION_BYTES),
206 false,
207 ),
208 Field::new("close_type", DataType::UInt8, false),
209 Field::new("ts_event", DataType::UInt64, false),
210 Field::new("ts_init", DataType::UInt64, false),
211 ];
212
213 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
214 assert_eq!(schema, expected_schema);
215 }
216
217 #[rstest]
218 fn test_get_schema_map() {
219 let schema_map = InstrumentClose::get_schema_map();
220 let mut expected_map = HashMap::new();
221
222 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
223 expected_map.insert("close_price".to_string(), fixed_size_binary);
224 expected_map.insert("close_type".to_string(), "UInt8".to_string());
225 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
226 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
227 assert_eq!(schema_map, expected_map);
228 }
229
230 #[rstest]
231 fn test_encode_batch() {
232 let instrument_id = InstrumentId::from("AAPL.XNAS");
233 let metadata = HashMap::from([
234 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
235 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
236 ]);
237
238 let close1 = InstrumentClose {
239 instrument_id,
240 close_price: Price::from("150.50"),
241 close_type: InstrumentCloseType::EndOfSession,
242 ts_event: 1.into(),
243 ts_init: 3.into(),
244 };
245
246 let close2 = InstrumentClose {
247 instrument_id,
248 close_price: Price::from("151.25"),
249 close_type: InstrumentCloseType::ContractExpired,
250 ts_event: 2.into(),
251 ts_init: 4.into(),
252 };
253
254 let data = vec![close1, close2];
255 let record_batch = InstrumentClose::encode_batch(&metadata, &data).unwrap();
256
257 let columns = record_batch.columns();
258 let close_price_values = columns[0]
259 .as_any()
260 .downcast_ref::<FixedSizeBinaryArray>()
261 .unwrap();
262 let close_type_values = columns[1].as_any().downcast_ref::<UInt8Array>().unwrap();
263 let ts_event_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
264 let ts_init_values = columns[3].as_any().downcast_ref::<UInt64Array>().unwrap();
265
266 assert_eq!(columns.len(), 4);
267 assert_eq!(close_price_values.len(), 2);
268 assert_eq!(
269 get_raw_price(close_price_values.value(0)),
270 (150.50 * FIXED_SCALAR) as PriceRaw
271 );
272 assert_eq!(
273 get_raw_price(close_price_values.value(1)),
274 (151.25 * FIXED_SCALAR) as PriceRaw
275 );
276 assert_eq!(close_type_values.len(), 2);
277 assert_eq!(
278 close_type_values.value(0),
279 InstrumentCloseType::EndOfSession as u8
280 );
281 assert_eq!(
282 close_type_values.value(1),
283 InstrumentCloseType::ContractExpired as u8
284 );
285 assert_eq!(ts_event_values.len(), 2);
286 assert_eq!(ts_event_values.value(0), 1);
287 assert_eq!(ts_event_values.value(1), 2);
288 assert_eq!(ts_init_values.len(), 2);
289 assert_eq!(ts_init_values.value(0), 3);
290 assert_eq!(ts_init_values.value(1), 4);
291 }
292
293 #[rstest]
294 fn test_decode_batch() {
295 let instrument_id = InstrumentId::from("AAPL.XNAS");
296 let metadata = HashMap::from([
297 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
298 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
299 ]);
300
301 let close_price = FixedSizeBinaryArray::from(vec![
302 &(15050 as PriceRaw).to_le_bytes(),
303 &(15125 as PriceRaw).to_le_bytes(),
304 ]);
305 let close_type = UInt8Array::from(vec![
306 InstrumentCloseType::EndOfSession as u8,
307 InstrumentCloseType::ContractExpired as u8,
308 ]);
309 let ts_event = UInt64Array::from(vec![1, 2]);
310 let ts_init = UInt64Array::from(vec![3, 4]);
311
312 let record_batch = RecordBatch::try_new(
313 InstrumentClose::get_schema(Some(metadata.clone())).into(),
314 vec![
315 Arc::new(close_price),
316 Arc::new(close_type),
317 Arc::new(ts_event),
318 Arc::new(ts_init),
319 ],
320 )
321 .unwrap();
322
323 let decoded_data = InstrumentClose::decode_batch(&metadata, record_batch).unwrap();
324
325 assert_eq!(decoded_data.len(), 2);
326 assert_eq!(decoded_data[0].instrument_id, instrument_id);
327 assert_eq!(decoded_data[0].close_price, Price::from_raw(15050, 2));
328 assert_eq!(
329 decoded_data[0].close_type,
330 InstrumentCloseType::EndOfSession
331 );
332 assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
333 assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
334
335 assert_eq!(decoded_data[1].instrument_id, instrument_id);
336 assert_eq!(decoded_data[1].close_price, Price::from_raw(15125, 2));
337 assert_eq!(
338 decoded_data[1].close_type,
339 InstrumentCloseType::ContractExpired
340 );
341 assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
342 assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
343 }
344}