Don't normalize input strings implicitly

This commit is contained in:
Dirkjan Ochtman 2021-02-08 15:53:24 +01:00
parent 8c08bb9e14
commit be0f8c0ed7
2 changed files with 35 additions and 21 deletions

View File

@ -52,8 +52,11 @@ impl Segmenter {
}
/// Appends list of words that is the best segmentation of `text` to `out`
pub fn segment(&self, text: &str, out: &mut Vec<String>) {
SegmentState::new(&Ascii::new(text), &self, out).run()
///
/// Requires that the input `text` consists of lowercase ASCII characters only. Otherwise,
/// returns `Err(InvalidCharacter)`.
pub fn segment(&self, text: &str, out: &mut Vec<String>) -> Result<(), InvalidCharacter> {
Ok(SegmentState::new(Ascii::new(text)?, &self, out).run())
}
fn score(&self, word: &str, previous: Option<&str>) -> f64 {
@ -90,7 +93,7 @@ impl Segmenter {
struct SegmentState<'a> {
data: &'a Segmenter,
text: &'a Ascii,
text: Ascii<'a>,
memo: HashMap<MemoKey, (f64, Range<usize>)>,
split_cache: Vec<usize>,
result: &'a mut Vec<String>,
@ -98,7 +101,7 @@ struct SegmentState<'a> {
}
impl<'a> SegmentState<'a> {
fn new(text: &'a Ascii, data: &'a Segmenter, result: &'a mut Vec<String>) -> Self {
fn new(text: Ascii<'a>, data: &'a Segmenter, result: &'a mut Vec<String>) -> Self {
Self {
data,
text,
@ -172,19 +175,16 @@ impl<'a> SegmentState<'a> {
type MemoKey = (Range<usize>, Range<usize>);
struct Ascii(Vec<u8>);
#[derive(Debug)]
struct Ascii<'a>(&'a [u8]);
impl Ascii {
fn new(s: &str) -> Self {
Self(
s.chars()
.filter_map(|c| match c.is_ascii_alphanumeric() {
true => Some(c.to_ascii_lowercase()),
false => None,
})
.collect::<std::string::String>()
.into_bytes(),
)
impl<'a> Ascii<'a> {
fn new(s: &'a str) -> Result<Self, InvalidCharacter> {
let bytes = s.as_bytes();
match bytes.iter().all(|b| b.is_ascii_lowercase()) {
true => Ok(Self(bytes)),
false => Err(InvalidCharacter),
}
}
fn len(&self) -> usize {
@ -192,7 +192,7 @@ impl Ascii {
}
}
impl Index<Range<usize>> for Ascii {
impl<'a> Index<Range<usize>> for Ascii<'a> {
type Output = str;
fn index(&self, index: Range<usize>) -> &Self::Output {
@ -202,6 +202,17 @@ impl Index<Range<usize>> for Ascii {
}
}
#[derive(Debug)]
pub struct InvalidCharacter;
impl std::error::Error for InvalidCharacter {}
impl std::fmt::Display for InvalidCharacter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("invalid character")
}
}
type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>;
const DEFAULT_LIMIT: usize = 24;
@ -211,7 +222,8 @@ const SEGMENT_SIZE: usize = 250;
pub mod tests {
#[test]
fn test_clean() {
let text = super::Ascii::new("Can't buy me love!");
super::Ascii::new("Can't buy me love!").unwrap_err();
let text = super::Ascii::new("cantbuymelove").unwrap();
assert_eq!(&text[0..text.len()], "cantbuymelove");
}
}

View File

@ -10,15 +10,17 @@ pub fn run(segmenter: &Segmenter) {
pub fn assert_segments(segmenter: &Segmenter, s: &[&str]) {
let mut out = Vec::new();
segmenter.segment(&s.join(""), &mut out);
segmenter.segment(&s.join(""), &mut out).unwrap();
let cmp = out.iter().map(|s| &*s).collect::<Vec<_>>();
assert_eq!(cmp, s);
}
pub fn check_segments(segmenter: &Segmenter, s: &[&str]) -> bool {
let mut out = Vec::new();
segmenter.segment(&s.join(""), &mut out);
s == out.iter().map(|s| &*s).collect::<Vec<_>>()
match segmenter.segment(&s.join(""), &mut out) {
Ok(()) => s == out.iter().map(|s| &*s).collect::<Vec<_>>(),
Err(_) => false,
}
}
/// Built-in test cases