Use more efficient segmentation strategy

Based on the triangular matrix approach as explained here:

https://towardsdatascience.com/fast-word-segmentation-for-noisy-text-2c2c41f9e8da

Use iteration rather than recursion to segment the input forwards
rather than backwards and use a `Vec`-based memoization strategy
instead of relying on a `HashMap` of words. This version is about
4.8x faster, 100 lines of code less and should use much less memory.
This commit is contained in:
Dirkjan Ochtman 2021-05-27 13:50:23 +02:00
parent 541644a329
commit 85f4f94b53
2 changed files with 50 additions and 165 deletions

View File

@ -20,7 +20,7 @@ Corpus][corpus], as described by Thorsten Brants and Alex Franz, and
data **"may only be used for linguistic education and research"**, so for any data **"may only be used for linguistic education and research"**, so for any
other usage you should acquire a different data set. other usage you should acquire a different data set.
For the microbenchmark included in this repository, Instant Segment is ~17x For the microbenchmark included in this repository, Instant Segment is ~100x
faster than the Python implementation. Further optimizations are planned -- see faster than the Python implementation. Further optimizations are planned -- see
the [issues][issues]. The API has been carefully constructed so that multiple the [issues][issues]. The API has been carefully constructed so that multiple
segmentations can share the underlying state to allow parallel usage. segmentations can share the underlying state to allow parallel usage.

View File

@ -1,4 +1,4 @@
use std::ops::{BitOrAssign, Index, Range}; use std::ops::{Index, Range};
use std::str; use std::str;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
@ -112,96 +112,70 @@ struct SegmentState<'a> {
data: &'a Segmenter, data: &'a Segmenter,
text: Ascii<'a>, text: Ascii<'a>,
search: &'a mut Search, search: &'a mut Search,
offset: usize,
} }
impl<'a> SegmentState<'a> { impl<'a> SegmentState<'a> {
fn new(text: Ascii<'a>, data: &'a Segmenter, search: &'a mut Search) -> Self { fn new(text: Ascii<'a>, data: &'a Segmenter, search: &'a mut Search) -> Self {
search.clear(); search.clear();
Self { Self { data, text, search }
data,
text,
search,
offset: 0,
}
} }
/// Returns a list of words that is the best segmentation of `text` fn run(self) {
fn run(mut self) { for end in 1..=self.text.len() {
let (mut start, mut end) = (0, 0); let start = end.saturating_sub(self.data.limit);
while end < self.text.len() { for split in start..end {
end = self.text.len().min(end + SEGMENT_SIZE); let (prev, prev_score) = match split {
self.offset = start; 0 => (None, 0.0),
self.search(0, start..end, None); _ => {
let prefix = self.search.candidates[split - 1];
let mut limit = usize::MAX; let word = &self.text[split - prefix.len as usize..split];
if end < self.text.len() { (Some(word), prefix.score)
limit = 5;
}
for split in self.search.best[0].decode(self.offset).take(limit) {
self.search.result.push(self.text[start..split].into());
start = split;
}
}
}
/// Score `word` in the context of `previous` word
fn search(&mut self, level: usize, range: Range<usize>, previous: Option<Range<usize>>) -> f64 {
if range.is_empty() {
self.search.best[level].clear();
return 0.0;
}
let mut best = f64::MIN;
for split in 1..(range.len().min(self.data.limit) + 1) {
let (start, split, end) = (range.start, range.start + split, range.end);
let previous = previous.clone().map(|range| &self.text[range]);
let prefix_score = self.data.score(&self.text[start..split], previous);
let key = (
(start - self.offset) as u8,
(split - self.offset) as u8,
(end - self.offset) as u8,
);
let (suffix_score, suffix_splits) = match self.search.memo.get(&key) {
Some((score, suffix_splits)) => (*score, *suffix_splits),
None => {
let suffix_score = self.search(level + 1, split..end, Some(start..split));
let suffix_splits = self.search.best[level + 1];
self.search.memo.insert(key, (suffix_score, suffix_splits));
(suffix_score, suffix_splits)
} }
}; };
let score = prefix_score + suffix_score; let word = &self.text[split..end];
if score > best { let score = self.data.score(word, prev) + prev_score;
best = score; match self.search.candidates.get_mut(end - 1) {
let new_splits = &mut self.search.best[level]; Some(cur) if cur.score < score => {
new_splits.clear(); cur.len = end - split;
new_splits.set(split - self.offset); cur.score = score;
*new_splits |= suffix_splits; }
None => self.search.candidates.push(Candidate {
len: end - split,
score,
}),
_ => {}
}
} }
} }
best let mut end = self.text.len();
let mut best = self.search.candidates[end - 1];
loop {
let word = &self.text[end - best.len as usize..end];
self.search.result.push(word.into());
end -= best.len as usize;
if end == 0 {
break;
}
best = self.search.candidates[end - 1];
}
self.search.result.reverse();
} }
} }
#[derive(Clone)] #[derive(Clone)]
pub struct Search { pub struct Search {
memo: HashMap<(u8, u8, u8), (f64, BitVec)>, candidates: Vec<Candidate>,
best: Box<[BitVec; SEGMENT_SIZE]>,
result: Vec<String>, result: Vec<String>,
} }
impl Search { impl Search {
fn clear(&mut self) { fn clear(&mut self) {
self.memo.clear(); self.candidates.clear();
for inner in self.best.iter_mut() {
inner.clear();
}
self.result.clear(); self.result.clear();
} }
@ -214,75 +188,16 @@ impl Search {
impl Default for Search { impl Default for Search {
fn default() -> Self { fn default() -> Self {
Self { Self {
memo: HashMap::default(), candidates: Vec::new(),
best: Box::new([BitVec::default(); SEGMENT_SIZE]),
result: Vec::new(), result: Vec::new(),
} }
} }
} }
#[derive(Clone, Copy, Default)] #[derive(Clone, Copy, Debug, Default)]
struct BitVec([u64; 4]); struct Candidate {
len: usize,
impl BitVec { score: f64,
fn set(&mut self, mut bit: usize) {
debug_assert!(bit < 256);
let mut idx = 3;
while bit > 63 {
idx -= 1;
bit -= 64;
}
self.0[idx] |= 1 << bit;
}
fn clear(&mut self) {
self.0.iter_mut().for_each(|dst| {
*dst = 0;
});
}
fn decode(&self, offset: usize) -> Splits {
Splits {
vec: self.0,
offset,
idx: 3,
}
}
}
impl BitOrAssign for BitVec {
fn bitor_assign(&mut self, rhs: Self) {
self.0
.iter_mut()
.zip(rhs.0.iter())
.for_each(|(dst, src)| *dst |= *src);
}
}
struct Splits {
vec: [u64; 4],
offset: usize,
idx: usize,
}
impl Iterator for Splits {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
while self.idx > 0 && self.vec[self.idx] == 0 {
self.idx -= 1;
}
let cur = self.vec[self.idx];
if cur == 0 {
return None;
}
let trailing = cur.trailing_zeros();
let next = Some(self.offset + (3 - self.idx) * 64 + trailing as usize);
self.vec[self.idx] -= 1 << trailing;
next
}
} }
#[derive(Debug)] #[derive(Debug)]
@ -326,7 +241,6 @@ impl std::fmt::Display for InvalidCharacter {
type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>; type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>;
const DEFAULT_LIMIT: usize = 24; const DEFAULT_LIMIT: usize = 24;
const SEGMENT_SIZE: usize = 250;
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
@ -338,33 +252,4 @@ pub mod tests {
let text = Ascii::new("cantbuymelove").unwrap(); let text = Ascii::new("cantbuymelove").unwrap();
assert_eq!(&text[0..text.len()], "cantbuymelove"); assert_eq!(&text[0..text.len()], "cantbuymelove");
} }
#[test]
fn bitvec() {
let mut splits = BitVec::default();
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![]);
splits.set(1);
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![1]);
splits.set(5);
assert_eq!(splits.decode(10).collect::<Vec<_>>(), vec![11, 15]);
splits.set(64);
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![1, 5, 64]);
splits.set(255);
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![1, 5, 64, 255]);
let mut new = BitVec::default();
new.set(3);
new.set(16);
new.set(128);
splits |= new;
assert_eq!(
splits.decode(0).collect::<Vec<_>>(),
vec![1, 3, 5, 16, 64, 128, 255]
);
}
} }