Created
July 26, 2017 10:25
-
-
Save dodheim/c66453dfd7ba82d559204ab625656edc to your computer and use it in GitHub Desktop.
Rust solution for /r/dailyprogrammer challenge #294 [easy]
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#![feature(ascii_ctype, attr_literals, const_fn, iterator_for_each, repr_align)] | |
#![cfg_attr(not(debug_assertions), feature(core_intrinsics))] | |
#![cfg_attr(test, feature(test))] | |
#[macro_use] | |
extern crate lazy_static; | |
extern crate rayon; | |
extern crate simd; | |
#[cfg(test)] | |
extern crate test; | |
use std::mem::transmute; | |
use rayon::prelude::*; | |
use simd::{i8x16, u8x16}; | |
mod letter_counter { | |
use super::*; | |
use std::iter::Cloned; | |
use std::slice::Iter; | |
#[repr(C, align(16))] | |
pub union LetterCounter { | |
raw: RawLetterCounter, | |
i8x16s: [i8x16; 2], | |
u8x16s: [u8x16; 2] | |
} | |
impl LetterCounter { | |
pub fn new(letters: &str) -> Self { | |
Self::new_with_wildcards(letters, 0) | |
} | |
pub fn iter(&self) -> Cloned<Iter<u8>> { | |
self.as_raw().1.iter().cloned() | |
} | |
pub(crate) const fn score_values() -> Self { | |
// a b c d e f g h i j k l m n o p q r s t u v w x y z | |
Self::from_raw(0, [1, 3, 3, 2, 1, 4, 2, 4, 1, 8, 5, 1, 3, 1, 1, 3, 10, 1, 1, 1, 1, 4, 4, 8, 4, 10]) | |
} | |
pub(crate) const fn from_raw(wildcards: u32, letter_counts: [u8; 26]) -> Self { | |
Self { raw: RawLetterCounter(wildcards, letter_counts, 0) } | |
} | |
pub(crate) fn new_with_wildcards(letters: &str, wildcards: i32) -> Self { | |
assert!(std::ascii::AsciiExt::is_ascii_lowercase(letters)); | |
let mut res = Self::from_raw(wildcards as u32, [0; 26]); | |
letters.bytes() | |
.map(|l| (l - b'a') as usize) | |
.for_each(|i| { | |
unsafe { #![cfg(not(debug_assertions))] | |
// is_ascii_lowercase check doesn't make the optimizer elide bounds checks even though it could in theory | |
std::intrinsics::assume(i < 26); | |
} | |
res.as_raw_mut().1[i] += 1; | |
}); | |
res | |
} | |
pub(crate) fn wildcards(&self) -> i32 { | |
self.as_raw().0 as i32 | |
} | |
pub(crate) fn extract_i8x16(&self, block: XmmBlock) -> i8x16 { | |
unsafe { self.i8x16s[block as usize] } | |
} | |
pub(crate) fn extract_u8x16(&self, block: XmmBlock) -> u8x16 { | |
unsafe { self.u8x16s[block as usize] } | |
} | |
fn as_raw(&self) -> &RawLetterCounter { | |
unsafe { &self.raw } | |
} | |
fn as_raw_mut(&mut self) -> &mut RawLetterCounter { | |
unsafe { &mut self.raw } | |
} | |
} | |
impl Eq for LetterCounter { } | |
impl PartialEq for LetterCounter { | |
fn eq(&self, other: &Self) -> bool { | |
let beq = |b| self.extract_i8x16(b).eq(other.extract_i8x16(b)); | |
(beq(XmmBlock::First) & beq(XmmBlock::Second)).all() | |
} | |
} | |
impl<'a> IntoIterator for &'a LetterCounter { | |
type Item = u8; | |
type IntoIter = Cloned<Iter<'a, u8>>; | |
fn into_iter(self) -> Self::IntoIter { | |
self.as_raw().1.into_iter().cloned() | |
} | |
} | |
#[repr(C)] | |
#[derive(Copy, Clone)] | |
pub(crate) struct RawLetterCounter(u32, [u8; 26], u16); | |
#[repr(usize)] | |
#[derive(Copy, Clone)] | |
pub(crate) enum XmmBlock { First = 0, Second = 1 } | |
} | |
pub use letter_counter::*; | |
pub struct Rack(LetterCounter); | |
impl Rack { | |
pub fn new(letters: &str) -> Self { | |
let len = letters.len(); | |
let i = letters.find('?').unwrap_or(len); | |
Rack(LetterCounter::new_with_wildcards(&letters[..i], (len - i) as i32)) | |
} | |
pub fn letter_counts(&self) -> &LetterCounter { | |
&self.0 | |
} | |
pub fn wildcards(&self) -> i32 { | |
self.0.wildcards() | |
} | |
const fn full() -> Self { | |
Rack(LetterCounter::from_raw(0, [u8::max_value(); 26])) | |
} | |
} | |
pub struct Word<'v>(LetterCounter, &'v str); | |
impl<'v> Word<'v> { | |
pub fn new<S: ?Sized + AsRef<str>>(letters: &'v S) -> Self { | |
let letters = letters.as_ref(); | |
Word(LetterCounter::new(letters), letters) | |
} | |
pub fn letter_counts(&self) -> &LetterCounter { | |
&self.0 | |
} | |
pub fn value(&self) -> &'v str { | |
self.1 | |
} | |
} | |
pub trait WordScorer { | |
fn scrabble(word: &Word, rack: &Rack) -> bool; | |
fn score(word: &Word, rack: &Rack) -> i32; | |
fn longest<'ws, 'w: 'ws>(words: &'ws [Word<'w>], rack: &Rack) -> Option<&'w str> { | |
words.par_iter() | |
.filter(|&w| Self::scrabble(w, rack)) | |
.max_by_key(|&w| w.value().len()) | |
.map(|w| w.value()) | |
} | |
fn highest<'ws, 'w: 'ws>(words: &'ws [Word<'w>], rack: &Rack) -> Option<&'w str> { | |
words.par_iter() | |
.filter(|&w| Self::scrabble(w, rack)) | |
.max_by_key(|&w| Self::score(w, rack)) | |
.map(|w| w.value()) | |
} | |
} | |
pub enum RegularScorer { } | |
impl WordScorer for RegularScorer { | |
fn scrabble(word: &Word, rack: &Rack) -> bool { | |
let wildcards_needed = word.letter_counts().iter() | |
.zip(rack.letter_counts().iter()) | |
.map(|(wc, rc)| wc.saturating_sub(rc)) | |
.sum::<u8>() as i32; | |
wildcards_needed <= rack.wildcards() | |
} | |
fn score(word: &Word, rack: &Rack) -> i32 { | |
word.letter_counts().iter() | |
.zip(rack.letter_counts().iter()) | |
.map(|(wc, rc)| std::cmp::min(wc, rc)) | |
.zip(LetterCounter::score_values().into_iter()) | |
.map(|(lc, sv)| lc * sv) | |
.sum::<u8>() as i32 | |
} | |
} | |
pub enum SseScorer { } | |
impl SseScorer { | |
fn process_and_sum_blocks<F>(word: &Word, rack: &Rack, process_block: F) -> i32 | |
where F: Fn(u8x16, u8x16, XmmBlock) -> u8x16 { | |
use simd::x86::sse2::*; | |
let pb = |b| process_block(word.letter_counts().extract_u8x16(b), rack.letter_counts().extract_u8x16(b), b); | |
let x = (pb(XmmBlock::First) + pb(XmmBlock::Second)).sad(u8x16::splat(0)); | |
(x.extract(0) + x.extract(1)) as i32 | |
} | |
} | |
impl WordScorer for SseScorer { | |
fn scrabble(word: &Word, rack: &Rack) -> bool { | |
use simd::x86::sse2::*; | |
Self::process_and_sum_blocks(word, rack, |w, r, _| w.subs(r)) <= rack.wildcards() | |
} | |
fn score(word: &Word, rack: &Rack) -> i32 { | |
use simd::x86::sse2::*; | |
use simd::x86::ssse3::*; | |
let score_vals = LetterCounter::score_values(); | |
Self::process_and_sum_blocks(word, rack, |w, r, b| unsafe { | |
transmute(w.min(r).maddubs(score_vals.extract_i8x16(b))) | |
}) | |
} | |
} | |
mod runners { | |
use super::*; | |
#[derive(Default)] | |
#[must_use] | |
pub struct TestFlags { flags: u64, index: u8 } | |
impl TestFlags { | |
pub fn flags(&self) -> u64 { | |
self.flags | |
} | |
pub fn count(&self) -> u8 { | |
self.index | |
} | |
pub fn passed(&self) -> u8 { | |
self.flags.count_ones() as u8 | |
} | |
#[cfg(test)] | |
pub fn all_passed(&self) -> bool { | |
self.passed() == self.count() | |
} | |
pub fn flag(&mut self, pass: bool) -> &mut Self { | |
debug_assert!(self.index < 64); | |
self.flags |= (pass as u64) << self.index; | |
self.index += 1; | |
self | |
} | |
pub fn with_flag(mut self, pass: bool) -> Self { | |
self.flag(pass); | |
self | |
} | |
} | |
macro_rules! runners { | |
( $($name: ident: $ws: ident;)+ ) => { $( | |
#[inline(always)] | |
pub fn $name(words: &[Word]) -> TestFlags { | |
letter_counts::<$ws>(words) | |
} | |
)+ } | |
} | |
runners!( | |
letter_counts_ssco: SseScorer; | |
letter_counts_rsco: RegularScorer; | |
); | |
pub(crate) fn load_words() -> Vec<Word<'static>> { | |
let mut res = Vec::new(); | |
load_words_into(&mut res); | |
res | |
} | |
pub(crate) fn load_words_into(words: &mut Vec<Word<'static>>) { | |
lazy_static! { | |
static ref DATA: Vec<String> = { | |
use std::io::{BufRead, BufReader}; | |
BufReader::new(std::fs::File::open(r"C:\Code\enable1.txt").expect("problem opening dictionary file")) | |
.lines() | |
.collect::<Result<Vec<_>, _>>() | |
.expect("problem reading dictionary file") | |
}; | |
} | |
DATA.par_iter() | |
.map(Word::new) | |
.collect_into(words); | |
} | |
#[inline(never)] | |
fn letter_counts<WS: WordScorer>(words: &[Word]) -> TestFlags { | |
let test_scrabble = |word, rack| WS::scrabble(&Word::new(word), &Rack::new(rack)); | |
let test_score = |word, rack| WS::score(&Word::new(word), &Rack::new(rack)); | |
let test_longest = |rack| WS::longest(words, &Rack::new(rack)).expect("test_longest returned nothing"); | |
let test_highest = |rack| WS::highest(words, &Rack::new(rack)).expect("test_highest returned nothing"); | |
let mut res = FIRST_20_RAW_LCS.iter() | |
.zip(words.iter()) | |
.fold(TestFlags::default(), |acc, (lcs, w)| acc.with_flag(lcs == w.letter_counts())); | |
res.flag( test_scrabble("daily", "ladilmy")) | |
.flag(!test_scrabble("eerie", "eerriin")) | |
.flag( test_scrabble("program", "orrpgma")) | |
.flag(!test_scrabble("program", "orppgma")) | |
.flag( test_scrabble("pizzazz", "pizza??")) | |
.flag(!test_scrabble("pizzazz", "piizza?")) | |
.flag( test_scrabble("program", "a??????")) | |
.flag(!test_scrabble("program", "b??????")); | |
res.flag(test_score("program", "progaaf????") == 8); | |
{ | |
let full_rack = Rack::full(); | |
let scored_words = words.par_iter() | |
.map(|w| (WS::score(w, &full_rack), w.value())) | |
.collect::<Vec<_>>(); | |
let &s = scored_words.par_iter().max_by_key(|&s| s.0).expect("scored_words is empty"); | |
res.flag(scored_words.into_par_iter().map(|s| s.0).sum::<i32>() == 2_551_763) | |
.flag(s.0 == 51) | |
.flag(s.1 == "razzamatazzes"); | |
} | |
res.flag( test_longest("dcthoyueorza") == "coauthored") | |
.flag(test_longest("uruqrnytrois") == "turquois") | |
.flag(test_longest("rryqeiaegicgeo??") == "greengrocery") | |
.flag(test_longest("udosjanyuiuebr??") == "subordinately") | |
.flag(test_longest("vaakojeaietg????????") == "ovolactovegetarian"); | |
res.with_flag( test_highest("dcthoyueorza") == "zydeco") | |
.with_flag(test_highest("uruqrnytrois") == "squinty") | |
.with_flag(test_highest("rryqeiaegicgeo??") == "reacquiring") | |
.with_flag(test_highest("udosjanyuiuebr??") == "jaybirds") | |
.with_flag(test_highest("vaakojeaietg????????") == "straightjacketed") | |
} | |
static FIRST_20_RAW_LCS: [LetterCounter; 20] = [ | |
LetterCounter::from_raw(0, [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), | |
LetterCounter::from_raw(0, [2, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), | |
LetterCounter::from_raw(0, [2, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), | |
LetterCounter::from_raw(0, [2, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), | |
LetterCounter::from_raw(0, [2, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]), | |
LetterCounter::from_raw(0, [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), | |
LetterCounter::from_raw(0, [2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), | |
LetterCounter::from_raw(0, [2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]), | |
LetterCounter::from_raw(0, [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]), | |
LetterCounter::from_raw(0, [3, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0]), | |
LetterCounter::from_raw(0, [3, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 1, 0, 0, 0, 0]), | |
LetterCounter::from_raw(0, [2, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]), | |
LetterCounter::from_raw(0, [2, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0]), | |
LetterCounter::from_raw(0, [2, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]), | |
LetterCounter::from_raw(0, [2, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0]), | |
LetterCounter::from_raw(0, [2, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0]), | |
LetterCounter::from_raw(0, [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]), | |
LetterCounter::from_raw(0, [2, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0]), | |
LetterCounter::from_raw(0, [2, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 0, 0, 0, 0]), | |
LetterCounter::from_raw(0, [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | |
]; | |
} | |
fn main() { | |
use runners::*; | |
let words = load_words(); | |
let print_lcs = |name, flags: TestFlags| println!("{}: {} ({:3$b})", name, flags.passed(), flags.flags(), flags.count() as usize); | |
print_lcs("letter_counts_ssco", letter_counts_ssco(&words)); | |
print_lcs("letter_counts_rsco", letter_counts_rsco(&words)); | |
} | |
#[cfg(test)] | |
mod tests { | |
use test::{Bencher, black_box}; | |
use runners::*; | |
#[test] | |
fn test_lettercounter_traits() { | |
use super::*; | |
let _: RawLetterCounter = unsafe { transmute([0u8; 32]) }; // static_assert(sizeof(RawLetterCounter) == 32); | |
let _: LetterCounter = unsafe { transmute([0u8; 32]) }; // static_assert(sizeof(LetterCounter) == 32); | |
assert_eq!(std::mem::align_of::<LetterCounter>(), 16); | |
} | |
#[bench] | |
fn bench_00_warmup(_: &mut Bencher) { | |
let words = load_words(); | |
let _ = black_box(letter_counts_ssco(&words)); | |
let _ = black_box(letter_counts_rsco(&words)); | |
} | |
macro_rules! bench_tests { | |
( $($rname: ident: ($bname: ident, $tname: ident);)+ ) => { $( | |
#[bench] | |
fn $bname(b: &mut Bencher) { | |
let mut words = Vec::new(); | |
b.iter(|| { | |
let words = black_box(&mut words); | |
load_words_into(words); | |
$rname(words).passed() | |
}); | |
} | |
#[test] | |
fn $tname() { | |
assert!($rname(&load_words()).all_passed()); | |
} | |
)+ } | |
} | |
bench_tests!( | |
letter_counts_ssco: (bench_01_letter_counts_ssco, test_letter_counts_ssco); | |
letter_counts_rsco: (bench_02_letter_counts_rsco, test_letter_counts_rsco); | |
); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment