Use more efficient segmentation strategy
Based on the triangular matrix approach as explained here: https://towardsdatascience.com/fast-word-segmentation-for-noisy-text-2c2c41f9e8da Use iteration rather than recursion to segment the input forwards rather than backwards and use a `Vec`-based memoization strategy instead of relying on a `HashMap` of words. This version is about 4.8x faster, 100 lines of code less and should use much less memory.
This commit is contained in:
parent
541644a329
commit
85f4f94b53
|
@ -20,7 +20,7 @@ Corpus][corpus], as described by Thorsten Brants and Alex Franz, and
|
||||||
data **"may only be used for linguistic education and research"**, so for any
|
data **"may only be used for linguistic education and research"**, so for any
|
||||||
other usage you should acquire a different data set.
|
other usage you should acquire a different data set.
|
||||||
|
|
||||||
For the microbenchmark included in this repository, Instant Segment is ~17x
|
For the microbenchmark included in this repository, Instant Segment is ~100x
|
||||||
faster than the Python implementation. Further optimizations are planned -- see
|
faster than the Python implementation. Further optimizations are planned -- see
|
||||||
the [issues][issues]. The API has been carefully constructed so that multiple
|
the [issues][issues]. The API has been carefully constructed so that multiple
|
||||||
segmentations can share the underlying state to allow parallel usage.
|
segmentations can share the underlying state to allow parallel usage.
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use std::ops::{BitOrAssign, Index, Range};
|
use std::ops::{Index, Range};
|
||||||
use std::str;
|
use std::str;
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
|
@ -112,96 +112,70 @@ struct SegmentState<'a> {
|
||||||
data: &'a Segmenter,
|
data: &'a Segmenter,
|
||||||
text: Ascii<'a>,
|
text: Ascii<'a>,
|
||||||
search: &'a mut Search,
|
search: &'a mut Search,
|
||||||
offset: usize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> SegmentState<'a> {
|
impl<'a> SegmentState<'a> {
|
||||||
fn new(text: Ascii<'a>, data: &'a Segmenter, search: &'a mut Search) -> Self {
|
fn new(text: Ascii<'a>, data: &'a Segmenter, search: &'a mut Search) -> Self {
|
||||||
search.clear();
|
search.clear();
|
||||||
Self {
|
Self { data, text, search }
|
||||||
data,
|
|
||||||
text,
|
|
||||||
search,
|
|
||||||
offset: 0,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a list of words that is the best segmentation of `text`
|
fn run(self) {
|
||||||
fn run(mut self) {
|
for end in 1..=self.text.len() {
|
||||||
let (mut start, mut end) = (0, 0);
|
let start = end.saturating_sub(self.data.limit);
|
||||||
while end < self.text.len() {
|
for split in start..end {
|
||||||
end = self.text.len().min(end + SEGMENT_SIZE);
|
let (prev, prev_score) = match split {
|
||||||
self.offset = start;
|
0 => (None, 0.0),
|
||||||
self.search(0, start..end, None);
|
_ => {
|
||||||
|
let prefix = self.search.candidates[split - 1];
|
||||||
|
let word = &self.text[split - prefix.len as usize..split];
|
||||||
|
(Some(word), prefix.score)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let mut limit = usize::MAX;
|
let word = &self.text[split..end];
|
||||||
if end < self.text.len() {
|
let score = self.data.score(word, prev) + prev_score;
|
||||||
limit = 5;
|
match self.search.candidates.get_mut(end - 1) {
|
||||||
}
|
Some(cur) if cur.score < score => {
|
||||||
|
cur.len = end - split;
|
||||||
for split in self.search.best[0].decode(self.offset).take(limit) {
|
cur.score = score;
|
||||||
self.search.result.push(self.text[start..split].into());
|
}
|
||||||
start = split;
|
None => self.search.candidates.push(Candidate {
|
||||||
}
|
len: end - split,
|
||||||
}
|
score,
|
||||||
}
|
}),
|
||||||
|
_ => {}
|
||||||
/// Score `word` in the context of `previous` word
|
|
||||||
fn search(&mut self, level: usize, range: Range<usize>, previous: Option<Range<usize>>) -> f64 {
|
|
||||||
if range.is_empty() {
|
|
||||||
self.search.best[level].clear();
|
|
||||||
return 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut best = f64::MIN;
|
|
||||||
for split in 1..(range.len().min(self.data.limit) + 1) {
|
|
||||||
let (start, split, end) = (range.start, range.start + split, range.end);
|
|
||||||
let previous = previous.clone().map(|range| &self.text[range]);
|
|
||||||
let prefix_score = self.data.score(&self.text[start..split], previous);
|
|
||||||
|
|
||||||
let key = (
|
|
||||||
(start - self.offset) as u8,
|
|
||||||
(split - self.offset) as u8,
|
|
||||||
(end - self.offset) as u8,
|
|
||||||
);
|
|
||||||
|
|
||||||
let (suffix_score, suffix_splits) = match self.search.memo.get(&key) {
|
|
||||||
Some((score, suffix_splits)) => (*score, *suffix_splits),
|
|
||||||
None => {
|
|
||||||
let suffix_score = self.search(level + 1, split..end, Some(start..split));
|
|
||||||
let suffix_splits = self.search.best[level + 1];
|
|
||||||
self.search.memo.insert(key, (suffix_score, suffix_splits));
|
|
||||||
(suffix_score, suffix_splits)
|
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
let score = prefix_score + suffix_score;
|
|
||||||
if score > best {
|
|
||||||
best = score;
|
|
||||||
let new_splits = &mut self.search.best[level];
|
|
||||||
new_splits.clear();
|
|
||||||
new_splits.set(split - self.offset);
|
|
||||||
*new_splits |= suffix_splits;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
best
|
let mut end = self.text.len();
|
||||||
|
let mut best = self.search.candidates[end - 1];
|
||||||
|
loop {
|
||||||
|
let word = &self.text[end - best.len as usize..end];
|
||||||
|
self.search.result.push(word.into());
|
||||||
|
|
||||||
|
end -= best.len as usize;
|
||||||
|
if end == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
best = self.search.candidates[end - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
self.search.result.reverse();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Search {
|
pub struct Search {
|
||||||
memo: HashMap<(u8, u8, u8), (f64, BitVec)>,
|
candidates: Vec<Candidate>,
|
||||||
best: Box<[BitVec; SEGMENT_SIZE]>,
|
|
||||||
result: Vec<String>,
|
result: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Search {
|
impl Search {
|
||||||
fn clear(&mut self) {
|
fn clear(&mut self) {
|
||||||
self.memo.clear();
|
self.candidates.clear();
|
||||||
for inner in self.best.iter_mut() {
|
|
||||||
inner.clear();
|
|
||||||
}
|
|
||||||
self.result.clear();
|
self.result.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -214,75 +188,16 @@ impl Search {
|
||||||
impl Default for Search {
|
impl Default for Search {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
memo: HashMap::default(),
|
candidates: Vec::new(),
|
||||||
best: Box::new([BitVec::default(); SEGMENT_SIZE]),
|
|
||||||
result: Vec::new(),
|
result: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Default)]
|
#[derive(Clone, Copy, Debug, Default)]
|
||||||
struct BitVec([u64; 4]);
|
struct Candidate {
|
||||||
|
len: usize,
|
||||||
impl BitVec {
|
score: f64,
|
||||||
fn set(&mut self, mut bit: usize) {
|
|
||||||
debug_assert!(bit < 256);
|
|
||||||
let mut idx = 3;
|
|
||||||
while bit > 63 {
|
|
||||||
idx -= 1;
|
|
||||||
bit -= 64;
|
|
||||||
}
|
|
||||||
self.0[idx] |= 1 << bit;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn clear(&mut self) {
|
|
||||||
self.0.iter_mut().for_each(|dst| {
|
|
||||||
*dst = 0;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
fn decode(&self, offset: usize) -> Splits {
|
|
||||||
Splits {
|
|
||||||
vec: self.0,
|
|
||||||
offset,
|
|
||||||
idx: 3,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BitOrAssign for BitVec {
|
|
||||||
fn bitor_assign(&mut self, rhs: Self) {
|
|
||||||
self.0
|
|
||||||
.iter_mut()
|
|
||||||
.zip(rhs.0.iter())
|
|
||||||
.for_each(|(dst, src)| *dst |= *src);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Splits {
|
|
||||||
vec: [u64; 4],
|
|
||||||
offset: usize,
|
|
||||||
idx: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Iterator for Splits {
|
|
||||||
type Item = usize;
|
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
|
||||||
while self.idx > 0 && self.vec[self.idx] == 0 {
|
|
||||||
self.idx -= 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
let cur = self.vec[self.idx];
|
|
||||||
if cur == 0 {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let trailing = cur.trailing_zeros();
|
|
||||||
let next = Some(self.offset + (3 - self.idx) * 64 + trailing as usize);
|
|
||||||
self.vec[self.idx] -= 1 << trailing;
|
|
||||||
next
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -326,7 +241,6 @@ impl std::fmt::Display for InvalidCharacter {
|
||||||
type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>;
|
type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>;
|
||||||
|
|
||||||
const DEFAULT_LIMIT: usize = 24;
|
const DEFAULT_LIMIT: usize = 24;
|
||||||
const SEGMENT_SIZE: usize = 250;
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub mod tests {
|
pub mod tests {
|
||||||
|
@ -338,33 +252,4 @@ pub mod tests {
|
||||||
let text = Ascii::new("cantbuymelove").unwrap();
|
let text = Ascii::new("cantbuymelove").unwrap();
|
||||||
assert_eq!(&text[0..text.len()], "cantbuymelove");
|
assert_eq!(&text[0..text.len()], "cantbuymelove");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn bitvec() {
|
|
||||||
let mut splits = BitVec::default();
|
|
||||||
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![]);
|
|
||||||
|
|
||||||
splits.set(1);
|
|
||||||
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![1]);
|
|
||||||
|
|
||||||
splits.set(5);
|
|
||||||
assert_eq!(splits.decode(10).collect::<Vec<_>>(), vec![11, 15]);
|
|
||||||
|
|
||||||
splits.set(64);
|
|
||||||
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![1, 5, 64]);
|
|
||||||
|
|
||||||
splits.set(255);
|
|
||||||
assert_eq!(splits.decode(0).collect::<Vec<_>>(), vec![1, 5, 64, 255]);
|
|
||||||
|
|
||||||
let mut new = BitVec::default();
|
|
||||||
new.set(3);
|
|
||||||
new.set(16);
|
|
||||||
new.set(128);
|
|
||||||
|
|
||||||
splits |= new;
|
|
||||||
assert_eq!(
|
|
||||||
splits.decode(0).collect::<Vec<_>>(),
|
|
||||||
vec![1, 3, 5, 16, 64, 128, 255]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue