nautilus_persistence/backend/
session.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{sync::Arc, vec::IntoIter};
17
18use ahash::{AHashMap, AHashSet};
19use compare::Compare;
20use datafusion::{
21    error::Result, logical_expr::expr::Sort, physical_plan::SendableRecordBatchStream, prelude::*,
22};
23use futures::StreamExt;
24use nautilus_core::{UnixNanos, ffi::cvec::CVec};
25use nautilus_model::data::{Data, HasTsInit};
26use nautilus_serialization::arrow::{
27    DataStreamingError, DecodeDataFromRecordBatch, EncodeToRecordBatch, WriteStream,
28};
29use object_store::ObjectStore;
30use url::Url;
31
32use super::kmerge_batch::{EagerStream, ElementBatchIter, KMerge};
33
34#[derive(Debug, Default)]
35pub struct TsInitComparator;
36
37impl<I> Compare<ElementBatchIter<I, Data>> for TsInitComparator
38where
39    I: Iterator<Item = IntoIter<Data>>,
40{
41    fn compare(
42        &self,
43        l: &ElementBatchIter<I, Data>,
44        r: &ElementBatchIter<I, Data>,
45    ) -> std::cmp::Ordering {
46        // Max heap ordering must be reversed
47        l.item.ts_init().cmp(&r.item.ts_init()).reverse()
48    }
49}
50
51pub type QueryResult = KMerge<EagerStream<std::vec::IntoIter<Data>>, Data, TsInitComparator>;
52
53/// Provides a DataFusion session and registers DataFusion queries.
54///
55/// The session is used to register data sources and make queries on them. A
56/// query returns a Chunk of Arrow records. It is decoded and converted into
57/// a Vec of data by types that implement [`DecodeDataFromRecordBatch`].
58#[cfg_attr(
59    feature = "python",
60    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence")
61)]
62pub struct DataBackendSession {
63    pub chunk_size: usize,
64    pub runtime: Arc<tokio::runtime::Runtime>,
65    session_ctx: SessionContext,
66    batch_streams: Vec<EagerStream<IntoIter<Data>>>,
67    registered_tables: AHashSet<String>,
68}
69
70impl DataBackendSession {
71    /// Creates a new [`DataBackendSession`] instance.
72    #[must_use]
73    pub fn new(chunk_size: usize) -> Self {
74        let runtime = tokio::runtime::Builder::new_multi_thread()
75            .enable_all()
76            .build()
77            .unwrap();
78        let session_cfg = SessionConfig::new()
79            .set_str("datafusion.optimizer.repartition_file_scans", "false")
80            .set_str("datafusion.optimizer.prefer_existing_sort", "true");
81        let session_ctx = SessionContext::new_with_config(session_cfg);
82        Self {
83            session_ctx,
84            batch_streams: Vec::default(),
85            chunk_size,
86            runtime: Arc::new(runtime),
87            registered_tables: AHashSet::new(),
88        }
89    }
90
91    /// Register an object store with the session context
92    pub fn register_object_store(&mut self, url: &Url, object_store: Arc<dyn ObjectStore>) {
93        self.session_ctx.register_object_store(url, object_store);
94    }
95
96    /// Register an object store with the session context from a URI with optional storage options
97    pub fn register_object_store_from_uri(
98        &mut self,
99        uri: &str,
100        storage_options: Option<AHashMap<String, String>>,
101    ) -> anyhow::Result<()> {
102        // Create object store from URI using the Rust implementation
103        let (object_store, _, _) =
104            crate::parquet::create_object_store_from_path(uri, storage_options)?;
105
106        // Parse the URI to get the base URL for registration
107        let parsed_uri = Url::parse(uri)?;
108
109        // Register the object store with the session
110        if matches!(
111            parsed_uri.scheme(),
112            "s3" | "gs" | "gcs" | "az" | "abfs" | "http" | "https"
113        ) {
114            // For cloud storage, register with the base URL (scheme + netloc)
115            let base_url = format!(
116                "{}://{}",
117                parsed_uri.scheme(),
118                parsed_uri.host_str().unwrap_or("")
119            );
120            let base_parsed_url = Url::parse(&base_url)?;
121            self.register_object_store(&base_parsed_url, object_store);
122        }
123
124        Ok(())
125    }
126
127    pub fn write_data<T: EncodeToRecordBatch>(
128        data: &[T],
129        metadata: &AHashMap<String, String>,
130        stream: &mut dyn WriteStream,
131    ) -> Result<(), DataStreamingError> {
132        // Convert AHashMap to HashMap for Arrow compatibility
133        let metadata: std::collections::HashMap<String, String> = metadata
134            .iter()
135            .map(|(k, v)| (k.clone(), v.clone()))
136            .collect();
137        let record_batch = T::encode_batch(&metadata, data)?;
138        stream.write(&record_batch)?;
139        Ok(())
140    }
141
142    /// Query a file for its records. the caller must specify `T` to indicate
143    /// the kind of data expected from this query.
144    ///
145    /// `table_name`: Logical `table_name` assigned to this file. Queries to this file should address the
146    /// file by its table name.
147    /// `file_path`: Path to file
148    /// `sql_query`: A custom sql query to retrieve records from file. If no query is provided a default
149    /// query "SELECT * FROM <`table_name`>" is run.
150    ///
151    /// # Safety
152    ///
153    /// The file data must be ordered by the `ts_init` in ascending order for this
154    /// to work correctly.
155    pub fn add_file<T>(
156        &mut self,
157        table_name: &str,
158        file_path: &str,
159        sql_query: Option<&str>,
160    ) -> Result<()>
161    where
162        T: DecodeDataFromRecordBatch + Into<Data>,
163    {
164        // Check if table is already registered to avoid duplicates
165        let is_new_table = !self.registered_tables.contains(table_name);
166
167        if is_new_table {
168            // Register the table only if it doesn't exist
169            let parquet_options = ParquetReadOptions::<'_> {
170                skip_metadata: Some(false),
171                file_sort_order: vec![vec![Sort {
172                    expr: col("ts_init"),
173                    asc: true,
174                    nulls_first: false,
175                }]],
176                ..Default::default()
177            };
178            self.runtime.block_on(self.session_ctx.register_parquet(
179                table_name,
180                file_path,
181                parquet_options,
182            ))?;
183
184            self.registered_tables.insert(table_name.to_string());
185
186            // Only add batch stream for newly registered tables to avoid duplicates
187            let default_query = format!("SELECT * FROM {} ORDER BY ts_init", &table_name);
188            let sql_query = sql_query.unwrap_or(&default_query);
189            let query = self.runtime.block_on(self.session_ctx.sql(sql_query))?;
190            let batch_stream = self.runtime.block_on(query.execute_stream())?;
191            self.add_batch_stream::<T>(batch_stream);
192        }
193
194        Ok(())
195    }
196
197    fn add_batch_stream<T>(&mut self, stream: SendableRecordBatchStream)
198    where
199        T: DecodeDataFromRecordBatch + Into<Data>,
200    {
201        let transform = stream.map(|result| match result {
202            Ok(batch) => T::decode_data_batch(batch.schema().metadata(), batch)
203                .unwrap()
204                .into_iter(),
205            Err(e) => panic!("Error getting next batch from RecordBatchStream: {e}"),
206        });
207
208        self.batch_streams
209            .push(EagerStream::from_stream_with_runtime(
210                transform,
211                self.runtime.clone(),
212            ));
213    }
214
215    // Consumes the registered queries and returns a [`QueryResult].
216    // Passes the output of the query though the a KMerge which sorts the
217    // queries in ascending order of `ts_init`.
218    // QueryResult is an iterator that return Vec<Data>.
219    pub fn get_query_result(&mut self) -> QueryResult {
220        let mut kmerge: KMerge<_, _, _> = KMerge::new(TsInitComparator);
221
222        self.batch_streams
223            .drain(..)
224            .for_each(|eager_stream| kmerge.push_iter(eager_stream));
225
226        kmerge
227    }
228
229    /// Clears all registered tables and batch streams.
230    ///
231    /// This is useful when the underlying files have changed and we need to
232    /// re-register tables with updated data.
233    pub fn clear_registered_tables(&mut self) {
234        self.registered_tables.clear();
235        self.batch_streams.clear();
236
237        // Create a new session context to completely reset the DataFusion state
238        let session_cfg = SessionConfig::new()
239            .set_str("datafusion.optimizer.repartition_file_scans", "false")
240            .set_str("datafusion.optimizer.prefer_existing_sort", "true");
241        self.session_ctx = SessionContext::new_with_config(session_cfg);
242    }
243}
244
245// SAFETY: DataBackendSession contains non-Send types but implements Send to satisfy
246// PyO3 trait bounds. It must only be used on a single Python thread.
247// WARNING: Actually sending this type across threads is undefined behavior.
248#[allow(unsafe_code)]
249unsafe impl Send for DataBackendSession {}
250
251#[must_use]
252pub fn build_query(
253    table: &str,
254    start: Option<UnixNanos>,
255    end: Option<UnixNanos>,
256    where_clause: Option<&str>,
257) -> String {
258    let mut conditions = Vec::new();
259
260    // Add where clause if provided
261    if let Some(clause) = where_clause {
262        conditions.push(clause.to_string());
263    }
264
265    // Add start condition if provided
266    if let Some(start_ts) = start {
267        conditions.push(format!("ts_init >= {start_ts}"));
268    }
269
270    // Add end condition if provided
271    if let Some(end_ts) = end {
272        conditions.push(format!("ts_init <= {end_ts}"));
273    }
274
275    // Build base query
276    let mut query = format!("SELECT * FROM {table}");
277
278    // Add WHERE clause if there are conditions
279    if !conditions.is_empty() {
280        query.push_str(" WHERE ");
281        query.push_str(&conditions.join(" AND "));
282    }
283
284    // Add ORDER BY clause
285    query.push_str(" ORDER BY ts_init");
286
287    query
288}
289
290#[cfg_attr(
291    feature = "python",
292    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
293)]
294pub struct DataQueryResult {
295    pub chunk: Option<CVec>,
296    pub result: QueryResult,
297    pub acc: Vec<Data>,
298    pub size: usize,
299}
300
301impl DataQueryResult {
302    /// Creates a new [`DataQueryResult`] instance.
303    #[must_use]
304    pub const fn new(result: QueryResult, size: usize) -> Self {
305        Self {
306            chunk: None,
307            result,
308            acc: Vec::new(),
309            size,
310        }
311    }
312
313    /// Set new `CVec` backed chunk from data
314    ///
315    /// It also drops previously allocated chunk
316    pub fn set_chunk(&mut self, data: Vec<Data>) -> CVec {
317        self.drop_chunk();
318
319        let chunk: CVec = data.into();
320        self.chunk = Some(chunk);
321        chunk
322    }
323
324    /// Chunks generated by iteration must be dropped after use, otherwise
325    /// it will leak memory. Current chunk is held by the reader,
326    /// drop if exists and reset the field.
327    pub fn drop_chunk(&mut self) {
328        if let Some(CVec { ptr, len, cap }) = self.chunk.take() {
329            assert!(
330                len <= cap,
331                "drop_chunk: len ({len}) > cap ({cap}) - memory corruption or wrong chunk type"
332            );
333            assert!(
334                len == 0 || !ptr.is_null(),
335                "drop_chunk: null ptr with non-zero len ({len}) - memory corruption"
336            );
337
338            let data: Vec<Data> = unsafe { Vec::from_raw_parts(ptr.cast::<Data>(), len, cap) };
339            drop(data);
340        }
341    }
342}
343
344impl Iterator for DataQueryResult {
345    type Item = Vec<Data>;
346
347    fn next(&mut self) -> Option<Self::Item> {
348        for _ in 0..self.size {
349            match self.result.next() {
350                Some(item) => self.acc.push(item),
351                None => break,
352            }
353        }
354
355        // TODO: consider using drain here if perf is unchanged
356        // Some(self.acc.drain(0..).collect())
357        let mut acc: Vec<Data> = Vec::new();
358        std::mem::swap(&mut acc, &mut self.acc);
359        Some(acc)
360    }
361}
362
363impl Drop for DataQueryResult {
364    fn drop(&mut self) {
365        self.drop_chunk();
366        self.result.clear();
367    }
368}
369
370// SAFETY: DataQueryResult contains non-Send types but implements Send to satisfy
371// PyO3 trait bounds. It must only be used on a single Python thread.
372// WARNING: Actually sending this type across threads is undefined behavior.
373#[allow(unsafe_code)]
374unsafe impl Send for DataQueryResult {}