nautilus_serialization/arrow/
index_price.rs1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::prices::IndexPriceUpdate,
26 identifiers::InstrumentId,
27 types::{Price, fixed::PRECISION_BYTES},
28};
29
30use super::{
31 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
32 extract_column,
33};
34use crate::arrow::{
35 ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
36};
37
38impl ArrowSchemaProvider for IndexPriceUpdate {
39 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
40 let fields = vec![
41 Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
42 Field::new("ts_event", DataType::UInt64, false),
43 Field::new("ts_init", DataType::UInt64, false),
44 ];
45
46 match metadata {
47 Some(metadata) => Schema::new_with_metadata(fields, metadata),
48 None => Schema::new(fields),
49 }
50 }
51}
52
53fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
54 let instrument_id_str = metadata
55 .get(KEY_INSTRUMENT_ID)
56 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
57 let instrument_id = InstrumentId::from_str(instrument_id_str)
58 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
59
60 let price_precision = metadata
61 .get(KEY_PRICE_PRECISION)
62 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
63 .parse::<u8>()
64 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
65
66 Ok((instrument_id, price_precision))
67}
68
69impl EncodeToRecordBatch for IndexPriceUpdate {
70 fn encode_batch(
71 metadata: &HashMap<String, String>,
72 data: &[Self],
73 ) -> Result<RecordBatch, ArrowError> {
74 let mut value_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
75 let mut ts_event_builder = UInt64Array::builder(data.len());
76 let mut ts_init_builder = UInt64Array::builder(data.len());
77
78 for update in data {
79 value_builder
80 .append_value(update.value.raw.to_le_bytes())
81 .unwrap();
82 ts_event_builder.append_value(update.ts_event.as_u64());
83 ts_init_builder.append_value(update.ts_init.as_u64());
84 }
85
86 RecordBatch::try_new(
87 Self::get_schema(Some(metadata.clone())).into(),
88 vec![
89 Arc::new(value_builder.finish()),
90 Arc::new(ts_event_builder.finish()),
91 Arc::new(ts_init_builder.finish()),
92 ],
93 )
94 }
95
96 fn metadata(&self) -> HashMap<String, String> {
97 let mut metadata = HashMap::new();
98 metadata.insert(
99 KEY_INSTRUMENT_ID.to_string(),
100 self.instrument_id.to_string(),
101 );
102 metadata.insert(
103 KEY_PRICE_PRECISION.to_string(),
104 self.value.precision.to_string(),
105 );
106 metadata
107 }
108}
109
110impl DecodeFromRecordBatch for IndexPriceUpdate {
111 fn decode_batch(
112 metadata: &HashMap<String, String>,
113 record_batch: RecordBatch,
114 ) -> Result<Vec<Self>, EncodingError> {
115 let (instrument_id, price_precision) = parse_metadata(metadata)?;
116 let cols = record_batch.columns();
117
118 let value_values = extract_column::<FixedSizeBinaryArray>(
119 cols,
120 "value",
121 0,
122 DataType::FixedSizeBinary(PRECISION_BYTES),
123 )?;
124 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 1, DataType::UInt64)?;
125 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 2, DataType::UInt64)?;
126
127 if value_values.value_length() != PRECISION_BYTES {
128 return Err(EncodingError::ParseError(
129 "value",
130 format!(
131 "Invalid value length: expected {PRECISION_BYTES}, found {}",
132 value_values.value_length()
133 ),
134 ));
135 }
136
137 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
138 .map(|row| {
139 Ok(Self {
140 instrument_id,
141 value: Price::from_raw(get_raw_price(value_values.value(row)), price_precision),
142 ts_event: ts_event_values.value(row).into(),
143 ts_init: ts_init_values.value(row).into(),
144 })
145 })
146 .collect();
147
148 result
149 }
150}
151
152impl DecodeDataFromRecordBatch for IndexPriceUpdate {
153 fn decode_data_batch(
154 metadata: &HashMap<String, String>,
155 record_batch: RecordBatch,
156 ) -> Result<Vec<Data>, EncodingError> {
157 let updates: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
158 Ok(updates.into_iter().map(Data::from).collect())
159 }
160}
161
162#[cfg(test)]
166mod tests {
167 use std::sync::Arc;
168
169 use arrow::{array::Array, record_batch::RecordBatch};
170 use nautilus_model::types::price::PriceRaw;
171 use rstest::rstest;
172 use rust_decimal_macros::dec;
173
174 use super::*;
175 use crate::arrow::get_raw_price;
176
177 #[rstest]
178 fn test_get_schema() {
179 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
180 let metadata = HashMap::from([
181 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
182 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
183 ]);
184 let schema = IndexPriceUpdate::get_schema(Some(metadata.clone()));
185
186 let expected_fields = vec![
187 Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
188 Field::new("ts_event", DataType::UInt64, false),
189 Field::new("ts_init", DataType::UInt64, false),
190 ];
191
192 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
193 assert_eq!(schema, expected_schema);
194 }
195
196 #[rstest]
197 fn test_get_schema_map() {
198 let schema_map = IndexPriceUpdate::get_schema_map();
199 let mut expected_map = HashMap::new();
200
201 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
202 expected_map.insert("value".to_string(), fixed_size_binary);
203 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
204 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
205 assert_eq!(schema_map, expected_map);
206 }
207
208 #[rstest]
209 fn test_encode_batch() {
210 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
211 let metadata = HashMap::from([
212 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
213 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
214 ]);
215
216 let update1 = IndexPriceUpdate {
217 instrument_id,
218 value: Price::from("50000.00"),
219 ts_event: 1.into(),
220 ts_init: 3.into(),
221 };
222
223 let update2 = IndexPriceUpdate {
224 instrument_id,
225 value: Price::from("51000.00"),
226 ts_event: 2.into(),
227 ts_init: 4.into(),
228 };
229
230 let data = vec![update1, update2];
231 let record_batch = IndexPriceUpdate::encode_batch(&metadata, &data).unwrap();
232
233 let columns = record_batch.columns();
234 let value_values = columns[0]
235 .as_any()
236 .downcast_ref::<FixedSizeBinaryArray>()
237 .unwrap();
238 let ts_event_values = columns[1].as_any().downcast_ref::<UInt64Array>().unwrap();
239 let ts_init_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
240
241 assert_eq!(columns.len(), 3);
242 assert_eq!(value_values.len(), 2);
243 assert_eq!(
244 get_raw_price(value_values.value(0)),
245 Price::from(dec!(50000.00).to_string()).raw
246 );
247 assert_eq!(
248 get_raw_price(value_values.value(1)),
249 Price::from(dec!(51000.00).to_string()).raw
250 );
251 assert_eq!(ts_event_values.len(), 2);
252 assert_eq!(ts_event_values.value(0), 1);
253 assert_eq!(ts_event_values.value(1), 2);
254 assert_eq!(ts_init_values.len(), 2);
255 assert_eq!(ts_init_values.value(0), 3);
256 assert_eq!(ts_init_values.value(1), 4);
257 }
258
259 #[rstest]
260 fn test_decode_batch() {
261 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
262 let metadata = HashMap::from([
263 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
264 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
265 ]);
266
267 let value = FixedSizeBinaryArray::from(vec![
268 &(5000000 as PriceRaw).to_le_bytes(),
269 &(5100000 as PriceRaw).to_le_bytes(),
270 ]);
271 let ts_event = UInt64Array::from(vec![1, 2]);
272 let ts_init = UInt64Array::from(vec![3, 4]);
273
274 let record_batch = RecordBatch::try_new(
275 IndexPriceUpdate::get_schema(Some(metadata.clone())).into(),
276 vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
277 )
278 .unwrap();
279
280 let decoded_data = IndexPriceUpdate::decode_batch(&metadata, record_batch).unwrap();
281
282 assert_eq!(decoded_data.len(), 2);
283 assert_eq!(decoded_data[0].instrument_id, instrument_id);
284 assert_eq!(decoded_data[0].value, Price::from_raw(5000000, 2));
285 assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
286 assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
287
288 assert_eq!(decoded_data[1].instrument_id, instrument_id);
289 assert_eq!(decoded_data[1].value, Price::from_raw(5100000, 2));
290 assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
291 assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
292 }
293}