nautilus_persistence/backend/
session.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{sync::Arc, vec::IntoIter};
17
18use ahash::{AHashMap, AHashSet};
19use datafusion::{
20    error::Result, logical_expr::expr::Sort, physical_plan::SendableRecordBatchStream, prelude::*,
21};
22use futures::StreamExt;
23use nautilus_core::{UnixNanos, ffi::cvec::CVec};
24use nautilus_model::data::{Data, HasTsInit};
25use nautilus_serialization::arrow::{
26    DataStreamingError, DecodeDataFromRecordBatch, EncodeToRecordBatch, WriteStream,
27};
28use object_store::ObjectStore;
29use url::Url;
30
31use super::{
32    compare::Compare,
33    kmerge_batch::{EagerStream, ElementBatchIter, KMerge},
34};
35
36#[derive(Debug, Default)]
37pub struct TsInitComparator;
38
39impl<I> Compare<ElementBatchIter<I, Data>> for TsInitComparator
40where
41    I: Iterator<Item = IntoIter<Data>>,
42{
43    fn compare(
44        &self,
45        l: &ElementBatchIter<I, Data>,
46        r: &ElementBatchIter<I, Data>,
47    ) -> std::cmp::Ordering {
48        // Max heap ordering must be reversed
49        l.item.ts_init().cmp(&r.item.ts_init()).reverse()
50    }
51}
52
53pub type QueryResult = KMerge<EagerStream<std::vec::IntoIter<Data>>, Data, TsInitComparator>;
54
55/// Provides a DataFusion session and registers DataFusion queries.
56///
57/// The session is used to register data sources and make queries on them. A
58/// query returns a Chunk of Arrow records. It is decoded and converted into
59/// a Vec of data by types that implement [`DecodeDataFromRecordBatch`].
60#[cfg_attr(
61    feature = "python",
62    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
63)]
64pub struct DataBackendSession {
65    pub chunk_size: usize,
66    pub runtime: Arc<tokio::runtime::Runtime>,
67    session_ctx: SessionContext,
68    batch_streams: Vec<EagerStream<IntoIter<Data>>>,
69    registered_tables: AHashSet<String>,
70}
71
72impl DataBackendSession {
73    /// Creates a new [`DataBackendSession`] instance.
74    #[must_use]
75    pub fn new(chunk_size: usize) -> Self {
76        let runtime = tokio::runtime::Builder::new_multi_thread()
77            .enable_all()
78            .build()
79            .unwrap();
80        let session_cfg = SessionConfig::new()
81            .set_str("datafusion.optimizer.repartition_file_scans", "false")
82            .set_str("datafusion.optimizer.prefer_existing_sort", "true");
83        let session_ctx = SessionContext::new_with_config(session_cfg);
84        Self {
85            session_ctx,
86            batch_streams: Vec::default(),
87            chunk_size,
88            runtime: Arc::new(runtime),
89            registered_tables: AHashSet::new(),
90        }
91    }
92
93    /// Register an object store with the session context
94    pub fn register_object_store(&mut self, url: &Url, object_store: Arc<dyn ObjectStore>) {
95        self.session_ctx.register_object_store(url, object_store);
96    }
97
98    /// Register an object store with the session context from a URI with optional storage options
99    pub fn register_object_store_from_uri(
100        &mut self,
101        uri: &str,
102        storage_options: Option<AHashMap<String, String>>,
103    ) -> anyhow::Result<()> {
104        // Create object store from URI using the Rust implementation
105        let (object_store, _, _) =
106            crate::parquet::create_object_store_from_path(uri, storage_options)?;
107
108        // Parse the URI to get the base URL for registration
109        let parsed_uri = Url::parse(uri)?;
110
111        // Register the object store with the session
112        if matches!(
113            parsed_uri.scheme(),
114            "s3" | "gs" | "gcs" | "az" | "abfs" | "http" | "https"
115        ) {
116            // For cloud storage, register with the base URL (scheme + netloc)
117            let base_url = format!(
118                "{}://{}",
119                parsed_uri.scheme(),
120                parsed_uri.host_str().unwrap_or("")
121            );
122            let base_parsed_url = Url::parse(&base_url)?;
123            self.register_object_store(&base_parsed_url, object_store);
124        }
125
126        Ok(())
127    }
128
129    pub fn write_data<T: EncodeToRecordBatch>(
130        data: &[T],
131        metadata: &AHashMap<String, String>,
132        stream: &mut dyn WriteStream,
133    ) -> Result<(), DataStreamingError> {
134        // Convert AHashMap to HashMap for Arrow compatibility
135        let metadata: std::collections::HashMap<String, String> = metadata
136            .iter()
137            .map(|(k, v)| (k.clone(), v.clone()))
138            .collect();
139        let record_batch = T::encode_batch(&metadata, data)?;
140        stream.write(&record_batch)?;
141        Ok(())
142    }
143
144    /// Query a file for its records. the caller must specify `T` to indicate
145    /// the kind of data expected from this query.
146    ///
147    /// `table_name`: Logical `table_name` assigned to this file. Queries to this file should address the
148    /// file by its table name.
149    /// `file_path`: Path to file
150    /// `sql_query`: A custom sql query to retrieve records from file. If no query is provided a default
151    /// query "SELECT * FROM <`table_name`>" is run.
152    ///
153    /// # Safety
154    ///
155    /// The file data must be ordered by the `ts_init` in ascending order for this
156    /// to work correctly.
157    pub fn add_file<T>(
158        &mut self,
159        table_name: &str,
160        file_path: &str,
161        sql_query: Option<&str>,
162    ) -> Result<()>
163    where
164        T: DecodeDataFromRecordBatch + Into<Data>,
165    {
166        // Check if table is already registered to avoid duplicates
167        let is_new_table = !self.registered_tables.contains(table_name);
168
169        if is_new_table {
170            // Register the table only if it doesn't exist
171            let parquet_options = ParquetReadOptions::<'_> {
172                skip_metadata: Some(false),
173                file_sort_order: vec![vec![Sort {
174                    expr: col("ts_init"),
175                    asc: true,
176                    nulls_first: false,
177                }]],
178                ..Default::default()
179            };
180            self.runtime.block_on(self.session_ctx.register_parquet(
181                table_name,
182                file_path,
183                parquet_options,
184            ))?;
185
186            self.registered_tables.insert(table_name.to_string());
187
188            // Only add batch stream for newly registered tables to avoid duplicates
189            let default_query = format!("SELECT * FROM {} ORDER BY ts_init", &table_name);
190            let sql_query = sql_query.unwrap_or(&default_query);
191            let query = self.runtime.block_on(self.session_ctx.sql(sql_query))?;
192            let batch_stream = self.runtime.block_on(query.execute_stream())?;
193            self.add_batch_stream::<T>(batch_stream);
194        }
195
196        Ok(())
197    }
198
199    fn add_batch_stream<T>(&mut self, stream: SendableRecordBatchStream)
200    where
201        T: DecodeDataFromRecordBatch + Into<Data>,
202    {
203        let transform = stream.map(|result| match result {
204            Ok(batch) => T::decode_data_batch(batch.schema().metadata(), batch)
205                .unwrap()
206                .into_iter(),
207            Err(e) => panic!("Error getting next batch from RecordBatchStream: {e}"),
208        });
209
210        self.batch_streams
211            .push(EagerStream::from_stream_with_runtime(
212                transform,
213                self.runtime.clone(),
214            ));
215    }
216
217    // Consumes the registered queries and returns a [`QueryResult].
218    // Passes the output of the query though the a KMerge which sorts the
219    // queries in ascending order of `ts_init`.
220    // QueryResult is an iterator that return Vec<Data>.
221    pub fn get_query_result(&mut self) -> QueryResult {
222        let mut kmerge: KMerge<_, _, _> = KMerge::new(TsInitComparator);
223
224        self.batch_streams
225            .drain(..)
226            .for_each(|eager_stream| kmerge.push_iter(eager_stream));
227
228        kmerge
229    }
230
231    /// Clears all registered tables and batch streams.
232    ///
233    /// This is useful when the underlying files have changed and we need to
234    /// re-register tables with updated data.
235    pub fn clear_registered_tables(&mut self) {
236        self.registered_tables.clear();
237        self.batch_streams.clear();
238
239        // Create a new session context to completely reset the DataFusion state
240        let session_cfg = SessionConfig::new()
241            .set_str("datafusion.optimizer.repartition_file_scans", "false")
242            .set_str("datafusion.optimizer.prefer_existing_sort", "true");
243        self.session_ctx = SessionContext::new_with_config(session_cfg);
244    }
245}
246
247#[must_use]
248pub fn build_query(
249    table: &str,
250    start: Option<UnixNanos>,
251    end: Option<UnixNanos>,
252    where_clause: Option<&str>,
253) -> String {
254    let mut conditions = Vec::new();
255
256    // Add where clause if provided
257    if let Some(clause) = where_clause {
258        conditions.push(clause.to_string());
259    }
260
261    // Add start condition if provided
262    if let Some(start_ts) = start {
263        conditions.push(format!("ts_init >= {start_ts}"));
264    }
265
266    // Add end condition if provided
267    if let Some(end_ts) = end {
268        conditions.push(format!("ts_init <= {end_ts}"));
269    }
270
271    // Build base query
272    let mut query = format!("SELECT * FROM {table}");
273
274    // Add WHERE clause if there are conditions
275    if !conditions.is_empty() {
276        query.push_str(" WHERE ");
277        query.push_str(&conditions.join(" AND "));
278    }
279
280    // Add ORDER BY clause
281    query.push_str(" ORDER BY ts_init");
282
283    query
284}
285
286#[cfg_attr(
287    feature = "python",
288    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
289)]
290pub struct DataQueryResult {
291    pub chunk: Option<CVec>,
292    pub result: QueryResult,
293    pub acc: Vec<Data>,
294    pub size: usize,
295}
296
297impl DataQueryResult {
298    /// Creates a new [`DataQueryResult`] instance.
299    #[must_use]
300    pub const fn new(result: QueryResult, size: usize) -> Self {
301        Self {
302            chunk: None,
303            result,
304            acc: Vec::new(),
305            size,
306        }
307    }
308
309    /// Set new `CVec` backed chunk from data
310    ///
311    /// It also drops previously allocated chunk
312    pub fn set_chunk(&mut self, data: Vec<Data>) -> CVec {
313        self.drop_chunk();
314
315        let chunk: CVec = data.into();
316        self.chunk = Some(chunk);
317        chunk
318    }
319
320    /// Chunks generated by iteration must be dropped after use, otherwise
321    /// it will leak memory. Current chunk is held by the reader,
322    /// drop if exists and reset the field.
323    pub fn drop_chunk(&mut self) {
324        if let Some(CVec { ptr, len, cap }) = self.chunk.take() {
325            assert!(
326                len <= cap,
327                "drop_chunk: len ({len}) > cap ({cap}) - memory corruption or wrong chunk type"
328            );
329            assert!(
330                len == 0 || !ptr.is_null(),
331                "drop_chunk: null ptr with non-zero len ({len}) - memory corruption"
332            );
333
334            let data: Vec<Data> = unsafe { Vec::from_raw_parts(ptr.cast::<Data>(), len, cap) };
335            drop(data);
336        }
337    }
338}
339
340impl Iterator for DataQueryResult {
341    type Item = Vec<Data>;
342
343    fn next(&mut self) -> Option<Self::Item> {
344        for _ in 0..self.size {
345            match self.result.next() {
346                Some(item) => self.acc.push(item),
347                None => break,
348            }
349        }
350
351        // TODO: consider using drain here if perf is unchanged
352        // Some(self.acc.drain(0..).collect())
353        let mut acc: Vec<Data> = Vec::new();
354        std::mem::swap(&mut acc, &mut self.acc);
355        Some(acc)
356    }
357}
358
359impl Drop for DataQueryResult {
360    fn drop(&mut self) {
361        self.drop_chunk();
362        self.result.clear();
363    }
364}