Separate state from Segmenter

This commit is contained in:
Dirkjan Ochtman 2020-05-28 19:56:13 +02:00
parent 98a8368be6
commit 76bdcf1ca5
1 changed files with 78 additions and 61 deletions

View File

@ -35,65 +35,7 @@ impl Segmenter {
/// Returns a list of words that is the best segmentation of `text` /// Returns a list of words that is the best segmentation of `text`
pub fn segment(&self, text: &str) -> Vec<String> { pub fn segment(&self, text: &str) -> Vec<String> {
let clean = clean(text); let clean = clean(text);
let mut words = vec![]; SegmentState::new(&clean, &self).run()
let mut memo = HashMap::new();
let (mut start, mut end) = (0, 0);
loop {
end = clean.len().min(end + SEGMENT_SIZE);
let prefix = &clean[start..end];
let window_words = self.search(&prefix, "<s>", &mut memo).1;
for word in &window_words[..window_words.len().saturating_sub(5)] {
start += word.len();
words.push((*word).into());
}
if end == clean.len() {
break;
}
}
let window_words = self.search(&clean[start..], "<s>", &mut memo).1;
words.extend(window_words.into_iter().map(|s| s.to_owned()));
words
}
/// Score `word` in the context of `previous` word
fn search<'a, 'b: 'a>(
&self,
text: &'b str,
previous: &str,
memo: &'a mut MemoMap<'b>,
) -> (f64, Vec<&'b str>) {
if text.is_empty() {
return (0.0, vec![]);
}
let mut best = (f64::MIN, vec![]);
for (prefix, suffix) in TextDivider::new(text, self.limit) {
let prefix_score = self.score(prefix, Some(previous)).log10();
let pair = (suffix, prefix);
let (suffix_score, suffix_words) = match memo.get(&pair) {
Some((score, words)) => (*score, words.as_slice()),
None => {
let (suffix_score, suffix_words) = self.search(&suffix, prefix, memo);
let value = memo.entry(pair).or_insert((suffix_score, suffix_words));
(suffix_score, value.1.as_slice())
}
};
let score = prefix_score + suffix_score;
if score > best.0 {
best.0 = score;
best.1.clear();
best.1.push(prefix);
best.1.extend(suffix_words);
}
}
best
} }
fn score(&self, word: &str, previous: Option<&str>) -> f64 { fn score(&self, word: &str, previous: Option<&str>) -> f64 {
@ -130,6 +72,83 @@ impl Segmenter {
} }
} }
struct SegmentState<'a> {
data: &'a Segmenter,
text: &'a str,
memo: HashMap<(&'a str, &'a str), (f64, Vec<&'a str>)>,
result: Vec<String>,
}
impl<'a> SegmentState<'a> {
fn new(text: &'a str, data: &'a Segmenter) -> Self {
Self {
data,
text,
memo: HashMap::new(),
result: Vec::new(),
}
}
/// Returns a list of words that is the best segmentation of `text`
pub fn run(mut self) -> Vec<String> {
let (mut start, mut end) = (0, 0);
loop {
end = self.text.len().min(end + SEGMENT_SIZE);
let prefix = &self.text[start..end];
let window_words = self.search(&prefix, "<s>").1;
for word in &window_words[..window_words.len().saturating_sub(5)] {
start += word.len();
self.result.push((*word).into());
}
if end == self.text.len() {
break;
}
}
let window_words = self.search(&self.text[start..], "<s>").1;
self.result
.extend(window_words.into_iter().map(|s| s.to_owned()));
self.result
}
/// Score `word` in the context of `previous` word
fn search(&mut self, text: &'a str, previous: &str) -> (f64, Vec<&'a str>) {
if text.is_empty() {
return (0.0, vec![]);
}
let mut best = (f64::MIN, vec![]);
for (prefix, suffix) in TextDivider::new(text, self.data.limit) {
let prefix_score = self.data.score(prefix, Some(previous)).log10();
let pair = (suffix, prefix);
let (suffix_score, suffix_words) = match self.memo.get(&pair) {
Some((score, words)) => (*score, words.as_slice()),
None => {
let (suffix_score, suffix_words) = self.search(&suffix, prefix);
let value = self
.memo
.entry(pair)
.or_insert((suffix_score, suffix_words));
(suffix_score, value.1.as_slice())
}
};
let score = prefix_score + suffix_score;
if score > best.0 {
best.0 = score;
best.1.clear();
best.1.push(prefix);
best.1.extend(suffix_words);
}
}
best
}
}
/// Parse unigrams from the `reader` (format: `<word>\t<int>\n`) /// Parse unigrams from the `reader` (format: `<word>\t<int>\n`)
/// ///
/// The optional `name` argument may be used to provide a source name for error messages. /// The optional `name` argument may be used to provide a source name for error messages.
@ -235,8 +254,6 @@ pub enum ParseError {
String(String), String(String),
} }
type MemoMap<'a> = HashMap<(&'a str, &'a str), (f64, Vec<&'a str>)>;
const DEFAULT_LIMIT: usize = 24; const DEFAULT_LIMIT: usize = 24;
const DEFAULT_TOTAL: f64 = 1_024_908_267_229.0; const DEFAULT_TOTAL: f64 = 1_024_908_267_229.0;
const SEGMENT_SIZE: usize = 250; const SEGMENT_SIZE: usize = 250;