nautilus_serialization/arrow/
index_price.rs1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::prices::IndexPriceUpdate,
26 identifiers::InstrumentId,
27 types::{Price, fixed::PRECISION_BYTES},
28};
29
30use super::{
31 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
32 extract_column, get_corrected_raw_price,
33};
34use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
35
36impl ArrowSchemaProvider for IndexPriceUpdate {
37 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
38 let fields = vec![
39 Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
40 Field::new("ts_event", DataType::UInt64, false),
41 Field::new("ts_init", DataType::UInt64, false),
42 ];
43
44 match metadata {
45 Some(metadata) => Schema::new_with_metadata(fields, metadata),
46 None => Schema::new(fields),
47 }
48 }
49}
50
51fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
52 let instrument_id_str = metadata
53 .get(KEY_INSTRUMENT_ID)
54 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
55 let instrument_id = InstrumentId::from_str(instrument_id_str)
56 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
57
58 let price_precision = metadata
59 .get(KEY_PRICE_PRECISION)
60 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
61 .parse::<u8>()
62 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
63
64 Ok((instrument_id, price_precision))
65}
66
67impl EncodeToRecordBatch for IndexPriceUpdate {
68 fn encode_batch(
69 metadata: &HashMap<String, String>,
70 data: &[Self],
71 ) -> Result<RecordBatch, ArrowError> {
72 let mut value_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
73 let mut ts_event_builder = UInt64Array::builder(data.len());
74 let mut ts_init_builder = UInt64Array::builder(data.len());
75
76 for update in data {
77 value_builder
78 .append_value(update.value.raw.to_le_bytes())
79 .unwrap();
80 ts_event_builder.append_value(update.ts_event.as_u64());
81 ts_init_builder.append_value(update.ts_init.as_u64());
82 }
83
84 RecordBatch::try_new(
85 Self::get_schema(Some(metadata.clone())).into(),
86 vec![
87 Arc::new(value_builder.finish()),
88 Arc::new(ts_event_builder.finish()),
89 Arc::new(ts_init_builder.finish()),
90 ],
91 )
92 }
93
94 fn metadata(&self) -> HashMap<String, String> {
95 let mut metadata = HashMap::new();
96 metadata.insert(
97 KEY_INSTRUMENT_ID.to_string(),
98 self.instrument_id.to_string(),
99 );
100 metadata.insert(
101 KEY_PRICE_PRECISION.to_string(),
102 self.value.precision.to_string(),
103 );
104 metadata
105 }
106}
107
108impl DecodeFromRecordBatch for IndexPriceUpdate {
109 fn decode_batch(
110 metadata: &HashMap<String, String>,
111 record_batch: RecordBatch,
112 ) -> Result<Vec<Self>, EncodingError> {
113 let (instrument_id, price_precision) = parse_metadata(metadata)?;
114 let cols = record_batch.columns();
115
116 let value_values = extract_column::<FixedSizeBinaryArray>(
117 cols,
118 "value",
119 0,
120 DataType::FixedSizeBinary(PRECISION_BYTES),
121 )?;
122 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 1, DataType::UInt64)?;
123 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 2, DataType::UInt64)?;
124
125 if value_values.value_length() != PRECISION_BYTES {
126 return Err(EncodingError::ParseError(
127 "value",
128 format!(
129 "Invalid value length: expected {PRECISION_BYTES}, found {}",
130 value_values.value_length()
131 ),
132 ));
133 }
134
135 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
136 .map(|row| {
137 Ok(Self {
138 instrument_id,
140 value: Price::from_raw(
141 get_corrected_raw_price(value_values.value(row), price_precision),
142 price_precision,
143 ),
144 ts_event: ts_event_values.value(row).into(),
145 ts_init: ts_init_values.value(row).into(),
146 })
147 })
148 .collect();
149
150 result
151 }
152}
153
154impl DecodeDataFromRecordBatch for IndexPriceUpdate {
155 fn decode_data_batch(
156 metadata: &HashMap<String, String>,
157 record_batch: RecordBatch,
158 ) -> Result<Vec<Data>, EncodingError> {
159 let updates: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
160 Ok(updates.into_iter().map(Data::from).collect())
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use std::sync::Arc;
167
168 use arrow::{array::Array, record_batch::RecordBatch};
169 use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw};
170 use rstest::rstest;
171 use rust_decimal_macros::dec;
172
173 use super::*;
174 use crate::arrow::get_raw_price;
175
176 #[rstest]
177 fn test_get_schema() {
178 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
179 let metadata = HashMap::from([
180 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
181 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
182 ]);
183 let schema = IndexPriceUpdate::get_schema(Some(metadata.clone()));
184
185 let expected_fields = vec![
186 Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
187 Field::new("ts_event", DataType::UInt64, false),
188 Field::new("ts_init", DataType::UInt64, false),
189 ];
190
191 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
192 assert_eq!(schema, expected_schema);
193 }
194
195 #[rstest]
196 fn test_get_schema_map() {
197 let schema_map = IndexPriceUpdate::get_schema_map();
198 let mut expected_map = HashMap::new();
199
200 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
201 expected_map.insert("value".to_string(), fixed_size_binary);
202 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
203 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
204 assert_eq!(schema_map, expected_map);
205 }
206
207 #[rstest]
208 fn test_encode_batch() {
209 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
210 let metadata = HashMap::from([
211 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
212 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
213 ]);
214
215 let update1 = IndexPriceUpdate {
216 instrument_id,
217 value: Price::from("50000.00"),
218 ts_event: 1.into(),
219 ts_init: 3.into(),
220 };
221
222 let update2 = IndexPriceUpdate {
223 instrument_id,
224 value: Price::from("51000.00"),
225 ts_event: 2.into(),
226 ts_init: 4.into(),
227 };
228
229 let data = vec![update1, update2];
230 let record_batch = IndexPriceUpdate::encode_batch(&metadata, &data).unwrap();
231
232 let columns = record_batch.columns();
233 let value_values = columns[0]
234 .as_any()
235 .downcast_ref::<FixedSizeBinaryArray>()
236 .unwrap();
237 let ts_event_values = columns[1].as_any().downcast_ref::<UInt64Array>().unwrap();
238 let ts_init_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
239
240 assert_eq!(columns.len(), 3);
241 assert_eq!(value_values.len(), 2);
242 assert_eq!(
243 get_raw_price(value_values.value(0)),
244 Price::from(dec!(50000.00).to_string()).raw
245 );
246 assert_eq!(
247 get_raw_price(value_values.value(1)),
248 Price::from(dec!(51000.00).to_string()).raw
249 );
250 assert_eq!(ts_event_values.len(), 2);
251 assert_eq!(ts_event_values.value(0), 1);
252 assert_eq!(ts_event_values.value(1), 2);
253 assert_eq!(ts_init_values.len(), 2);
254 assert_eq!(ts_init_values.value(0), 3);
255 assert_eq!(ts_init_values.value(1), 4);
256 }
257
258 #[rstest]
259 fn test_decode_batch() {
260 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
261 let metadata = HashMap::from([
262 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
263 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
264 ]);
265
266 let raw_price1 = (50.00 * FIXED_SCALAR) as PriceRaw;
267 let raw_price2 = (51.00 * FIXED_SCALAR) as PriceRaw;
268 let value =
269 FixedSizeBinaryArray::from(vec![&raw_price1.to_le_bytes(), &raw_price2.to_le_bytes()]);
270 let ts_event = UInt64Array::from(vec![1, 2]);
271 let ts_init = UInt64Array::from(vec![3, 4]);
272
273 let record_batch = RecordBatch::try_new(
274 IndexPriceUpdate::get_schema(Some(metadata.clone())).into(),
275 vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
276 )
277 .unwrap();
278
279 let decoded_data = IndexPriceUpdate::decode_batch(&metadata, record_batch).unwrap();
280
281 assert_eq!(decoded_data.len(), 2);
282 assert_eq!(decoded_data[0].instrument_id, instrument_id);
283 assert_eq!(decoded_data[0].value, Price::from_raw(raw_price1, 2));
284 assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
285 assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
286
287 assert_eq!(decoded_data[1].instrument_id, instrument_id);
288 assert_eq!(decoded_data[1].value, Price::from_raw(raw_price2, 2));
289 assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
290 assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
291 }
292}