nautilus_persistence/backend/
kmerge_batch.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{sync::Arc, vec::IntoIter};
17
18use binary_heap_plus::{BinaryHeap, PeekMut};
19use compare::Compare;
20use futures::{Stream, StreamExt};
21use tokio::{
22    runtime::Runtime,
23    sync::mpsc::{self, Receiver},
24    task::JoinHandle,
25};
26
27pub struct EagerStream<T> {
28    rx: Receiver<T>,
29    task: JoinHandle<()>,
30    runtime: Arc<Runtime>,
31}
32
33impl<T> EagerStream<T> {
34    pub fn from_stream_with_runtime<S>(stream: S, runtime: Arc<Runtime>) -> Self
35    where
36        S: Stream<Item = T> + Send + 'static,
37        T: Send + 'static,
38    {
39        let _guard = runtime.enter();
40        let (tx, rx) = mpsc::channel(1);
41
42        let task = tokio::spawn(async move {
43            stream
44                .for_each(|item| async {
45                    let _ = tx.send(item).await;
46                })
47                .await;
48        });
49
50        Self { rx, task, runtime }
51    }
52}
53
54impl<T> Iterator for EagerStream<T> {
55    type Item = T;
56
57    fn next(&mut self) -> Option<Self::Item> {
58        self.runtime.block_on(self.rx.recv())
59    }
60}
61
62impl<T> Drop for EagerStream<T> {
63    fn drop(&mut self) {
64        self.rx.close();
65        self.task.abort();
66    }
67}
68
69// TODO: Investigate implementing Iterator for ElementBatchIter
70// to reduce next element duplication. May be difficult to make it peekable.
71pub struct ElementBatchIter<I, T>
72where
73    I: Iterator<Item = IntoIter<T>>,
74{
75    pub item: T,
76    batch: I::Item,
77    iter: I,
78}
79
80impl<I, T> ElementBatchIter<I, T>
81where
82    I: Iterator<Item = IntoIter<T>>,
83{
84    fn new_from_iter(mut iter: I) -> Option<Self> {
85        loop {
86            match iter.next() {
87                Some(mut batch) => match batch.next() {
88                    Some(item) => {
89                        break Some(Self { item, batch, iter });
90                    }
91                    None => continue,
92                },
93                None => break None,
94            }
95        }
96    }
97}
98
99pub struct KMerge<I, T, C>
100where
101    I: Iterator<Item = IntoIter<T>>,
102{
103    heap: BinaryHeap<ElementBatchIter<I, T>, C>,
104}
105
106impl<I, T, C> KMerge<I, T, C>
107where
108    I: Iterator<Item = IntoIter<T>>,
109    C: Compare<ElementBatchIter<I, T>>,
110{
111    /// Creates a new [`KMerge`] instance.
112    pub fn new(cmp: C) -> Self {
113        Self {
114            heap: BinaryHeap::from_vec_cmp(Vec::new(), cmp),
115        }
116    }
117
118    pub fn push_iter(&mut self, s: I) {
119        if let Some(heap_elem) = ElementBatchIter::new_from_iter(s) {
120            self.heap.push(heap_elem);
121        }
122    }
123
124    pub fn clear(&mut self) {
125        self.heap.clear();
126    }
127}
128
129impl<I, T, C> Iterator for KMerge<I, T, C>
130where
131    I: Iterator<Item = IntoIter<T>>,
132    C: Compare<ElementBatchIter<I, T>>,
133{
134    type Item = T;
135
136    fn next(&mut self) -> Option<Self::Item> {
137        match self.heap.peek_mut() {
138            Some(mut heap_elem) => {
139                // Get next element from batch
140                match heap_elem.batch.next() {
141                    // Swap current heap element with new element
142                    // return the old element
143                    Some(mut item) => {
144                        std::mem::swap(&mut item, &mut heap_elem.item);
145                        Some(item)
146                    }
147                    // Otherwise get the next batch and the element from it
148                    // Unless the underlying iterator is exhausted
149                    None => loop {
150                        if let Some(mut batch) = heap_elem.iter.next() {
151                            match batch.next() {
152                                Some(mut item) => {
153                                    heap_elem.batch = batch;
154                                    std::mem::swap(&mut item, &mut heap_elem.item);
155                                    break Some(item);
156                                }
157                                // Get next batch from iterator
158                                None => continue,
159                            }
160                        } else {
161                            let ElementBatchIter {
162                                item,
163                                batch: _,
164                                iter: _,
165                            } = PeekMut::pop(heap_elem);
166                            break Some(item);
167                        }
168                    },
169                }
170            }
171            None => None,
172        }
173    }
174}
175
176////////////////////////////////////////////////////////////////////////////////
177// Tests
178////////////////////////////////////////////////////////////////////////////////
179#[cfg(test)]
180mod tests {
181
182    use proptest::prelude::*;
183    use rstest::rstest;
184
185    use super::*;
186
187    struct OrdComparator;
188    impl<S> Compare<ElementBatchIter<S, i32>> for OrdComparator
189    where
190        S: Iterator<Item = IntoIter<i32>>,
191    {
192        fn compare(
193            &self,
194            l: &ElementBatchIter<S, i32>,
195            r: &ElementBatchIter<S, i32>,
196        ) -> std::cmp::Ordering {
197            // Max heap ordering must be reversed
198            l.item.cmp(&r.item).reverse()
199        }
200    }
201
202    impl<S> Compare<ElementBatchIter<S, u64>> for OrdComparator
203    where
204        S: Iterator<Item = IntoIter<u64>>,
205    {
206        fn compare(
207            &self,
208            l: &ElementBatchIter<S, u64>,
209            r: &ElementBatchIter<S, u64>,
210        ) -> std::cmp::Ordering {
211            // Max heap ordering must be reversed
212            l.item.cmp(&r.item).reverse()
213        }
214    }
215
216    #[rstest]
217    fn test1() {
218        let iter_a = vec![vec![1, 2, 3].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
219        let iter_b = vec![vec![4, 5, 6].into_iter()].into_iter();
220        let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
221        kmerge.push_iter(iter_a);
222        kmerge.push_iter(iter_b);
223
224        let values: Vec<i32> = kmerge.collect();
225        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
226    }
227
228    #[rstest]
229    fn test2() {
230        let iter_a = vec![vec![1, 2, 6].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
231        let iter_b = vec![vec![3, 4, 5, 6].into_iter()].into_iter();
232        let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
233        kmerge.push_iter(iter_a);
234        kmerge.push_iter(iter_b);
235
236        let values: Vec<i32> = kmerge.collect();
237        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 6, 7, 8, 9]);
238    }
239
240    #[rstest]
241    fn test3() {
242        let iter_a = vec![vec![1, 4, 7].into_iter(), vec![24, 35, 56].into_iter()].into_iter();
243        let iter_b = vec![vec![2, 4, 8].into_iter()].into_iter();
244        let iter_c = vec![vec![3, 5, 9].into_iter(), vec![12, 12, 90].into_iter()].into_iter();
245        let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
246        kmerge.push_iter(iter_a);
247        kmerge.push_iter(iter_b);
248        kmerge.push_iter(iter_c);
249
250        let values: Vec<i32> = kmerge.collect();
251        assert_eq!(
252            values,
253            vec![1, 2, 3, 4, 4, 5, 7, 8, 9, 12, 12, 24, 35, 56, 90]
254        );
255    }
256
257    #[rstest]
258    fn test5() {
259        let iter_a = vec![
260            vec![1, 3, 5].into_iter(),
261            vec![].into_iter(),
262            vec![7, 9, 11].into_iter(),
263        ]
264        .into_iter();
265        let iter_b = vec![vec![2, 4, 6].into_iter()].into_iter();
266        let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
267        kmerge.push_iter(iter_a);
268        kmerge.push_iter(iter_b);
269
270        let values: Vec<i32> = kmerge.collect();
271        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 9, 11]);
272    }
273
274    #[derive(Debug, Clone)]
275    struct SortedNestedVec(Vec<Vec<u64>>);
276
277    /// Strategy to generate nested vectors where each inner vector is sorted.
278    fn sorted_nested_vec_strategy() -> impl Strategy<Value = SortedNestedVec> {
279        // Generate a vector of u64 values, then split into sorted chunks
280        prop::collection::vec(any::<u64>(), 0..=100).prop_flat_map(|mut flat_vec| {
281            flat_vec.sort_unstable();
282
283            // Generate chunk sizes that will split the sorted vector
284            let total_len = flat_vec.len();
285            if total_len == 0 {
286                return Just(SortedNestedVec(vec![vec![]])).boxed();
287            }
288
289            // Generate random chunk boundaries
290            prop::collection::vec(0..=total_len, 0..=10)
291                .prop_map(move |mut boundaries| {
292                    boundaries.push(0);
293                    boundaries.push(total_len);
294                    boundaries.sort_unstable();
295                    boundaries.dedup();
296
297                    let mut nested_vec = Vec::new();
298                    for window in boundaries.windows(2) {
299                        let start = window[0];
300                        let end = window[1];
301                        nested_vec.push(flat_vec[start..end].to_vec());
302                    }
303
304                    SortedNestedVec(nested_vec)
305                })
306                .boxed()
307        })
308    }
309
310    ////////////////////////////////////////////////////////////////////////////////
311    // Property-based testing
312    ////////////////////////////////////////////////////////////////////////////////
313
314    proptest! {
315        /// Property: K-way merge should produce the same result as sorting all data together
316        #[rstest]
317        fn prop_kmerge_equivalent_to_sort(
318            all_data in prop::collection::vec(sorted_nested_vec_strategy(), 0..=10)
319        ) {
320            let mut kmerge: KMerge<_, u64, _> = KMerge::new(OrdComparator);
321
322            let copy_data = all_data.clone();
323            for stream in copy_data {
324                let input = stream.0.into_iter().map(std::iter::IntoIterator::into_iter);
325                kmerge.push_iter(input);
326            }
327            let merged_data: Vec<u64> = kmerge.collect();
328
329            let mut sorted_data: Vec<u64> = all_data
330                .into_iter()
331                .flat_map(|stream| stream.0.into_iter().flatten())
332                .collect();
333            sorted_data.sort_unstable();
334
335            prop_assert_eq!(merged_data.len(), sorted_data.len(), "Lengths should be equal");
336            prop_assert_eq!(merged_data, sorted_data, "Merged data should equal sorted data");
337        }
338
339        /// Property: K-way merge should preserve sortedness when inputs are sorted
340        #[rstest]
341        fn prop_kmerge_preserves_sort_order(
342            all_data in prop::collection::vec(sorted_nested_vec_strategy(), 1..=5)
343        ) {
344            let mut kmerge: KMerge<_, u64, _> = KMerge::new(OrdComparator);
345
346            for stream in all_data {
347                let input = stream.0.into_iter().map(std::iter::IntoIterator::into_iter);
348                kmerge.push_iter(input);
349            }
350            let merged_data: Vec<u64> = kmerge.collect();
351
352            // Check that the merged data is sorted
353            for window in merged_data.windows(2) {
354                prop_assert!(window[0] <= window[1], "Merged data should be sorted");
355            }
356        }
357
358        /// Property: Empty iterators should not affect the merge result
359        #[rstest]
360        fn prop_kmerge_handles_empty_iterators(
361            data in sorted_nested_vec_strategy(),
362            empty_count in 0usize..=5
363        ) {
364            let mut kmerge_with_empty: KMerge<_, u64, _> = KMerge::new(OrdComparator);
365            let mut kmerge_without_empty: KMerge<_, u64, _> = KMerge::new(OrdComparator);
366
367            // Add the actual data to both merges
368            let input_with_empty = data.0.clone().into_iter().map(std::iter::IntoIterator::into_iter);
369            let input_without_empty = data.0.into_iter().map(std::iter::IntoIterator::into_iter);
370
371            kmerge_with_empty.push_iter(input_with_empty);
372            kmerge_without_empty.push_iter(input_without_empty);
373
374            // Add empty iterators to the first merge
375            for _ in 0..empty_count {
376                let empty_vec: Vec<Vec<u64>> = vec![];
377                let empty_input = empty_vec.into_iter().map(std::iter::IntoIterator::into_iter);
378                kmerge_with_empty.push_iter(empty_input);
379            }
380
381            let result_with_empty: Vec<u64> = kmerge_with_empty.collect();
382            let result_without_empty: Vec<u64> = kmerge_without_empty.collect();
383
384            prop_assert_eq!(result_with_empty, result_without_empty, "Empty iterators should not affect result");
385        }
386    }
387}