Created
January 10, 2019 04:35
-
-
Save timvermeulen/f170a393816e493d3ac277c23d04b5bf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::mem::ManuallyDrop; | |
use std::ops::{Deref, DerefMut}; | |
use std::ptr; | |
struct NonEmptyBinaryHeap<T> { | |
data: Vec<T>, | |
} | |
impl<T: Ord> NonEmptyBinaryHeap<T> { | |
fn with_root_and_capacity(root: T, capacity: usize) -> Self { | |
let mut vec = Vec::with_capacity(capacity); | |
vec.push(root); | |
NonEmptyBinaryHeap { data: vec } | |
} | |
fn sift_up(&mut self, pos: usize) { | |
unsafe { | |
let mut hole = Hole::new(&mut self.data, pos); | |
while hole.pos() > 0 { | |
let parent_index = (hole.pos() - 1) / 2; | |
if hole.element() <= hole.get(parent_index) { | |
break; | |
} | |
hole.move_to(parent_index); | |
} | |
} | |
} | |
fn sift_down(&mut self) { | |
let end = self.data.len(); | |
unsafe { | |
let mut hole = Hole::new(&mut self.data, 0); | |
let mut child = 1; | |
while child < end { | |
let right = child + 1; | |
if right < end && hole.get(child) < hole.get(right) { | |
child = right; | |
} | |
if hole.element() > hole.get(child) { | |
break; | |
} | |
hole.move_to(child); | |
child = 2 * hole.pos() + 1; | |
} | |
} | |
} | |
fn push(&mut self, element: T) { | |
let pos = self.data.len(); | |
self.data.push(element); | |
self.sift_up(pos); | |
} | |
fn peek_mut(&mut self) -> PeekMut<'_, T> { | |
PeekMut { | |
heap: self, | |
sift: false, | |
} | |
} | |
} | |
impl<T: Ord> Extend<T> for NonEmptyBinaryHeap<T> { | |
fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) { | |
iter.into_iter().for_each(|x| self.push(x)); | |
} | |
} | |
struct PeekMut<'a, T: Ord> { | |
heap: &'a mut NonEmptyBinaryHeap<T>, | |
sift: bool, | |
} | |
impl<T: Ord> Deref for PeekMut<'_, T> { | |
type Target = T; | |
fn deref(&self) -> &Self::Target { | |
unsafe { self.heap.data.get_unchecked(0) } | |
} | |
} | |
impl<T: Ord> DerefMut for PeekMut<'_, T> { | |
fn deref_mut(&mut self) -> &mut Self::Target { | |
self.sift = true; | |
unsafe { self.heap.data.get_unchecked_mut(0) } | |
} | |
} | |
impl<T: Ord> Drop for PeekMut<'_, T> { | |
fn drop(&mut self) { | |
if self.sift { | |
self.heap.sift_down(); | |
} | |
} | |
} | |
struct Hole<'a, T> { | |
data: &'a mut [T], | |
element: ManuallyDrop<T>, | |
pos: usize, | |
} | |
impl<'a, T> Hole<'a, T> { | |
unsafe fn new(data: &'a mut [T], pos: usize) -> Self { | |
debug_assert!(pos < data.len()); | |
let elt = ptr::read(&data[pos]); | |
Hole { | |
data, | |
element: ManuallyDrop::new(elt), | |
pos, | |
} | |
} | |
fn pos(&self) -> usize { | |
self.pos | |
} | |
fn element(&self) -> &T { | |
&self.element | |
} | |
unsafe fn get(&self, index: usize) -> &T { | |
debug_assert!(index != self.pos); | |
debug_assert!(index < self.data.len()); | |
self.data.get_unchecked(index) | |
} | |
unsafe fn move_to(&mut self, index: usize) { | |
debug_assert!(index != self.pos); | |
debug_assert!(index < self.data.len()); | |
let index_ptr: *const _ = self.data.get_unchecked(index); | |
let hole_ptr = self.data.get_unchecked_mut(self.pos); | |
ptr::copy_nonoverlapping(index_ptr, hole_ptr, 1); | |
self.pos = index; | |
} | |
} | |
impl<'a, T> Drop for Hole<'a, T> { | |
fn drop(&mut self) { | |
unsafe { | |
ptr::copy_nonoverlapping(&*self.element, self.data.get_unchecked_mut(self.pos), 1); | |
} | |
} | |
} | |
pub trait Smallest<T> { | |
fn smallest(self, n: usize) -> Vec<T>; | |
} | |
impl<T, I> Smallest<T> for I | |
where | |
T: Ord, | |
I: IntoIterator<Item = T>, | |
{ | |
fn smallest(self, n: usize) -> Vec<T> { | |
let mut iter = self.into_iter(); | |
let first = match iter.next() { | |
Some(first) => first, | |
None => return Vec::new(), | |
}; | |
let mut heap = NonEmptyBinaryHeap::with_root_and_capacity(first, n); | |
heap.extend(iter.by_ref().take(n - 1)); | |
for x in iter { | |
let mut root = heap.peek_mut(); | |
if x < *root { | |
*root = x; | |
} | |
} | |
heap.data.sort_unstable(); | |
heap.data | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment