nautilus_persistence/backend/
session.rs1use std::{sync::Arc, vec::IntoIter};
17
18use ahash::{AHashMap, AHashSet};
19use datafusion::{
20 error::Result, logical_expr::expr::Sort, physical_plan::SendableRecordBatchStream, prelude::*,
21};
22use futures::StreamExt;
23use nautilus_core::{UnixNanos, ffi::cvec::CVec};
24use nautilus_model::data::{Data, HasTsInit};
25use nautilus_serialization::arrow::{
26 DataStreamingError, DecodeDataFromRecordBatch, EncodeToRecordBatch, WriteStream,
27};
28use object_store::ObjectStore;
29use url::Url;
30
31use super::{
32 compare::Compare,
33 kmerge_batch::{EagerStream, ElementBatchIter, KMerge},
34};
35
36#[derive(Debug, Default)]
37pub struct TsInitComparator;
38
39impl<I> Compare<ElementBatchIter<I, Data>> for TsInitComparator
40where
41 I: Iterator<Item = IntoIter<Data>>,
42{
43 fn compare(
44 &self,
45 l: &ElementBatchIter<I, Data>,
46 r: &ElementBatchIter<I, Data>,
47 ) -> std::cmp::Ordering {
48 l.item.ts_init().cmp(&r.item.ts_init()).reverse()
50 }
51}
52
53pub type QueryResult = KMerge<EagerStream<std::vec::IntoIter<Data>>, Data, TsInitComparator>;
54
55#[cfg_attr(
61 feature = "python",
62 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
63)]
64pub struct DataBackendSession {
65 pub chunk_size: usize,
66 pub runtime: Arc<tokio::runtime::Runtime>,
67 session_ctx: SessionContext,
68 batch_streams: Vec<EagerStream<IntoIter<Data>>>,
69 registered_tables: AHashSet<String>,
70}
71
72impl DataBackendSession {
73 #[must_use]
75 pub fn new(chunk_size: usize) -> Self {
76 let runtime = tokio::runtime::Builder::new_multi_thread()
77 .enable_all()
78 .build()
79 .unwrap();
80 let session_cfg = SessionConfig::new()
81 .set_str("datafusion.optimizer.repartition_file_scans", "false")
82 .set_str("datafusion.optimizer.prefer_existing_sort", "true");
83 let session_ctx = SessionContext::new_with_config(session_cfg);
84 Self {
85 session_ctx,
86 batch_streams: Vec::default(),
87 chunk_size,
88 runtime: Arc::new(runtime),
89 registered_tables: AHashSet::new(),
90 }
91 }
92
93 pub fn register_object_store(&mut self, url: &Url, object_store: Arc<dyn ObjectStore>) {
95 self.session_ctx.register_object_store(url, object_store);
96 }
97
98 pub fn register_object_store_from_uri(
100 &mut self,
101 uri: &str,
102 storage_options: Option<AHashMap<String, String>>,
103 ) -> anyhow::Result<()> {
104 let (object_store, _, _) =
106 crate::parquet::create_object_store_from_path(uri, storage_options)?;
107
108 let parsed_uri = Url::parse(uri)?;
110
111 if matches!(
113 parsed_uri.scheme(),
114 "s3" | "gs" | "gcs" | "az" | "abfs" | "http" | "https"
115 ) {
116 let base_url = format!(
118 "{}://{}",
119 parsed_uri.scheme(),
120 parsed_uri.host_str().unwrap_or("")
121 );
122 let base_parsed_url = Url::parse(&base_url)?;
123 self.register_object_store(&base_parsed_url, object_store);
124 }
125
126 Ok(())
127 }
128
129 pub fn write_data<T: EncodeToRecordBatch>(
130 data: &[T],
131 metadata: &AHashMap<String, String>,
132 stream: &mut dyn WriteStream,
133 ) -> Result<(), DataStreamingError> {
134 let metadata: std::collections::HashMap<String, String> = metadata
136 .iter()
137 .map(|(k, v)| (k.clone(), v.clone()))
138 .collect();
139 let record_batch = T::encode_batch(&metadata, data)?;
140 stream.write(&record_batch)?;
141 Ok(())
142 }
143
144 pub fn add_file<T>(
158 &mut self,
159 table_name: &str,
160 file_path: &str,
161 sql_query: Option<&str>,
162 ) -> Result<()>
163 where
164 T: DecodeDataFromRecordBatch + Into<Data>,
165 {
166 let is_new_table = !self.registered_tables.contains(table_name);
168
169 if is_new_table {
170 let parquet_options = ParquetReadOptions::<'_> {
172 skip_metadata: Some(false),
173 file_sort_order: vec![vec![Sort {
174 expr: col("ts_init"),
175 asc: true,
176 nulls_first: false,
177 }]],
178 ..Default::default()
179 };
180 self.runtime.block_on(self.session_ctx.register_parquet(
181 table_name,
182 file_path,
183 parquet_options,
184 ))?;
185
186 self.registered_tables.insert(table_name.to_string());
187
188 let default_query = format!("SELECT * FROM {} ORDER BY ts_init", &table_name);
190 let sql_query = sql_query.unwrap_or(&default_query);
191 let query = self.runtime.block_on(self.session_ctx.sql(sql_query))?;
192 let batch_stream = self.runtime.block_on(query.execute_stream())?;
193 self.add_batch_stream::<T>(batch_stream);
194 }
195
196 Ok(())
197 }
198
199 fn add_batch_stream<T>(&mut self, stream: SendableRecordBatchStream)
200 where
201 T: DecodeDataFromRecordBatch + Into<Data>,
202 {
203 let transform = stream.map(|result| match result {
204 Ok(batch) => T::decode_data_batch(batch.schema().metadata(), batch)
205 .unwrap()
206 .into_iter(),
207 Err(e) => panic!("Error getting next batch from RecordBatchStream: {e}"),
208 });
209
210 self.batch_streams
211 .push(EagerStream::from_stream_with_runtime(
212 transform,
213 self.runtime.clone(),
214 ));
215 }
216
217 pub fn get_query_result(&mut self) -> QueryResult {
222 let mut kmerge: KMerge<_, _, _> = KMerge::new(TsInitComparator);
223
224 self.batch_streams
225 .drain(..)
226 .for_each(|eager_stream| kmerge.push_iter(eager_stream));
227
228 kmerge
229 }
230
231 pub fn clear_registered_tables(&mut self) {
236 self.registered_tables.clear();
237 self.batch_streams.clear();
238
239 let session_cfg = SessionConfig::new()
241 .set_str("datafusion.optimizer.repartition_file_scans", "false")
242 .set_str("datafusion.optimizer.prefer_existing_sort", "true");
243 self.session_ctx = SessionContext::new_with_config(session_cfg);
244 }
245}
246
247#[must_use]
248pub fn build_query(
249 table: &str,
250 start: Option<UnixNanos>,
251 end: Option<UnixNanos>,
252 where_clause: Option<&str>,
253) -> String {
254 let mut conditions = Vec::new();
255
256 if let Some(clause) = where_clause {
258 conditions.push(clause.to_string());
259 }
260
261 if let Some(start_ts) = start {
263 conditions.push(format!("ts_init >= {start_ts}"));
264 }
265
266 if let Some(end_ts) = end {
268 conditions.push(format!("ts_init <= {end_ts}"));
269 }
270
271 let mut query = format!("SELECT * FROM {table}");
273
274 if !conditions.is_empty() {
276 query.push_str(" WHERE ");
277 query.push_str(&conditions.join(" AND "));
278 }
279
280 query.push_str(" ORDER BY ts_init");
282
283 query
284}
285
286#[cfg_attr(
287 feature = "python",
288 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
289)]
290pub struct DataQueryResult {
291 pub chunk: Option<CVec>,
292 pub result: QueryResult,
293 pub acc: Vec<Data>,
294 pub size: usize,
295}
296
297impl DataQueryResult {
298 #[must_use]
300 pub const fn new(result: QueryResult, size: usize) -> Self {
301 Self {
302 chunk: None,
303 result,
304 acc: Vec::new(),
305 size,
306 }
307 }
308
309 pub fn set_chunk(&mut self, data: Vec<Data>) -> CVec {
313 self.drop_chunk();
314
315 let chunk: CVec = data.into();
316 self.chunk = Some(chunk);
317 chunk
318 }
319
320 pub fn drop_chunk(&mut self) {
324 if let Some(CVec { ptr, len, cap }) = self.chunk.take() {
325 assert!(
326 len <= cap,
327 "drop_chunk: len ({len}) > cap ({cap}) - memory corruption or wrong chunk type"
328 );
329 assert!(
330 len == 0 || !ptr.is_null(),
331 "drop_chunk: null ptr with non-zero len ({len}) - memory corruption"
332 );
333
334 let data: Vec<Data> = unsafe { Vec::from_raw_parts(ptr.cast::<Data>(), len, cap) };
335 drop(data);
336 }
337 }
338}
339
340impl Iterator for DataQueryResult {
341 type Item = Vec<Data>;
342
343 fn next(&mut self) -> Option<Self::Item> {
344 for _ in 0..self.size {
345 match self.result.next() {
346 Some(item) => self.acc.push(item),
347 None => break,
348 }
349 }
350
351 let mut acc: Vec<Data> = Vec::new();
354 std::mem::swap(&mut acc, &mut self.acc);
355 Some(acc)
356 }
357}
358
359impl Drop for DataQueryResult {
360 fn drop(&mut self) {
361 self.drop_chunk();
362 self.result.clear();
363 }
364}