1use std::{sync::Arc, vec::IntoIter};
17
18use binary_heap_plus::{BinaryHeap, PeekMut};
19use compare::Compare;
20use futures::{Stream, StreamExt};
21use tokio::{
22 runtime::Runtime,
23 sync::mpsc::{self, Receiver},
24 task::JoinHandle,
25};
26
27pub struct EagerStream<T> {
28 rx: Receiver<T>,
29 task: JoinHandle<()>,
30 runtime: Arc<Runtime>,
31}
32
33impl<T> EagerStream<T> {
34 pub fn from_stream_with_runtime<S>(stream: S, runtime: Arc<Runtime>) -> Self
35 where
36 S: Stream<Item = T> + Send + 'static,
37 T: Send + 'static,
38 {
39 let _guard = runtime.enter();
40 let (tx, rx) = mpsc::channel(1);
41
42 let task = tokio::spawn(async move {
43 stream
44 .for_each(|item| async {
45 let _ = tx.send(item).await;
46 })
47 .await;
48 });
49
50 Self { rx, task, runtime }
51 }
52}
53
54impl<T> Iterator for EagerStream<T> {
55 type Item = T;
56
57 fn next(&mut self) -> Option<Self::Item> {
58 self.runtime.block_on(self.rx.recv())
59 }
60}
61
62impl<T> Drop for EagerStream<T> {
63 fn drop(&mut self) {
64 self.rx.close();
65 self.task.abort();
66 }
67}
68
69pub struct ElementBatchIter<I, T>
72where
73 I: Iterator<Item = IntoIter<T>>,
74{
75 pub item: T,
76 batch: I::Item,
77 iter: I,
78}
79
80impl<I, T> ElementBatchIter<I, T>
81where
82 I: Iterator<Item = IntoIter<T>>,
83{
84 fn new_from_iter(mut iter: I) -> Option<Self> {
85 loop {
86 match iter.next() {
87 Some(mut batch) => match batch.next() {
88 Some(item) => {
89 break Some(Self { item, batch, iter });
90 }
91 None => continue,
92 },
93 None => break None,
94 }
95 }
96 }
97}
98
99pub struct KMerge<I, T, C>
100where
101 I: Iterator<Item = IntoIter<T>>,
102{
103 heap: BinaryHeap<ElementBatchIter<I, T>, C>,
104}
105
106impl<I, T, C> KMerge<I, T, C>
107where
108 I: Iterator<Item = IntoIter<T>>,
109 C: Compare<ElementBatchIter<I, T>>,
110{
111 pub fn new(cmp: C) -> Self {
113 Self {
114 heap: BinaryHeap::from_vec_cmp(Vec::new(), cmp),
115 }
116 }
117
118 pub fn push_iter(&mut self, s: I) {
119 if let Some(heap_elem) = ElementBatchIter::new_from_iter(s) {
120 self.heap.push(heap_elem);
121 }
122 }
123
124 pub fn clear(&mut self) {
125 self.heap.clear();
126 }
127}
128
129impl<I, T, C> Iterator for KMerge<I, T, C>
130where
131 I: Iterator<Item = IntoIter<T>>,
132 C: Compare<ElementBatchIter<I, T>>,
133{
134 type Item = T;
135
136 fn next(&mut self) -> Option<Self::Item> {
137 match self.heap.peek_mut() {
138 Some(mut heap_elem) => {
139 match heap_elem.batch.next() {
141 Some(mut item) => {
144 std::mem::swap(&mut item, &mut heap_elem.item);
145 Some(item)
146 }
147 None => loop {
150 if let Some(mut batch) = heap_elem.iter.next() {
151 match batch.next() {
152 Some(mut item) => {
153 heap_elem.batch = batch;
154 std::mem::swap(&mut item, &mut heap_elem.item);
155 break Some(item);
156 }
157 None => continue,
159 }
160 } else {
161 let ElementBatchIter {
162 item,
163 batch: _,
164 iter: _,
165 } = PeekMut::pop(heap_elem);
166 break Some(item);
167 }
168 },
169 }
170 }
171 None => None,
172 }
173 }
174}
175
176#[cfg(test)]
177mod tests {
178
179 use proptest::prelude::*;
180 use rstest::rstest;
181
182 use super::*;
183
184 struct OrdComparator;
185 impl<S> Compare<ElementBatchIter<S, i32>> for OrdComparator
186 where
187 S: Iterator<Item = IntoIter<i32>>,
188 {
189 fn compare(
190 &self,
191 l: &ElementBatchIter<S, i32>,
192 r: &ElementBatchIter<S, i32>,
193 ) -> std::cmp::Ordering {
194 l.item.cmp(&r.item).reverse()
196 }
197 }
198
199 impl<S> Compare<ElementBatchIter<S, u64>> for OrdComparator
200 where
201 S: Iterator<Item = IntoIter<u64>>,
202 {
203 fn compare(
204 &self,
205 l: &ElementBatchIter<S, u64>,
206 r: &ElementBatchIter<S, u64>,
207 ) -> std::cmp::Ordering {
208 l.item.cmp(&r.item).reverse()
210 }
211 }
212
213 #[rstest]
214 fn test1() {
215 let iter_a = vec![vec![1, 2, 3].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
216 let iter_b = vec![vec![4, 5, 6].into_iter()].into_iter();
217 let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
218 kmerge.push_iter(iter_a);
219 kmerge.push_iter(iter_b);
220
221 let values: Vec<i32> = kmerge.collect();
222 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
223 }
224
225 #[rstest]
226 fn test2() {
227 let iter_a = vec![vec![1, 2, 6].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
228 let iter_b = vec![vec![3, 4, 5, 6].into_iter()].into_iter();
229 let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
230 kmerge.push_iter(iter_a);
231 kmerge.push_iter(iter_b);
232
233 let values: Vec<i32> = kmerge.collect();
234 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 6, 7, 8, 9]);
235 }
236
237 #[rstest]
238 fn test3() {
239 let iter_a = vec![vec![1, 4, 7].into_iter(), vec![24, 35, 56].into_iter()].into_iter();
240 let iter_b = vec![vec![2, 4, 8].into_iter()].into_iter();
241 let iter_c = vec![vec![3, 5, 9].into_iter(), vec![12, 12, 90].into_iter()].into_iter();
242 let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
243 kmerge.push_iter(iter_a);
244 kmerge.push_iter(iter_b);
245 kmerge.push_iter(iter_c);
246
247 let values: Vec<i32> = kmerge.collect();
248 assert_eq!(
249 values,
250 vec![1, 2, 3, 4, 4, 5, 7, 8, 9, 12, 12, 24, 35, 56, 90]
251 );
252 }
253
254 #[rstest]
255 fn test5() {
256 let iter_a = vec![
257 vec![1, 3, 5].into_iter(),
258 vec![].into_iter(),
259 vec![7, 9, 11].into_iter(),
260 ]
261 .into_iter();
262 let iter_b = vec![vec![2, 4, 6].into_iter()].into_iter();
263 let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
264 kmerge.push_iter(iter_a);
265 kmerge.push_iter(iter_b);
266
267 let values: Vec<i32> = kmerge.collect();
268 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 9, 11]);
269 }
270
271 #[derive(Debug, Clone)]
272 struct SortedNestedVec(Vec<Vec<u64>>);
273
274 fn sorted_nested_vec_strategy() -> impl Strategy<Value = SortedNestedVec> {
276 prop::collection::vec(any::<u64>(), 0..=100).prop_flat_map(|mut flat_vec| {
278 flat_vec.sort_unstable();
279
280 let total_len = flat_vec.len();
282 if total_len == 0 {
283 return Just(SortedNestedVec(vec![vec![]])).boxed();
284 }
285
286 prop::collection::vec(0..=total_len, 0..=10)
288 .prop_map(move |mut boundaries| {
289 boundaries.push(0);
290 boundaries.push(total_len);
291 boundaries.sort_unstable();
292 boundaries.dedup();
293
294 let mut nested_vec = Vec::new();
295 for window in boundaries.windows(2) {
296 let start = window[0];
297 let end = window[1];
298 nested_vec.push(flat_vec[start..end].to_vec());
299 }
300
301 SortedNestedVec(nested_vec)
302 })
303 .boxed()
304 })
305 }
306
307 proptest! {
312 #[rstest]
314 fn prop_kmerge_equivalent_to_sort(
315 all_data in prop::collection::vec(sorted_nested_vec_strategy(), 0..=10)
316 ) {
317 let mut kmerge: KMerge<_, u64, _> = KMerge::new(OrdComparator);
318
319 let copy_data = all_data.clone();
320 for stream in copy_data {
321 let input = stream.0.into_iter().map(std::iter::IntoIterator::into_iter);
322 kmerge.push_iter(input);
323 }
324 let merged_data: Vec<u64> = kmerge.collect();
325
326 let mut sorted_data: Vec<u64> = all_data
327 .into_iter()
328 .flat_map(|stream| stream.0.into_iter().flatten())
329 .collect();
330 sorted_data.sort_unstable();
331
332 prop_assert_eq!(merged_data.len(), sorted_data.len(), "Lengths should be equal");
333 prop_assert_eq!(merged_data, sorted_data, "Merged data should equal sorted data");
334 }
335
336 #[rstest]
338 fn prop_kmerge_preserves_sort_order(
339 all_data in prop::collection::vec(sorted_nested_vec_strategy(), 1..=5)
340 ) {
341 let mut kmerge: KMerge<_, u64, _> = KMerge::new(OrdComparator);
342
343 for stream in all_data {
344 let input = stream.0.into_iter().map(std::iter::IntoIterator::into_iter);
345 kmerge.push_iter(input);
346 }
347 let merged_data: Vec<u64> = kmerge.collect();
348
349 for window in merged_data.windows(2) {
351 prop_assert!(window[0] <= window[1], "Merged data should be sorted");
352 }
353 }
354
355 #[rstest]
357 fn prop_kmerge_handles_empty_iterators(
358 data in sorted_nested_vec_strategy(),
359 empty_count in 0usize..=5
360 ) {
361 let mut kmerge_with_empty: KMerge<_, u64, _> = KMerge::new(OrdComparator);
362 let mut kmerge_without_empty: KMerge<_, u64, _> = KMerge::new(OrdComparator);
363
364 let input_with_empty = data.0.clone().into_iter().map(std::iter::IntoIterator::into_iter);
366 let input_without_empty = data.0.into_iter().map(std::iter::IntoIterator::into_iter);
367
368 kmerge_with_empty.push_iter(input_with_empty);
369 kmerge_without_empty.push_iter(input_without_empty);
370
371 for _ in 0..empty_count {
373 let empty_vec: Vec<Vec<u64>> = vec![];
374 let empty_input = empty_vec.into_iter().map(std::iter::IntoIterator::into_iter);
375 kmerge_with_empty.push_iter(empty_input);
376 }
377
378 let result_with_empty: Vec<u64> = kmerge_with_empty.collect();
379 let result_without_empty: Vec<u64> = kmerge_without_empty.collect();
380
381 prop_assert_eq!(result_with_empty, result_without_empty, "Empty iterators should not affect result");
382 }
383 }
384}