Use bit vectors to improve performance

This commit is contained in:
Dirkjan Ochtman 2021-02-11 11:39:58 +01:00
parent 9dd1cf089d
commit 83aa46593a
1 changed files with 120 additions and 29 deletions

View File

@ -1,4 +1,4 @@
use std::ops::{Index, Range}; use std::ops::{BitOrAssign, Index, Range};
use std::str; use std::str;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
@ -101,12 +101,18 @@ struct SegmentState<'a> {
data: &'a Segmenter, data: &'a Segmenter,
text: Ascii<'a>, text: Ascii<'a>,
search: &'a mut Search, search: &'a mut Search,
offset: usize,
} }
impl<'a> SegmentState<'a> { impl<'a> SegmentState<'a> {
fn new(text: Ascii<'a>, data: &'a Segmenter, search: &'a mut Search) -> Self { fn new(text: Ascii<'a>, data: &'a Segmenter, search: &'a mut Search) -> Self {
search.clear(); search.clear();
Self { data, text, search } Self {
data,
text,
search,
offset: 0,
}
} }
/// Returns a list of words that is the best segmentation of `text` /// Returns a list of words that is the best segmentation of `text`
@ -114,14 +120,15 @@ impl<'a> SegmentState<'a> {
let (mut start, mut end) = (0, 0); let (mut start, mut end) = (0, 0);
while end < self.text.len() { while end < self.text.len() {
end = self.text.len().min(end + SEGMENT_SIZE); end = self.text.len().min(end + SEGMENT_SIZE);
self.offset = start;
self.search(0, start..end, None); self.search(0, start..end, None);
let mut splits = &self.search.best[0][..]; let mut limit = usize::MAX;
if end < self.text.len() { if end < self.text.len() {
splits = &splits[..splits.len().saturating_sub(5)]; limit = 5;
} }
for &split in splits { for split in self.search.best[0].decode(self.offset).take(limit) {
self.search.result.push(self.text[start..split].into()); self.search.result.push(self.text[start..split].into());
start = split; start = split;
} }
@ -143,30 +150,22 @@ impl<'a> SegmentState<'a> {
let pair = (split..end, start..split); let pair = (split..end, start..split);
let (suffix_score, suffix_splits) = match self.search.memo.get(&pair) { let (suffix_score, suffix_splits) = match self.search.memo.get(&pair) {
Some((score, splits)) => { Some((score, suffix_splits)) => (*score, *suffix_splits),
(*score, &self.search.split_cache[splits.start..splits.end])
}
None => { None => {
let suffix_score = self.search(level + 1, split..end, Some(start..split)); let suffix_score = self.search(level + 1, split..end, Some(start..split));
let suffix_splits = self.search.best[level + 1];
let start = self.search.split_cache.len(); self.search.memo.insert(pair, (suffix_score, suffix_splits));
self.search (suffix_score, suffix_splits)
.split_cache
.extend(&self.search.best[level + 1][..]);
let end = self.search.split_cache.len();
self.search.memo.insert(pair, (suffix_score, start..end));
(suffix_score, &self.search.split_cache[start..end])
} }
}; };
let score = prefix_score + suffix_score; let score = prefix_score + suffix_score;
if score > best { if score > best {
best = score; best = score;
let splits = &mut self.search.best[level]; let new_splits = &mut self.search.best[level];
splits.clear(); new_splits.clear();
splits.push(split); new_splits.set(split - self.offset);
splits.extend(suffix_splits); *new_splits |= suffix_splits;
} }
} }
@ -176,9 +175,8 @@ impl<'a> SegmentState<'a> {
#[derive(Clone)] #[derive(Clone)]
pub struct Search { pub struct Search {
memo: HashMap<MemoKey, (f64, Range<usize>)>, memo: HashMap<MemoKey, (f64, BitVec)>,
split_cache: Vec<usize>, best: [BitVec; SEGMENT_SIZE],
best: Vec<Vec<usize>>,
result: Vec<String>, result: Vec<String>,
} }
@ -186,8 +184,7 @@ impl Default for Search {
fn default() -> Self { fn default() -> Self {
Self { Self {
memo: HashMap::default(), memo: HashMap::default(),
split_cache: Vec::with_capacity(32), best: [BitVec::default(); SEGMENT_SIZE],
best: vec![vec![]; SEGMENT_SIZE],
result: Vec::new(), result: Vec::new(),
} }
} }
@ -196,7 +193,6 @@ impl Default for Search {
impl Search { impl Search {
fn clear(&mut self) { fn clear(&mut self) {
self.memo.clear(); self.memo.clear();
self.split_cache.clear();
for inner in self.best.iter_mut() { for inner in self.best.iter_mut() {
inner.clear(); inner.clear();
} }
@ -204,6 +200,70 @@ impl Search {
} }
} }
#[derive(Clone, Copy, Default)]
struct BitVec([u64; 4]);
impl BitVec {
fn set(&mut self, mut bit: usize) {
debug_assert!(bit < 256);
let mut idx = 3;
while bit > 63 {
idx -= 1;
bit -= 64;
}
self.0[idx] |= 1 << bit;
}
fn clear(&mut self) {
self.0.iter_mut().for_each(|dst| {
*dst = 0;
});
}
fn decode(&self, offset: usize) -> Splits {
Splits {
vec: self.0,
offset,
idx: 3,
}
}
}
impl BitOrAssign for BitVec {
fn bitor_assign(&mut self, rhs: Self) {
self.0
.iter_mut()
.zip(rhs.0.iter())
.for_each(|(dst, src)| *dst |= *src);
}
}
struct Splits {
vec: [u64; 4],
offset: usize,
idx: usize,
}
impl Iterator for Splits {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
while self.idx > 0 && self.vec[self.idx] == 0 {
self.idx -= 1;
}
let cur = self.vec[self.idx];
if cur == 0 {
return None;
}
let trailing = cur.trailing_zeros();
let next = Some(self.offset + (3 - self.idx) * 64 + trailing as usize);
self.vec[self.idx] -= 1 << trailing;
next
}
}
type MemoKey = (Range<usize>, Range<usize>); type MemoKey = (Range<usize>, Range<usize>);
#[derive(Debug)] #[derive(Debug)]
@ -251,10 +311,41 @@ const SEGMENT_SIZE: usize = 250;
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
use super::*;
#[test] #[test]
fn test_clean() { fn test_clean() {
super::Ascii::new("Can't buy me love!").unwrap_err(); Ascii::new("Can't buy me love!").unwrap_err();
let text = super::Ascii::new("cantbuymelove").unwrap(); let text = Ascii::new("cantbuymelove").unwrap();
assert_eq!(&text[0..text.len()], "cantbuymelove"); assert_eq!(&text[0..text.len()], "cantbuymelove");
} }
#[test]
fn bitvec() {
let mut splits = BitVec::default();
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![]);
splits.set(1);
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![1]);
splits.set(5);
assert_eq!(splits.decode(10).collect::<Vec<_>>(), vec![11, 15]);
splits.set(64);
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![1, 5, 64]);
splits.set(255);
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![1, 5, 64, 255]);
let mut new = BitVec::default();
new.set(3);
new.set(16);
new.set(128);
splits |= new;
assert_eq!(
splits.decode(0).collect::<Vec<_>>(),
vec![1, 3, 5, 16, 64, 128, 255]
);
}
} }