nautilus_persistence/backend/
session.rs1use std::{sync::Arc, vec::IntoIter};
17
18use ahash::{AHashMap, AHashSet};
19use compare::Compare;
20use datafusion::{
21 error::Result, logical_expr::expr::Sort, physical_plan::SendableRecordBatchStream, prelude::*,
22};
23use futures::StreamExt;
24use nautilus_core::{UnixNanos, ffi::cvec::CVec};
25use nautilus_model::data::{Data, HasTsInit};
26use nautilus_serialization::arrow::{
27 DataStreamingError, DecodeDataFromRecordBatch, EncodeToRecordBatch, WriteStream,
28};
29use object_store::ObjectStore;
30use url::Url;
31
32use super::kmerge_batch::{EagerStream, ElementBatchIter, KMerge};
33
34#[derive(Debug, Default)]
35pub struct TsInitComparator;
36
37impl<I> Compare<ElementBatchIter<I, Data>> for TsInitComparator
38where
39 I: Iterator<Item = IntoIter<Data>>,
40{
41 fn compare(
42 &self,
43 l: &ElementBatchIter<I, Data>,
44 r: &ElementBatchIter<I, Data>,
45 ) -> std::cmp::Ordering {
46 l.item.ts_init().cmp(&r.item.ts_init()).reverse()
48 }
49}
50
51pub type QueryResult = KMerge<EagerStream<std::vec::IntoIter<Data>>, Data, TsInitComparator>;
52
53#[cfg_attr(
59 feature = "python",
60 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence")
61)]
62pub struct DataBackendSession {
63 pub chunk_size: usize,
64 pub runtime: Arc<tokio::runtime::Runtime>,
65 session_ctx: SessionContext,
66 batch_streams: Vec<EagerStream<IntoIter<Data>>>,
67 registered_tables: AHashSet<String>,
68}
69
70impl DataBackendSession {
71 #[must_use]
73 pub fn new(chunk_size: usize) -> Self {
74 let runtime = tokio::runtime::Builder::new_multi_thread()
75 .enable_all()
76 .build()
77 .unwrap();
78 let session_cfg = SessionConfig::new()
79 .set_str("datafusion.optimizer.repartition_file_scans", "false")
80 .set_str("datafusion.optimizer.prefer_existing_sort", "true");
81 let session_ctx = SessionContext::new_with_config(session_cfg);
82 Self {
83 session_ctx,
84 batch_streams: Vec::default(),
85 chunk_size,
86 runtime: Arc::new(runtime),
87 registered_tables: AHashSet::new(),
88 }
89 }
90
91 pub fn register_object_store(&mut self, url: &Url, object_store: Arc<dyn ObjectStore>) {
93 self.session_ctx.register_object_store(url, object_store);
94 }
95
96 pub fn register_object_store_from_uri(
98 &mut self,
99 uri: &str,
100 storage_options: Option<AHashMap<String, String>>,
101 ) -> anyhow::Result<()> {
102 let (object_store, _, _) =
104 crate::parquet::create_object_store_from_path(uri, storage_options)?;
105
106 let parsed_uri = Url::parse(uri)?;
108
109 if matches!(
111 parsed_uri.scheme(),
112 "s3" | "gs" | "gcs" | "az" | "abfs" | "http" | "https"
113 ) {
114 let base_url = format!(
116 "{}://{}",
117 parsed_uri.scheme(),
118 parsed_uri.host_str().unwrap_or("")
119 );
120 let base_parsed_url = Url::parse(&base_url)?;
121 self.register_object_store(&base_parsed_url, object_store);
122 }
123
124 Ok(())
125 }
126
127 pub fn write_data<T: EncodeToRecordBatch>(
128 data: &[T],
129 metadata: &AHashMap<String, String>,
130 stream: &mut dyn WriteStream,
131 ) -> Result<(), DataStreamingError> {
132 let metadata: std::collections::HashMap<String, String> = metadata
134 .iter()
135 .map(|(k, v)| (k.clone(), v.clone()))
136 .collect();
137 let record_batch = T::encode_batch(&metadata, data)?;
138 stream.write(&record_batch)?;
139 Ok(())
140 }
141
142 pub fn add_file<T>(
156 &mut self,
157 table_name: &str,
158 file_path: &str,
159 sql_query: Option<&str>,
160 ) -> Result<()>
161 where
162 T: DecodeDataFromRecordBatch + Into<Data>,
163 {
164 let is_new_table = !self.registered_tables.contains(table_name);
166
167 if is_new_table {
168 let parquet_options = ParquetReadOptions::<'_> {
170 skip_metadata: Some(false),
171 file_sort_order: vec![vec![Sort {
172 expr: col("ts_init"),
173 asc: true,
174 nulls_first: false,
175 }]],
176 ..Default::default()
177 };
178 self.runtime.block_on(self.session_ctx.register_parquet(
179 table_name,
180 file_path,
181 parquet_options,
182 ))?;
183
184 self.registered_tables.insert(table_name.to_string());
185
186 let default_query = format!("SELECT * FROM {} ORDER BY ts_init", &table_name);
188 let sql_query = sql_query.unwrap_or(&default_query);
189 let query = self.runtime.block_on(self.session_ctx.sql(sql_query))?;
190 let batch_stream = self.runtime.block_on(query.execute_stream())?;
191 self.add_batch_stream::<T>(batch_stream);
192 }
193
194 Ok(())
195 }
196
197 fn add_batch_stream<T>(&mut self, stream: SendableRecordBatchStream)
198 where
199 T: DecodeDataFromRecordBatch + Into<Data>,
200 {
201 let transform = stream.map(|result| match result {
202 Ok(batch) => T::decode_data_batch(batch.schema().metadata(), batch)
203 .unwrap()
204 .into_iter(),
205 Err(e) => panic!("Error getting next batch from RecordBatchStream: {e}"),
206 });
207
208 self.batch_streams
209 .push(EagerStream::from_stream_with_runtime(
210 transform,
211 self.runtime.clone(),
212 ));
213 }
214
215 pub fn get_query_result(&mut self) -> QueryResult {
220 let mut kmerge: KMerge<_, _, _> = KMerge::new(TsInitComparator);
221
222 self.batch_streams
223 .drain(..)
224 .for_each(|eager_stream| kmerge.push_iter(eager_stream));
225
226 kmerge
227 }
228
229 pub fn clear_registered_tables(&mut self) {
234 self.registered_tables.clear();
235 self.batch_streams.clear();
236
237 let session_cfg = SessionConfig::new()
239 .set_str("datafusion.optimizer.repartition_file_scans", "false")
240 .set_str("datafusion.optimizer.prefer_existing_sort", "true");
241 self.session_ctx = SessionContext::new_with_config(session_cfg);
242 }
243}
244
245#[allow(unsafe_code)]
249unsafe impl Send for DataBackendSession {}
250
251#[must_use]
252pub fn build_query(
253 table: &str,
254 start: Option<UnixNanos>,
255 end: Option<UnixNanos>,
256 where_clause: Option<&str>,
257) -> String {
258 let mut conditions = Vec::new();
259
260 if let Some(clause) = where_clause {
262 conditions.push(clause.to_string());
263 }
264
265 if let Some(start_ts) = start {
267 conditions.push(format!("ts_init >= {start_ts}"));
268 }
269
270 if let Some(end_ts) = end {
272 conditions.push(format!("ts_init <= {end_ts}"));
273 }
274
275 let mut query = format!("SELECT * FROM {table}");
277
278 if !conditions.is_empty() {
280 query.push_str(" WHERE ");
281 query.push_str(&conditions.join(" AND "));
282 }
283
284 query.push_str(" ORDER BY ts_init");
286
287 query
288}
289
290#[cfg_attr(
291 feature = "python",
292 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
293)]
294pub struct DataQueryResult {
295 pub chunk: Option<CVec>,
296 pub result: QueryResult,
297 pub acc: Vec<Data>,
298 pub size: usize,
299}
300
301impl DataQueryResult {
302 #[must_use]
304 pub const fn new(result: QueryResult, size: usize) -> Self {
305 Self {
306 chunk: None,
307 result,
308 acc: Vec::new(),
309 size,
310 }
311 }
312
313 pub fn set_chunk(&mut self, data: Vec<Data>) -> CVec {
317 self.drop_chunk();
318
319 let chunk: CVec = data.into();
320 self.chunk = Some(chunk);
321 chunk
322 }
323
324 pub fn drop_chunk(&mut self) {
328 if let Some(CVec { ptr, len, cap }) = self.chunk.take() {
329 assert!(
330 len <= cap,
331 "drop_chunk: len ({len}) > cap ({cap}) - memory corruption or wrong chunk type"
332 );
333 assert!(
334 len == 0 || !ptr.is_null(),
335 "drop_chunk: null ptr with non-zero len ({len}) - memory corruption"
336 );
337
338 let data: Vec<Data> = unsafe { Vec::from_raw_parts(ptr.cast::<Data>(), len, cap) };
339 drop(data);
340 }
341 }
342}
343
344impl Iterator for DataQueryResult {
345 type Item = Vec<Data>;
346
347 fn next(&mut self) -> Option<Self::Item> {
348 for _ in 0..self.size {
349 match self.result.next() {
350 Some(item) => self.acc.push(item),
351 None => break,
352 }
353 }
354
355 let mut acc: Vec<Data> = Vec::new();
358 std::mem::swap(&mut acc, &mut self.acc);
359 Some(acc)
360 }
361}
362
363impl Drop for DataQueryResult {
364 fn drop(&mut self) {
365 self.drop_chunk();
366 self.result.clear();
367 }
368}
369
370#[allow(unsafe_code)]
374unsafe impl Send for DataQueryResult {}