1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::prices::IndexPriceUpdate,
26 identifiers::InstrumentId,
27 types::{Price, fixed::PRECISION_BYTES},
28};
29
30use super::{
31 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
32 extract_column,
33};
34use crate::arrow::{
35 ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
36};
37
38impl ArrowSchemaProvider for IndexPriceUpdate {
39 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
40 let fields = vec![
41 Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
42 Field::new("ts_event", DataType::UInt64, false),
43 Field::new("ts_init", DataType::UInt64, false),
44 ];
45
46 match metadata {
47 Some(metadata) => Schema::new_with_metadata(fields, metadata),
48 None => Schema::new(fields),
49 }
50 }
51}
52
53fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
54 let instrument_id_str = metadata
55 .get(KEY_INSTRUMENT_ID)
56 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
57 let instrument_id = InstrumentId::from_str(instrument_id_str)
58 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
59
60 let price_precision = metadata
61 .get(KEY_PRICE_PRECISION)
62 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
63 .parse::<u8>()
64 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
65
66 Ok((instrument_id, price_precision))
67}
68
69impl EncodeToRecordBatch for IndexPriceUpdate {
70 fn encode_batch(
71 metadata: &HashMap<String, String>,
72 data: &[Self],
73 ) -> Result<RecordBatch, ArrowError> {
74 let mut value_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
75 let mut ts_event_builder = UInt64Array::builder(data.len());
76 let mut ts_init_builder = UInt64Array::builder(data.len());
77
78 for update in data {
79 value_builder
80 .append_value(update.value.raw.to_le_bytes())
81 .unwrap();
82 ts_event_builder.append_value(update.ts_event.as_u64());
83 ts_init_builder.append_value(update.ts_init.as_u64());
84 }
85
86 RecordBatch::try_new(
87 Self::get_schema(Some(metadata.clone())).into(),
88 vec![
89 Arc::new(value_builder.finish()),
90 Arc::new(ts_event_builder.finish()),
91 Arc::new(ts_init_builder.finish()),
92 ],
93 )
94 }
95
96 fn metadata(&self) -> HashMap<String, String> {
97 let mut metadata = HashMap::new();
98 metadata.insert(
99 KEY_INSTRUMENT_ID.to_string(),
100 self.instrument_id.to_string(),
101 );
102 metadata.insert(
103 KEY_PRICE_PRECISION.to_string(),
104 self.value.precision.to_string(),
105 );
106 metadata
107 }
108}
109
110impl DecodeFromRecordBatch for IndexPriceUpdate {
111 fn decode_batch(
112 metadata: &HashMap<String, String>,
113 record_batch: RecordBatch,
114 ) -> Result<Vec<Self>, EncodingError> {
115 let (instrument_id, price_precision) = parse_metadata(metadata)?;
116 let cols = record_batch.columns();
117
118 let value_values = extract_column::<FixedSizeBinaryArray>(
119 cols,
120 "value",
121 0,
122 DataType::FixedSizeBinary(PRECISION_BYTES),
123 )?;
124 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 1, DataType::UInt64)?;
125 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 2, DataType::UInt64)?;
126
127 assert_eq!(
128 value_values.value_length(),
129 PRECISION_BYTES,
130 "Price precision uses {PRECISION_BYTES} byte value"
131 );
132
133 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
134 .map(|row| {
135 Ok(Self {
136 instrument_id,
137 value: Price::from_raw(get_raw_price(value_values.value(row)), price_precision),
138 ts_event: ts_event_values.value(row).into(),
139 ts_init: ts_init_values.value(row).into(),
140 })
141 })
142 .collect();
143
144 result
145 }
146}
147
148impl DecodeDataFromRecordBatch for IndexPriceUpdate {
149 fn decode_data_batch(
150 metadata: &HashMap<String, String>,
151 record_batch: RecordBatch,
152 ) -> Result<Vec<Data>, EncodingError> {
153 let updates: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
154 Ok(updates.into_iter().map(Data::from).collect())
155 }
156}
157
158#[cfg(test)]
162mod tests {
163 use std::sync::Arc;
164
165 use arrow::{array::Array, record_batch::RecordBatch};
166 use nautilus_model::types::{fixed::FIXED_SCALAR, price::PriceRaw};
167 use rstest::rstest;
168
169 use super::*;
170 use crate::arrow::get_raw_price;
171
172 #[rstest]
173 fn test_get_schema() {
174 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
175 let metadata = HashMap::from([
176 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
177 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
178 ]);
179 let schema = IndexPriceUpdate::get_schema(Some(metadata.clone()));
180
181 let expected_fields = vec![
182 Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
183 Field::new("ts_event", DataType::UInt64, false),
184 Field::new("ts_init", DataType::UInt64, false),
185 ];
186
187 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
188 assert_eq!(schema, expected_schema);
189 }
190
191 #[rstest]
192 fn test_get_schema_map() {
193 let schema_map = IndexPriceUpdate::get_schema_map();
194 let mut expected_map = HashMap::new();
195
196 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
197 expected_map.insert("value".to_string(), fixed_size_binary);
198 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
199 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
200 assert_eq!(schema_map, expected_map);
201 }
202
203 #[rstest]
204 fn test_encode_batch() {
205 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
206 let metadata = HashMap::from([
207 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
208 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
209 ]);
210
211 let update1 = IndexPriceUpdate {
212 instrument_id,
213 value: Price::from("50000.00"),
214 ts_event: 1.into(),
215 ts_init: 3.into(),
216 };
217
218 let update2 = IndexPriceUpdate {
219 instrument_id,
220 value: Price::from("51000.00"),
221 ts_event: 2.into(),
222 ts_init: 4.into(),
223 };
224
225 let data = vec![update1, update2];
226 let record_batch = IndexPriceUpdate::encode_batch(&metadata, &data).unwrap();
227
228 let columns = record_batch.columns();
229 let value_values = columns[0]
230 .as_any()
231 .downcast_ref::<FixedSizeBinaryArray>()
232 .unwrap();
233 let ts_event_values = columns[1].as_any().downcast_ref::<UInt64Array>().unwrap();
234 let ts_init_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
235
236 assert_eq!(columns.len(), 3);
237 assert_eq!(value_values.len(), 2);
238 assert_eq!(
239 get_raw_price(value_values.value(0)),
240 (50000.00 * FIXED_SCALAR) as PriceRaw
241 );
242 assert_eq!(
243 get_raw_price(value_values.value(1)),
244 (51000.00 * FIXED_SCALAR) as PriceRaw
245 );
246 assert_eq!(ts_event_values.len(), 2);
247 assert_eq!(ts_event_values.value(0), 1);
248 assert_eq!(ts_event_values.value(1), 2);
249 assert_eq!(ts_init_values.len(), 2);
250 assert_eq!(ts_init_values.value(0), 3);
251 assert_eq!(ts_init_values.value(1), 4);
252 }
253
254 #[rstest]
255 fn test_decode_batch() {
256 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
257 let metadata = HashMap::from([
258 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
259 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
260 ]);
261
262 let value = FixedSizeBinaryArray::from(vec![
263 &(5000000 as PriceRaw).to_le_bytes(),
264 &(5100000 as PriceRaw).to_le_bytes(),
265 ]);
266 let ts_event = UInt64Array::from(vec![1, 2]);
267 let ts_init = UInt64Array::from(vec![3, 4]);
268
269 let record_batch = RecordBatch::try_new(
270 IndexPriceUpdate::get_schema(Some(metadata.clone())).into(),
271 vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
272 )
273 .unwrap();
274
275 let decoded_data = IndexPriceUpdate::decode_batch(&metadata, record_batch).unwrap();
276
277 assert_eq!(decoded_data.len(), 2);
278 assert_eq!(decoded_data[0].instrument_id, instrument_id);
279 assert_eq!(decoded_data[0].value, Price::from_raw(5000000, 2));
280 assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
281 assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
282
283 assert_eq!(decoded_data[1].instrument_id, instrument_id);
284 assert_eq!(decoded_data[1].value, Price::from_raw(5100000, 2));
285 assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
286 assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
287 }
288}