Use bit vectors to improve performance
This commit is contained in:
parent
9dd1cf089d
commit
83aa46593a
149
src/lib.rs
149
src/lib.rs
|
@ -1,4 +1,4 @@
|
||||||
use std::ops::{Index, Range};
|
use std::ops::{BitOrAssign, Index, Range};
|
||||||
use std::str;
|
use std::str;
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
|
@ -101,12 +101,18 @@ struct SegmentState<'a> {
|
||||||
data: &'a Segmenter,
|
data: &'a Segmenter,
|
||||||
text: Ascii<'a>,
|
text: Ascii<'a>,
|
||||||
search: &'a mut Search,
|
search: &'a mut Search,
|
||||||
|
offset: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> SegmentState<'a> {
|
impl<'a> SegmentState<'a> {
|
||||||
fn new(text: Ascii<'a>, data: &'a Segmenter, search: &'a mut Search) -> Self {
|
fn new(text: Ascii<'a>, data: &'a Segmenter, search: &'a mut Search) -> Self {
|
||||||
search.clear();
|
search.clear();
|
||||||
Self { data, text, search }
|
Self {
|
||||||
|
data,
|
||||||
|
text,
|
||||||
|
search,
|
||||||
|
offset: 0,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a list of words that is the best segmentation of `text`
|
/// Returns a list of words that is the best segmentation of `text`
|
||||||
|
@ -114,14 +120,15 @@ impl<'a> SegmentState<'a> {
|
||||||
let (mut start, mut end) = (0, 0);
|
let (mut start, mut end) = (0, 0);
|
||||||
while end < self.text.len() {
|
while end < self.text.len() {
|
||||||
end = self.text.len().min(end + SEGMENT_SIZE);
|
end = self.text.len().min(end + SEGMENT_SIZE);
|
||||||
|
self.offset = start;
|
||||||
self.search(0, start..end, None);
|
self.search(0, start..end, None);
|
||||||
|
|
||||||
let mut splits = &self.search.best[0][..];
|
let mut limit = usize::MAX;
|
||||||
if end < self.text.len() {
|
if end < self.text.len() {
|
||||||
splits = &splits[..splits.len().saturating_sub(5)];
|
limit = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
for &split in splits {
|
for split in self.search.best[0].decode(self.offset).take(limit) {
|
||||||
self.search.result.push(self.text[start..split].into());
|
self.search.result.push(self.text[start..split].into());
|
||||||
start = split;
|
start = split;
|
||||||
}
|
}
|
||||||
|
@ -143,30 +150,22 @@ impl<'a> SegmentState<'a> {
|
||||||
|
|
||||||
let pair = (split..end, start..split);
|
let pair = (split..end, start..split);
|
||||||
let (suffix_score, suffix_splits) = match self.search.memo.get(&pair) {
|
let (suffix_score, suffix_splits) = match self.search.memo.get(&pair) {
|
||||||
Some((score, splits)) => {
|
Some((score, suffix_splits)) => (*score, *suffix_splits),
|
||||||
(*score, &self.search.split_cache[splits.start..splits.end])
|
|
||||||
}
|
|
||||||
None => {
|
None => {
|
||||||
let suffix_score = self.search(level + 1, split..end, Some(start..split));
|
let suffix_score = self.search(level + 1, split..end, Some(start..split));
|
||||||
|
let suffix_splits = self.search.best[level + 1];
|
||||||
let start = self.search.split_cache.len();
|
self.search.memo.insert(pair, (suffix_score, suffix_splits));
|
||||||
self.search
|
(suffix_score, suffix_splits)
|
||||||
.split_cache
|
|
||||||
.extend(&self.search.best[level + 1][..]);
|
|
||||||
let end = self.search.split_cache.len();
|
|
||||||
self.search.memo.insert(pair, (suffix_score, start..end));
|
|
||||||
|
|
||||||
(suffix_score, &self.search.split_cache[start..end])
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let score = prefix_score + suffix_score;
|
let score = prefix_score + suffix_score;
|
||||||
if score > best {
|
if score > best {
|
||||||
best = score;
|
best = score;
|
||||||
let splits = &mut self.search.best[level];
|
let new_splits = &mut self.search.best[level];
|
||||||
splits.clear();
|
new_splits.clear();
|
||||||
splits.push(split);
|
new_splits.set(split - self.offset);
|
||||||
splits.extend(suffix_splits);
|
*new_splits |= suffix_splits;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,9 +175,8 @@ impl<'a> SegmentState<'a> {
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Search {
|
pub struct Search {
|
||||||
memo: HashMap<MemoKey, (f64, Range<usize>)>,
|
memo: HashMap<MemoKey, (f64, BitVec)>,
|
||||||
split_cache: Vec<usize>,
|
best: [BitVec; SEGMENT_SIZE],
|
||||||
best: Vec<Vec<usize>>,
|
|
||||||
result: Vec<String>,
|
result: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -186,8 +184,7 @@ impl Default for Search {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
memo: HashMap::default(),
|
memo: HashMap::default(),
|
||||||
split_cache: Vec::with_capacity(32),
|
best: [BitVec::default(); SEGMENT_SIZE],
|
||||||
best: vec![vec![]; SEGMENT_SIZE],
|
|
||||||
result: Vec::new(),
|
result: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -196,7 +193,6 @@ impl Default for Search {
|
||||||
impl Search {
|
impl Search {
|
||||||
fn clear(&mut self) {
|
fn clear(&mut self) {
|
||||||
self.memo.clear();
|
self.memo.clear();
|
||||||
self.split_cache.clear();
|
|
||||||
for inner in self.best.iter_mut() {
|
for inner in self.best.iter_mut() {
|
||||||
inner.clear();
|
inner.clear();
|
||||||
}
|
}
|
||||||
|
@ -204,6 +200,70 @@ impl Search {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Default)]
|
||||||
|
struct BitVec([u64; 4]);
|
||||||
|
|
||||||
|
impl BitVec {
|
||||||
|
fn set(&mut self, mut bit: usize) {
|
||||||
|
debug_assert!(bit < 256);
|
||||||
|
let mut idx = 3;
|
||||||
|
while bit > 63 {
|
||||||
|
idx -= 1;
|
||||||
|
bit -= 64;
|
||||||
|
}
|
||||||
|
self.0[idx] |= 1 << bit;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear(&mut self) {
|
||||||
|
self.0.iter_mut().for_each(|dst| {
|
||||||
|
*dst = 0;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode(&self, offset: usize) -> Splits {
|
||||||
|
Splits {
|
||||||
|
vec: self.0,
|
||||||
|
offset,
|
||||||
|
idx: 3,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BitOrAssign for BitVec {
|
||||||
|
fn bitor_assign(&mut self, rhs: Self) {
|
||||||
|
self.0
|
||||||
|
.iter_mut()
|
||||||
|
.zip(rhs.0.iter())
|
||||||
|
.for_each(|(dst, src)| *dst |= *src);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Splits {
|
||||||
|
vec: [u64; 4],
|
||||||
|
offset: usize,
|
||||||
|
idx: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Iterator for Splits {
|
||||||
|
type Item = usize;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
while self.idx > 0 && self.vec[self.idx] == 0 {
|
||||||
|
self.idx -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let cur = self.vec[self.idx];
|
||||||
|
if cur == 0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let trailing = cur.trailing_zeros();
|
||||||
|
let next = Some(self.offset + (3 - self.idx) * 64 + trailing as usize);
|
||||||
|
self.vec[self.idx] -= 1 << trailing;
|
||||||
|
next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type MemoKey = (Range<usize>, Range<usize>);
|
type MemoKey = (Range<usize>, Range<usize>);
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -251,10 +311,41 @@ const SEGMENT_SIZE: usize = 250;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub mod tests {
|
pub mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_clean() {
|
fn test_clean() {
|
||||||
super::Ascii::new("Can't buy me love!").unwrap_err();
|
Ascii::new("Can't buy me love!").unwrap_err();
|
||||||
let text = super::Ascii::new("cantbuymelove").unwrap();
|
let text = Ascii::new("cantbuymelove").unwrap();
|
||||||
assert_eq!(&text[0..text.len()], "cantbuymelove");
|
assert_eq!(&text[0..text.len()], "cantbuymelove");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bitvec() {
|
||||||
|
let mut splits = BitVec::default();
|
||||||
|
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![]);
|
||||||
|
|
||||||
|
splits.set(1);
|
||||||
|
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![1]);
|
||||||
|
|
||||||
|
splits.set(5);
|
||||||
|
assert_eq!(splits.decode(10).collect::<Vec<_>>(), vec![11, 15]);
|
||||||
|
|
||||||
|
splits.set(64);
|
||||||
|
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![1, 5, 64]);
|
||||||
|
|
||||||
|
splits.set(255);
|
||||||
|
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![1, 5, 64, 255]);
|
||||||
|
|
||||||
|
let mut new = BitVec::default();
|
||||||
|
new.set(3);
|
||||||
|
new.set(16);
|
||||||
|
new.set(128);
|
||||||
|
|
||||||
|
splits |= new;
|
||||||
|
assert_eq!(
|
||||||
|
splits.decode(0).collect::<Vec<_>>(),
|
||||||
|
vec![1, 3, 5, 16, 64, 128, 255]
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue