nautilus_persistence/backend/
session.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, sync::Arc, vec::IntoIter};
17
18use compare::Compare;
19use datafusion::{
20    error::Result, logical_expr::expr::Sort, physical_plan::SendableRecordBatchStream, prelude::*,
21};
22use futures::StreamExt;
23use nautilus_core::{UnixNanos, ffi::cvec::CVec};
24use nautilus_model::data::{Data, HasTsInit};
25use nautilus_serialization::arrow::{
26    DataStreamingError, DecodeDataFromRecordBatch, EncodeToRecordBatch, WriteStream,
27};
28use object_store::ObjectStore;
29use url::Url;
30
31use super::kmerge_batch::{EagerStream, ElementBatchIter, KMerge};
32
33#[derive(Debug, Default)]
34pub struct TsInitComparator;
35
36impl<I> Compare<ElementBatchIter<I, Data>> for TsInitComparator
37where
38    I: Iterator<Item = IntoIter<Data>>,
39{
40    fn compare(
41        &self,
42        l: &ElementBatchIter<I, Data>,
43        r: &ElementBatchIter<I, Data>,
44    ) -> std::cmp::Ordering {
45        // Max heap ordering must be reversed
46        l.item.ts_init().cmp(&r.item.ts_init()).reverse()
47    }
48}
49
50pub type QueryResult = KMerge<EagerStream<std::vec::IntoIter<Data>>, Data, TsInitComparator>;
51
52/// Provides a DataFusion session and registers DataFusion queries.
53///
54/// The session is used to register data sources and make queries on them. A
55/// query returns a Chunk of Arrow records. It is decoded and converted into
56/// a Vec of data by types that implement [`DecodeDataFromRecordBatch`].
57#[cfg_attr(
58    feature = "python",
59    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence")
60)]
61pub struct DataBackendSession {
62    pub chunk_size: usize,
63    pub runtime: Arc<tokio::runtime::Runtime>,
64    session_ctx: SessionContext,
65    batch_streams: Vec<EagerStream<IntoIter<Data>>>,
66    registered_tables: std::collections::HashSet<String>,
67}
68
69impl DataBackendSession {
70    /// Creates a new [`DataBackendSession`] instance.
71    #[must_use]
72    pub fn new(chunk_size: usize) -> Self {
73        let runtime = tokio::runtime::Builder::new_multi_thread()
74            .enable_all()
75            .build()
76            .unwrap();
77        let session_cfg = SessionConfig::new()
78            .set_str("datafusion.optimizer.repartition_file_scans", "false")
79            .set_str("datafusion.optimizer.prefer_existing_sort", "true");
80        let session_ctx = SessionContext::new_with_config(session_cfg);
81        Self {
82            session_ctx,
83            batch_streams: Vec::default(),
84            chunk_size,
85            runtime: Arc::new(runtime),
86            registered_tables: std::collections::HashSet::new(),
87        }
88    }
89
90    /// Register an object store with the session context
91    pub fn register_object_store(&mut self, url: &Url, object_store: Arc<dyn ObjectStore>) {
92        self.session_ctx.register_object_store(url, object_store);
93    }
94
95    /// Register an object store with the session context from a URI with optional storage options
96    pub fn register_object_store_from_uri(
97        &mut self,
98        uri: &str,
99        storage_options: Option<std::collections::HashMap<String, String>>,
100    ) -> anyhow::Result<()> {
101        // Create object store from URI using the Rust implementation
102        let (object_store, _, _) =
103            crate::parquet::create_object_store_from_path(uri, storage_options)?;
104
105        // Parse the URI to get the base URL for registration
106        let parsed_uri = Url::parse(uri)?;
107
108        // Register the object store with the session
109        if matches!(
110            parsed_uri.scheme(),
111            "s3" | "gs" | "gcs" | "azure" | "abfs" | "http" | "https"
112        ) {
113            // For cloud storage, register with the base URL (scheme + netloc)
114            let base_url = format!(
115                "{}://{}",
116                parsed_uri.scheme(),
117                parsed_uri.host_str().unwrap_or("")
118            );
119            let base_parsed_url = Url::parse(&base_url)?;
120            self.register_object_store(&base_parsed_url, object_store);
121        }
122
123        Ok(())
124    }
125
126    pub fn write_data<T: EncodeToRecordBatch>(
127        data: &[T],
128        metadata: &HashMap<String, String>,
129        stream: &mut dyn WriteStream,
130    ) -> Result<(), DataStreamingError> {
131        let record_batch = T::encode_batch(metadata, data)?;
132        stream.write(&record_batch)?;
133        Ok(())
134    }
135
136    /// Query a file for its records. the caller must specify `T` to indicate
137    /// the kind of data expected from this query.
138    ///
139    /// `table_name`: Logical `table_name` assigned to this file. Queries to this file should address the
140    /// file by its table name.
141    /// `file_path`: Path to file
142    /// `sql_query`: A custom sql query to retrieve records from file. If no query is provided a default
143    /// query "SELECT * FROM <`table_name`>" is run.
144    ///
145    /// # Safety
146    ///
147    /// The file data must be ordered by the `ts_init` in ascending order for this
148    /// to work correctly.
149    pub fn add_file<T>(
150        &mut self,
151        table_name: &str,
152        file_path: &str,
153        sql_query: Option<&str>,
154    ) -> Result<()>
155    where
156        T: DecodeDataFromRecordBatch + Into<Data>,
157    {
158        // Check if table is already registered to avoid duplicates
159        let is_new_table = !self.registered_tables.contains(table_name);
160
161        if is_new_table {
162            // Register the table only if it doesn't exist
163            let parquet_options = ParquetReadOptions::<'_> {
164                skip_metadata: Some(false),
165                file_sort_order: vec![vec![Sort {
166                    expr: col("ts_init"),
167                    asc: true,
168                    nulls_first: false,
169                }]],
170                ..Default::default()
171            };
172            self.runtime.block_on(self.session_ctx.register_parquet(
173                table_name,
174                file_path,
175                parquet_options,
176            ))?;
177
178            self.registered_tables.insert(table_name.to_string());
179
180            // Only add batch stream for newly registered tables to avoid duplicates
181            let default_query = format!("SELECT * FROM {} ORDER BY ts_init", &table_name);
182            let sql_query = sql_query.unwrap_or(&default_query);
183            let query = self.runtime.block_on(self.session_ctx.sql(sql_query))?;
184            let batch_stream = self.runtime.block_on(query.execute_stream())?;
185            self.add_batch_stream::<T>(batch_stream);
186        }
187
188        Ok(())
189    }
190
191    fn add_batch_stream<T>(&mut self, stream: SendableRecordBatchStream)
192    where
193        T: DecodeDataFromRecordBatch + Into<Data>,
194    {
195        let transform = stream.map(|result| match result {
196            Ok(batch) => T::decode_data_batch(batch.schema().metadata(), batch)
197                .unwrap()
198                .into_iter(),
199            Err(e) => panic!("Error getting next batch from RecordBatchStream: {e}"),
200        });
201
202        self.batch_streams
203            .push(EagerStream::from_stream_with_runtime(
204                transform,
205                self.runtime.clone(),
206            ));
207    }
208
209    // Consumes the registered queries and returns a [`QueryResult].
210    // Passes the output of the query though the a KMerge which sorts the
211    // queries in ascending order of `ts_init`.
212    // QueryResult is an iterator that return Vec<Data>.
213    pub fn get_query_result(&mut self) -> QueryResult {
214        let mut kmerge: KMerge<_, _, _> = KMerge::new(TsInitComparator);
215
216        self.batch_streams
217            .drain(..)
218            .for_each(|eager_stream| kmerge.push_iter(eager_stream));
219
220        kmerge
221    }
222
223    /// Clears all registered tables and batch streams.
224    ///
225    /// This is useful when the underlying files have changed and we need to
226    /// re-register tables with updated data.
227    pub fn clear_registered_tables(&mut self) {
228        self.registered_tables.clear();
229        self.batch_streams.clear();
230
231        // Create a new session context to completely reset the DataFusion state
232        let session_cfg = SessionConfig::new()
233            .set_str("datafusion.optimizer.repartition_file_scans", "false")
234            .set_str("datafusion.optimizer.prefer_existing_sort", "true");
235        self.session_ctx = SessionContext::new_with_config(session_cfg);
236    }
237}
238
239// Note: Intended to be used on a single Python thread
240unsafe impl Send for DataBackendSession {}
241
242#[must_use]
243pub fn build_query(
244    table: &str,
245    start: Option<UnixNanos>,
246    end: Option<UnixNanos>,
247    where_clause: Option<&str>,
248) -> String {
249    let mut conditions = Vec::new();
250
251    // Add where clause if provided
252    if let Some(clause) = where_clause {
253        conditions.push(clause.to_string());
254    }
255
256    // Add start condition if provided
257    if let Some(start_ts) = start {
258        conditions.push(format!("ts_init >= {start_ts}"));
259    }
260
261    // Add end condition if provided
262    if let Some(end_ts) = end {
263        conditions.push(format!("ts_init <= {end_ts}"));
264    }
265
266    // Build base query
267    let mut query = format!("SELECT * FROM {table}");
268
269    // Add WHERE clause if there are conditions
270    if !conditions.is_empty() {
271        query.push_str(" WHERE ");
272        query.push_str(&conditions.join(" AND "));
273    }
274
275    // Add ORDER BY clause
276    query.push_str(" ORDER BY ts_init");
277
278    query
279}
280
281#[cfg_attr(
282    feature = "python",
283    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
284)]
285pub struct DataQueryResult {
286    pub chunk: Option<CVec>,
287    pub result: QueryResult,
288    pub acc: Vec<Data>,
289    pub size: usize,
290}
291
292impl DataQueryResult {
293    /// Creates a new [`DataQueryResult`] instance.
294    #[must_use]
295    pub const fn new(result: QueryResult, size: usize) -> Self {
296        Self {
297            chunk: None,
298            result,
299            acc: Vec::new(),
300            size,
301        }
302    }
303
304    /// Set new `CVec` backed chunk from data
305    ///
306    /// It also drops previously allocated chunk
307    pub fn set_chunk(&mut self, data: Vec<Data>) -> CVec {
308        self.drop_chunk();
309
310        let chunk: CVec = data.into();
311        self.chunk = Some(chunk);
312        chunk
313    }
314
315    /// Chunks generated by iteration must be dropped after use, otherwise
316    /// it will leak memory. Current chunk is held by the reader,
317    /// drop if exists and reset the field.
318    pub fn drop_chunk(&mut self) {
319        if let Some(CVec { ptr, len, cap }) = self.chunk.take() {
320            let data: Vec<Data> =
321                unsafe { Vec::from_raw_parts(ptr.cast::<nautilus_model::data::Data>(), len, cap) };
322            drop(data);
323        }
324    }
325}
326
327impl Iterator for DataQueryResult {
328    type Item = Vec<Data>;
329
330    fn next(&mut self) -> Option<Self::Item> {
331        for _ in 0..self.size {
332            match self.result.next() {
333                Some(item) => self.acc.push(item),
334                None => break,
335            }
336        }
337
338        // TODO: consider using drain here if perf is unchanged
339        // Some(self.acc.drain(0..).collect())
340        let mut acc: Vec<Data> = Vec::new();
341        std::mem::swap(&mut acc, &mut self.acc);
342        Some(acc)
343    }
344}
345
346impl Drop for DataQueryResult {
347    fn drop(&mut self) {
348        self.drop_chunk();
349        self.result.clear();
350    }
351}
352
353// Note: Intended to be used on a single Python thread
354unsafe impl Send for DataQueryResult {}