diff --git a/src/lib.rs b/src/lib.rs index df22809..b509465 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,8 +52,11 @@ impl Segmenter { } /// Appends list of words that is the best segmentation of `text` to `out` - pub fn segment(&self, text: &str, out: &mut Vec) { - SegmentState::new(&Ascii::new(text), &self, out).run() + /// + /// Requires that the input `text` consists of lowercase ASCII characters only. Otherwise, + /// returns `Err(InvalidCharacter)`. + pub fn segment(&self, text: &str, out: &mut Vec) -> Result<(), InvalidCharacter> { + Ok(SegmentState::new(Ascii::new(text)?, &self, out).run()) } fn score(&self, word: &str, previous: Option<&str>) -> f64 { @@ -90,7 +93,7 @@ impl Segmenter { struct SegmentState<'a> { data: &'a Segmenter, - text: &'a Ascii, + text: Ascii<'a>, memo: HashMap)>, split_cache: Vec, result: &'a mut Vec, @@ -98,7 +101,7 @@ struct SegmentState<'a> { } impl<'a> SegmentState<'a> { - fn new(text: &'a Ascii, data: &'a Segmenter, result: &'a mut Vec) -> Self { + fn new(text: Ascii<'a>, data: &'a Segmenter, result: &'a mut Vec) -> Self { Self { data, text, @@ -172,19 +175,16 @@ impl<'a> SegmentState<'a> { type MemoKey = (Range, Range); -struct Ascii(Vec); +#[derive(Debug)] +struct Ascii<'a>(&'a [u8]); -impl Ascii { - fn new(s: &str) -> Self { - Self( - s.chars() - .filter_map(|c| match c.is_ascii_alphanumeric() { - true => Some(c.to_ascii_lowercase()), - false => None, - }) - .collect::() - .into_bytes(), - ) +impl<'a> Ascii<'a> { + fn new(s: &'a str) -> Result { + let bytes = s.as_bytes(); + match bytes.iter().all(|b| b.is_ascii_lowercase()) { + true => Ok(Self(bytes)), + false => Err(InvalidCharacter), + } } fn len(&self) -> usize { @@ -192,7 +192,7 @@ impl Ascii { } } -impl Index> for Ascii { +impl<'a> Index> for Ascii<'a> { type Output = str; fn index(&self, index: Range) -> &Self::Output { @@ -202,6 +202,17 @@ impl Index> for Ascii { } } +#[derive(Debug)] +pub struct InvalidCharacter; + +impl std::error::Error for InvalidCharacter {} + +impl std::fmt::Display for InvalidCharacter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("invalid character") + } +} + type HashMap = std::collections::HashMap; const DEFAULT_LIMIT: usize = 24; @@ -211,7 +222,8 @@ const SEGMENT_SIZE: usize = 250; pub mod tests { #[test] fn test_clean() { - let text = super::Ascii::new("Can't buy me love!"); + super::Ascii::new("Can't buy me love!").unwrap_err(); + let text = super::Ascii::new("cantbuymelove").unwrap(); assert_eq!(&text[0..text.len()], "cantbuymelove"); } } diff --git a/src/test_cases.rs b/src/test_cases.rs index 0f53c99..8d86eb3 100644 --- a/src/test_cases.rs +++ b/src/test_cases.rs @@ -10,15 +10,17 @@ pub fn run(segmenter: &Segmenter) { pub fn assert_segments(segmenter: &Segmenter, s: &[&str]) { let mut out = Vec::new(); - segmenter.segment(&s.join(""), &mut out); + segmenter.segment(&s.join(""), &mut out).unwrap(); let cmp = out.iter().map(|s| &*s).collect::>(); assert_eq!(cmp, s); } pub fn check_segments(segmenter: &Segmenter, s: &[&str]) -> bool { let mut out = Vec::new(); - segmenter.segment(&s.join(""), &mut out); - s == out.iter().map(|s| &*s).collect::>() + match segmenter.segment(&s.join(""), &mut out) { + Ok(()) => s == out.iter().map(|s| &*s).collect::>(), + Err(_) => false, + } } /// Built-in test cases