Take an explicit search parameter

This commit is contained in:
Dirkjan Ochtman 2021-02-10 11:48:24 +01:00
parent be0f8c0ed7
commit 13b29d183e
3 changed files with 68 additions and 28 deletions

View File

@ -8,5 +8,6 @@ benchmark_main!(benches);
fn short(bench: &mut Bencher) { fn short(bench: &mut Bencher) {
let segmenter = instant_segment::test_data::segmenter(); let segmenter = instant_segment::test_data::segmenter();
let mut out = Vec::new(); let mut out = Vec::new();
bench.iter(|| segmenter.segment("thisisatest", &mut out)); let mut search = instant_segment::Search::default();
bench.iter(|| segmenter.segment("thisisatest", &mut out, &mut search));
} }

View File

@ -54,9 +54,15 @@ impl Segmenter {
/// Appends list of words that is the best segmentation of `text` to `out` /// Appends list of words that is the best segmentation of `text` to `out`
/// ///
/// Requires that the input `text` consists of lowercase ASCII characters only. Otherwise, /// Requires that the input `text` consists of lowercase ASCII characters only. Otherwise,
/// returns `Err(InvalidCharacter)`. /// returns `Err(InvalidCharacter)`. The `search` parameter contains caches that are used
pub fn segment(&self, text: &str, out: &mut Vec<String>) -> Result<(), InvalidCharacter> { /// segmentation; passing it in allows the callers to reuse the cache allocations.
Ok(SegmentState::new(Ascii::new(text)?, &self, out).run()) pub fn segment(
&self,
text: &str,
out: &mut Vec<String>,
search: &mut Search,
) -> Result<(), InvalidCharacter> {
Ok(SegmentState::new(Ascii::new(text)?, &self, out, search).run())
} }
fn score(&self, word: &str, previous: Option<&str>) -> f64 { fn score(&self, word: &str, previous: Option<&str>) -> f64 {
@ -94,21 +100,23 @@ impl Segmenter {
struct SegmentState<'a> { struct SegmentState<'a> {
data: &'a Segmenter, data: &'a Segmenter,
text: Ascii<'a>, text: Ascii<'a>,
memo: HashMap<MemoKey, (f64, Range<usize>)>,
split_cache: Vec<usize>,
result: &'a mut Vec<String>, result: &'a mut Vec<String>,
best: Vec<Vec<usize>>, search: &'a mut Search,
} }
impl<'a> SegmentState<'a> { impl<'a> SegmentState<'a> {
fn new(text: Ascii<'a>, data: &'a Segmenter, result: &'a mut Vec<String>) -> Self { fn new(
text: Ascii<'a>,
data: &'a Segmenter,
result: &'a mut Vec<String>,
search: &'a mut Search,
) -> Self {
search.clear();
Self { Self {
data, data,
text, text,
memo: HashMap::default(),
split_cache: Vec::new(),
result, result,
best: vec![vec![]; SEGMENT_SIZE], search,
} }
} }
@ -119,7 +127,7 @@ impl<'a> SegmentState<'a> {
end = self.text.len().min(end + SEGMENT_SIZE); end = self.text.len().min(end + SEGMENT_SIZE);
self.search(0, start..end, None); self.search(0, start..end, None);
let mut splits = &self.best[0][..]; let mut splits = &self.search.best[0][..];
if end < self.text.len() { if end < self.text.len() {
splits = &splits[..splits.len().saturating_sub(5)]; splits = &splits[..splits.len().saturating_sub(5)];
} }
@ -134,7 +142,7 @@ impl<'a> SegmentState<'a> {
/// Score `word` in the context of `previous` word /// Score `word` in the context of `previous` word
fn search(&mut self, level: usize, range: Range<usize>, previous: Option<Range<usize>>) -> f64 { fn search(&mut self, level: usize, range: Range<usize>, previous: Option<Range<usize>>) -> f64 {
if range.is_empty() { if range.is_empty() {
self.best[level].clear(); self.search.best[level].clear();
return 0.0; return 0.0;
} }
@ -145,24 +153,28 @@ impl<'a> SegmentState<'a> {
let prefix_score = self.data.score(&self.text[start..split], previous).log10(); let prefix_score = self.data.score(&self.text[start..split], previous).log10();
let pair = (split..end, start..split); let pair = (split..end, start..split);
let (suffix_score, suffix_splits) = match self.memo.get(&pair) { let (suffix_score, suffix_splits) = match self.search.memo.get(&pair) {
Some((score, splits)) => (*score, &self.split_cache[splits.start..splits.end]), Some((score, splits)) => {
(*score, &self.search.split_cache[splits.start..splits.end])
}
None => { None => {
let suffix_score = self.search(level + 1, split..end, Some(start..split)); let suffix_score = self.search(level + 1, split..end, Some(start..split));
let start = self.split_cache.len(); let start = self.search.split_cache.len();
self.split_cache.extend(&self.best[level + 1][..]); self.search
let end = self.split_cache.len(); .split_cache
self.memo.insert(pair, (suffix_score, start..end)); .extend(&self.search.best[level + 1][..]);
let end = self.search.split_cache.len();
self.search.memo.insert(pair, (suffix_score, start..end));
(suffix_score, &self.split_cache[start..end]) (suffix_score, &self.search.split_cache[start..end])
} }
}; };
let score = prefix_score + suffix_score; let score = prefix_score + suffix_score;
if score > best { if score > best {
best = score; best = score;
let splits = &mut self.best[level]; let splits = &mut self.search.best[level];
splits.clear(); splits.clear();
splits.push(split); splits.push(split);
splits.extend(suffix_splits); splits.extend(suffix_splits);
@ -173,6 +185,32 @@ impl<'a> SegmentState<'a> {
} }
} }
pub struct Search {
memo: HashMap<MemoKey, (f64, Range<usize>)>,
split_cache: Vec<usize>,
best: Vec<Vec<usize>>,
}
impl Default for Search {
fn default() -> Self {
Self {
memo: HashMap::default(),
split_cache: Vec::with_capacity(32),
best: vec![vec![]; SEGMENT_SIZE],
}
}
}
impl Search {
fn clear(&mut self) {
self.memo.clear();
self.split_cache.clear();
for inner in self.best.iter_mut() {
inner.clear();
}
}
}
type MemoKey = (Range<usize>, Range<usize>); type MemoKey = (Range<usize>, Range<usize>);
#[derive(Debug)] #[derive(Debug)]

View File

@ -1,23 +1,24 @@
use crate::Segmenter; use crate::{Search, Segmenter};
/// Run a segmenter against the built-in test cases /// Run a segmenter against the built-in test cases
pub fn run(segmenter: &Segmenter) { pub fn run(segmenter: &Segmenter) {
let mut search = Search::default();
for test in TEST_CASES.iter().copied() { for test in TEST_CASES.iter().copied() {
assert_segments(segmenter, test); assert_segments(test, &mut search, segmenter);
} }
assert_segments(segmenter, FAIL); assert_segments(FAIL, &mut search, segmenter);
} }
pub fn assert_segments(segmenter: &Segmenter, s: &[&str]) { pub fn assert_segments(s: &[&str], search: &mut Search, segmenter: &Segmenter) {
let mut out = Vec::new(); let mut out = Vec::new();
segmenter.segment(&s.join(""), &mut out).unwrap(); segmenter.segment(&s.join(""), &mut out, search).unwrap();
let cmp = out.iter().map(|s| &*s).collect::<Vec<_>>(); let cmp = out.iter().map(|s| &*s).collect::<Vec<_>>();
assert_eq!(cmp, s); assert_eq!(cmp, s);
} }
pub fn check_segments(segmenter: &Segmenter, s: &[&str]) -> bool { pub fn check_segments(s: &[&str], search: &mut Search, segmenter: &Segmenter) -> bool {
let mut out = Vec::new(); let mut out = Vec::new();
match segmenter.segment(&s.join(""), &mut out) { match segmenter.segment(&s.join(""), &mut out, search) {
Ok(()) => s == out.iter().map(|s| &*s).collect::<Vec<_>>(), Ok(()) => s == out.iter().map(|s| &*s).collect::<Vec<_>>(),
Err(_) => false, Err(_) => false,
} }