nautilus_persistence/backend/
kmerge_batch.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{sync::Arc, vec::IntoIter};
17
18use binary_heap_plus::{BinaryHeap, PeekMut};
19use compare::Compare;
20use futures::{Stream, StreamExt};
21use tokio::{
22    runtime::Runtime,
23    sync::mpsc::{self, Receiver},
24    task::JoinHandle,
25};
26
27pub struct EagerStream<T> {
28    rx: Receiver<T>,
29    task: JoinHandle<()>,
30    runtime: Arc<Runtime>,
31}
32
33impl<T> EagerStream<T> {
34    pub fn from_stream_with_runtime<S>(stream: S, runtime: Arc<Runtime>) -> Self
35    where
36        S: Stream<Item = T> + Send + 'static,
37        T: Send + 'static,
38    {
39        let _guard = runtime.enter();
40        let (tx, rx) = mpsc::channel(1);
41        let task = tokio::spawn(async move {
42            stream
43                .for_each(|item| async {
44                    let _ = tx.send(item).await;
45                })
46                .await;
47        });
48
49        Self { rx, task, runtime }
50    }
51}
52
53impl<T> Iterator for EagerStream<T> {
54    type Item = T;
55
56    fn next(&mut self) -> Option<Self::Item> {
57        self.runtime.block_on(self.rx.recv())
58    }
59}
60
61impl<T> Drop for EagerStream<T> {
62    fn drop(&mut self) {
63        self.rx.close();
64        self.task.abort();
65    }
66}
67
68// TODO: Investigate implementing Iterator for ElementBatchIter
69// to reduce next element duplication. May be difficult to make it peekable.
70pub struct ElementBatchIter<I, T>
71where
72    I: Iterator<Item = IntoIter<T>>,
73{
74    pub item: T,
75    batch: I::Item,
76    iter: I,
77}
78
79impl<I, T> ElementBatchIter<I, T>
80where
81    I: Iterator<Item = IntoIter<T>>,
82{
83    fn new_from_iter(mut iter: I) -> Option<Self> {
84        loop {
85            match iter.next() {
86                Some(mut batch) => match batch.next() {
87                    Some(item) => {
88                        break Some(Self { item, batch, iter });
89                    }
90                    None => continue,
91                },
92                None => break None,
93            }
94        }
95    }
96}
97
98pub struct KMerge<I, T, C>
99where
100    I: Iterator<Item = IntoIter<T>>,
101{
102    heap: BinaryHeap<ElementBatchIter<I, T>, C>,
103}
104
105impl<I, T, C> KMerge<I, T, C>
106where
107    I: Iterator<Item = IntoIter<T>>,
108    C: Compare<ElementBatchIter<I, T>>,
109{
110    /// Creates a new [`KMerge`] instance.
111    pub fn new(cmp: C) -> Self {
112        Self {
113            heap: BinaryHeap::from_vec_cmp(Vec::new(), cmp),
114        }
115    }
116
117    pub fn push_iter(&mut self, s: I) {
118        if let Some(heap_elem) = ElementBatchIter::new_from_iter(s) {
119            self.heap.push(heap_elem);
120        }
121    }
122
123    pub fn clear(&mut self) {
124        self.heap.clear();
125    }
126}
127
128impl<I, T, C> Iterator for KMerge<I, T, C>
129where
130    I: Iterator<Item = IntoIter<T>>,
131    C: Compare<ElementBatchIter<I, T>>,
132{
133    type Item = T;
134
135    fn next(&mut self) -> Option<Self::Item> {
136        match self.heap.peek_mut() {
137            Some(mut heap_elem) => {
138                // Get next element from batch
139                match heap_elem.batch.next() {
140                    // Swap current heap element with new element
141                    // return the old element
142                    Some(mut item) => {
143                        std::mem::swap(&mut item, &mut heap_elem.item);
144                        Some(item)
145                    }
146                    // Otherwise get the next batch and the element from it
147                    // Unless the underlying iterator is exhausted
148                    None => loop {
149                        if let Some(mut batch) = heap_elem.iter.next() {
150                            match batch.next() {
151                                Some(mut item) => {
152                                    heap_elem.batch = batch;
153                                    std::mem::swap(&mut item, &mut heap_elem.item);
154                                    break Some(item);
155                                }
156                                // Get next batch from iterator
157                                None => continue,
158                            }
159                        } else {
160                            let ElementBatchIter {
161                                item,
162                                batch: _,
163                                iter: _,
164                            } = PeekMut::pop(heap_elem);
165                            break Some(item);
166                        }
167                    },
168                }
169            }
170            None => None,
171        }
172    }
173}
174
175////////////////////////////////////////////////////////////////////////////////
176// Tests
177////////////////////////////////////////////////////////////////////////////////
178#[cfg(test)]
179mod tests {
180
181    use proptest::prelude::*;
182    use rstest::rstest;
183
184    use super::*;
185
186    struct OrdComparator;
187    impl<S> Compare<ElementBatchIter<S, i32>> for OrdComparator
188    where
189        S: Iterator<Item = IntoIter<i32>>,
190    {
191        fn compare(
192            &self,
193            l: &ElementBatchIter<S, i32>,
194            r: &ElementBatchIter<S, i32>,
195        ) -> std::cmp::Ordering {
196            // Max heap ordering must be reversed
197            l.item.cmp(&r.item).reverse()
198        }
199    }
200
201    impl<S> Compare<ElementBatchIter<S, u64>> for OrdComparator
202    where
203        S: Iterator<Item = IntoIter<u64>>,
204    {
205        fn compare(
206            &self,
207            l: &ElementBatchIter<S, u64>,
208            r: &ElementBatchIter<S, u64>,
209        ) -> std::cmp::Ordering {
210            // Max heap ordering must be reversed
211            l.item.cmp(&r.item).reverse()
212        }
213    }
214
215    #[rstest]
216    fn test1() {
217        let iter_a = vec![vec![1, 2, 3].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
218        let iter_b = vec![vec![4, 5, 6].into_iter()].into_iter();
219        let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
220        kmerge.push_iter(iter_a);
221        kmerge.push_iter(iter_b);
222
223        let values: Vec<i32> = kmerge.collect();
224        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
225    }
226
227    #[rstest]
228    fn test2() {
229        let iter_a = vec![vec![1, 2, 6].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
230        let iter_b = vec![vec![3, 4, 5, 6].into_iter()].into_iter();
231        let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
232        kmerge.push_iter(iter_a);
233        kmerge.push_iter(iter_b);
234
235        let values: Vec<i32> = kmerge.collect();
236        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 6, 7, 8, 9]);
237    }
238
239    #[rstest]
240    fn test3() {
241        let iter_a = vec![vec![1, 4, 7].into_iter(), vec![24, 35, 56].into_iter()].into_iter();
242        let iter_b = vec![vec![2, 4, 8].into_iter()].into_iter();
243        let iter_c = vec![vec![3, 5, 9].into_iter(), vec![12, 12, 90].into_iter()].into_iter();
244        let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
245        kmerge.push_iter(iter_a);
246        kmerge.push_iter(iter_b);
247        kmerge.push_iter(iter_c);
248
249        let values: Vec<i32> = kmerge.collect();
250        assert_eq!(
251            values,
252            vec![1, 2, 3, 4, 4, 5, 7, 8, 9, 12, 12, 24, 35, 56, 90]
253        );
254    }
255
256    #[rstest]
257    fn test5() {
258        let iter_a = vec![
259            vec![1, 3, 5].into_iter(),
260            vec![].into_iter(),
261            vec![7, 9, 11].into_iter(),
262        ]
263        .into_iter();
264        let iter_b = vec![vec![2, 4, 6].into_iter()].into_iter();
265        let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
266        kmerge.push_iter(iter_a);
267        kmerge.push_iter(iter_b);
268
269        let values: Vec<i32> = kmerge.collect();
270        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 9, 11]);
271    }
272
273    #[derive(Debug, Clone)]
274    struct SortedNestedVec(Vec<Vec<u64>>);
275
276    /// Strategy to generate nested vectors where each inner vector is sorted.
277    fn sorted_nested_vec_strategy() -> impl Strategy<Value = SortedNestedVec> {
278        // Generate a vector of u64 values, then split into sorted chunks
279        prop::collection::vec(any::<u64>(), 0..=100).prop_flat_map(|mut flat_vec| {
280            flat_vec.sort_unstable();
281
282            // Generate chunk sizes that will split the sorted vector
283            let total_len = flat_vec.len();
284            if total_len == 0 {
285                return Just(SortedNestedVec(vec![vec![]])).boxed();
286            }
287
288            // Generate random chunk boundaries
289            prop::collection::vec(0..=total_len, 0..=10)
290                .prop_map(move |mut boundaries| {
291                    boundaries.push(0);
292                    boundaries.push(total_len);
293                    boundaries.sort_unstable();
294                    boundaries.dedup();
295
296                    let mut nested_vec = Vec::new();
297                    for window in boundaries.windows(2) {
298                        let start = window[0];
299                        let end = window[1];
300                        nested_vec.push(flat_vec[start..end].to_vec());
301                    }
302
303                    SortedNestedVec(nested_vec)
304                })
305                .boxed()
306        })
307    }
308
309    ////////////////////////////////////////////////////////////////////////////////
310    // Property-based testing
311    ////////////////////////////////////////////////////////////////////////////////
312
313    proptest! {
314        /// Property: K-way merge should produce the same result as sorting all data together
315        #[test]
316        fn prop_kmerge_equivalent_to_sort(
317            all_data in prop::collection::vec(sorted_nested_vec_strategy(), 0..=10)
318        ) {
319            let mut kmerge: KMerge<_, u64, _> = KMerge::new(OrdComparator);
320
321            let copy_data = all_data.clone();
322            for stream in copy_data {
323                let input = stream.0.into_iter().map(std::iter::IntoIterator::into_iter);
324                kmerge.push_iter(input);
325            }
326            let merged_data: Vec<u64> = kmerge.collect();
327
328            let mut sorted_data: Vec<u64> = all_data
329                .into_iter()
330                .flat_map(|stream| stream.0.into_iter().flatten())
331                .collect();
332            sorted_data.sort_unstable();
333
334            prop_assert_eq!(merged_data.len(), sorted_data.len(), "Lengths should be equal");
335            prop_assert_eq!(merged_data, sorted_data, "Merged data should equal sorted data");
336        }
337
338        /// Property: K-way merge should preserve sortedness when inputs are sorted
339        #[test]
340        fn prop_kmerge_preserves_sort_order(
341            all_data in prop::collection::vec(sorted_nested_vec_strategy(), 1..=5)
342        ) {
343            let mut kmerge: KMerge<_, u64, _> = KMerge::new(OrdComparator);
344
345            for stream in all_data {
346                let input = stream.0.into_iter().map(std::iter::IntoIterator::into_iter);
347                kmerge.push_iter(input);
348            }
349            let merged_data: Vec<u64> = kmerge.collect();
350
351            // Check that the merged data is sorted
352            for window in merged_data.windows(2) {
353                prop_assert!(window[0] <= window[1], "Merged data should be sorted");
354            }
355        }
356
357        /// Property: Empty iterators should not affect the merge result
358        #[test]
359        fn prop_kmerge_handles_empty_iterators(
360            data in sorted_nested_vec_strategy(),
361            empty_count in 0usize..=5
362        ) {
363            let mut kmerge_with_empty: KMerge<_, u64, _> = KMerge::new(OrdComparator);
364            let mut kmerge_without_empty: KMerge<_, u64, _> = KMerge::new(OrdComparator);
365
366            // Add the actual data to both merges
367            let input_with_empty = data.0.clone().into_iter().map(std::iter::IntoIterator::into_iter);
368            let input_without_empty = data.0.into_iter().map(std::iter::IntoIterator::into_iter);
369
370            kmerge_with_empty.push_iter(input_with_empty);
371            kmerge_without_empty.push_iter(input_without_empty);
372
373            // Add empty iterators to the first merge
374            for _ in 0..empty_count {
375                let empty_vec: Vec<Vec<u64>> = vec![];
376                let empty_input = empty_vec.into_iter().map(std::iter::IntoIterator::into_iter);
377                kmerge_with_empty.push_iter(empty_input);
378            }
379
380            let result_with_empty: Vec<u64> = kmerge_with_empty.collect();
381            let result_without_empty: Vec<u64> = kmerge_without_empty.collect();
382
383            prop_assert_eq!(result_with_empty, result_without_empty, "Empty iterators should not affect result");
384        }
385    }
386}