1use std::{
17 cell::RefCell,
18 collections::{HashMap, HashSet},
19 rc::Rc,
20 sync::Arc,
21};
22
23use datafusion::arrow::{
24 datatypes::Schema, error::ArrowError, ipc::writer::StreamWriter, record_batch::RecordBatch,
25};
26use nautilus_common::clock::Clock;
27use nautilus_core::UnixNanos;
28use nautilus_serialization::arrow::{EncodeToRecordBatch, KEY_INSTRUMENT_ID};
29use object_store::{ObjectStore, path::Path};
30
31use super::catalog::CatalogPathPrefix;
32
33#[derive(Debug, Default, PartialEq, PartialOrd, Hash, Eq, Clone)]
34pub struct FileWriterPath {
35 path: Path,
36 type_str: String,
37 instrument_id: Option<String>,
38}
39
40pub struct FeatherBuffer {
44 writer: StreamWriter<Vec<u8>>,
46 size: u64,
48 schema: Schema,
52 max_buffer_size: u64,
54 rotation_config: RotationConfig,
56}
57
58impl FeatherBuffer {
59 pub fn new(schema: &Schema, rotation_config: RotationConfig) -> Result<Self, ArrowError> {
61 let writer = StreamWriter::try_new(Vec::new(), schema)?;
62 let mut max_buffer_size = 1_000_000_000_000; if let RotationConfig::Size { max_size } = &rotation_config {
65 max_buffer_size = *max_size;
66 }
67
68 Ok(Self {
69 writer,
70 size: 0,
71 max_buffer_size,
73 schema: schema.clone(),
74 rotation_config,
75 })
76 }
77
78 pub fn write_record_batch(&mut self, batch: &RecordBatch) -> Result<bool, ArrowError> {
82 self.writer.write(batch)?;
83 self.size += batch.get_array_memory_size() as u64;
84 Ok(self.size >= self.max_buffer_size)
85 }
86
87 pub fn take_buffer(&mut self) -> Result<Vec<u8>, ArrowError> {
89 let mut writer = StreamWriter::try_new(Vec::new(), &self.schema)?;
90 std::mem::swap(&mut self.writer, &mut writer);
91 let buffer = writer.into_inner()?;
92 self.size = 0;
94 Ok(buffer)
95 }
96
97 #[must_use]
99 pub const fn should_rotate(&self) -> bool {
100 match &self.rotation_config {
101 RotationConfig::Size { max_size } => self.size >= *max_size,
102 _ => false,
103 }
104 }
105}
106
107#[derive(Debug, Clone)]
109pub enum RotationConfig {
110 Size {
112 max_size: u64,
114 },
115 Interval {
117 interval_ns: u64,
119 },
120 ScheduledDates {
122 interval_ns: u64,
124 schedule_ns: UnixNanos,
126 },
127 NoRotation,
129}
130
131pub struct FeatherWriter {
138 base_path: String,
140 store: Arc<dyn ObjectStore>,
142 clock: Rc<RefCell<dyn Clock>>,
144 rotation_config: RotationConfig,
146 included_types: Option<HashSet<String>>,
148 per_instrument_types: HashSet<String>,
150 writers: HashMap<FileWriterPath, FeatherBuffer>,
152}
153
154impl FeatherWriter {
155 pub fn new(
157 base_path: String,
158 store: Arc<dyn ObjectStore>,
159 clock: Rc<RefCell<dyn Clock>>,
160 rotation_config: RotationConfig,
161 included_types: Option<HashSet<String>>,
162 per_instrument_types: Option<HashSet<String>>,
163 ) -> Self {
164 Self {
165 base_path,
166 store,
167 clock,
168 rotation_config,
169 included_types,
170 per_instrument_types: per_instrument_types.unwrap_or_default(),
171 writers: HashMap::new(),
172 }
173 }
174
175 pub async fn write<T>(&mut self, data: T) -> Result<(), Box<dyn std::error::Error>>
180 where
181 T: EncodeToRecordBatch + CatalogPathPrefix + 'static,
182 {
183 if !self.should_write::<T>() {
184 return Ok(());
185 }
186
187 let path = self.get_writer_path(&data)?;
188
189 if !self.writers.contains_key(&path) {
191 self.create_writer::<T>(path.clone(), &data)?;
192 }
193
194 let batch = T::encode_batch(&T::metadata(&data), &[data])?;
196
197 if let Some(writer) = self.writers.get_mut(&path) {
199 let should_rotate = writer.write_record_batch(&batch)?;
200 if should_rotate {
201 self.rotate_writer(&path).await?;
202 }
203 }
204
205 Ok(())
206 }
207
208 async fn rotate_writer(
211 &mut self,
212 path: &FileWriterPath,
213 ) -> Result<(), Box<dyn std::error::Error>> {
214 let mut writer = self.writers.remove(path).unwrap();
215 let bytes = writer.take_buffer()?;
216 self.store.put(&path.path, bytes.into()).await?;
217 let new_path = self.regen_writer_path(path)?;
218 self.writers.insert(new_path, writer);
219 Ok(())
220 }
221
222 fn create_writer<T>(&mut self, path: FileWriterPath, data: &T) -> Result<(), ArrowError>
224 where
225 T: EncodeToRecordBatch + CatalogPathPrefix + 'static,
226 {
227 let schema = if self.per_instrument_types.contains(T::path_prefix()) {
228 let metadata = T::metadata(data);
229 T::get_schema(Some(metadata))
230 } else {
231 T::get_schema(None)
232 };
233
234 let writer = FeatherBuffer::new(&schema, self.rotation_config.clone())?;
235 self.writers.insert(path, writer);
236 Ok(())
237 }
238
239 pub async fn flush(&mut self) -> Result<(), Box<dyn std::error::Error>> {
244 for (path, mut writer) in self.writers.drain() {
245 let bytes = writer.take_buffer()?;
246 self.store.put(&path.path, bytes.into()).await?;
247 }
248 Ok(())
249 }
250
251 fn should_write<T: CatalogPathPrefix>(&self) -> bool {
253 self.included_types.as_ref().is_none_or(|included| {
254 let path = T::path_prefix();
255 included.contains(path)
256 })
257 }
258
259 fn regen_writer_path(
260 &self,
261 path: &FileWriterPath,
262 ) -> Result<FileWriterPath, Box<dyn std::error::Error>> {
263 let type_str = path.type_str.clone();
264 let instrument_id = path.instrument_id.clone();
265 let timestamp = self.clock.borrow().timestamp_ns();
266 let mut path = Path::from(self.base_path.clone());
268 if let Some(ref instrument_id) = instrument_id {
269 path = path.child(type_str.clone());
270 path = path.child(format!("{instrument_id}_{timestamp}.feather"));
271 } else {
272 path = path.child(format!("{type_str}_{timestamp}.feather"));
273 }
274
275 Ok(FileWriterPath {
276 path,
277 type_str,
278 instrument_id,
279 })
280 }
281
282 fn get_writer_path<T>(&self, data: &T) -> Result<FileWriterPath, Box<dyn std::error::Error>>
284 where
285 T: EncodeToRecordBatch + CatalogPathPrefix,
286 {
287 let type_str = T::path_prefix();
288 let instrument_id = self.per_instrument_types.contains(type_str).then(|| {
289 let metadata = T::metadata(data);
290 metadata
291 .get(KEY_INSTRUMENT_ID)
292 .cloned()
293 .expect("Data {type_str} expected instrument_id metadata for per instrument writer")
294 });
295
296 let timestamp = self.clock.borrow().timestamp_ns();
297 let mut path = Path::from(self.base_path.clone());
298 if let Some(ref instrument_id) = instrument_id {
299 path = path.child(type_str);
300 path = path.child(format!("{instrument_id}_{timestamp}.feather"));
301 } else {
302 path = path.child(format!("{type_str}_{timestamp}.feather"));
303 }
304
305 Ok(FileWriterPath {
306 path,
307 type_str: type_str.to_string(),
308 instrument_id,
309 })
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use std::{io::Cursor, sync::Arc};
316
317 use datafusion::arrow::ipc::reader::StreamReader;
318 use nautilus_common::clock::TestClock;
319 use nautilus_model::{
320 data::{Data, QuoteTick, TradeTick},
321 enums::AggressorSide,
322 identifiers::{InstrumentId, TradeId},
323 types::{Price, Quantity},
324 };
325 use nautilus_serialization::arrow::{
326 ArrowSchemaProvider, DecodeDataFromRecordBatch, EncodeToRecordBatch,
327 };
328 use object_store::{ObjectStore, local::LocalFileSystem};
329 use tempfile::TempDir;
330
331 use super::*;
332
333 #[tokio::test]
334 async fn test_writer_manager_keys() {
335 let temp_dir = TempDir::new().unwrap();
337 let base_path = temp_dir.path().to_str().unwrap().to_string();
338
339 let local_fs = LocalFileSystem::new_with_prefix(temp_dir.path()).unwrap();
341 let store: Arc<dyn ObjectStore> = Arc::new(local_fs);
342
343 let clock: Rc<RefCell<dyn Clock>> = Rc::new(RefCell::new(TestClock::new()));
345 let timestamp = clock.borrow().timestamp_ns();
346
347 let quote_type_str = QuoteTick::path_prefix();
348
349 let mut per_instrument = HashSet::new();
350 per_instrument.insert(quote_type_str.to_string());
351
352 let mut manager = FeatherWriter::new(
353 base_path.clone(),
354 store,
355 clock,
356 RotationConfig::NoRotation,
357 None,
358 Some(per_instrument),
359 );
360
361 let instrument_id = "AAPL.AAPL";
362 let quote = QuoteTick::new(
364 InstrumentId::from(instrument_id),
365 Price::from("100.0"),
366 Price::from("100.0"),
367 Quantity::from("100.0"),
368 Quantity::from("100.0"),
369 UnixNanos::from(1000000000000000000),
370 UnixNanos::from(1000000000000000000),
371 );
372
373 let trade = TradeTick::new(
374 InstrumentId::from(instrument_id),
375 Price::from("100.0"),
376 Quantity::from("100.0"),
377 AggressorSide::Buyer,
378 TradeId::from("1"),
379 UnixNanos::from(1000000000000000000),
380 UnixNanos::from(1000000000000000000),
381 );
382
383 manager.write(quote).await.unwrap();
384 manager.write(trade).await.unwrap();
385
386 let path = manager.get_writer_path("e).unwrap();
388 let expected_path = Path::from(format!(
389 "{base_path}/quotes/{instrument_id}_{timestamp}.feather"
390 ));
391 assert_eq!(path.path, expected_path);
392 assert!(manager.writers.contains_key(&path));
393 let writer = manager.writers.get(&path).unwrap();
394 assert!(writer.size > 0);
395
396 let path = manager.get_writer_path(&trade).unwrap();
397 let expected_path = Path::from(format!("{base_path}/trades_{timestamp}.feather"));
398 assert_eq!(path.path, expected_path);
399 assert!(manager.writers.contains_key(&path));
400 let writer = manager.writers.get(&path).unwrap();
401 assert!(writer.size > 0);
402 }
403
404 #[test]
405 fn test_file_writer_round_trip() {
406 let instrument_id = "AAPL.AAPL";
407 let quote = QuoteTick::new(
409 InstrumentId::from(instrument_id),
410 Price::from("100.0"),
411 Price::from("100.0"),
412 Quantity::from("100.0"),
413 Quantity::from("100.0"),
414 UnixNanos::from(100),
415 UnixNanos::from(100),
416 );
417 let metadata = QuoteTick::metadata("e);
418 let schema = QuoteTick::get_schema(Some(metadata.clone()));
419 let batch = QuoteTick::encode_batch(&QuoteTick::metadata("e), &[quote]).unwrap();
420
421 let mut writer = FeatherBuffer::new(&schema, RotationConfig::NoRotation).unwrap();
422 writer.write_record_batch(&batch).unwrap();
423
424 let buffer = writer.take_buffer().unwrap();
425 let mut reader = StreamReader::try_new(Cursor::new(buffer.as_slice()), None).unwrap();
426
427 let read_metadata = reader.schema().metadata().clone();
428 assert_eq!(read_metadata, metadata);
429
430 let read_batch = reader.next().unwrap().unwrap();
431 assert_eq!(read_batch.column(0), batch.column(0));
432
433 let decoded = QuoteTick::decode_data_batch(&metadata, batch).unwrap();
434 assert_eq!(decoded[0], Data::from(quote));
435 }
436
437 #[tokio::test]
438 async fn test_round_trip() {
439 let temp_dir = TempDir::new_in(".").unwrap();
441 let base_path = temp_dir.path().to_str().unwrap().to_string();
442
443 let local_fs = LocalFileSystem::new_with_prefix(&base_path).unwrap();
445 let store: Arc<dyn ObjectStore> = Arc::new(local_fs);
446
447 let clock: Rc<RefCell<dyn Clock>> = Rc::new(RefCell::new(TestClock::new()));
449
450 let quote_type_str = QuoteTick::path_prefix();
451 let trade_type_str = TradeTick::path_prefix();
452
453 let mut per_instrument = HashSet::new();
454 per_instrument.insert(quote_type_str.to_string());
455 per_instrument.insert(trade_type_str.to_string());
456
457 let mut manager = FeatherWriter::new(
458 base_path.clone(),
459 store,
460 clock,
461 RotationConfig::NoRotation,
462 None,
463 Some(per_instrument),
464 );
465
466 let instrument_id = "AAPL.AAPL";
467 let quote = QuoteTick::new(
469 InstrumentId::from(instrument_id),
470 Price::from("100.0"),
471 Price::from("100.0"),
472 Quantity::from("100.0"),
473 Quantity::from("100.0"),
474 UnixNanos::from(100),
475 UnixNanos::from(100),
476 );
477
478 let trade = TradeTick::new(
479 InstrumentId::from(instrument_id),
480 Price::from("100.0"),
481 Quantity::from("100.0"),
482 AggressorSide::Buyer,
483 TradeId::from("1"),
484 UnixNanos::from(100),
485 UnixNanos::from(100),
486 );
487
488 manager.write(quote).await.unwrap();
489 manager.write(trade).await.unwrap();
490
491 let paths = manager.writers.keys().cloned().collect::<Vec<_>>();
492 assert_eq!(paths.len(), 2);
493
494 manager.flush().await.unwrap();
496
497 let mut recovered_quotes = Vec::new();
499 let mut recovered_trades = Vec::new();
500 let local_fs = LocalFileSystem::new_with_prefix(&base_path).unwrap();
501 for path in paths {
502 let path_str = local_fs.path_to_filesystem(&path.path).unwrap();
503 let buffer = std::fs::File::open(&path_str).unwrap();
504 let reader = StreamReader::try_new(buffer, None).unwrap();
505 let metadata = reader.schema().metadata().clone();
506 for batch in reader {
507 let batch = batch.unwrap();
508 if path_str.to_str().unwrap().contains("quotes") {
509 let decoded = QuoteTick::decode_data_batch(&metadata, batch).unwrap();
510 recovered_quotes.extend(decoded);
511 } else if path_str.to_str().unwrap().contains("trades") {
512 let decoded = TradeTick::decode_data_batch(&metadata, batch).unwrap();
513 recovered_trades.extend(decoded);
514 }
515 }
516 }
517
518 assert_eq!(recovered_quotes.len(), 1, "Expected one QuoteTick record");
520 assert_eq!(recovered_trades.len(), 1, "Expected one TradeTick record");
521
522 assert_eq!(recovered_quotes[0], Data::from(quote));
524 assert_eq!(recovered_trades[0], Data::from(trade));
525 }
526}