mirror of
https://github.com/instant-labs/instant-segment.git
synced 2025-01-19 07:19:07 +00:00
Take an explicit search parameter
This commit is contained in:
parent
be0f8c0ed7
commit
13b29d183e
@ -8,5 +8,6 @@ benchmark_main!(benches);
|
|||||||
fn short(bench: &mut Bencher) {
|
fn short(bench: &mut Bencher) {
|
||||||
let segmenter = instant_segment::test_data::segmenter();
|
let segmenter = instant_segment::test_data::segmenter();
|
||||||
let mut out = Vec::new();
|
let mut out = Vec::new();
|
||||||
bench.iter(|| segmenter.segment("thisisatest", &mut out));
|
let mut search = instant_segment::Search::default();
|
||||||
|
bench.iter(|| segmenter.segment("thisisatest", &mut out, &mut search));
|
||||||
}
|
}
|
||||||
|
78
src/lib.rs
78
src/lib.rs
@ -54,9 +54,15 @@ impl Segmenter {
|
|||||||
/// Appends list of words that is the best segmentation of `text` to `out`
|
/// Appends list of words that is the best segmentation of `text` to `out`
|
||||||
///
|
///
|
||||||
/// Requires that the input `text` consists of lowercase ASCII characters only. Otherwise,
|
/// Requires that the input `text` consists of lowercase ASCII characters only. Otherwise,
|
||||||
/// returns `Err(InvalidCharacter)`.
|
/// returns `Err(InvalidCharacter)`. The `search` parameter contains caches that are used
|
||||||
pub fn segment(&self, text: &str, out: &mut Vec<String>) -> Result<(), InvalidCharacter> {
|
/// segmentation; passing it in allows the callers to reuse the cache allocations.
|
||||||
Ok(SegmentState::new(Ascii::new(text)?, &self, out).run())
|
pub fn segment(
|
||||||
|
&self,
|
||||||
|
text: &str,
|
||||||
|
out: &mut Vec<String>,
|
||||||
|
search: &mut Search,
|
||||||
|
) -> Result<(), InvalidCharacter> {
|
||||||
|
Ok(SegmentState::new(Ascii::new(text)?, &self, out, search).run())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn score(&self, word: &str, previous: Option<&str>) -> f64 {
|
fn score(&self, word: &str, previous: Option<&str>) -> f64 {
|
||||||
@ -94,21 +100,23 @@ impl Segmenter {
|
|||||||
struct SegmentState<'a> {
|
struct SegmentState<'a> {
|
||||||
data: &'a Segmenter,
|
data: &'a Segmenter,
|
||||||
text: Ascii<'a>,
|
text: Ascii<'a>,
|
||||||
memo: HashMap<MemoKey, (f64, Range<usize>)>,
|
|
||||||
split_cache: Vec<usize>,
|
|
||||||
result: &'a mut Vec<String>,
|
result: &'a mut Vec<String>,
|
||||||
best: Vec<Vec<usize>>,
|
search: &'a mut Search,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> SegmentState<'a> {
|
impl<'a> SegmentState<'a> {
|
||||||
fn new(text: Ascii<'a>, data: &'a Segmenter, result: &'a mut Vec<String>) -> Self {
|
fn new(
|
||||||
|
text: Ascii<'a>,
|
||||||
|
data: &'a Segmenter,
|
||||||
|
result: &'a mut Vec<String>,
|
||||||
|
search: &'a mut Search,
|
||||||
|
) -> Self {
|
||||||
|
search.clear();
|
||||||
Self {
|
Self {
|
||||||
data,
|
data,
|
||||||
text,
|
text,
|
||||||
memo: HashMap::default(),
|
|
||||||
split_cache: Vec::new(),
|
|
||||||
result,
|
result,
|
||||||
best: vec![vec![]; SEGMENT_SIZE],
|
search,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,7 +127,7 @@ impl<'a> SegmentState<'a> {
|
|||||||
end = self.text.len().min(end + SEGMENT_SIZE);
|
end = self.text.len().min(end + SEGMENT_SIZE);
|
||||||
self.search(0, start..end, None);
|
self.search(0, start..end, None);
|
||||||
|
|
||||||
let mut splits = &self.best[0][..];
|
let mut splits = &self.search.best[0][..];
|
||||||
if end < self.text.len() {
|
if end < self.text.len() {
|
||||||
splits = &splits[..splits.len().saturating_sub(5)];
|
splits = &splits[..splits.len().saturating_sub(5)];
|
||||||
}
|
}
|
||||||
@ -134,7 +142,7 @@ impl<'a> SegmentState<'a> {
|
|||||||
/// Score `word` in the context of `previous` word
|
/// Score `word` in the context of `previous` word
|
||||||
fn search(&mut self, level: usize, range: Range<usize>, previous: Option<Range<usize>>) -> f64 {
|
fn search(&mut self, level: usize, range: Range<usize>, previous: Option<Range<usize>>) -> f64 {
|
||||||
if range.is_empty() {
|
if range.is_empty() {
|
||||||
self.best[level].clear();
|
self.search.best[level].clear();
|
||||||
return 0.0;
|
return 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,24 +153,28 @@ impl<'a> SegmentState<'a> {
|
|||||||
let prefix_score = self.data.score(&self.text[start..split], previous).log10();
|
let prefix_score = self.data.score(&self.text[start..split], previous).log10();
|
||||||
|
|
||||||
let pair = (split..end, start..split);
|
let pair = (split..end, start..split);
|
||||||
let (suffix_score, suffix_splits) = match self.memo.get(&pair) {
|
let (suffix_score, suffix_splits) = match self.search.memo.get(&pair) {
|
||||||
Some((score, splits)) => (*score, &self.split_cache[splits.start..splits.end]),
|
Some((score, splits)) => {
|
||||||
|
(*score, &self.search.split_cache[splits.start..splits.end])
|
||||||
|
}
|
||||||
None => {
|
None => {
|
||||||
let suffix_score = self.search(level + 1, split..end, Some(start..split));
|
let suffix_score = self.search(level + 1, split..end, Some(start..split));
|
||||||
|
|
||||||
let start = self.split_cache.len();
|
let start = self.search.split_cache.len();
|
||||||
self.split_cache.extend(&self.best[level + 1][..]);
|
self.search
|
||||||
let end = self.split_cache.len();
|
.split_cache
|
||||||
self.memo.insert(pair, (suffix_score, start..end));
|
.extend(&self.search.best[level + 1][..]);
|
||||||
|
let end = self.search.split_cache.len();
|
||||||
|
self.search.memo.insert(pair, (suffix_score, start..end));
|
||||||
|
|
||||||
(suffix_score, &self.split_cache[start..end])
|
(suffix_score, &self.search.split_cache[start..end])
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let score = prefix_score + suffix_score;
|
let score = prefix_score + suffix_score;
|
||||||
if score > best {
|
if score > best {
|
||||||
best = score;
|
best = score;
|
||||||
let splits = &mut self.best[level];
|
let splits = &mut self.search.best[level];
|
||||||
splits.clear();
|
splits.clear();
|
||||||
splits.push(split);
|
splits.push(split);
|
||||||
splits.extend(suffix_splits);
|
splits.extend(suffix_splits);
|
||||||
@ -173,6 +185,32 @@ impl<'a> SegmentState<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct Search {
|
||||||
|
memo: HashMap<MemoKey, (f64, Range<usize>)>,
|
||||||
|
split_cache: Vec<usize>,
|
||||||
|
best: Vec<Vec<usize>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Search {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
memo: HashMap::default(),
|
||||||
|
split_cache: Vec::with_capacity(32),
|
||||||
|
best: vec![vec![]; SEGMENT_SIZE],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Search {
|
||||||
|
fn clear(&mut self) {
|
||||||
|
self.memo.clear();
|
||||||
|
self.split_cache.clear();
|
||||||
|
for inner in self.best.iter_mut() {
|
||||||
|
inner.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type MemoKey = (Range<usize>, Range<usize>);
|
type MemoKey = (Range<usize>, Range<usize>);
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -1,23 +1,24 @@
|
|||||||
use crate::Segmenter;
|
use crate::{Search, Segmenter};
|
||||||
|
|
||||||
/// Run a segmenter against the built-in test cases
|
/// Run a segmenter against the built-in test cases
|
||||||
pub fn run(segmenter: &Segmenter) {
|
pub fn run(segmenter: &Segmenter) {
|
||||||
|
let mut search = Search::default();
|
||||||
for test in TEST_CASES.iter().copied() {
|
for test in TEST_CASES.iter().copied() {
|
||||||
assert_segments(segmenter, test);
|
assert_segments(test, &mut search, segmenter);
|
||||||
}
|
}
|
||||||
assert_segments(segmenter, FAIL);
|
assert_segments(FAIL, &mut search, segmenter);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn assert_segments(segmenter: &Segmenter, s: &[&str]) {
|
pub fn assert_segments(s: &[&str], search: &mut Search, segmenter: &Segmenter) {
|
||||||
let mut out = Vec::new();
|
let mut out = Vec::new();
|
||||||
segmenter.segment(&s.join(""), &mut out).unwrap();
|
segmenter.segment(&s.join(""), &mut out, search).unwrap();
|
||||||
let cmp = out.iter().map(|s| &*s).collect::<Vec<_>>();
|
let cmp = out.iter().map(|s| &*s).collect::<Vec<_>>();
|
||||||
assert_eq!(cmp, s);
|
assert_eq!(cmp, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn check_segments(segmenter: &Segmenter, s: &[&str]) -> bool {
|
pub fn check_segments(s: &[&str], search: &mut Search, segmenter: &Segmenter) -> bool {
|
||||||
let mut out = Vec::new();
|
let mut out = Vec::new();
|
||||||
match segmenter.segment(&s.join(""), &mut out) {
|
match segmenter.segment(&s.join(""), &mut out, search) {
|
||||||
Ok(()) => s == out.iter().map(|s| &*s).collect::<Vec<_>>(),
|
Ok(()) => s == out.iter().map(|s| &*s).collect::<Vec<_>>(),
|
||||||
Err(_) => false,
|
Err(_) => false,
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user