nautilus_persistence/backend/
session.rs1use std::{collections::HashMap, sync::Arc, vec::IntoIter};
17
18use compare::Compare;
19use datafusion::{
20 error::Result, logical_expr::expr::Sort, physical_plan::SendableRecordBatchStream, prelude::*,
21};
22use futures::StreamExt;
23use nautilus_core::{UnixNanos, ffi::cvec::CVec};
24use nautilus_model::data::{Data, GetTsInit};
25use nautilus_serialization::arrow::{
26 DataStreamingError, DecodeDataFromRecordBatch, EncodeToRecordBatch, WriteStream,
27};
28
29use super::kmerge_batch::{EagerStream, ElementBatchIter, KMerge};
30
31#[derive(Debug, Default)]
32pub struct TsInitComparator;
33
34impl<I> Compare<ElementBatchIter<I, Data>> for TsInitComparator
35where
36 I: Iterator<Item = IntoIter<Data>>,
37{
38 fn compare(
39 &self,
40 l: &ElementBatchIter<I, Data>,
41 r: &ElementBatchIter<I, Data>,
42 ) -> std::cmp::Ordering {
43 l.item.ts_init().cmp(&r.item.ts_init()).reverse()
45 }
46}
47
48pub type QueryResult = KMerge<EagerStream<std::vec::IntoIter<Data>>, Data, TsInitComparator>;
49
50#[cfg_attr(
56 feature = "python",
57 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence")
58)]
59pub struct DataBackendSession {
60 pub chunk_size: usize,
61 pub runtime: Arc<tokio::runtime::Runtime>,
62 session_ctx: SessionContext,
63 batch_streams: Vec<EagerStream<IntoIter<Data>>>,
64}
65
66impl DataBackendSession {
67 #[must_use]
69 pub fn new(chunk_size: usize) -> Self {
70 let runtime = tokio::runtime::Builder::new_multi_thread()
71 .enable_all()
72 .build()
73 .unwrap();
74 let session_cfg = SessionConfig::new()
75 .set_str("datafusion.optimizer.repartition_file_scans", "false")
76 .set_str("datafusion.optimizer.prefer_existing_sort", "true");
77 let session_ctx = SessionContext::new_with_config(session_cfg);
78 Self {
79 session_ctx,
80 batch_streams: Vec::default(),
81 chunk_size,
82 runtime: Arc::new(runtime),
83 }
84 }
85
86 pub fn write_data<T: EncodeToRecordBatch>(
87 data: &[T],
88 metadata: &HashMap<String, String>,
89 stream: &mut dyn WriteStream,
90 ) -> Result<(), DataStreamingError> {
91 let record_batch = T::encode_batch(metadata, data)?;
92 stream.write(&record_batch)?;
93 Ok(())
94 }
95
96 pub fn add_file<T>(
110 &mut self,
111 table_name: &str,
112 file_path: &str,
113 sql_query: Option<&str>,
114 ) -> Result<()>
115 where
116 T: DecodeDataFromRecordBatch + Into<Data>,
117 {
118 let parquet_options = ParquetReadOptions::<'_> {
119 skip_metadata: Some(false),
120 file_sort_order: vec![vec![Sort {
121 expr: col("ts_init"),
122 asc: true,
123 nulls_first: false,
124 }]],
125 ..Default::default()
126 };
127 self.runtime.block_on(self.session_ctx.register_parquet(
128 table_name,
129 file_path,
130 parquet_options,
131 ))?;
132
133 let default_query = format!("SELECT * FROM {} ORDER BY ts_init", &table_name);
134 let sql_query = sql_query.unwrap_or(&default_query);
135 let query = self.runtime.block_on(self.session_ctx.sql(sql_query))?;
136
137 let batch_stream = self.runtime.block_on(query.execute_stream())?;
138
139 self.add_batch_stream::<T>(batch_stream);
140 Ok(())
141 }
142
143 fn add_batch_stream<T>(&mut self, stream: SendableRecordBatchStream)
144 where
145 T: DecodeDataFromRecordBatch + Into<Data>,
146 {
147 let transform = stream.map(|result| match result {
148 Ok(batch) => T::decode_data_batch(batch.schema().metadata(), batch)
149 .unwrap()
150 .into_iter(),
151 Err(e) => panic!("Error getting next batch from RecordBatchStream: {e}"),
152 });
153
154 self.batch_streams
155 .push(EagerStream::from_stream_with_runtime(
156 transform,
157 self.runtime.clone(),
158 ));
159 }
160
161 pub fn get_query_result(&mut self) -> QueryResult {
166 let mut kmerge: KMerge<_, _, _> = KMerge::new(TsInitComparator);
167
168 self.batch_streams
169 .drain(..)
170 .for_each(|eager_stream| kmerge.push_iter(eager_stream));
171
172 kmerge
173 }
174}
175
176unsafe impl Send for DataBackendSession {}
178
179#[must_use]
180pub fn build_query(
181 table: &str,
182 start: Option<UnixNanos>,
183 end: Option<UnixNanos>,
184 where_clause: Option<&str>,
185) -> String {
186 let mut conditions = Vec::new();
187
188 if let Some(clause) = where_clause {
190 conditions.push(clause.to_string());
191 }
192
193 if let Some(start_ts) = start {
195 conditions.push(format!("ts_init >= {start_ts}"));
196 }
197
198 if let Some(end_ts) = end {
200 conditions.push(format!("ts_init <= {end_ts}"));
201 }
202
203 let mut query = format!("SELECT * FROM {table}");
205
206 if !conditions.is_empty() {
208 query.push_str(" WHERE ");
209 query.push_str(&conditions.join(" AND "));
210 }
211
212 query.push_str(" ORDER BY ts_init");
214
215 query
216}
217
218#[cfg_attr(
219 feature = "python",
220 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
221)]
222pub struct DataQueryResult {
223 pub chunk: Option<CVec>,
224 pub result: QueryResult,
225 pub acc: Vec<Data>,
226 pub size: usize,
227}
228
229impl DataQueryResult {
230 #[must_use]
232 pub const fn new(result: QueryResult, size: usize) -> Self {
233 Self {
234 chunk: None,
235 result,
236 acc: Vec::new(),
237 size,
238 }
239 }
240
241 pub fn set_chunk(&mut self, data: Vec<Data>) -> CVec {
245 self.drop_chunk();
246
247 let chunk: CVec = data.into();
248 self.chunk = Some(chunk);
249 chunk
250 }
251
252 pub fn drop_chunk(&mut self) {
256 if let Some(CVec { ptr, len, cap }) = self.chunk.take() {
257 let data: Vec<Data> =
258 unsafe { Vec::from_raw_parts(ptr.cast::<nautilus_model::data::Data>(), len, cap) };
259 drop(data);
260 }
261 }
262}
263
264impl Iterator for DataQueryResult {
265 type Item = Vec<Data>;
266
267 fn next(&mut self) -> Option<Self::Item> {
268 for _ in 0..self.size {
269 match self.result.next() {
270 Some(item) => self.acc.push(item),
271 None => break,
272 }
273 }
274
275 let mut acc: Vec<Data> = Vec::new();
278 std::mem::swap(&mut acc, &mut self.acc);
279 Some(acc)
280 }
281}
282
283impl Drop for DataQueryResult {
284 fn drop(&mut self) {
285 self.drop_chunk();
286 self.result.clear();
287 }
288}
289
290unsafe impl Send for DataQueryResult {}