1use std::{sync::Arc, vec::IntoIter};
17
18use futures::{Stream, StreamExt};
19use tokio::{
20 runtime::Runtime,
21 sync::mpsc::{self, Receiver},
22 task::JoinHandle,
23};
24
25use super::{
26 binary_heap::{BinaryHeap, PeekMut},
27 compare::Compare,
28};
29
30pub struct EagerStream<T> {
31 rx: Receiver<T>,
32 task: JoinHandle<()>,
33 runtime: Arc<Runtime>,
34}
35
36impl<T> EagerStream<T> {
37 pub fn from_stream_with_runtime<S>(stream: S, runtime: Arc<Runtime>) -> Self
38 where
39 S: Stream<Item = T> + Send + 'static,
40 T: Send + 'static,
41 {
42 let (tx, rx) = mpsc::channel(1);
43
44 let task = runtime.spawn(async move {
45 futures::pin_mut!(stream);
46 while let Some(item) = stream.next().await {
47 if tx.send(item).await.is_err() {
48 break;
49 }
50 }
51 });
52
53 Self { rx, task, runtime }
54 }
55}
56
57impl<T> Iterator for EagerStream<T> {
58 type Item = T;
59
60 fn next(&mut self) -> Option<Self::Item> {
61 self.runtime.block_on(self.rx.recv())
62 }
63}
64
65impl<T> Drop for EagerStream<T> {
66 fn drop(&mut self) {
67 self.rx.close();
68 self.task.abort();
69 }
70}
71
72pub struct ElementBatchIter<I, T>
75where
76 I: Iterator<Item = IntoIter<T>>,
77{
78 pub item: T,
79 batch: I::Item,
80 iter: I,
81}
82
83impl<I, T> ElementBatchIter<I, T>
84where
85 I: Iterator<Item = IntoIter<T>>,
86{
87 fn new_from_iter(mut iter: I) -> Option<Self> {
88 loop {
89 match iter.next() {
90 Some(mut batch) => match batch.next() {
91 Some(item) => {
92 break Some(Self { item, batch, iter });
93 }
94 None => continue,
95 },
96 None => break None,
97 }
98 }
99 }
100}
101
102pub struct KMerge<I, T, C>
103where
104 I: Iterator<Item = IntoIter<T>>,
105{
106 heap: BinaryHeap<ElementBatchIter<I, T>, C>,
107}
108
109impl<I, T, C> KMerge<I, T, C>
110where
111 I: Iterator<Item = IntoIter<T>>,
112 C: Compare<ElementBatchIter<I, T>>,
113{
114 pub fn new(cmp: C) -> Self {
116 Self {
117 heap: BinaryHeap::from_vec_cmp(Vec::new(), cmp),
118 }
119 }
120
121 pub fn push_iter(&mut self, s: I) {
122 if let Some(heap_elem) = ElementBatchIter::new_from_iter(s) {
123 self.heap.push(heap_elem);
124 }
125 }
126
127 pub fn clear(&mut self) {
128 self.heap.clear();
129 }
130}
131
132impl<I, T, C> Iterator for KMerge<I, T, C>
133where
134 I: Iterator<Item = IntoIter<T>>,
135 C: Compare<ElementBatchIter<I, T>>,
136{
137 type Item = T;
138
139 fn next(&mut self) -> Option<Self::Item> {
140 match self.heap.peek_mut() {
141 Some(mut heap_elem) => {
142 match heap_elem.batch.next() {
144 Some(mut item) => {
147 std::mem::swap(&mut item, &mut heap_elem.item);
148 Some(item)
149 }
150 None => loop {
153 if let Some(mut batch) = heap_elem.iter.next() {
154 match batch.next() {
155 Some(mut item) => {
156 heap_elem.batch = batch;
157 std::mem::swap(&mut item, &mut heap_elem.item);
158 break Some(item);
159 }
160 None => continue,
162 }
163 } else {
164 let ElementBatchIter {
165 item,
166 batch: _,
167 iter: _,
168 } = PeekMut::pop(heap_elem);
169 break Some(item);
170 }
171 },
172 }
173 }
174 None => None,
175 }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use proptest::prelude::*;
182 use rstest::rstest;
183
184 use super::*;
185
186 struct OrdComparator;
187 impl<S> Compare<ElementBatchIter<S, i32>> for OrdComparator
188 where
189 S: Iterator<Item = IntoIter<i32>>,
190 {
191 fn compare(
192 &self,
193 l: &ElementBatchIter<S, i32>,
194 r: &ElementBatchIter<S, i32>,
195 ) -> std::cmp::Ordering {
196 l.item.cmp(&r.item).reverse()
198 }
199 }
200
201 impl<S> Compare<ElementBatchIter<S, u64>> for OrdComparator
202 where
203 S: Iterator<Item = IntoIter<u64>>,
204 {
205 fn compare(
206 &self,
207 l: &ElementBatchIter<S, u64>,
208 r: &ElementBatchIter<S, u64>,
209 ) -> std::cmp::Ordering {
210 l.item.cmp(&r.item).reverse()
212 }
213 }
214
215 #[rstest]
216 fn test1() {
217 let iter_a = vec![vec![1, 2, 3].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
218 let iter_b = vec![vec![4, 5, 6].into_iter()].into_iter();
219 let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
220 kmerge.push_iter(iter_a);
221 kmerge.push_iter(iter_b);
222
223 let values: Vec<i32> = kmerge.collect();
224 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
225 }
226
227 #[rstest]
228 fn test2() {
229 let iter_a = vec![vec![1, 2, 6].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
230 let iter_b = vec![vec![3, 4, 5, 6].into_iter()].into_iter();
231 let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
232 kmerge.push_iter(iter_a);
233 kmerge.push_iter(iter_b);
234
235 let values: Vec<i32> = kmerge.collect();
236 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 6, 7, 8, 9]);
237 }
238
239 #[rstest]
240 fn test3() {
241 let iter_a = vec![vec![1, 4, 7].into_iter(), vec![24, 35, 56].into_iter()].into_iter();
242 let iter_b = vec![vec![2, 4, 8].into_iter()].into_iter();
243 let iter_c = vec![vec![3, 5, 9].into_iter(), vec![12, 12, 90].into_iter()].into_iter();
244 let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
245 kmerge.push_iter(iter_a);
246 kmerge.push_iter(iter_b);
247 kmerge.push_iter(iter_c);
248
249 let values: Vec<i32> = kmerge.collect();
250 assert_eq!(
251 values,
252 vec![1, 2, 3, 4, 4, 5, 7, 8, 9, 12, 12, 24, 35, 56, 90]
253 );
254 }
255
256 #[rstest]
257 fn test5() {
258 let iter_a = vec![
259 vec![1, 3, 5].into_iter(),
260 vec![].into_iter(),
261 vec![7, 9, 11].into_iter(),
262 ]
263 .into_iter();
264 let iter_b = vec![vec![2, 4, 6].into_iter()].into_iter();
265 let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
266 kmerge.push_iter(iter_a);
267 kmerge.push_iter(iter_b);
268
269 let values: Vec<i32> = kmerge.collect();
270 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 9, 11]);
271 }
272
273 #[derive(Debug, Clone)]
274 struct SortedNestedVec(Vec<Vec<u64>>);
275
276 fn sorted_nested_vec_strategy() -> impl Strategy<Value = SortedNestedVec> {
278 prop::collection::vec(any::<u64>(), 0..=100).prop_flat_map(|mut flat_vec| {
280 flat_vec.sort_unstable();
281
282 let total_len = flat_vec.len();
284 if total_len == 0 {
285 return Just(SortedNestedVec(vec![vec![]])).boxed();
286 }
287
288 prop::collection::vec(0..=total_len, 0..=10)
290 .prop_map(move |mut boundaries| {
291 boundaries.push(0);
292 boundaries.push(total_len);
293 boundaries.sort_unstable();
294 boundaries.dedup();
295
296 let mut nested_vec = Vec::new();
297 for window in boundaries.windows(2) {
298 let start = window[0];
299 let end = window[1];
300 nested_vec.push(flat_vec[start..end].to_vec());
301 }
302
303 SortedNestedVec(nested_vec)
304 })
305 .boxed()
306 })
307 }
308
309 proptest! {
310 #[rstest]
312 fn prop_kmerge_equivalent_to_sort(
313 all_data in prop::collection::vec(sorted_nested_vec_strategy(), 0..=10)
314 ) {
315 let mut kmerge: KMerge<_, u64, _> = KMerge::new(OrdComparator);
316
317 let copy_data = all_data.clone();
318 for stream in copy_data {
319 let input = stream.0.into_iter().map(std::iter::IntoIterator::into_iter);
320 kmerge.push_iter(input);
321 }
322 let merged_data: Vec<u64> = kmerge.collect();
323
324 let mut sorted_data: Vec<u64> = all_data
325 .into_iter()
326 .flat_map(|stream| stream.0.into_iter().flatten())
327 .collect();
328 sorted_data.sort_unstable();
329
330 prop_assert_eq!(merged_data.len(), sorted_data.len(), "Lengths should be equal");
331 prop_assert_eq!(merged_data, sorted_data, "Merged data should equal sorted data");
332 }
333
334 #[rstest]
336 fn prop_kmerge_preserves_sort_order(
337 all_data in prop::collection::vec(sorted_nested_vec_strategy(), 1..=5)
338 ) {
339 let mut kmerge: KMerge<_, u64, _> = KMerge::new(OrdComparator);
340
341 for stream in all_data {
342 let input = stream.0.into_iter().map(std::iter::IntoIterator::into_iter);
343 kmerge.push_iter(input);
344 }
345 let merged_data: Vec<u64> = kmerge.collect();
346
347 for window in merged_data.windows(2) {
349 prop_assert!(window[0] <= window[1], "Merged data should be sorted");
350 }
351 }
352
353 #[rstest]
355 fn prop_kmerge_handles_empty_iterators(
356 data in sorted_nested_vec_strategy(),
357 empty_count in 0usize..=5
358 ) {
359 let mut kmerge_with_empty: KMerge<_, u64, _> = KMerge::new(OrdComparator);
360 let mut kmerge_without_empty: KMerge<_, u64, _> = KMerge::new(OrdComparator);
361
362 let input_with_empty = data.0.clone().into_iter().map(std::iter::IntoIterator::into_iter);
364 let input_without_empty = data.0.into_iter().map(std::iter::IntoIterator::into_iter);
365
366 kmerge_with_empty.push_iter(input_with_empty);
367 kmerge_without_empty.push_iter(input_without_empty);
368
369 for _ in 0..empty_count {
371 let empty_vec: Vec<Vec<u64>> = vec![];
372 let empty_input = empty_vec.into_iter().map(std::iter::IntoIterator::into_iter);
373 kmerge_with_empty.push_iter(empty_input);
374 }
375
376 let result_with_empty: Vec<u64> = kmerge_with_empty.collect();
377 let result_without_empty: Vec<u64> = kmerge_without_empty.collect();
378
379 prop_assert_eq!(result_with_empty, result_without_empty, "Empty iterators should not affect result");
380 }
381 }
382}