nautilus_persistence/backend/
session.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, sync::Arc, vec::IntoIter};
17
18use compare::Compare;
19use datafusion::{
20    error::Result, logical_expr::expr::Sort, physical_plan::SendableRecordBatchStream, prelude::*,
21};
22use futures::StreamExt;
23use nautilus_core::{UnixNanos, ffi::cvec::CVec};
24use nautilus_model::data::{Data, GetTsInit};
25use nautilus_serialization::arrow::{
26    DataStreamingError, DecodeDataFromRecordBatch, EncodeToRecordBatch, WriteStream,
27};
28
29use super::kmerge_batch::{EagerStream, ElementBatchIter, KMerge};
30
31#[derive(Debug, Default)]
32pub struct TsInitComparator;
33
34impl<I> Compare<ElementBatchIter<I, Data>> for TsInitComparator
35where
36    I: Iterator<Item = IntoIter<Data>>,
37{
38    fn compare(
39        &self,
40        l: &ElementBatchIter<I, Data>,
41        r: &ElementBatchIter<I, Data>,
42    ) -> std::cmp::Ordering {
43        // Max heap ordering must be reversed
44        l.item.ts_init().cmp(&r.item.ts_init()).reverse()
45    }
46}
47
48pub type QueryResult = KMerge<EagerStream<std::vec::IntoIter<Data>>, Data, TsInitComparator>;
49
50/// Provides a DataFusion session and registers DataFusion queries.
51///
52/// The session is used to register data sources and make queries on them. A
53/// query returns a Chunk of Arrow records. It is decoded and converted into
54/// a Vec of data by types that implement [`DecodeDataFromRecordBatch`].
55#[cfg_attr(
56    feature = "python",
57    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence")
58)]
59pub struct DataBackendSession {
60    pub chunk_size: usize,
61    pub runtime: Arc<tokio::runtime::Runtime>,
62    session_ctx: SessionContext,
63    batch_streams: Vec<EagerStream<IntoIter<Data>>>,
64}
65
66impl DataBackendSession {
67    /// Creates a new [`DataBackendSession`] instance.
68    #[must_use]
69    pub fn new(chunk_size: usize) -> Self {
70        let runtime = tokio::runtime::Builder::new_multi_thread()
71            .enable_all()
72            .build()
73            .unwrap();
74        let session_cfg = SessionConfig::new()
75            .set_str("datafusion.optimizer.repartition_file_scans", "false")
76            .set_str("datafusion.optimizer.prefer_existing_sort", "true");
77        let session_ctx = SessionContext::new_with_config(session_cfg);
78        Self {
79            session_ctx,
80            batch_streams: Vec::default(),
81            chunk_size,
82            runtime: Arc::new(runtime),
83        }
84    }
85
86    pub fn write_data<T: EncodeToRecordBatch>(
87        data: &[T],
88        metadata: &HashMap<String, String>,
89        stream: &mut dyn WriteStream,
90    ) -> Result<(), DataStreamingError> {
91        let record_batch = T::encode_batch(metadata, data)?;
92        stream.write(&record_batch)?;
93        Ok(())
94    }
95
96    /// Query a file for its records. the caller must specify `T` to indicate
97    /// the kind of data expected from this query.
98    ///
99    /// `table_name`: Logical `table_name` assigned to this file. Queries to this file should address the
100    /// file by its table name.
101    /// `file_path`: Path to file
102    /// `sql_query`: A custom sql query to retrieve records from file. If no query is provided a default
103    /// query "SELECT * FROM <`table_name`>" is run.
104    ///
105    /// # Safety
106    ///
107    /// The file data must be ordered by the `ts_init` in ascending order for this
108    /// to work correctly.
109    pub fn add_file<T>(
110        &mut self,
111        table_name: &str,
112        file_path: &str,
113        sql_query: Option<&str>,
114    ) -> Result<()>
115    where
116        T: DecodeDataFromRecordBatch + Into<Data>,
117    {
118        let parquet_options = ParquetReadOptions::<'_> {
119            skip_metadata: Some(false),
120            file_sort_order: vec![vec![Sort {
121                expr: col("ts_init"),
122                asc: true,
123                nulls_first: false,
124            }]],
125            ..Default::default()
126        };
127        self.runtime.block_on(self.session_ctx.register_parquet(
128            table_name,
129            file_path,
130            parquet_options,
131        ))?;
132
133        let default_query = format!("SELECT * FROM {} ORDER BY ts_init", &table_name);
134        let sql_query = sql_query.unwrap_or(&default_query);
135        let query = self.runtime.block_on(self.session_ctx.sql(sql_query))?;
136
137        let batch_stream = self.runtime.block_on(query.execute_stream())?;
138
139        self.add_batch_stream::<T>(batch_stream);
140        Ok(())
141    }
142
143    fn add_batch_stream<T>(&mut self, stream: SendableRecordBatchStream)
144    where
145        T: DecodeDataFromRecordBatch + Into<Data>,
146    {
147        let transform = stream.map(|result| match result {
148            Ok(batch) => T::decode_data_batch(batch.schema().metadata(), batch)
149                .unwrap()
150                .into_iter(),
151            Err(e) => panic!("Error getting next batch from RecordBatchStream: {e}"),
152        });
153
154        self.batch_streams
155            .push(EagerStream::from_stream_with_runtime(
156                transform,
157                self.runtime.clone(),
158            ));
159    }
160
161    // Consumes the registered queries and returns a [`QueryResult].
162    // Passes the output of the query though the a KMerge which sorts the
163    // queries in ascending order of `ts_init`.
164    // QueryResult is an iterator that return Vec<Data>.
165    pub fn get_query_result(&mut self) -> QueryResult {
166        let mut kmerge: KMerge<_, _, _> = KMerge::new(TsInitComparator);
167
168        self.batch_streams
169            .drain(..)
170            .for_each(|eager_stream| kmerge.push_iter(eager_stream));
171
172        kmerge
173    }
174}
175
176// Note: Intended to be used on a single Python thread
177unsafe impl Send for DataBackendSession {}
178
179#[must_use]
180pub fn build_query(
181    table: &str,
182    start: Option<UnixNanos>,
183    end: Option<UnixNanos>,
184    where_clause: Option<&str>,
185) -> String {
186    let mut conditions = Vec::new();
187
188    // Add where clause if provided
189    if let Some(clause) = where_clause {
190        conditions.push(clause.to_string());
191    }
192
193    // Add start condition if provided
194    if let Some(start_ts) = start {
195        conditions.push(format!("ts_init >= {start_ts}"));
196    }
197
198    // Add end condition if provided
199    if let Some(end_ts) = end {
200        conditions.push(format!("ts_init <= {end_ts}"));
201    }
202
203    // Build base query
204    let mut query = format!("SELECT * FROM {table}");
205
206    // Add WHERE clause if there are conditions
207    if !conditions.is_empty() {
208        query.push_str(" WHERE ");
209        query.push_str(&conditions.join(" AND "));
210    }
211
212    // Add ORDER BY clause
213    query.push_str(" ORDER BY ts_init");
214
215    query
216}
217
218#[cfg_attr(
219    feature = "python",
220    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
221)]
222pub struct DataQueryResult {
223    pub chunk: Option<CVec>,
224    pub result: QueryResult,
225    pub acc: Vec<Data>,
226    pub size: usize,
227}
228
229impl DataQueryResult {
230    /// Creates a new [`DataQueryResult`] instance.
231    #[must_use]
232    pub const fn new(result: QueryResult, size: usize) -> Self {
233        Self {
234            chunk: None,
235            result,
236            acc: Vec::new(),
237            size,
238        }
239    }
240
241    /// Set new `CVec` backed chunk from data
242    ///
243    /// It also drops previously allocated chunk
244    pub fn set_chunk(&mut self, data: Vec<Data>) -> CVec {
245        self.drop_chunk();
246
247        let chunk: CVec = data.into();
248        self.chunk = Some(chunk);
249        chunk
250    }
251
252    /// Chunks generated by iteration must be dropped after use, otherwise
253    /// it will leak memory. Current chunk is held by the reader,
254    /// drop if exists and reset the field.
255    pub fn drop_chunk(&mut self) {
256        if let Some(CVec { ptr, len, cap }) = self.chunk.take() {
257            let data: Vec<Data> =
258                unsafe { Vec::from_raw_parts(ptr.cast::<nautilus_model::data::Data>(), len, cap) };
259            drop(data);
260        }
261    }
262}
263
264impl Iterator for DataQueryResult {
265    type Item = Vec<Data>;
266
267    fn next(&mut self) -> Option<Self::Item> {
268        for _ in 0..self.size {
269            match self.result.next() {
270                Some(item) => self.acc.push(item),
271                None => break,
272            }
273        }
274
275        // TODO: consider using drain here if perf is unchanged
276        // Some(self.acc.drain(0..).collect())
277        let mut acc: Vec<Data> = Vec::new();
278        std::mem::swap(&mut acc, &mut self.acc);
279        Some(acc)
280    }
281}
282
283impl Drop for DataQueryResult {
284    fn drop(&mut self) {
285        self.drop_chunk();
286        self.result.clear();
287    }
288}
289
290// Note: Intended to be used on a single Python thread
291unsafe impl Send for DataQueryResult {}