nautilus_persistence/backend/
feather.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    cell::RefCell,
18    collections::{HashMap, HashSet},
19    rc::Rc,
20    sync::Arc,
21};
22
23use datafusion::arrow::{
24    datatypes::Schema, error::ArrowError, ipc::writer::StreamWriter, record_batch::RecordBatch,
25};
26use nautilus_common::clock::Clock;
27use nautilus_core::UnixNanos;
28use nautilus_serialization::arrow::{EncodeToRecordBatch, KEY_INSTRUMENT_ID};
29use object_store::{ObjectStore, path::Path};
30
31use super::catalog::CatalogPathPrefix;
32
33#[derive(Debug, Default, PartialEq, PartialOrd, Hash, Eq, Clone)]
34pub struct FileWriterPath {
35    path: Path,
36    type_str: String,
37    instrument_id: Option<String>,
38}
39
40/// A `FeatherBuffer` encodes data via an Arrow `StreamWriter`.
41///
42/// It flushes the internal byte buffer according to rotation policy.
43pub struct FeatherBuffer {
44    /// Arrow `StreamWriter` that writes to an in-memory `Vec<u8>`.
45    writer: StreamWriter<Vec<u8>>,
46    /// Current size in bytes.
47    size: u64,
48    /// TODO: Optional next rotation timestamp.
49    // next_rotation: Option<UnixNanos>,
50    /// Schema of the data being written.
51    schema: Schema,
52    /// Maximum buffer size in bytes.
53    max_buffer_size: u64,
54    /// Rotation config
55    rotation_config: RotationConfig,
56}
57
58impl FeatherBuffer {
59    /// Creates a new `FileWriter` using the given path, schema and maximum buffer size.
60    pub fn new(schema: &Schema, rotation_config: RotationConfig) -> Result<Self, ArrowError> {
61        let writer = StreamWriter::try_new(Vec::new(), schema)?;
62        let mut max_buffer_size = 1_000_000_000_000; // 1 GB
63
64        if let RotationConfig::Size { max_size } = &rotation_config {
65            max_buffer_size = *max_size;
66        }
67
68        Ok(Self {
69            writer,
70            size: 0,
71            // next_rotation: None,
72            max_buffer_size,
73            schema: schema.clone(),
74            rotation_config,
75        })
76    }
77
78    /// Writes the given `RecordBatch` to the internal buffer.
79    ///
80    /// Returns true if it should be rotated according rotation policy
81    pub fn write_record_batch(&mut self, batch: &RecordBatch) -> Result<bool, ArrowError> {
82        self.writer.write(batch)?;
83        self.size += batch.get_array_memory_size() as u64;
84        Ok(self.size >= self.max_buffer_size)
85    }
86
87    /// Consumes the writer and returns the buffer of bytes from the `StreamWriter`
88    pub fn take_buffer(&mut self) -> Result<Vec<u8>, ArrowError> {
89        let mut writer = StreamWriter::try_new(Vec::new(), &self.schema)?;
90        std::mem::swap(&mut self.writer, &mut writer);
91        let buffer = writer.into_inner()?;
92        // TODO: Handle rotation config here
93        self.size = 0;
94        Ok(buffer)
95    }
96
97    /// Should rotate
98    #[must_use]
99    pub const fn should_rotate(&self) -> bool {
100        match &self.rotation_config {
101            RotationConfig::Size { max_size } => self.size >= *max_size,
102            _ => false,
103        }
104    }
105}
106
107/// Configuration for file rotation.
108#[derive(Debug, Clone)]
109pub enum RotationConfig {
110    /// Rotate based on file size.
111    Size {
112        /// Maximum buffer size in bytes before rotation.
113        max_size: u64,
114    },
115    /// Rotate based on a time interval.
116    Interval {
117        /// Interval in nanoseconds.
118        interval_ns: u64,
119    },
120    /// Rotate based on scheduled dates.
121    ScheduledDates {
122        /// Interval in nanoseconds.
123        interval_ns: u64,
124        /// Start of the scheduled rotation period.
125        schedule_ns: UnixNanos,
126    },
127    /// No automatic rotation.
128    NoRotation,
129}
130
131/// Manages multiple `FeatherBuffers` and handles encoding, rotation, and flushing to the object store.
132///
133/// The `write()` method is the single entry point for clients: they supply a data value (of generic type T)
134/// and the manager encodes it (using T's metadata via `EncodeToRecordBatch`), routes it by `CatalogPathPrefix`,
135/// and writes it to the appropriate `FileWriter`. When a writer's buffer is full or rotation criteria are met,
136/// its contents are flushed to the object store and it is replaced.
137pub struct FeatherWriter {
138    /// Base directory for writing files.
139    base_path: String,
140    /// Object store for persistence.
141    store: Arc<dyn ObjectStore>,
142    /// Clock for timestamps and rotation.
143    clock: Rc<RefCell<dyn Clock>>,
144    /// Rotation configuration.
145    rotation_config: RotationConfig,
146    /// Optional set of type names to include.
147    included_types: Option<HashSet<String>>,
148    /// Set of types that should be split by instrument.
149    per_instrument_types: HashSet<String>,
150    /// Map of active `FeatherBuffers` keyed by their path.
151    writers: HashMap<FileWriterPath, FeatherBuffer>,
152}
153
154impl FeatherWriter {
155    /// Creates a new `FileWriterManager` instance.
156    pub fn new(
157        base_path: String,
158        store: Arc<dyn ObjectStore>,
159        clock: Rc<RefCell<dyn Clock>>,
160        rotation_config: RotationConfig,
161        included_types: Option<HashSet<String>>,
162        per_instrument_types: Option<HashSet<String>>,
163    ) -> Self {
164        Self {
165            base_path,
166            store,
167            clock,
168            rotation_config,
169            included_types,
170            per_instrument_types: per_instrument_types.unwrap_or_default(),
171            writers: HashMap::new(),
172        }
173    }
174
175    /// Writes a single data value.
176    /// This is the user entry point. The data is encoded into a `RecordBatch` and written to the appropriate `FileWriter`.
177    /// If the writer's buffer reaches capacity or meets rotation criteria (based on the rotation configuration),
178    /// the `FileWriter` is flushed to the object store and replaced.
179    pub async fn write<T>(&mut self, data: T) -> Result<(), Box<dyn std::error::Error>>
180    where
181        T: EncodeToRecordBatch + CatalogPathPrefix + 'static,
182    {
183        if !self.should_write::<T>() {
184            return Ok(());
185        }
186
187        let path = self.get_writer_path(&data)?;
188
189        // Create a new FileWriter if one does not exist.
190        if !self.writers.contains_key(&path) {
191            self.create_writer::<T>(path.clone(), &data)?;
192        }
193
194        // Encode the data into a RecordBatch using T's encoding logic.
195        let batch = T::encode_batch(&T::metadata(&data), &[data])?;
196
197        // Write the RecordBatch to the appropriate FileWriter.
198        if let Some(writer) = self.writers.get_mut(&path) {
199            let should_rotate = writer.write_record_batch(&batch)?;
200            if should_rotate {
201                self.rotate_writer(&path).await?;
202            }
203        }
204
205        Ok(())
206    }
207
208    /// Flushes and rotates `FileWriter` associated with `key`.
209    /// TODO: Fix error type to handle arrow error and object store error
210    async fn rotate_writer(
211        &mut self,
212        path: &FileWriterPath,
213    ) -> Result<(), Box<dyn std::error::Error>> {
214        let mut writer = self.writers.remove(path).unwrap();
215        let bytes = writer.take_buffer()?;
216        self.store.put(&path.path, bytes.into()).await?;
217        let new_path = self.regen_writer_path(path)?;
218        self.writers.insert(new_path, writer);
219        Ok(())
220    }
221
222    /// Creates (and inserts) a new `FileWriter` for type T.
223    fn create_writer<T>(&mut self, path: FileWriterPath, data: &T) -> Result<(), ArrowError>
224    where
225        T: EncodeToRecordBatch + CatalogPathPrefix + 'static,
226    {
227        let schema = if self.per_instrument_types.contains(T::path_prefix()) {
228            let metadata = T::metadata(data);
229            T::get_schema(Some(metadata))
230        } else {
231            T::get_schema(None)
232        };
233
234        let writer = FeatherBuffer::new(&schema, self.rotation_config.clone())?;
235        self.writers.insert(path, writer);
236        Ok(())
237    }
238
239    /// Flushes all active `FeatherBuffers` by writing any remaining buffered bytes to the object store.
240    ///
241    /// Note: This is not called automatically and must be called by the client.
242    /// It is expected that no other writes are performed after this.
243    pub async fn flush(&mut self) -> Result<(), Box<dyn std::error::Error>> {
244        for (path, mut writer) in self.writers.drain() {
245            let bytes = writer.take_buffer()?;
246            self.store.put(&path.path, bytes.into()).await?;
247        }
248        Ok(())
249    }
250
251    /// Determines whether type T should be written, based on the inclusion filter.
252    fn should_write<T: CatalogPathPrefix>(&self) -> bool {
253        self.included_types.as_ref().is_none_or(|included| {
254            let path = T::path_prefix();
255            included.contains(path)
256        })
257    }
258
259    fn regen_writer_path(
260        &self,
261        path: &FileWriterPath,
262    ) -> Result<FileWriterPath, Box<dyn std::error::Error>> {
263        let type_str = path.type_str.clone();
264        let instrument_id = path.instrument_id.clone();
265        let timestamp = self.clock.borrow().timestamp_ns();
266        // Note: Path removes prefixing slashes
267        let mut path = Path::from(self.base_path.clone());
268        if let Some(ref instrument_id) = instrument_id {
269            path = path.child(type_str.clone());
270            path = path.child(format!("{instrument_id}_{timestamp}.feather"));
271        } else {
272            path = path.child(format!("{type_str}_{timestamp}.feather"));
273        }
274
275        Ok(FileWriterPath {
276            path,
277            type_str,
278            instrument_id,
279        })
280    }
281
282    /// Generates a key for a `FileWriter` based on type T and optional instrument ID.
283    fn get_writer_path<T>(&self, data: &T) -> Result<FileWriterPath, Box<dyn std::error::Error>>
284    where
285        T: EncodeToRecordBatch + CatalogPathPrefix,
286    {
287        let type_str = T::path_prefix();
288        let instrument_id = self.per_instrument_types.contains(type_str).then(|| {
289            let metadata = T::metadata(data);
290            metadata
291                .get(KEY_INSTRUMENT_ID)
292                .cloned()
293                .expect("Data {type_str} expected instrument_id metadata for per instrument writer")
294        });
295
296        let timestamp = self.clock.borrow().timestamp_ns();
297        let mut path = Path::from(self.base_path.clone());
298        if let Some(ref instrument_id) = instrument_id {
299            path = path.child(type_str);
300            path = path.child(format!("{instrument_id}_{timestamp}.feather"));
301        } else {
302            path = path.child(format!("{type_str}_{timestamp}.feather"));
303        }
304
305        Ok(FileWriterPath {
306            path,
307            type_str: type_str.to_string(),
308            instrument_id,
309        })
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use std::{io::Cursor, sync::Arc};
316
317    use datafusion::arrow::ipc::reader::StreamReader;
318    use nautilus_common::clock::TestClock;
319    use nautilus_model::{
320        data::{Data, QuoteTick, TradeTick},
321        enums::AggressorSide,
322        identifiers::{InstrumentId, TradeId},
323        types::{Price, Quantity},
324    };
325    use nautilus_serialization::arrow::{
326        ArrowSchemaProvider, DecodeDataFromRecordBatch, EncodeToRecordBatch,
327    };
328    use object_store::{ObjectStore, local::LocalFileSystem};
329    use tempfile::TempDir;
330
331    use super::*;
332
333    #[tokio::test]
334    async fn test_writer_manager_keys() {
335        // Create a temporary directory for base path
336        let temp_dir = TempDir::new().unwrap();
337        let base_path = temp_dir.path().to_str().unwrap().to_string();
338
339        // Create a LocalFileSystem based object store using the temp directory
340        let local_fs = LocalFileSystem::new_with_prefix(temp_dir.path()).unwrap();
341        let store: Arc<dyn ObjectStore> = Arc::new(local_fs);
342
343        // Create a test clock
344        let clock: Rc<RefCell<dyn Clock>> = Rc::new(RefCell::new(TestClock::new()));
345        let timestamp = clock.borrow().timestamp_ns();
346
347        let quote_type_str = QuoteTick::path_prefix();
348
349        let mut per_instrument = HashSet::new();
350        per_instrument.insert(quote_type_str.to_string());
351
352        let mut manager = FeatherWriter::new(
353            base_path.clone(),
354            store,
355            clock,
356            RotationConfig::NoRotation,
357            None,
358            Some(per_instrument),
359        );
360
361        let instrument_id = "AAPL.AAPL";
362        // Write a dummy value
363        let quote = QuoteTick::new(
364            InstrumentId::from(instrument_id),
365            Price::from("100.0"),
366            Price::from("100.0"),
367            Quantity::from("100.0"),
368            Quantity::from("100.0"),
369            UnixNanos::from(1000000000000000000),
370            UnixNanos::from(1000000000000000000),
371        );
372
373        let trade = TradeTick::new(
374            InstrumentId::from(instrument_id),
375            Price::from("100.0"),
376            Quantity::from("100.0"),
377            AggressorSide::Buyer,
378            TradeId::from("1"),
379            UnixNanos::from(1000000000000000000),
380            UnixNanos::from(1000000000000000000),
381        );
382
383        manager.write(quote).await.unwrap();
384        manager.write(trade).await.unwrap();
385
386        // Check keys and paths for quotes and trades
387        let path = manager.get_writer_path(&quote).unwrap();
388        let expected_path = Path::from(format!(
389            "{base_path}/quotes/{instrument_id}_{timestamp}.feather"
390        ));
391        assert_eq!(path.path, expected_path);
392        assert!(manager.writers.contains_key(&path));
393        let writer = manager.writers.get(&path).unwrap();
394        assert!(writer.size > 0);
395
396        let path = manager.get_writer_path(&trade).unwrap();
397        let expected_path = Path::from(format!("{base_path}/trades_{timestamp}.feather"));
398        assert_eq!(path.path, expected_path);
399        assert!(manager.writers.contains_key(&path));
400        let writer = manager.writers.get(&path).unwrap();
401        assert!(writer.size > 0);
402    }
403
404    #[test]
405    fn test_file_writer_round_trip() {
406        let instrument_id = "AAPL.AAPL";
407        // Write a dummy value.
408        let quote = QuoteTick::new(
409            InstrumentId::from(instrument_id),
410            Price::from("100.0"),
411            Price::from("100.0"),
412            Quantity::from("100.0"),
413            Quantity::from("100.0"),
414            UnixNanos::from(100),
415            UnixNanos::from(100),
416        );
417        let metadata = QuoteTick::metadata(&quote);
418        let schema = QuoteTick::get_schema(Some(metadata.clone()));
419        let batch = QuoteTick::encode_batch(&QuoteTick::metadata(&quote), &[quote]).unwrap();
420
421        let mut writer = FeatherBuffer::new(&schema, RotationConfig::NoRotation).unwrap();
422        writer.write_record_batch(&batch).unwrap();
423
424        let buffer = writer.take_buffer().unwrap();
425        let mut reader = StreamReader::try_new(Cursor::new(buffer.as_slice()), None).unwrap();
426
427        let read_metadata = reader.schema().metadata().clone();
428        assert_eq!(read_metadata, metadata);
429
430        let read_batch = reader.next().unwrap().unwrap();
431        assert_eq!(read_batch.column(0), batch.column(0));
432
433        let decoded = QuoteTick::decode_data_batch(&metadata, batch).unwrap();
434        assert_eq!(decoded[0], Data::from(quote));
435    }
436
437    #[tokio::test]
438    async fn test_round_trip() {
439        // Create a temporary directory for base path
440        let temp_dir = TempDir::new_in(".").unwrap();
441        let base_path = temp_dir.path().to_str().unwrap().to_string();
442
443        // Create a LocalFileSystem based object store using the temp directory
444        let local_fs = LocalFileSystem::new_with_prefix(&base_path).unwrap();
445        let store: Arc<dyn ObjectStore> = Arc::new(local_fs);
446
447        // Create a test clock
448        let clock: Rc<RefCell<dyn Clock>> = Rc::new(RefCell::new(TestClock::new()));
449
450        let quote_type_str = QuoteTick::path_prefix();
451        let trade_type_str = TradeTick::path_prefix();
452
453        let mut per_instrument = HashSet::new();
454        per_instrument.insert(quote_type_str.to_string());
455        per_instrument.insert(trade_type_str.to_string());
456
457        let mut manager = FeatherWriter::new(
458            base_path.clone(),
459            store,
460            clock,
461            RotationConfig::NoRotation,
462            None,
463            Some(per_instrument),
464        );
465
466        let instrument_id = "AAPL.AAPL";
467        // Write a dummy value.
468        let quote = QuoteTick::new(
469            InstrumentId::from(instrument_id),
470            Price::from("100.0"),
471            Price::from("100.0"),
472            Quantity::from("100.0"),
473            Quantity::from("100.0"),
474            UnixNanos::from(100),
475            UnixNanos::from(100),
476        );
477
478        let trade = TradeTick::new(
479            InstrumentId::from(instrument_id),
480            Price::from("100.0"),
481            Quantity::from("100.0"),
482            AggressorSide::Buyer,
483            TradeId::from("1"),
484            UnixNanos::from(100),
485            UnixNanos::from(100),
486        );
487
488        manager.write(quote).await.unwrap();
489        manager.write(trade).await.unwrap();
490
491        let paths = manager.writers.keys().cloned().collect::<Vec<_>>();
492        assert_eq!(paths.len(), 2);
493
494        // Flush data
495        manager.flush().await.unwrap();
496
497        // Read files from the temporary directory
498        let mut recovered_quotes = Vec::new();
499        let mut recovered_trades = Vec::new();
500        let local_fs = LocalFileSystem::new_with_prefix(&base_path).unwrap();
501        for path in paths {
502            let path_str = local_fs.path_to_filesystem(&path.path).unwrap();
503            let buffer = std::fs::File::open(&path_str).unwrap();
504            let reader = StreamReader::try_new(buffer, None).unwrap();
505            let metadata = reader.schema().metadata().clone();
506            for batch in reader {
507                let batch = batch.unwrap();
508                if path_str.to_str().unwrap().contains("quotes") {
509                    let decoded = QuoteTick::decode_data_batch(&metadata, batch).unwrap();
510                    recovered_quotes.extend(decoded);
511                } else if path_str.to_str().unwrap().contains("trades") {
512                    let decoded = TradeTick::decode_data_batch(&metadata, batch).unwrap();
513                    recovered_trades.extend(decoded);
514                }
515            }
516        }
517
518        // Assert that the recovered data matches the written data
519        assert_eq!(recovered_quotes.len(), 1, "Expected one QuoteTick record");
520        assert_eq!(recovered_trades.len(), 1, "Expected one TradeTick record");
521
522        // Check key fields to ensure the data round-tripped correctly
523        assert_eq!(recovered_quotes[0], Data::from(quote));
524        assert_eq!(recovered_trades[0], Data::from(trade));
525    }
526}