From 83aa46593ab5fbe8908e309c749f7912113177c1 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Thu, 11 Feb 2021 11:39:58 +0100 Subject: [PATCH] Use bit vectors to improve performance --- src/lib.rs | 149 ++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 120 insertions(+), 29 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3cb1eac..12e32e5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -use std::ops::{Index, Range}; +use std::ops::{BitOrAssign, Index, Range}; use std::str; #[cfg(feature = "serde")] @@ -101,12 +101,18 @@ struct SegmentState<'a> { data: &'a Segmenter, text: Ascii<'a>, search: &'a mut Search, + offset: usize, } impl<'a> SegmentState<'a> { fn new(text: Ascii<'a>, data: &'a Segmenter, search: &'a mut Search) -> Self { search.clear(); - Self { data, text, search } + Self { + data, + text, + search, + offset: 0, + } } /// Returns a list of words that is the best segmentation of `text` @@ -114,14 +120,15 @@ impl<'a> SegmentState<'a> { let (mut start, mut end) = (0, 0); while end < self.text.len() { end = self.text.len().min(end + SEGMENT_SIZE); + self.offset = start; self.search(0, start..end, None); - let mut splits = &self.search.best[0][..]; + let mut limit = usize::MAX; if end < self.text.len() { - splits = &splits[..splits.len().saturating_sub(5)]; + limit = 5; } - for &split in splits { + for split in self.search.best[0].decode(self.offset).take(limit) { self.search.result.push(self.text[start..split].into()); start = split; } @@ -143,30 +150,22 @@ impl<'a> SegmentState<'a> { let pair = (split..end, start..split); let (suffix_score, suffix_splits) = match self.search.memo.get(&pair) { - Some((score, splits)) => { - (*score, &self.search.split_cache[splits.start..splits.end]) - } + Some((score, suffix_splits)) => (*score, *suffix_splits), None => { let suffix_score = self.search(level + 1, split..end, Some(start..split)); - - let start = self.search.split_cache.len(); - self.search - .split_cache - .extend(&self.search.best[level + 1][..]); - let end = self.search.split_cache.len(); - self.search.memo.insert(pair, (suffix_score, start..end)); - - (suffix_score, &self.search.split_cache[start..end]) + let suffix_splits = self.search.best[level + 1]; + self.search.memo.insert(pair, (suffix_score, suffix_splits)); + (suffix_score, suffix_splits) } }; let score = prefix_score + suffix_score; if score > best { best = score; - let splits = &mut self.search.best[level]; - splits.clear(); - splits.push(split); - splits.extend(suffix_splits); + let new_splits = &mut self.search.best[level]; + new_splits.clear(); + new_splits.set(split - self.offset); + *new_splits |= suffix_splits; } } @@ -176,9 +175,8 @@ impl<'a> SegmentState<'a> { #[derive(Clone)] pub struct Search { - memo: HashMap)>, - split_cache: Vec, - best: Vec>, + memo: HashMap, + best: [BitVec; SEGMENT_SIZE], result: Vec, } @@ -186,8 +184,7 @@ impl Default for Search { fn default() -> Self { Self { memo: HashMap::default(), - split_cache: Vec::with_capacity(32), - best: vec![vec![]; SEGMENT_SIZE], + best: [BitVec::default(); SEGMENT_SIZE], result: Vec::new(), } } @@ -196,7 +193,6 @@ impl Default for Search { impl Search { fn clear(&mut self) { self.memo.clear(); - self.split_cache.clear(); for inner in self.best.iter_mut() { inner.clear(); } @@ -204,6 +200,70 @@ impl Search { } } +#[derive(Clone, Copy, Default)] +struct BitVec([u64; 4]); + +impl BitVec { + fn set(&mut self, mut bit: usize) { + debug_assert!(bit < 256); + let mut idx = 3; + while bit > 63 { + idx -= 1; + bit -= 64; + } + self.0[idx] |= 1 << bit; + } + + fn clear(&mut self) { + self.0.iter_mut().for_each(|dst| { + *dst = 0; + }); + } + + fn decode(&self, offset: usize) -> Splits { + Splits { + vec: self.0, + offset, + idx: 3, + } + } +} + +impl BitOrAssign for BitVec { + fn bitor_assign(&mut self, rhs: Self) { + self.0 + .iter_mut() + .zip(rhs.0.iter()) + .for_each(|(dst, src)| *dst |= *src); + } +} + +struct Splits { + vec: [u64; 4], + offset: usize, + idx: usize, +} + +impl Iterator for Splits { + type Item = usize; + + fn next(&mut self) -> Option { + while self.idx > 0 && self.vec[self.idx] == 0 { + self.idx -= 1; + } + + let cur = self.vec[self.idx]; + if cur == 0 { + return None; + } + + let trailing = cur.trailing_zeros(); + let next = Some(self.offset + (3 - self.idx) * 64 + trailing as usize); + self.vec[self.idx] -= 1 << trailing; + next + } +} + type MemoKey = (Range, Range); #[derive(Debug)] @@ -251,10 +311,41 @@ const SEGMENT_SIZE: usize = 250; #[cfg(test)] pub mod tests { + use super::*; + #[test] fn test_clean() { - super::Ascii::new("Can't buy me love!").unwrap_err(); - let text = super::Ascii::new("cantbuymelove").unwrap(); + Ascii::new("Can't buy me love!").unwrap_err(); + let text = Ascii::new("cantbuymelove").unwrap(); assert_eq!(&text[0..text.len()], "cantbuymelove"); } + + #[test] + fn bitvec() { + let mut splits = BitVec::default(); + assert_eq!(splits.decode(0).collect::>(), vec![]); + + splits.set(1); + assert_eq!(splits.decode(0).collect::>(), vec![1]); + + splits.set(5); + assert_eq!(splits.decode(10).collect::>(), vec![11, 15]); + + splits.set(64); + assert_eq!(splits.decode(0).collect::>(), vec![1, 5, 64]); + + splits.set(255); + assert_eq!(splits.decode(0).collect::>(), vec![1, 5, 64, 255]); + + let mut new = BitVec::default(); + new.set(3); + new.set(16); + new.set(128); + + splits |= new; + assert_eq!( + splits.decode(0).collect::>(), + vec![1, 3, 5, 16, 64, 128, 255] + ); + } }