nautilus_persistence/backend/
kmerge_batch.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{sync::Arc, vec::IntoIter};
17
18use binary_heap_plus::{BinaryHeap, PeekMut};
19use compare::Compare;
20use futures::{Stream, StreamExt};
21use tokio::{
22    runtime::Runtime,
23    sync::mpsc::{self, Receiver},
24    task::JoinHandle,
25};
26
27pub struct EagerStream<T> {
28    rx: Receiver<T>,
29    task: JoinHandle<()>,
30    runtime: Arc<Runtime>,
31}
32
33impl<T> EagerStream<T> {
34    pub fn from_stream_with_runtime<S>(stream: S, runtime: Arc<Runtime>) -> Self
35    where
36        S: Stream<Item = T> + Send + 'static,
37        T: Send + 'static,
38    {
39        let _guard = runtime.enter();
40        let (tx, rx) = mpsc::channel(1);
41        let task = tokio::spawn(async move {
42            stream
43                .for_each(|item| async {
44                    let _ = tx.send(item).await;
45                })
46                .await;
47        });
48
49        Self { rx, task, runtime }
50    }
51}
52
53impl<T> Iterator for EagerStream<T> {
54    type Item = T;
55
56    fn next(&mut self) -> Option<Self::Item> {
57        self.runtime.block_on(self.rx.recv())
58    }
59}
60
61impl<T> Drop for EagerStream<T> {
62    fn drop(&mut self) {
63        self.rx.close();
64        self.task.abort();
65    }
66}
67
68// TODO: Investigate implementing Iterator for ElementBatchIter
69// to reduce next element duplication. May be difficult to make it peekable.
70pub struct ElementBatchIter<I, T>
71where
72    I: Iterator<Item = IntoIter<T>>,
73{
74    pub item: T,
75    batch: I::Item,
76    iter: I,
77}
78
79impl<I, T> ElementBatchIter<I, T>
80where
81    I: Iterator<Item = IntoIter<T>>,
82{
83    fn new_from_iter(mut iter: I) -> Option<Self> {
84        loop {
85            match iter.next() {
86                Some(mut batch) => match batch.next() {
87                    Some(item) => {
88                        break Some(Self { item, batch, iter });
89                    }
90                    None => continue,
91                },
92                None => break None,
93            }
94        }
95    }
96}
97
98pub struct KMerge<I, T, C>
99where
100    I: Iterator<Item = IntoIter<T>>,
101{
102    heap: BinaryHeap<ElementBatchIter<I, T>, C>,
103}
104
105impl<I, T, C> KMerge<I, T, C>
106where
107    I: Iterator<Item = IntoIter<T>>,
108    C: Compare<ElementBatchIter<I, T>>,
109{
110    /// Creates a new [`KMerge`] instance.
111    pub fn new(cmp: C) -> Self {
112        Self {
113            heap: BinaryHeap::from_vec_cmp(Vec::new(), cmp),
114        }
115    }
116
117    pub fn push_iter(&mut self, s: I) {
118        if let Some(heap_elem) = ElementBatchIter::new_from_iter(s) {
119            self.heap.push(heap_elem);
120        }
121    }
122
123    pub fn clear(&mut self) {
124        self.heap.clear();
125    }
126}
127
128impl<I, T, C> Iterator for KMerge<I, T, C>
129where
130    I: Iterator<Item = IntoIter<T>>,
131    C: Compare<ElementBatchIter<I, T>>,
132{
133    type Item = T;
134
135    fn next(&mut self) -> Option<Self::Item> {
136        match self.heap.peek_mut() {
137            Some(mut heap_elem) => {
138                // Get next element from batch
139                match heap_elem.batch.next() {
140                    // Swap current heap element with new element
141                    // return the old element
142                    Some(mut item) => {
143                        std::mem::swap(&mut item, &mut heap_elem.item);
144                        Some(item)
145                    }
146                    // Otherwise get the next batch and the element from it
147                    // Unless the underlying iterator is exhausted
148                    None => loop {
149                        if let Some(mut batch) = heap_elem.iter.next() {
150                            match batch.next() {
151                                Some(mut item) => {
152                                    heap_elem.batch = batch;
153                                    std::mem::swap(&mut item, &mut heap_elem.item);
154                                    break Some(item);
155                                }
156                                // Get next batch from iterator
157                                None => continue,
158                            }
159                        } else {
160                            let ElementBatchIter {
161                                item,
162                                batch: _,
163                                iter: _,
164                            } = PeekMut::pop(heap_elem);
165                            break Some(item);
166                        }
167                    },
168                }
169            }
170            None => None,
171        }
172    }
173}
174
175////////////////////////////////////////////////////////////////////////////////
176// Tests
177////////////////////////////////////////////////////////////////////////////////
178#[cfg(test)]
179mod tests {
180
181    use quickcheck::{Arbitrary, empty_shrinker};
182    use quickcheck_macros::quickcheck;
183    use rstest::rstest;
184
185    use super::*;
186
187    struct OrdComparator;
188    impl<S> Compare<ElementBatchIter<S, i32>> for OrdComparator
189    where
190        S: Iterator<Item = IntoIter<i32>>,
191    {
192        fn compare(
193            &self,
194            l: &ElementBatchIter<S, i32>,
195            r: &ElementBatchIter<S, i32>,
196        ) -> std::cmp::Ordering {
197            // Max heap ordering must be reversed
198            l.item.cmp(&r.item).reverse()
199        }
200    }
201
202    impl<S> Compare<ElementBatchIter<S, u64>> for OrdComparator
203    where
204        S: Iterator<Item = IntoIter<u64>>,
205    {
206        fn compare(
207            &self,
208            l: &ElementBatchIter<S, u64>,
209            r: &ElementBatchIter<S, u64>,
210        ) -> std::cmp::Ordering {
211            // Max heap ordering must be reversed
212            l.item.cmp(&r.item).reverse()
213        }
214    }
215
216    #[rstest]
217    fn test1() {
218        let iter_a = vec![vec![1, 2, 3].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
219        let iter_b = vec![vec![4, 5, 6].into_iter()].into_iter();
220        let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
221        kmerge.push_iter(iter_a);
222        kmerge.push_iter(iter_b);
223
224        let values: Vec<i32> = kmerge.collect();
225        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
226    }
227
228    #[rstest]
229    fn test2() {
230        let iter_a = vec![vec![1, 2, 6].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
231        let iter_b = vec![vec![3, 4, 5, 6].into_iter()].into_iter();
232        let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
233        kmerge.push_iter(iter_a);
234        kmerge.push_iter(iter_b);
235
236        let values: Vec<i32> = kmerge.collect();
237        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 6, 7, 8, 9]);
238    }
239
240    #[rstest]
241    fn test3() {
242        let iter_a = vec![vec![1, 4, 7].into_iter(), vec![24, 35, 56].into_iter()].into_iter();
243        let iter_b = vec![vec![2, 4, 8].into_iter()].into_iter();
244        let iter_c = vec![vec![3, 5, 9].into_iter(), vec![12, 12, 90].into_iter()].into_iter();
245        let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
246        kmerge.push_iter(iter_a);
247        kmerge.push_iter(iter_b);
248        kmerge.push_iter(iter_c);
249
250        let values: Vec<i32> = kmerge.collect();
251        assert_eq!(
252            values,
253            vec![1, 2, 3, 4, 4, 5, 7, 8, 9, 12, 12, 24, 35, 56, 90]
254        );
255    }
256
257    #[rstest]
258    fn test5() {
259        let iter_a = vec![
260            vec![1, 3, 5].into_iter(),
261            vec![].into_iter(),
262            vec![7, 9, 11].into_iter(),
263        ]
264        .into_iter();
265        let iter_b = vec![vec![2, 4, 6].into_iter()].into_iter();
266        let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
267        kmerge.push_iter(iter_a);
268        kmerge.push_iter(iter_b);
269
270        let values: Vec<i32> = kmerge.collect();
271        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 9, 11]);
272    }
273
274    #[derive(Debug, Clone)]
275    struct SortedNestedVec(Vec<Vec<u64>>);
276
277    impl Arbitrary for SortedNestedVec {
278        fn arbitrary(g: &mut quickcheck::Gen) -> Self {
279            // Generate a random Vec<u64>
280            let mut vec: Vec<u64> = Arbitrary::arbitrary(g);
281
282            // Sort the vector
283            vec.sort_unstable();
284
285            // Recreate nested Vec structure by splitting the flattened_sorted_vec into sorted chunks
286            let mut nested_sorted_vec = Vec::new();
287            let mut start = 0;
288            while start < vec.len() {
289                // let chunk_size: usize = g.rng.gen_range(0, vec.len() - start + 1);
290                let chunk_size: usize = Arbitrary::arbitrary(g);
291                let chunk_size = chunk_size % (vec.len() - start + 1);
292                let end = start + chunk_size;
293                let chunk = vec[start..end].to_vec();
294                nested_sorted_vec.push(chunk);
295                start = end;
296            }
297
298            // Wrap the sorted nested vector in the SortedNestedVecU64 struct
299            Self(nested_sorted_vec)
300        }
301
302        // Optionally, implement the `shrink` method if you want to shrink the generated data on test failures
303        fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
304            empty_shrinker()
305        }
306    }
307
308    #[quickcheck]
309    fn prop_test(all_data: Vec<SortedNestedVec>) -> bool {
310        let mut kmerge: KMerge<_, u64, _> = KMerge::new(OrdComparator);
311
312        let copy_data = all_data.clone();
313        for stream in copy_data {
314            let input = stream.0.into_iter().map(std::iter::IntoIterator::into_iter);
315            kmerge.push_iter(input);
316        }
317        let merged_data: Vec<u64> = kmerge.collect();
318
319        let mut sorted_data: Vec<u64> = all_data
320            .into_iter()
321            .flat_map(|stream| stream.0.into_iter().flatten())
322            .collect();
323        sorted_data.sort_unstable();
324
325        merged_data.len() == sorted_data.len() && merged_data.eq(&sorted_data)
326    }
327}