nautilus_persistence/backend/
binary_heap.rs1#![deny(unsafe_op_in_unsafe_fn)]
23
24use std::{
25 fmt,
26 mem::{ManuallyDrop, swap},
27 ops::{Deref, DerefMut},
28 ptr,
29};
30
31use super::compare::Compare;
32
33pub struct BinaryHeap<T, C> {
38 data: Vec<T>,
39 cmp: C,
40}
41
42impl<T, C: Compare<T>> BinaryHeap<T, C> {
43 pub fn from_vec_cmp(vec: Vec<T>, cmp: C) -> Self {
45 let mut heap = Self { data: vec, cmp };
46 if !heap.data.is_empty() {
47 heap.rebuild();
48 }
49 heap
50 }
51
52 pub fn peek_mut(&mut self) -> Option<PeekMut<'_, T, C>> {
55 if self.is_empty() {
56 None
57 } else {
58 Some(PeekMut {
59 heap: self,
60 sift: false,
61 })
62 }
63 }
64
65 pub fn pop(&mut self) -> Option<T> {
68 self.data.pop().map(|mut item| {
69 if !self.is_empty() {
70 swap(&mut item, &mut self.data[0]);
71 unsafe { self.sift_down_to_bottom(0) };
73 }
74 item
75 })
76 }
77
78 pub fn push(&mut self, item: T) {
80 let old_len = self.len();
81 self.data.push(item);
82 unsafe { self.sift_up(0, old_len) };
85 }
86
87 unsafe fn sift_up(&mut self, start: usize, pos: usize) -> usize {
93 let mut hole = unsafe { Hole::new(&mut self.data, pos) };
95
96 while hole.pos() > start {
97 let parent = (hole.pos() - 1) / 2;
98
99 if self
102 .cmp
103 .compares_le(hole.element(), unsafe { hole.get(parent) })
104 {
105 break;
106 }
107
108 unsafe { hole.move_to(parent) };
110 }
111
112 hole.pos()
113 }
114
115 unsafe fn sift_down_range(&mut self, pos: usize, end: usize) {
121 let mut hole = unsafe { Hole::new(&mut self.data, pos) };
123 let mut child = 2 * hole.pos() + 1;
124
125 while child <= end.saturating_sub(2) {
126 child += unsafe { self.cmp.compares_le(hole.get(child), hole.get(child + 1)) } as usize;
129
130 if self
132 .cmp
133 .compares_ge(hole.element(), unsafe { hole.get(child) })
134 {
135 return;
136 }
137
138 unsafe { hole.move_to(child) };
140 child = 2 * hole.pos() + 1;
141 }
142
143 if child == end - 1
145 && self
146 .cmp
147 .compares_lt(hole.element(), unsafe { hole.get(child) })
148 {
149 unsafe { hole.move_to(child) };
150 }
151 }
152
153 unsafe fn sift_down(&mut self, pos: usize) {
159 let len = self.len();
160 unsafe { self.sift_down_range(pos, len) };
162 }
163
164 unsafe fn sift_down_to_bottom(&mut self, mut pos: usize) {
171 let end = self.len();
172 let start = pos;
173
174 let mut hole = unsafe { Hole::new(&mut self.data, pos) };
176 let mut child = 2 * hole.pos() + 1;
177
178 while child <= end.saturating_sub(2) {
179 child += unsafe { self.cmp.compares_le(hole.get(child), hole.get(child + 1)) } as usize;
182
183 unsafe { hole.move_to(child) };
185 child = 2 * hole.pos() + 1;
186 }
187
188 if child == end - 1 {
189 unsafe { hole.move_to(child) };
191 }
192 pos = hole.pos();
193 drop(hole);
194
195 unsafe { self.sift_up(start, pos) };
197 }
198
199 fn rebuild(&mut self) {
201 let mut n = self.len() / 2;
202 while n > 0 {
203 n -= 1;
204 unsafe { self.sift_down(n) };
206 }
207 }
208}
209
210impl<T, C> BinaryHeap<T, C> {
211 #[must_use]
213 pub fn len(&self) -> usize {
214 self.data.len()
215 }
216
217 #[must_use]
219 pub fn is_empty(&self) -> bool {
220 self.len() == 0
221 }
222
223 pub fn clear(&mut self) {
225 self.data.clear();
226 }
227}
228
229impl<T: fmt::Debug, C> fmt::Debug for BinaryHeap<T, C> {
230 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 f.debug_list().entries(self.data.iter()).finish()
232 }
233}
234
235impl<T: Clone, C: Clone> Clone for BinaryHeap<T, C> {
236 fn clone(&self) -> Self {
237 Self {
238 data: self.data.clone(),
239 cmp: self.cmp.clone(),
240 }
241 }
242}
243
244pub struct PeekMut<'a, T: 'a, C: 'a + Compare<T>> {
247 heap: &'a mut BinaryHeap<T, C>,
248 sift: bool,
249}
250
251impl<T: fmt::Debug, C: Compare<T>> fmt::Debug for PeekMut<'_, T, C> {
252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 f.debug_tuple("PeekMut").field(&self.heap.data[0]).finish()
254 }
255}
256
257impl<T, C: Compare<T>> Drop for PeekMut<'_, T, C> {
258 fn drop(&mut self) {
259 if self.sift {
260 unsafe { self.heap.sift_down(0) };
262 }
263 }
264}
265
266impl<T, C: Compare<T>> Deref for PeekMut<'_, T, C> {
267 type Target = T;
268 fn deref(&self) -> &T {
269 debug_assert!(!self.heap.is_empty());
270 unsafe { self.heap.data.get_unchecked(0) }
272 }
273}
274
275impl<T, C: Compare<T>> DerefMut for PeekMut<'_, T, C> {
276 fn deref_mut(&mut self) -> &mut T {
277 debug_assert!(!self.heap.is_empty());
278 self.sift = true;
279 unsafe { self.heap.data.get_unchecked_mut(0) }
281 }
282}
283
284impl<'a, T, C: Compare<T>> PeekMut<'a, T, C> {
285 pub fn pop(mut this: Self) -> T {
287 let value = this.heap.pop().unwrap();
288 this.sift = false;
289 value
290 }
291}
292
293struct Hole<'a, T: 'a> {
296 data: &'a mut [T],
297 elt: ManuallyDrop<T>,
298 pos: usize,
299}
300
301impl<'a, T> Hole<'a, T> {
302 #[inline]
308 unsafe fn new(data: &'a mut [T], pos: usize) -> Self {
309 debug_assert!(pos < data.len());
310 let elt = unsafe { ptr::read(data.get_unchecked(pos)) };
312 Hole {
313 data,
314 elt: ManuallyDrop::new(elt),
315 pos,
316 }
317 }
318
319 #[inline]
320 fn pos(&self) -> usize {
321 self.pos
322 }
323
324 #[inline]
326 fn element(&self) -> &T {
327 &self.elt
328 }
329
330 #[inline]
336 unsafe fn get(&self, index: usize) -> &T {
337 debug_assert!(index != self.pos);
338 debug_assert!(index < self.data.len());
339 unsafe { self.data.get_unchecked(index) }
340 }
341
342 #[inline]
348 unsafe fn move_to(&mut self, index: usize) {
349 debug_assert!(index != self.pos);
350 debug_assert!(index < self.data.len());
351 unsafe {
352 let ptr = self.data.as_mut_ptr();
353 let index_ptr: *const _ = ptr.add(index);
354 let hole_ptr = ptr.add(self.pos);
355 ptr::copy_nonoverlapping(index_ptr, hole_ptr, 1);
356 }
357 self.pos = index;
358 }
359}
360
361impl<T> Drop for Hole<'_, T> {
362 #[inline]
363 fn drop(&mut self) {
364 unsafe {
366 let pos = self.pos;
367 ptr::copy_nonoverlapping(
368 ptr::from_ref(&*self.elt),
369 self.data.get_unchecked_mut(pos),
370 1,
371 );
372 }
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use std::cmp::Ordering;
379
380 use rstest::rstest;
381
382 use super::*;
383
384 struct MaxComparator;
385
386 impl Compare<i32> for MaxComparator {
387 fn compare(&self, a: &i32, b: &i32) -> Ordering {
388 a.cmp(b)
389 }
390 }
391
392 struct MinComparator;
393
394 impl Compare<i32> for MinComparator {
395 fn compare(&self, a: &i32, b: &i32) -> Ordering {
396 b.cmp(a)
397 }
398 }
399
400 #[rstest]
401 fn test_max_heap() {
402 let mut heap = BinaryHeap::from_vec_cmp(vec![], MaxComparator);
403 heap.push(3);
404 heap.push(1);
405 heap.push(5);
406
407 assert_eq!(heap.pop(), Some(5));
408 assert_eq!(heap.pop(), Some(3));
409 assert_eq!(heap.pop(), Some(1));
410 assert_eq!(heap.pop(), None);
411 }
412
413 #[rstest]
414 fn test_min_heap() {
415 let mut heap = BinaryHeap::from_vec_cmp(vec![], MinComparator);
416 heap.push(3);
417 heap.push(1);
418 heap.push(5);
419
420 assert_eq!(heap.pop(), Some(1));
421 assert_eq!(heap.pop(), Some(3));
422 assert_eq!(heap.pop(), Some(5));
423 assert_eq!(heap.pop(), None);
424 }
425
426 #[rstest]
427 fn test_peek_mut() {
428 let mut heap = BinaryHeap::from_vec_cmp(vec![1, 5, 2], MaxComparator);
429
430 if let Some(mut val) = heap.peek_mut() {
431 *val = 0;
432 }
433
434 assert_eq!(heap.pop(), Some(2));
435 }
436
437 #[rstest]
438 fn test_peek_mut_pop() {
439 let mut heap = BinaryHeap::from_vec_cmp(vec![1, 5, 2], MaxComparator);
440
441 if let Some(val) = heap.peek_mut() {
442 let popped = PeekMut::pop(val);
443 assert_eq!(popped, 5);
444 }
445
446 assert_eq!(heap.pop(), Some(2));
447 assert_eq!(heap.pop(), Some(1));
448 }
449
450 #[rstest]
451 fn test_clear() {
452 let mut heap = BinaryHeap::from_vec_cmp(vec![1, 2, 3], MaxComparator);
453 assert!(!heap.is_empty());
454
455 heap.clear();
456 assert!(heap.is_empty());
457 assert_eq!(heap.len(), 0);
458 }
459
460 #[rstest]
461 fn test_from_vec() {
462 let heap = BinaryHeap::from_vec_cmp(vec![3, 1, 4, 1, 5, 9, 2, 6], MaxComparator);
463 let mut sorted = Vec::new();
464 let mut heap = heap;
465 while let Some(v) = heap.pop() {
466 sorted.push(v);
467 }
468 assert_eq!(sorted, vec![9, 6, 5, 4, 3, 2, 1, 1]);
469 }
470}