diff --git a/src/lib.rs b/src/lib.rs index 06a6134..9d1ce04 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,8 @@ use std::error::Error; use std::io; use std::num::ParseIntError; -use std::ops::Range; +use std::ops::{Index, Range}; +use std::str; use ahash::AHashMap as HashMap; use smartstring::alias::String; @@ -37,8 +38,7 @@ impl Segmenter { /// Appends list of words that is the best segmentation of `text` to `out` pub fn segment(&self, text: &str, out: &mut Vec) { - let clean = clean(text); - SegmentState::new(&clean, &self, out).run() + SegmentState::new(&Ascii::new(text), &self, out).run() } fn score(&self, word: &str, previous: Option<&str>) -> f64 { @@ -75,15 +75,15 @@ impl Segmenter { struct SegmentState<'a> { data: &'a Segmenter, - text: &'a str, - memo: HashMap<(&'a str, &'a str), (f64, Range)>, + text: &'a Ascii, + memo: HashMap<(Range, Range), (f64, Range)>, split_cache: Vec, result: &'a mut Vec, best: Vec>, } impl<'a> SegmentState<'a> { - fn new(text: &'a str, data: &'a Segmenter, result: &'a mut Vec) -> Self { + fn new(text: &'a Ascii, data: &'a Segmenter, result: &'a mut Vec) -> Self { Self { data, text, @@ -123,13 +123,11 @@ impl<'a> SegmentState<'a> { let mut best = f64::MIN; for split in 1..(range.len().min(self.data.limit) + 1) { - let (start, end) = (range.start, range.end); - let (prefix, suffix) = self.text[start..end].split_at(split); - let split = start + split; - + let (start, split, end) = (range.start, range.start + split, range.end); + let prefix = &self.text[start..split]; let prefix_score = self.data.score(prefix, previous).log10(); - let pair = (suffix, prefix); + let pair = (split..end, start..split); let (suffix_score, suffix_splits) = match self.memo.get(&pair) { Some((score, splits)) => (*score, &self.split_cache[splits.start..splits.end]), None => { @@ -161,17 +159,34 @@ impl<'a> SegmentState<'a> { } } -/// Return `text` lower-cased with non-alphanumeric characters removed -fn clean(s: &str) -> String { - s.chars() - .filter_map(|c| { - if c.is_ascii_alphanumeric() { - Some(c.to_ascii_lowercase()) - } else { - None - } - }) - .collect() +struct Ascii(Vec); + +impl Ascii { + fn new(s: &str) -> Self { + Self( + s.chars() + .filter_map(|c| match c.is_ascii_alphanumeric() { + true => Some(c.to_ascii_lowercase()), + false => None, + }) + .collect::() + .into_bytes(), + ) + } + + fn len(&self) -> usize { + self.0.len() + } +} + +impl Index> for Ascii { + type Output = str; + + fn index(&self, index: Range) -> &Self::Output { + let bytes = self.0.index(index); + // Since `Ascii` can only be instantiated with ASCII characters, this should be safe + unsafe { str::from_utf8_unchecked(bytes) } + } } #[derive(Debug, Error)] @@ -198,6 +213,7 @@ const SEGMENT_SIZE: usize = 250; pub mod tests { #[test] fn test_clean() { - assert_eq!(&super::clean("Can't buy me love!"), "cantbuymelove"); + let text = super::Ascii::new("Can't buy me love!"); + assert_eq!(&text[0..text.len()], "cantbuymelove"); } }