Skip to main content

nautilus_persistence/backend/
binary_heap.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! A priority queue implemented with a binary heap.
17//!
18//! Vendored from the `binary-heap-plus` crate which depends on the unmaintained
19//! `compare` crate. Distributed here under MIT (see `licenses/` directory).
20//! Original source: <https://github.com/sekineh/binary-heap-plus-rs>
21
22#![deny(unsafe_op_in_unsafe_fn)]
23
24use std::{
25    fmt,
26    mem::{ManuallyDrop, swap},
27    ops::{Deref, DerefMut},
28    ptr,
29};
30
31use super::compare::Compare;
32
33/// A priority queue implemented with a binary heap.
34///
35/// This will be a max-heap by default, but the ordering is determined by the
36/// comparator `C`.
37pub struct BinaryHeap<T, C> {
38    data: Vec<T>,
39    cmp: C,
40}
41
42impl<T, C: Compare<T>> BinaryHeap<T, C> {
43    /// Creates a `BinaryHeap` from a `Vec` and comparator.
44    pub fn from_vec_cmp(vec: Vec<T>, cmp: C) -> Self {
45        let mut heap = Self { data: vec, cmp };
46        if !heap.data.is_empty() {
47            heap.rebuild();
48        }
49        heap
50    }
51
52    /// Returns a mutable reference to the greatest item in the binary heap, or
53    /// `None` if it is empty.
54    pub fn peek_mut(&mut self) -> Option<PeekMut<'_, T, C>> {
55        if self.is_empty() {
56            None
57        } else {
58            Some(PeekMut {
59                heap: self,
60                sift: false,
61            })
62        }
63    }
64
65    /// Removes the greatest item from the binary heap and returns it, or `None`
66    /// if it is empty.
67    pub fn pop(&mut self) -> Option<T> {
68        self.data.pop().map(|mut item| {
69            if !self.is_empty() {
70                swap(&mut item, &mut self.data[0]);
71                // SAFETY: !self.is_empty() means that self.len() > 0
72                unsafe { self.sift_down_to_bottom(0) };
73            }
74            item
75        })
76    }
77
78    /// Pushes an item onto the binary heap.
79    pub fn push(&mut self, item: T) {
80        let old_len = self.len();
81        self.data.push(item);
82        // SAFETY: Since we pushed a new item it means that
83        //  old_len = self.len() - 1 < self.len()
84        unsafe { self.sift_up(0, old_len) };
85    }
86
87    /// Sifts an element up towards the root until heap property is restored.
88    ///
89    /// # Safety
90    ///
91    /// The caller must guarantee that `pos < self.len()`.
92    unsafe fn sift_up(&mut self, start: usize, pos: usize) -> usize {
93        // SAFETY: The caller guarantees that pos < self.len()
94        let mut hole = unsafe { Hole::new(&mut self.data, pos) };
95
96        while hole.pos() > start {
97            let parent = (hole.pos() - 1) / 2;
98
99            // SAFETY: hole.pos() > start >= 0, which means hole.pos() > 0
100            //  and so hole.pos() - 1 can't underflow.
101            if self
102                .cmp
103                .compares_le(hole.element(), unsafe { hole.get(parent) })
104            {
105                break;
106            }
107
108            // SAFETY: Same as above
109            unsafe { hole.move_to(parent) };
110        }
111
112        hole.pos()
113    }
114
115    /// Sifts an element down within the range `[pos, end)`.
116    ///
117    /// # Safety
118    ///
119    /// The caller must guarantee that `pos < end <= self.len()`.
120    unsafe fn sift_down_range(&mut self, pos: usize, end: usize) {
121        // SAFETY: The caller guarantees that pos < end <= self.len().
122        let mut hole = unsafe { Hole::new(&mut self.data, pos) };
123        let mut child = 2 * hole.pos() + 1;
124
125        while child <= end.saturating_sub(2) {
126            // SAFETY: child < end - 1 < self.len() and
127            //  child + 1 < end <= self.len(), so they're valid indexes.
128            child += unsafe { self.cmp.compares_le(hole.get(child), hole.get(child + 1)) } as usize;
129
130            // SAFETY: child is now either the old child or the old child+1
131            if self
132                .cmp
133                .compares_ge(hole.element(), unsafe { hole.get(child) })
134            {
135                return;
136            }
137
138            // SAFETY: same as above.
139            unsafe { hole.move_to(child) };
140            child = 2 * hole.pos() + 1;
141        }
142
143        // SAFETY: && short circuit
144        if child == end - 1
145            && self
146                .cmp
147                .compares_lt(hole.element(), unsafe { hole.get(child) })
148        {
149            unsafe { hole.move_to(child) };
150        }
151    }
152
153    /// Sifts an element down until heap property is restored.
154    ///
155    /// # Safety
156    ///
157    /// The caller must guarantee that `pos < self.len()`.
158    unsafe fn sift_down(&mut self, pos: usize) {
159        let len = self.len();
160        // SAFETY: pos < len is guaranteed by the caller
161        unsafe { self.sift_down_range(pos, len) };
162    }
163
164    /// Take an element at `pos` and move it all the way down the heap,
165    /// then sift it up to its position.
166    ///
167    /// # Safety
168    ///
169    /// The caller must guarantee that `pos < self.len()`.
170    unsafe fn sift_down_to_bottom(&mut self, mut pos: usize) {
171        let end = self.len();
172        let start = pos;
173
174        // SAFETY: The caller guarantees that pos < self.len().
175        let mut hole = unsafe { Hole::new(&mut self.data, pos) };
176        let mut child = 2 * hole.pos() + 1;
177
178        while child <= end.saturating_sub(2) {
179            // SAFETY: child < end - 1 < self.len() and
180            //  child + 1 < end <= self.len(), so they're valid indexes.
181            child += unsafe { self.cmp.compares_le(hole.get(child), hole.get(child + 1)) } as usize;
182
183            // SAFETY: Same as above
184            unsafe { hole.move_to(child) };
185            child = 2 * hole.pos() + 1;
186        }
187
188        if child == end - 1 {
189            // SAFETY: child == end - 1 < self.len(), so it's a valid index
190            unsafe { hole.move_to(child) };
191        }
192        pos = hole.pos();
193        drop(hole);
194
195        // SAFETY: pos is the position in the hole
196        unsafe { self.sift_up(start, pos) };
197    }
198
199    /// Rebuilds the heap from an unordered vector.
200    fn rebuild(&mut self) {
201        let mut n = self.len() / 2;
202        while n > 0 {
203            n -= 1;
204            // SAFETY: n starts from self.len() / 2 and goes down to 0.
205            unsafe { self.sift_down(n) };
206        }
207    }
208}
209
210impl<T, C> BinaryHeap<T, C> {
211    /// Returns the length of the binary heap.
212    #[must_use]
213    pub fn len(&self) -> usize {
214        self.data.len()
215    }
216
217    /// Checks if the binary heap is empty.
218    #[must_use]
219    pub fn is_empty(&self) -> bool {
220        self.len() == 0
221    }
222
223    /// Drops all items from the binary heap.
224    pub fn clear(&mut self) {
225        self.data.clear();
226    }
227}
228
229impl<T: fmt::Debug, C> fmt::Debug for BinaryHeap<T, C> {
230    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231        f.debug_list().entries(self.data.iter()).finish()
232    }
233}
234
235impl<T: Clone, C: Clone> Clone for BinaryHeap<T, C> {
236    fn clone(&self) -> Self {
237        Self {
238            data: self.data.clone(),
239            cmp: self.cmp.clone(),
240        }
241    }
242}
243
244/// Structure wrapping a mutable reference to the greatest item on a
245/// `BinaryHeap`.
246pub struct PeekMut<'a, T: 'a, C: 'a + Compare<T>> {
247    heap: &'a mut BinaryHeap<T, C>,
248    sift: bool,
249}
250
251impl<T: fmt::Debug, C: Compare<T>> fmt::Debug for PeekMut<'_, T, C> {
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        f.debug_tuple("PeekMut").field(&self.heap.data[0]).finish()
254    }
255}
256
257impl<T, C: Compare<T>> Drop for PeekMut<'_, T, C> {
258    fn drop(&mut self) {
259        if self.sift {
260            // SAFETY: PeekMut is only instantiated for non-empty heaps.
261            unsafe { self.heap.sift_down(0) };
262        }
263    }
264}
265
266impl<T, C: Compare<T>> Deref for PeekMut<'_, T, C> {
267    type Target = T;
268    fn deref(&self) -> &T {
269        debug_assert!(!self.heap.is_empty());
270        // SAFETY: PeekMut is only instantiated for non-empty heaps
271        unsafe { self.heap.data.get_unchecked(0) }
272    }
273}
274
275impl<T, C: Compare<T>> DerefMut for PeekMut<'_, T, C> {
276    fn deref_mut(&mut self) -> &mut T {
277        debug_assert!(!self.heap.is_empty());
278        self.sift = true;
279        // SAFETY: PeekMut is only instantiated for non-empty heaps
280        unsafe { self.heap.data.get_unchecked_mut(0) }
281    }
282}
283
284impl<'a, T, C: Compare<T>> PeekMut<'a, T, C> {
285    /// Removes the peeked value from the heap and returns it.
286    pub fn pop(mut this: Self) -> T {
287        let value = this.heap.pop().unwrap();
288        this.sift = false;
289        value
290    }
291}
292
293/// Hole represents a hole in a slice i.e., an index without valid value
294/// (because it was moved from or duplicated).
295struct Hole<'a, T: 'a> {
296    data: &'a mut [T],
297    elt: ManuallyDrop<T>,
298    pos: usize,
299}
300
301impl<'a, T> Hole<'a, T> {
302    /// Create a new `Hole` at index `pos`.
303    ///
304    /// # Safety
305    ///
306    /// `pos` must be within the data slice.
307    #[inline]
308    unsafe fn new(data: &'a mut [T], pos: usize) -> Self {
309        debug_assert!(pos < data.len());
310        // SAFETY: pos should be inside the slice
311        let elt = unsafe { ptr::read(data.get_unchecked(pos)) };
312        Hole {
313            data,
314            elt: ManuallyDrop::new(elt),
315            pos,
316        }
317    }
318
319    #[inline]
320    fn pos(&self) -> usize {
321        self.pos
322    }
323
324    /// Returns a reference to the element removed.
325    #[inline]
326    fn element(&self) -> &T {
327        &self.elt
328    }
329
330    /// Returns a reference to the element at `index`.
331    ///
332    /// # Safety
333    ///
334    /// `index` must be within the data slice and not equal to pos.
335    #[inline]
336    unsafe fn get(&self, index: usize) -> &T {
337        debug_assert!(index != self.pos);
338        debug_assert!(index < self.data.len());
339        unsafe { self.data.get_unchecked(index) }
340    }
341
342    /// Move hole to new location.
343    ///
344    /// # Safety
345    ///
346    /// `index` must be within the data slice and not equal to pos.
347    #[inline]
348    unsafe fn move_to(&mut self, index: usize) {
349        debug_assert!(index != self.pos);
350        debug_assert!(index < self.data.len());
351        unsafe {
352            let ptr = self.data.as_mut_ptr();
353            let index_ptr: *const _ = ptr.add(index);
354            let hole_ptr = ptr.add(self.pos);
355            ptr::copy_nonoverlapping(index_ptr, hole_ptr, 1);
356        }
357        self.pos = index;
358    }
359}
360
361impl<T> Drop for Hole<'_, T> {
362    #[inline]
363    fn drop(&mut self) {
364        // Fill the hole again
365        unsafe {
366            let pos = self.pos;
367            ptr::copy_nonoverlapping(
368                ptr::from_ref(&*self.elt),
369                self.data.get_unchecked_mut(pos),
370                1,
371            );
372        }
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use std::cmp::Ordering;
379
380    use rstest::rstest;
381
382    use super::*;
383
384    struct MaxComparator;
385
386    impl Compare<i32> for MaxComparator {
387        fn compare(&self, a: &i32, b: &i32) -> Ordering {
388            a.cmp(b)
389        }
390    }
391
392    struct MinComparator;
393
394    impl Compare<i32> for MinComparator {
395        fn compare(&self, a: &i32, b: &i32) -> Ordering {
396            b.cmp(a)
397        }
398    }
399
400    #[rstest]
401    fn test_max_heap() {
402        let mut heap = BinaryHeap::from_vec_cmp(vec![], MaxComparator);
403        heap.push(3);
404        heap.push(1);
405        heap.push(5);
406
407        assert_eq!(heap.pop(), Some(5));
408        assert_eq!(heap.pop(), Some(3));
409        assert_eq!(heap.pop(), Some(1));
410        assert_eq!(heap.pop(), None);
411    }
412
413    #[rstest]
414    fn test_min_heap() {
415        let mut heap = BinaryHeap::from_vec_cmp(vec![], MinComparator);
416        heap.push(3);
417        heap.push(1);
418        heap.push(5);
419
420        assert_eq!(heap.pop(), Some(1));
421        assert_eq!(heap.pop(), Some(3));
422        assert_eq!(heap.pop(), Some(5));
423        assert_eq!(heap.pop(), None);
424    }
425
426    #[rstest]
427    fn test_peek_mut() {
428        let mut heap = BinaryHeap::from_vec_cmp(vec![1, 5, 2], MaxComparator);
429
430        if let Some(mut val) = heap.peek_mut() {
431            *val = 0;
432        }
433
434        assert_eq!(heap.pop(), Some(2));
435    }
436
437    #[rstest]
438    fn test_peek_mut_pop() {
439        let mut heap = BinaryHeap::from_vec_cmp(vec![1, 5, 2], MaxComparator);
440
441        if let Some(val) = heap.peek_mut() {
442            let popped = PeekMut::pop(val);
443            assert_eq!(popped, 5);
444        }
445
446        assert_eq!(heap.pop(), Some(2));
447        assert_eq!(heap.pop(), Some(1));
448    }
449
450    #[rstest]
451    fn test_clear() {
452        let mut heap = BinaryHeap::from_vec_cmp(vec![1, 2, 3], MaxComparator);
453        assert!(!heap.is_empty());
454
455        heap.clear();
456        assert!(heap.is_empty());
457        assert_eq!(heap.len(), 0);
458    }
459
460    #[rstest]
461    fn test_from_vec() {
462        let heap = BinaryHeap::from_vec_cmp(vec![3, 1, 4, 1, 5, 9, 2, 6], MaxComparator);
463        let mut sorted = Vec::new();
464        let mut heap = heap;
465        while let Some(v) = heap.pop() {
466            sorted.push(v);
467        }
468        assert_eq!(sorted, vec![9, 6, 5, 4, 3, 2, 1, 1]);
469    }
470}