nautilus_persistence/backend/
session.rs1use std::{collections::HashMap, sync::Arc, vec::IntoIter};
17
18use compare::Compare;
19use datafusion::{
20 error::Result, logical_expr::expr::Sort, physical_plan::SendableRecordBatchStream, prelude::*,
21};
22use futures::StreamExt;
23use nautilus_core::{UnixNanos, ffi::cvec::CVec};
24use nautilus_model::data::{Data, HasTsInit};
25use nautilus_serialization::arrow::{
26 DataStreamingError, DecodeDataFromRecordBatch, EncodeToRecordBatch, WriteStream,
27};
28use object_store::ObjectStore;
29use url::Url;
30
31use super::kmerge_batch::{EagerStream, ElementBatchIter, KMerge};
32
33#[derive(Debug, Default)]
34pub struct TsInitComparator;
35
36impl<I> Compare<ElementBatchIter<I, Data>> for TsInitComparator
37where
38 I: Iterator<Item = IntoIter<Data>>,
39{
40 fn compare(
41 &self,
42 l: &ElementBatchIter<I, Data>,
43 r: &ElementBatchIter<I, Data>,
44 ) -> std::cmp::Ordering {
45 l.item.ts_init().cmp(&r.item.ts_init()).reverse()
47 }
48}
49
50pub type QueryResult = KMerge<EagerStream<std::vec::IntoIter<Data>>, Data, TsInitComparator>;
51
52#[cfg_attr(
58 feature = "python",
59 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence")
60)]
61pub struct DataBackendSession {
62 pub chunk_size: usize,
63 pub runtime: Arc<tokio::runtime::Runtime>,
64 session_ctx: SessionContext,
65 batch_streams: Vec<EagerStream<IntoIter<Data>>>,
66 registered_tables: std::collections::HashSet<String>,
67}
68
69impl DataBackendSession {
70 #[must_use]
72 pub fn new(chunk_size: usize) -> Self {
73 let runtime = tokio::runtime::Builder::new_multi_thread()
74 .enable_all()
75 .build()
76 .unwrap();
77 let session_cfg = SessionConfig::new()
78 .set_str("datafusion.optimizer.repartition_file_scans", "false")
79 .set_str("datafusion.optimizer.prefer_existing_sort", "true");
80 let session_ctx = SessionContext::new_with_config(session_cfg);
81 Self {
82 session_ctx,
83 batch_streams: Vec::default(),
84 chunk_size,
85 runtime: Arc::new(runtime),
86 registered_tables: std::collections::HashSet::new(),
87 }
88 }
89
90 pub fn register_object_store(&mut self, url: &Url, object_store: Arc<dyn ObjectStore>) {
92 self.session_ctx.register_object_store(url, object_store);
93 }
94
95 pub fn register_object_store_from_uri(
97 &mut self,
98 uri: &str,
99 storage_options: Option<std::collections::HashMap<String, String>>,
100 ) -> anyhow::Result<()> {
101 let (object_store, _, _) =
103 crate::parquet::create_object_store_from_path(uri, storage_options)?;
104
105 let parsed_uri = Url::parse(uri)?;
107
108 if matches!(
110 parsed_uri.scheme(),
111 "s3" | "gs" | "gcs" | "azure" | "abfs" | "http" | "https"
112 ) {
113 let base_url = format!(
115 "{}://{}",
116 parsed_uri.scheme(),
117 parsed_uri.host_str().unwrap_or("")
118 );
119 let base_parsed_url = Url::parse(&base_url)?;
120 self.register_object_store(&base_parsed_url, object_store);
121 }
122
123 Ok(())
124 }
125
126 pub fn write_data<T: EncodeToRecordBatch>(
127 data: &[T],
128 metadata: &HashMap<String, String>,
129 stream: &mut dyn WriteStream,
130 ) -> Result<(), DataStreamingError> {
131 let record_batch = T::encode_batch(metadata, data)?;
132 stream.write(&record_batch)?;
133 Ok(())
134 }
135
136 pub fn add_file<T>(
150 &mut self,
151 table_name: &str,
152 file_path: &str,
153 sql_query: Option<&str>,
154 ) -> Result<()>
155 where
156 T: DecodeDataFromRecordBatch + Into<Data>,
157 {
158 let is_new_table = !self.registered_tables.contains(table_name);
160
161 if is_new_table {
162 let parquet_options = ParquetReadOptions::<'_> {
164 skip_metadata: Some(false),
165 file_sort_order: vec![vec![Sort {
166 expr: col("ts_init"),
167 asc: true,
168 nulls_first: false,
169 }]],
170 ..Default::default()
171 };
172 self.runtime.block_on(self.session_ctx.register_parquet(
173 table_name,
174 file_path,
175 parquet_options,
176 ))?;
177
178 self.registered_tables.insert(table_name.to_string());
179
180 let default_query = format!("SELECT * FROM {} ORDER BY ts_init", &table_name);
182 let sql_query = sql_query.unwrap_or(&default_query);
183 let query = self.runtime.block_on(self.session_ctx.sql(sql_query))?;
184 let batch_stream = self.runtime.block_on(query.execute_stream())?;
185 self.add_batch_stream::<T>(batch_stream);
186 }
187
188 Ok(())
189 }
190
191 fn add_batch_stream<T>(&mut self, stream: SendableRecordBatchStream)
192 where
193 T: DecodeDataFromRecordBatch + Into<Data>,
194 {
195 let transform = stream.map(|result| match result {
196 Ok(batch) => T::decode_data_batch(batch.schema().metadata(), batch)
197 .unwrap()
198 .into_iter(),
199 Err(e) => panic!("Error getting next batch from RecordBatchStream: {e}"),
200 });
201
202 self.batch_streams
203 .push(EagerStream::from_stream_with_runtime(
204 transform,
205 self.runtime.clone(),
206 ));
207 }
208
209 pub fn get_query_result(&mut self) -> QueryResult {
214 let mut kmerge: KMerge<_, _, _> = KMerge::new(TsInitComparator);
215
216 self.batch_streams
217 .drain(..)
218 .for_each(|eager_stream| kmerge.push_iter(eager_stream));
219
220 kmerge
221 }
222
223 pub fn clear_registered_tables(&mut self) {
228 self.registered_tables.clear();
229 self.batch_streams.clear();
230
231 let session_cfg = SessionConfig::new()
233 .set_str("datafusion.optimizer.repartition_file_scans", "false")
234 .set_str("datafusion.optimizer.prefer_existing_sort", "true");
235 self.session_ctx = SessionContext::new_with_config(session_cfg);
236 }
237}
238
239unsafe impl Send for DataBackendSession {}
241
242#[must_use]
243pub fn build_query(
244 table: &str,
245 start: Option<UnixNanos>,
246 end: Option<UnixNanos>,
247 where_clause: Option<&str>,
248) -> String {
249 let mut conditions = Vec::new();
250
251 if let Some(clause) = where_clause {
253 conditions.push(clause.to_string());
254 }
255
256 if let Some(start_ts) = start {
258 conditions.push(format!("ts_init >= {start_ts}"));
259 }
260
261 if let Some(end_ts) = end {
263 conditions.push(format!("ts_init <= {end_ts}"));
264 }
265
266 let mut query = format!("SELECT * FROM {table}");
268
269 if !conditions.is_empty() {
271 query.push_str(" WHERE ");
272 query.push_str(&conditions.join(" AND "));
273 }
274
275 query.push_str(" ORDER BY ts_init");
277
278 query
279}
280
281#[cfg_attr(
282 feature = "python",
283 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
284)]
285pub struct DataQueryResult {
286 pub chunk: Option<CVec>,
287 pub result: QueryResult,
288 pub acc: Vec<Data>,
289 pub size: usize,
290}
291
292impl DataQueryResult {
293 #[must_use]
295 pub const fn new(result: QueryResult, size: usize) -> Self {
296 Self {
297 chunk: None,
298 result,
299 acc: Vec::new(),
300 size,
301 }
302 }
303
304 pub fn set_chunk(&mut self, data: Vec<Data>) -> CVec {
308 self.drop_chunk();
309
310 let chunk: CVec = data.into();
311 self.chunk = Some(chunk);
312 chunk
313 }
314
315 pub fn drop_chunk(&mut self) {
319 if let Some(CVec { ptr, len, cap }) = self.chunk.take() {
320 let data: Vec<Data> =
321 unsafe { Vec::from_raw_parts(ptr.cast::<nautilus_model::data::Data>(), len, cap) };
322 drop(data);
323 }
324 }
325}
326
327impl Iterator for DataQueryResult {
328 type Item = Vec<Data>;
329
330 fn next(&mut self) -> Option<Self::Item> {
331 for _ in 0..self.size {
332 match self.result.next() {
333 Some(item) => self.acc.push(item),
334 None => break,
335 }
336 }
337
338 let mut acc: Vec<Data> = Vec::new();
341 std::mem::swap(&mut acc, &mut self.acc);
342 Some(acc)
343 }
344}
345
346impl Drop for DataQueryResult {
347 fn drop(&mut self) {
348 self.drop_chunk();
349 self.result.clear();
350 }
351}
352
353unsafe impl Send for DataQueryResult {}