Use more efficient segmentation strategy

Based on the triangular matrix approach as explained here:

https://towardsdatascience.com/fast-word-segmentation-for-noisy-text-2c2c41f9e8da

Use iteration rather than recursion to segment the input forwards
rather than backwards and use a `Vec`-based memoization strategy
instead of relying on a `HashMap` of words. This version is about
4.8x faster, 100 lines of code less and should use much less memory.
This commit is contained in:
Dirkjan Ochtman 2021-05-27 13:50:23 +02:00
parent 541644a329
commit 85f4f94b53
2 changed files with 50 additions and 165 deletions

View File

@ -20,7 +20,7 @@ Corpus][corpus], as described by Thorsten Brants and Alex Franz, and
data **"may only be used for linguistic education and research"**, so for any data **"may only be used for linguistic education and research"**, so for any
other usage you should acquire a different data set. other usage you should acquire a different data set.
For the microbenchmark included in this repository, Instant Segment is ~17x For the microbenchmark included in this repository, Instant Segment is ~100x
faster than the Python implementation. Further optimizations are planned -- see faster than the Python implementation. Further optimizations are planned -- see
the [issues][issues]. The API has been carefully constructed so that multiple the [issues][issues]. The API has been carefully constructed so that multiple
segmentations can share the underlying state to allow parallel usage. segmentations can share the underlying state to allow parallel usage.
@ -110,4 +110,4 @@ make test-python
[corpus]: [corpus]:
http://googleresearch.blogspot.com/2006/08/all-our-n-gram-are-belong-to-you.html http://googleresearch.blogspot.com/2006/08/all-our-n-gram-are-belong-to-you.html
[distributed]: https://catalog.ldc.upenn.edu/LDC2006T13 [distributed]: https://catalog.ldc.upenn.edu/LDC2006T13
[issues]: https://github.com/InstantDomainSearch/instant-segment/issues [issues]: https://github.com/InstantDomainSearch/instant-segment/issues

View File

@ -1,4 +1,4 @@
use std::ops::{BitOrAssign, Index, Range}; use std::ops::{Index, Range};
use std::str; use std::str;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
@ -112,96 +112,70 @@ struct SegmentState<'a> {
data: &'a Segmenter, data: &'a Segmenter,
text: Ascii<'a>, text: Ascii<'a>,
search: &'a mut Search, search: &'a mut Search,
offset: usize,
} }
impl<'a> SegmentState<'a> { impl<'a> SegmentState<'a> {
fn new(text: Ascii<'a>, data: &'a Segmenter, search: &'a mut Search) -> Self { fn new(text: Ascii<'a>, data: &'a Segmenter, search: &'a mut Search) -> Self {
search.clear(); search.clear();
Self { Self { data, text, search }
data,
text,
search,
offset: 0,
}
} }
/// Returns a list of words that is the best segmentation of `text` fn run(self) {
fn run(mut self) { for end in 1..=self.text.len() {
let (mut start, mut end) = (0, 0); let start = end.saturating_sub(self.data.limit);
while end < self.text.len() { for split in start..end {
end = self.text.len().min(end + SEGMENT_SIZE); let (prev, prev_score) = match split {
self.offset = start; 0 => (None, 0.0),
self.search(0, start..end, None); _ => {
let prefix = self.search.candidates[split - 1];
let word = &self.text[split - prefix.len as usize..split];
(Some(word), prefix.score)
}
};
let mut limit = usize::MAX; let word = &self.text[split..end];
if end < self.text.len() { let score = self.data.score(word, prev) + prev_score;
limit = 5; match self.search.candidates.get_mut(end - 1) {
} Some(cur) if cur.score < score => {
cur.len = end - split;
for split in self.search.best[0].decode(self.offset).take(limit) { cur.score = score;
self.search.result.push(self.text[start..split].into()); }
start = split; None => self.search.candidates.push(Candidate {
} len: end - split,
} score,
} }),
_ => {}
/// Score `word` in the context of `previous` word
fn search(&mut self, level: usize, range: Range<usize>, previous: Option<Range<usize>>) -> f64 {
if range.is_empty() {
self.search.best[level].clear();
return 0.0;
}
let mut best = f64::MIN;
for split in 1..(range.len().min(self.data.limit) + 1) {
let (start, split, end) = (range.start, range.start + split, range.end);
let previous = previous.clone().map(|range| &self.text[range]);
let prefix_score = self.data.score(&self.text[start..split], previous);
let key = (
(start - self.offset) as u8,
(split - self.offset) as u8,
(end - self.offset) as u8,
);
let (suffix_score, suffix_splits) = match self.search.memo.get(&key) {
Some((score, suffix_splits)) => (*score, *suffix_splits),
None => {
let suffix_score = self.search(level + 1, split..end, Some(start..split));
let suffix_splits = self.search.best[level + 1];
self.search.memo.insert(key, (suffix_score, suffix_splits));
(suffix_score, suffix_splits)
} }
};
let score = prefix_score + suffix_score;
if score > best {
best = score;
let new_splits = &mut self.search.best[level];
new_splits.clear();
new_splits.set(split - self.offset);
*new_splits |= suffix_splits;
} }
} }
best let mut end = self.text.len();
let mut best = self.search.candidates[end - 1];
loop {
let word = &self.text[end - best.len as usize..end];
self.search.result.push(word.into());
end -= best.len as usize;
if end == 0 {
break;
}
best = self.search.candidates[end - 1];
}
self.search.result.reverse();
} }
} }
#[derive(Clone)] #[derive(Clone)]
pub struct Search { pub struct Search {
memo: HashMap<(u8, u8, u8), (f64, BitVec)>, candidates: Vec<Candidate>,
best: Box<[BitVec; SEGMENT_SIZE]>,
result: Vec<String>, result: Vec<String>,
} }
impl Search { impl Search {
fn clear(&mut self) { fn clear(&mut self) {
self.memo.clear(); self.candidates.clear();
for inner in self.best.iter_mut() {
inner.clear();
}
self.result.clear(); self.result.clear();
} }
@ -214,75 +188,16 @@ impl Search {
impl Default for Search { impl Default for Search {
fn default() -> Self { fn default() -> Self {
Self { Self {
memo: HashMap::default(), candidates: Vec::new(),
best: Box::new([BitVec::default(); SEGMENT_SIZE]),
result: Vec::new(), result: Vec::new(),
} }
} }
} }
#[derive(Clone, Copy, Default)] #[derive(Clone, Copy, Debug, Default)]
struct BitVec([u64; 4]); struct Candidate {
len: usize,
impl BitVec { score: f64,
fn set(&mut self, mut bit: usize) {
debug_assert!(bit < 256);
let mut idx = 3;
while bit > 63 {
idx -= 1;
bit -= 64;
}
self.0[idx] |= 1 << bit;
}
fn clear(&mut self) {
self.0.iter_mut().for_each(|dst| {
*dst = 0;
});
}
fn decode(&self, offset: usize) -> Splits {
Splits {
vec: self.0,
offset,
idx: 3,
}
}
}
impl BitOrAssign for BitVec {
fn bitor_assign(&mut self, rhs: Self) {
self.0
.iter_mut()
.zip(rhs.0.iter())
.for_each(|(dst, src)| *dst |= *src);
}
}
struct Splits {
vec: [u64; 4],
offset: usize,
idx: usize,
}
impl Iterator for Splits {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
while self.idx > 0 && self.vec[self.idx] == 0 {
self.idx -= 1;
}
let cur = self.vec[self.idx];
if cur == 0 {
return None;
}
let trailing = cur.trailing_zeros();
let next = Some(self.offset + (3 - self.idx) * 64 + trailing as usize);
self.vec[self.idx] -= 1 << trailing;
next
}
} }
#[derive(Debug)] #[derive(Debug)]
@ -326,7 +241,6 @@ impl std::fmt::Display for InvalidCharacter {
type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>; type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>;
const DEFAULT_LIMIT: usize = 24; const DEFAULT_LIMIT: usize = 24;
const SEGMENT_SIZE: usize = 250;
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
@ -338,33 +252,4 @@ pub mod tests {
let text = Ascii::new("cantbuymelove").unwrap(); let text = Ascii::new("cantbuymelove").unwrap();
assert_eq!(&text[0..text.len()], "cantbuymelove"); assert_eq!(&text[0..text.len()], "cantbuymelove");
} }
#[test]
fn bitvec() {
let mut splits = BitVec::default();
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![]);
splits.set(1);
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![1]);
splits.set(5);
assert_eq!(splits.decode(10).collect::<Vec<_>>(), vec![11, 15]);
splits.set(64);
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![1, 5, 64]);
splits.set(255);
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![1, 5, 64, 255]);
let mut new = BitVec::default();
new.set(3);
new.set(16);
new.set(128);
splits |= new;
assert_eq!(
splits.decode(0).collect::<Vec<_>>(),
vec![1, 3, 5, 16, 64, 128, 255]
);
}
} }