1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::prices::IndexPriceUpdate, identifiers::InstrumentId, types::fixed::PRECISION_BYTES,
26};
27
28use super::{
29 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION, decode_price,
30 extract_column, validate_precision_bytes,
31};
32use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
33
34impl ArrowSchemaProvider for IndexPriceUpdate {
35 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
36 let fields = vec![
37 Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
38 Field::new("ts_event", DataType::UInt64, false),
39 Field::new("ts_init", DataType::UInt64, false),
40 ];
41
42 match metadata {
43 Some(metadata) => Schema::new_with_metadata(fields, metadata),
44 None => Schema::new(fields),
45 }
46 }
47}
48
49fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
50 let instrument_id_str = metadata
51 .get(KEY_INSTRUMENT_ID)
52 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
53 let instrument_id = InstrumentId::from_str(instrument_id_str)
54 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
55
56 let price_precision = metadata
57 .get(KEY_PRICE_PRECISION)
58 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
59 .parse::<u8>()
60 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
61
62 Ok((instrument_id, price_precision))
63}
64
65impl EncodeToRecordBatch for IndexPriceUpdate {
66 fn encode_batch(
67 metadata: &HashMap<String, String>,
68 data: &[Self],
69 ) -> Result<RecordBatch, ArrowError> {
70 let mut value_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
71 let mut ts_event_builder = UInt64Array::builder(data.len());
72 let mut ts_init_builder = UInt64Array::builder(data.len());
73
74 for update in data {
75 value_builder
76 .append_value(update.value.raw.to_le_bytes())
77 .unwrap();
78 ts_event_builder.append_value(update.ts_event.as_u64());
79 ts_init_builder.append_value(update.ts_init.as_u64());
80 }
81
82 RecordBatch::try_new(
83 Self::get_schema(Some(metadata.clone())).into(),
84 vec![
85 Arc::new(value_builder.finish()),
86 Arc::new(ts_event_builder.finish()),
87 Arc::new(ts_init_builder.finish()),
88 ],
89 )
90 }
91
92 fn metadata(&self) -> HashMap<String, String> {
93 let mut metadata = HashMap::new();
94 metadata.insert(
95 KEY_INSTRUMENT_ID.to_string(),
96 self.instrument_id.to_string(),
97 );
98 metadata.insert(
99 KEY_PRICE_PRECISION.to_string(),
100 self.value.precision.to_string(),
101 );
102 metadata
103 }
104}
105
106impl DecodeFromRecordBatch for IndexPriceUpdate {
107 fn decode_batch(
108 metadata: &HashMap<String, String>,
109 record_batch: RecordBatch,
110 ) -> Result<Vec<Self>, EncodingError> {
111 let (instrument_id, price_precision) = parse_metadata(metadata)?;
112 let cols = record_batch.columns();
113
114 let value_values = extract_column::<FixedSizeBinaryArray>(
115 cols,
116 "value",
117 0,
118 DataType::FixedSizeBinary(PRECISION_BYTES),
119 )?;
120 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 1, DataType::UInt64)?;
121 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 2, DataType::UInt64)?;
122
123 validate_precision_bytes(value_values, "value")?;
124
125 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
126 .map(|row| {
127 let value = decode_price(value_values.value(row), price_precision, "value", row)?;
128 Ok(Self {
129 instrument_id,
130 value,
131 ts_event: ts_event_values.value(row).into(),
132 ts_init: ts_init_values.value(row).into(),
133 })
134 })
135 .collect();
136
137 result
138 }
139}
140
141impl DecodeDataFromRecordBatch for IndexPriceUpdate {
142 fn decode_data_batch(
143 metadata: &HashMap<String, String>,
144 record_batch: RecordBatch,
145 ) -> Result<Vec<Data>, EncodingError> {
146 let updates: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
147 Ok(updates.into_iter().map(Data::from).collect())
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use std::sync::Arc;
154
155 use arrow::{array::Array, record_batch::RecordBatch};
156 use nautilus_model::types::{Price, fixed::FIXED_SCALAR, price::PriceRaw};
157 use rstest::rstest;
158 use rust_decimal_macros::dec;
159
160 use super::*;
161 use crate::arrow::get_raw_price;
162
163 #[rstest]
164 fn test_get_schema() {
165 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
166 let metadata = HashMap::from([
167 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
168 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
169 ]);
170 let schema = IndexPriceUpdate::get_schema(Some(metadata.clone()));
171
172 let expected_fields = vec![
173 Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
174 Field::new("ts_event", DataType::UInt64, false),
175 Field::new("ts_init", DataType::UInt64, false),
176 ];
177
178 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
179 assert_eq!(schema, expected_schema);
180 }
181
182 #[rstest]
183 fn test_get_schema_map() {
184 let schema_map = IndexPriceUpdate::get_schema_map();
185 let mut expected_map = HashMap::new();
186
187 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
188 expected_map.insert("value".to_string(), fixed_size_binary);
189 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
190 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
191 assert_eq!(schema_map, expected_map);
192 }
193
194 #[rstest]
195 fn test_encode_batch() {
196 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
197 let metadata = HashMap::from([
198 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
199 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
200 ]);
201
202 let update1 = IndexPriceUpdate {
203 instrument_id,
204 value: Price::from("50000.00"),
205 ts_event: 1.into(),
206 ts_init: 3.into(),
207 };
208
209 let update2 = IndexPriceUpdate {
210 instrument_id,
211 value: Price::from("51000.00"),
212 ts_event: 2.into(),
213 ts_init: 4.into(),
214 };
215
216 let data = vec![update1, update2];
217 let record_batch = IndexPriceUpdate::encode_batch(&metadata, &data).unwrap();
218
219 let columns = record_batch.columns();
220 let value_values = columns[0]
221 .as_any()
222 .downcast_ref::<FixedSizeBinaryArray>()
223 .unwrap();
224 let ts_event_values = columns[1].as_any().downcast_ref::<UInt64Array>().unwrap();
225 let ts_init_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
226
227 assert_eq!(columns.len(), 3);
228 assert_eq!(value_values.len(), 2);
229 assert_eq!(
230 get_raw_price(value_values.value(0)),
231 Price::from(dec!(50000.00).to_string()).raw
232 );
233 assert_eq!(
234 get_raw_price(value_values.value(1)),
235 Price::from(dec!(51000.00).to_string()).raw
236 );
237 assert_eq!(ts_event_values.len(), 2);
238 assert_eq!(ts_event_values.value(0), 1);
239 assert_eq!(ts_event_values.value(1), 2);
240 assert_eq!(ts_init_values.len(), 2);
241 assert_eq!(ts_init_values.value(0), 3);
242 assert_eq!(ts_init_values.value(1), 4);
243 }
244
245 #[rstest]
246 fn test_decode_batch() {
247 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
248 let metadata = HashMap::from([
249 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
250 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
251 ]);
252
253 let raw_price1 = (50.00 * FIXED_SCALAR) as PriceRaw;
254 let raw_price2 = (51.00 * FIXED_SCALAR) as PriceRaw;
255 let value =
256 FixedSizeBinaryArray::from(vec![&raw_price1.to_le_bytes(), &raw_price2.to_le_bytes()]);
257 let ts_event = UInt64Array::from(vec![1, 2]);
258 let ts_init = UInt64Array::from(vec![3, 4]);
259
260 let record_batch = RecordBatch::try_new(
261 IndexPriceUpdate::get_schema(Some(metadata.clone())).into(),
262 vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
263 )
264 .unwrap();
265
266 let decoded_data = IndexPriceUpdate::decode_batch(&metadata, record_batch).unwrap();
267
268 assert_eq!(decoded_data.len(), 2);
269 assert_eq!(decoded_data[0].instrument_id, instrument_id);
270 assert_eq!(decoded_data[0].value, Price::from_raw(raw_price1, 2));
271 assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
272 assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
273
274 assert_eq!(decoded_data[1].instrument_id, instrument_id);
275 assert_eq!(decoded_data[1].value, Price::from_raw(raw_price2, 2));
276 assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
277 assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
278 }
279
280 #[rstest]
281 fn test_decode_batch_invalid_value_returns_error() {
282 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
283 let metadata = HashMap::from([
284 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
285 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
286 ]);
287
288 let invalid_price: PriceRaw = PriceRaw::MAX - 1000;
289 let value = FixedSizeBinaryArray::from(vec![&invalid_price.to_le_bytes()]);
290 let ts_event = UInt64Array::from(vec![1]);
291 let ts_init = UInt64Array::from(vec![2]);
292
293 let record_batch = RecordBatch::try_new(
294 IndexPriceUpdate::get_schema(Some(metadata.clone())).into(),
295 vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
296 )
297 .unwrap();
298
299 let result = IndexPriceUpdate::decode_batch(&metadata, record_batch);
300 assert!(result.is_err());
301 let err = result.unwrap_err();
302 assert!(
303 err.to_string().contains("value") && err.to_string().contains("row 0"),
304 "Expected value error at row 0, was: {err}"
305 );
306 }
307
308 #[rstest]
309 fn test_decode_batch_missing_instrument_id_returns_error() {
310 let mut metadata = HashMap::from([
311 (
312 KEY_INSTRUMENT_ID.to_string(),
313 "BTC-USDT.BINANCE".to_string(),
314 ),
315 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
316 ]);
317
318 let raw_price = (50.00 * FIXED_SCALAR) as PriceRaw;
319 let value = FixedSizeBinaryArray::from(vec![&raw_price.to_le_bytes()]);
320 let ts_event = UInt64Array::from(vec![1]);
321 let ts_init = UInt64Array::from(vec![2]);
322
323 let record_batch = RecordBatch::try_new(
324 IndexPriceUpdate::get_schema(Some(metadata.clone())).into(),
325 vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
326 )
327 .unwrap();
328
329 metadata.remove(KEY_INSTRUMENT_ID);
330
331 let result = IndexPriceUpdate::decode_batch(&metadata, record_batch);
332 assert!(result.is_err());
333 let err = result.unwrap_err();
334 assert!(
335 err.to_string().contains("instrument_id"),
336 "Expected missing instrument_id error, was: {err}"
337 );
338 }
339
340 #[rstest]
341 fn test_encode_decode_round_trip() {
342 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
343 let metadata = HashMap::from([
344 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
345 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
346 ]);
347
348 let update1 = IndexPriceUpdate {
349 instrument_id,
350 value: Price::from("50000.00"),
351 ts_event: 1_000_000_000.into(),
352 ts_init: 1_000_000_001.into(),
353 };
354
355 let update2 = IndexPriceUpdate {
356 instrument_id,
357 value: Price::from("51000.00"),
358 ts_event: 2_000_000_000.into(),
359 ts_init: 2_000_000_001.into(),
360 };
361
362 let original = vec![update1, update2];
363 let record_batch = IndexPriceUpdate::encode_batch(&metadata, &original).unwrap();
364 let decoded = IndexPriceUpdate::decode_batch(&metadata, record_batch).unwrap();
365
366 assert_eq!(decoded.len(), original.len());
367 for (orig, dec) in original.iter().zip(decoded.iter()) {
368 assert_eq!(dec.instrument_id, orig.instrument_id);
369 assert_eq!(dec.value, orig.value);
370 assert_eq!(dec.ts_event, orig.ts_event);
371 assert_eq!(dec.ts_init, orig.ts_init);
372 }
373 }
374}