mirror of
https://github.com/instant-labs/instant-segment.git
synced 2025-01-19 07:19:07 +00:00
Take an explicit search parameter
This commit is contained in:
parent
be0f8c0ed7
commit
13b29d183e
@ -8,5 +8,6 @@ benchmark_main!(benches);
|
||||
fn short(bench: &mut Bencher) {
|
||||
let segmenter = instant_segment::test_data::segmenter();
|
||||
let mut out = Vec::new();
|
||||
bench.iter(|| segmenter.segment("thisisatest", &mut out));
|
||||
let mut search = instant_segment::Search::default();
|
||||
bench.iter(|| segmenter.segment("thisisatest", &mut out, &mut search));
|
||||
}
|
||||
|
78
src/lib.rs
78
src/lib.rs
@ -54,9 +54,15 @@ impl Segmenter {
|
||||
/// Appends list of words that is the best segmentation of `text` to `out`
|
||||
///
|
||||
/// Requires that the input `text` consists of lowercase ASCII characters only. Otherwise,
|
||||
/// returns `Err(InvalidCharacter)`.
|
||||
pub fn segment(&self, text: &str, out: &mut Vec<String>) -> Result<(), InvalidCharacter> {
|
||||
Ok(SegmentState::new(Ascii::new(text)?, &self, out).run())
|
||||
/// returns `Err(InvalidCharacter)`. The `search` parameter contains caches that are used
|
||||
/// segmentation; passing it in allows the callers to reuse the cache allocations.
|
||||
pub fn segment(
|
||||
&self,
|
||||
text: &str,
|
||||
out: &mut Vec<String>,
|
||||
search: &mut Search,
|
||||
) -> Result<(), InvalidCharacter> {
|
||||
Ok(SegmentState::new(Ascii::new(text)?, &self, out, search).run())
|
||||
}
|
||||
|
||||
fn score(&self, word: &str, previous: Option<&str>) -> f64 {
|
||||
@ -94,21 +100,23 @@ impl Segmenter {
|
||||
struct SegmentState<'a> {
|
||||
data: &'a Segmenter,
|
||||
text: Ascii<'a>,
|
||||
memo: HashMap<MemoKey, (f64, Range<usize>)>,
|
||||
split_cache: Vec<usize>,
|
||||
result: &'a mut Vec<String>,
|
||||
best: Vec<Vec<usize>>,
|
||||
search: &'a mut Search,
|
||||
}
|
||||
|
||||
impl<'a> SegmentState<'a> {
|
||||
fn new(text: Ascii<'a>, data: &'a Segmenter, result: &'a mut Vec<String>) -> Self {
|
||||
fn new(
|
||||
text: Ascii<'a>,
|
||||
data: &'a Segmenter,
|
||||
result: &'a mut Vec<String>,
|
||||
search: &'a mut Search,
|
||||
) -> Self {
|
||||
search.clear();
|
||||
Self {
|
||||
data,
|
||||
text,
|
||||
memo: HashMap::default(),
|
||||
split_cache: Vec::new(),
|
||||
result,
|
||||
best: vec![vec![]; SEGMENT_SIZE],
|
||||
search,
|
||||
}
|
||||
}
|
||||
|
||||
@ -119,7 +127,7 @@ impl<'a> SegmentState<'a> {
|
||||
end = self.text.len().min(end + SEGMENT_SIZE);
|
||||
self.search(0, start..end, None);
|
||||
|
||||
let mut splits = &self.best[0][..];
|
||||
let mut splits = &self.search.best[0][..];
|
||||
if end < self.text.len() {
|
||||
splits = &splits[..splits.len().saturating_sub(5)];
|
||||
}
|
||||
@ -134,7 +142,7 @@ impl<'a> SegmentState<'a> {
|
||||
/// Score `word` in the context of `previous` word
|
||||
fn search(&mut self, level: usize, range: Range<usize>, previous: Option<Range<usize>>) -> f64 {
|
||||
if range.is_empty() {
|
||||
self.best[level].clear();
|
||||
self.search.best[level].clear();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
@ -145,24 +153,28 @@ impl<'a> SegmentState<'a> {
|
||||
let prefix_score = self.data.score(&self.text[start..split], previous).log10();
|
||||
|
||||
let pair = (split..end, start..split);
|
||||
let (suffix_score, suffix_splits) = match self.memo.get(&pair) {
|
||||
Some((score, splits)) => (*score, &self.split_cache[splits.start..splits.end]),
|
||||
let (suffix_score, suffix_splits) = match self.search.memo.get(&pair) {
|
||||
Some((score, splits)) => {
|
||||
(*score, &self.search.split_cache[splits.start..splits.end])
|
||||
}
|
||||
None => {
|
||||
let suffix_score = self.search(level + 1, split..end, Some(start..split));
|
||||
|
||||
let start = self.split_cache.len();
|
||||
self.split_cache.extend(&self.best[level + 1][..]);
|
||||
let end = self.split_cache.len();
|
||||
self.memo.insert(pair, (suffix_score, start..end));
|
||||
let start = self.search.split_cache.len();
|
||||
self.search
|
||||
.split_cache
|
||||
.extend(&self.search.best[level + 1][..]);
|
||||
let end = self.search.split_cache.len();
|
||||
self.search.memo.insert(pair, (suffix_score, start..end));
|
||||
|
||||
(suffix_score, &self.split_cache[start..end])
|
||||
(suffix_score, &self.search.split_cache[start..end])
|
||||
}
|
||||
};
|
||||
|
||||
let score = prefix_score + suffix_score;
|
||||
if score > best {
|
||||
best = score;
|
||||
let splits = &mut self.best[level];
|
||||
let splits = &mut self.search.best[level];
|
||||
splits.clear();
|
||||
splits.push(split);
|
||||
splits.extend(suffix_splits);
|
||||
@ -173,6 +185,32 @@ impl<'a> SegmentState<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Search {
|
||||
memo: HashMap<MemoKey, (f64, Range<usize>)>,
|
||||
split_cache: Vec<usize>,
|
||||
best: Vec<Vec<usize>>,
|
||||
}
|
||||
|
||||
impl Default for Search {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
memo: HashMap::default(),
|
||||
split_cache: Vec::with_capacity(32),
|
||||
best: vec![vec![]; SEGMENT_SIZE],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Search {
|
||||
fn clear(&mut self) {
|
||||
self.memo.clear();
|
||||
self.split_cache.clear();
|
||||
for inner in self.best.iter_mut() {
|
||||
inner.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type MemoKey = (Range<usize>, Range<usize>);
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -1,23 +1,24 @@
|
||||
use crate::Segmenter;
|
||||
use crate::{Search, Segmenter};
|
||||
|
||||
/// Run a segmenter against the built-in test cases
|
||||
pub fn run(segmenter: &Segmenter) {
|
||||
let mut search = Search::default();
|
||||
for test in TEST_CASES.iter().copied() {
|
||||
assert_segments(segmenter, test);
|
||||
assert_segments(test, &mut search, segmenter);
|
||||
}
|
||||
assert_segments(segmenter, FAIL);
|
||||
assert_segments(FAIL, &mut search, segmenter);
|
||||
}
|
||||
|
||||
pub fn assert_segments(segmenter: &Segmenter, s: &[&str]) {
|
||||
pub fn assert_segments(s: &[&str], search: &mut Search, segmenter: &Segmenter) {
|
||||
let mut out = Vec::new();
|
||||
segmenter.segment(&s.join(""), &mut out).unwrap();
|
||||
segmenter.segment(&s.join(""), &mut out, search).unwrap();
|
||||
let cmp = out.iter().map(|s| &*s).collect::<Vec<_>>();
|
||||
assert_eq!(cmp, s);
|
||||
}
|
||||
|
||||
pub fn check_segments(segmenter: &Segmenter, s: &[&str]) -> bool {
|
||||
pub fn check_segments(s: &[&str], search: &mut Search, segmenter: &Segmenter) -> bool {
|
||||
let mut out = Vec::new();
|
||||
match segmenter.segment(&s.join(""), &mut out) {
|
||||
match segmenter.segment(&s.join(""), &mut out, search) {
|
||||
Ok(()) => s == out.iter().map(|s| &*s).collect::<Vec<_>>(),
|
||||
Err(_) => false,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user