nautilus_persistence/backend/
kmerge_batch.rs1use std::{sync::Arc, vec::IntoIter};
17
18use binary_heap_plus::{BinaryHeap, PeekMut};
19use compare::Compare;
20use futures::{Stream, StreamExt};
21use tokio::{
22 runtime::Runtime,
23 sync::mpsc::{self, Receiver},
24 task::JoinHandle,
25};
26
27pub struct EagerStream<T> {
28 rx: Receiver<T>,
29 task: JoinHandle<()>,
30 runtime: Arc<Runtime>,
31}
32
33impl<T> EagerStream<T> {
34 pub fn from_stream_with_runtime<S>(stream: S, runtime: Arc<Runtime>) -> Self
35 where
36 S: Stream<Item = T> + Send + 'static,
37 T: Send + 'static,
38 {
39 let _guard = runtime.enter();
40 let (tx, rx) = mpsc::channel(1);
41 let task = tokio::spawn(async move {
42 stream
43 .for_each(|item| async {
44 let _ = tx.send(item).await;
45 })
46 .await;
47 });
48
49 Self { rx, task, runtime }
50 }
51}
52
53impl<T> Iterator for EagerStream<T> {
54 type Item = T;
55
56 fn next(&mut self) -> Option<Self::Item> {
57 self.runtime.block_on(self.rx.recv())
58 }
59}
60
61impl<T> Drop for EagerStream<T> {
62 fn drop(&mut self) {
63 self.rx.close();
64 self.task.abort();
65 }
66}
67
68pub struct ElementBatchIter<I, T>
71where
72 I: Iterator<Item = IntoIter<T>>,
73{
74 pub item: T,
75 batch: I::Item,
76 iter: I,
77}
78
79impl<I, T> ElementBatchIter<I, T>
80where
81 I: Iterator<Item = IntoIter<T>>,
82{
83 fn new_from_iter(mut iter: I) -> Option<Self> {
84 loop {
85 match iter.next() {
86 Some(mut batch) => match batch.next() {
87 Some(item) => {
88 break Some(Self { item, batch, iter });
89 }
90 None => continue,
91 },
92 None => break None,
93 }
94 }
95 }
96}
97
98pub struct KMerge<I, T, C>
99where
100 I: Iterator<Item = IntoIter<T>>,
101{
102 heap: BinaryHeap<ElementBatchIter<I, T>, C>,
103}
104
105impl<I, T, C> KMerge<I, T, C>
106where
107 I: Iterator<Item = IntoIter<T>>,
108 C: Compare<ElementBatchIter<I, T>>,
109{
110 pub fn new(cmp: C) -> Self {
112 Self {
113 heap: BinaryHeap::from_vec_cmp(Vec::new(), cmp),
114 }
115 }
116
117 pub fn push_iter(&mut self, s: I) {
118 if let Some(heap_elem) = ElementBatchIter::new_from_iter(s) {
119 self.heap.push(heap_elem);
120 }
121 }
122
123 pub fn clear(&mut self) {
124 self.heap.clear();
125 }
126}
127
128impl<I, T, C> Iterator for KMerge<I, T, C>
129where
130 I: Iterator<Item = IntoIter<T>>,
131 C: Compare<ElementBatchIter<I, T>>,
132{
133 type Item = T;
134
135 fn next(&mut self) -> Option<Self::Item> {
136 match self.heap.peek_mut() {
137 Some(mut heap_elem) => {
138 match heap_elem.batch.next() {
140 Some(mut item) => {
143 std::mem::swap(&mut item, &mut heap_elem.item);
144 Some(item)
145 }
146 None => loop {
149 if let Some(mut batch) = heap_elem.iter.next() {
150 match batch.next() {
151 Some(mut item) => {
152 heap_elem.batch = batch;
153 std::mem::swap(&mut item, &mut heap_elem.item);
154 break Some(item);
155 }
156 None => continue,
158 }
159 } else {
160 let ElementBatchIter {
161 item,
162 batch: _,
163 iter: _,
164 } = PeekMut::pop(heap_elem);
165 break Some(item);
166 }
167 },
168 }
169 }
170 None => None,
171 }
172 }
173}
174
175#[cfg(test)]
179mod tests {
180
181 use quickcheck::{Arbitrary, empty_shrinker};
182 use quickcheck_macros::quickcheck;
183 use rstest::rstest;
184
185 use super::*;
186
187 struct OrdComparator;
188 impl<S> Compare<ElementBatchIter<S, i32>> for OrdComparator
189 where
190 S: Iterator<Item = IntoIter<i32>>,
191 {
192 fn compare(
193 &self,
194 l: &ElementBatchIter<S, i32>,
195 r: &ElementBatchIter<S, i32>,
196 ) -> std::cmp::Ordering {
197 l.item.cmp(&r.item).reverse()
199 }
200 }
201
202 impl<S> Compare<ElementBatchIter<S, u64>> for OrdComparator
203 where
204 S: Iterator<Item = IntoIter<u64>>,
205 {
206 fn compare(
207 &self,
208 l: &ElementBatchIter<S, u64>,
209 r: &ElementBatchIter<S, u64>,
210 ) -> std::cmp::Ordering {
211 l.item.cmp(&r.item).reverse()
213 }
214 }
215
216 #[rstest]
217 fn test1() {
218 let iter_a = vec![vec![1, 2, 3].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
219 let iter_b = vec![vec![4, 5, 6].into_iter()].into_iter();
220 let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
221 kmerge.push_iter(iter_a);
222 kmerge.push_iter(iter_b);
223
224 let values: Vec<i32> = kmerge.collect();
225 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
226 }
227
228 #[rstest]
229 fn test2() {
230 let iter_a = vec![vec![1, 2, 6].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
231 let iter_b = vec![vec![3, 4, 5, 6].into_iter()].into_iter();
232 let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
233 kmerge.push_iter(iter_a);
234 kmerge.push_iter(iter_b);
235
236 let values: Vec<i32> = kmerge.collect();
237 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 6, 7, 8, 9]);
238 }
239
240 #[rstest]
241 fn test3() {
242 let iter_a = vec![vec![1, 4, 7].into_iter(), vec![24, 35, 56].into_iter()].into_iter();
243 let iter_b = vec![vec![2, 4, 8].into_iter()].into_iter();
244 let iter_c = vec![vec![3, 5, 9].into_iter(), vec![12, 12, 90].into_iter()].into_iter();
245 let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
246 kmerge.push_iter(iter_a);
247 kmerge.push_iter(iter_b);
248 kmerge.push_iter(iter_c);
249
250 let values: Vec<i32> = kmerge.collect();
251 assert_eq!(
252 values,
253 vec![1, 2, 3, 4, 4, 5, 7, 8, 9, 12, 12, 24, 35, 56, 90]
254 );
255 }
256
257 #[rstest]
258 fn test5() {
259 let iter_a = vec![
260 vec![1, 3, 5].into_iter(),
261 vec![].into_iter(),
262 vec![7, 9, 11].into_iter(),
263 ]
264 .into_iter();
265 let iter_b = vec![vec![2, 4, 6].into_iter()].into_iter();
266 let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
267 kmerge.push_iter(iter_a);
268 kmerge.push_iter(iter_b);
269
270 let values: Vec<i32> = kmerge.collect();
271 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 9, 11]);
272 }
273
274 #[derive(Debug, Clone)]
275 struct SortedNestedVec(Vec<Vec<u64>>);
276
277 impl Arbitrary for SortedNestedVec {
278 fn arbitrary(g: &mut quickcheck::Gen) -> Self {
279 let mut vec: Vec<u64> = Arbitrary::arbitrary(g);
281
282 vec.sort_unstable();
284
285 let mut nested_sorted_vec = Vec::new();
287 let mut start = 0;
288 while start < vec.len() {
289 let chunk_size: usize = Arbitrary::arbitrary(g);
291 let chunk_size = chunk_size % (vec.len() - start + 1);
292 let end = start + chunk_size;
293 let chunk = vec[start..end].to_vec();
294 nested_sorted_vec.push(chunk);
295 start = end;
296 }
297
298 Self(nested_sorted_vec)
300 }
301
302 fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
304 empty_shrinker()
305 }
306 }
307
308 #[quickcheck]
309 fn prop_test(all_data: Vec<SortedNestedVec>) -> bool {
310 let mut kmerge: KMerge<_, u64, _> = KMerge::new(OrdComparator);
311
312 let copy_data = all_data.clone();
313 for stream in copy_data {
314 let input = stream.0.into_iter().map(std::iter::IntoIterator::into_iter);
315 kmerge.push_iter(input);
316 }
317 let merged_data: Vec<u64> = kmerge.collect();
318
319 let mut sorted_data: Vec<u64> = all_data
320 .into_iter()
321 .flat_map(|stream| stream.0.into_iter().flatten())
322 .collect();
323 sorted_data.sort_unstable();
324
325 merged_data.len() == sorted_data.len() && merged_data.eq(&sorted_data)
326 }
327}