From 540348f703b4a2a6e758e3224a8f984575df9832 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Tue, 24 Nov 2020 10:38:19 +0100 Subject: [PATCH] Abstract over test data format code and API --- .github/workflows/rust.yml | 2 +- Cargo.toml | 3 ++ benches/bench.rs | 11 +---- src/lib.rs | 91 ++++++++------------------------------ src/test_data.rs | 49 ++++++++++++++++++++ tests/basic.rs | 10 +---- 6 files changed, 75 insertions(+), 91 deletions(-) create mode 100644 src/test_data.rs diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 065bf1b..e57f177 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -33,7 +33,7 @@ jobs: - uses: actions-rs/cargo@v1 with: command: test - args: --workspace + args: --workspace --all-features lint: runs-on: ubuntu-latest diff --git a/Cargo.toml b/Cargo.toml index aba62af..ae2d9f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,9 @@ homepage = "https://github.com/InstantDomainSearch/word-segmenters" repository = "https://github.com/InstantDomainSearch/word-segmenters" documentation = "https://docs.rs/word-segmenters" +[features] +__test_data = [] + [dependencies] ahash = "0.6.1" smartstring = "0.2.5" diff --git a/benches/bench.rs b/benches/bench.rs index 33e8533..15b894d 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -1,19 +1,12 @@ -use std::path::PathBuf; +#![cfg(feature = "__test_data")] use bencher::{benchmark_group, benchmark_main, Bencher}; -use word_segmenters::Segmenter; - benchmark_group!(benches, short); benchmark_main!(benches); fn short(bench: &mut Bencher) { - let segmenter = Segmenter::from_dir(&PathBuf::from(format!( - "{}/data", - env!("CARGO_MANIFEST_DIR") - ))) - .unwrap(); - + let segmenter = word_segmenters::test_data::segmenter(); let mut out = Vec::new(); bench.iter(|| segmenter.segment("thisisatest", &mut out)); } diff --git a/src/lib.rs b/src/lib.rs index 7dc877d..c187967 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,15 @@ -use std::{ - fs::File, - io::{self, BufRead, BufReader}, - num::ParseIntError, - ops::Range, - path::Path, - str::FromStr, -}; +use std::error::Error; +use std::io; +use std::num::ParseIntError; +use std::ops::Range; use ahash::AHashMap as HashMap; use smartstring::alias::String; use thiserror::Error; +#[cfg(feature = "__test_data")] +pub mod test_data; + pub struct Segmenter { unigrams: HashMap, bigrams: HashMap<(String, String), f64>, @@ -19,15 +18,18 @@ pub struct Segmenter { } impl Segmenter { - /// Create `Segmenter` from files in the given directory + /// Create `Segmenter` from the given iterators /// - /// Reads from `unigrams.txt` and `bigrams.txt` in `dir`. - pub fn from_dir(dir: &Path) -> Result { - let uni_file = dir.join("unigrams.txt"); - let bi_file = dir.join("bigrams.txt"); + /// Note: the `String` types used in this API are defined in the `smartstring` crate. Any + /// `&str` or `String` can be converted into the `String` used here by calling `into()` on it. + pub fn from_iters<'a, U, B>(unigrams: U, bigrams: B) -> Result> + where + U: Iterator>>, + B: Iterator>>, + { Ok(Self { - unigrams: parse_unigrams(BufReader::new(File::open(&uni_file)?), uni_file.to_str())?, - bigrams: parse_bigrams(BufReader::new(File::open(&bi_file)?), bi_file.to_str())?, + unigrams: unigrams.collect::, _>>()?, + bigrams: bigrams.collect::, _>>()?, limit: DEFAULT_LIMIT, total: DEFAULT_TOTAL, }) @@ -149,63 +151,6 @@ impl<'a> SegmentState<'a> { } } -/// Parse unigrams from the `reader` (format: `\t\n`) -/// -/// The optional `name` argument may be used to provide a source name for error messages. -pub fn parse_unigrams( - reader: R, - name: Option<&str>, -) -> Result, ParseError> { - let name = name.unwrap_or("(unnamed)"); - reader - .lines() - .enumerate() - .map(|(i, ln)| { - let ln = ln?; - let split = ln - .find('\t') - .ok_or_else(|| format!("no tab found in {:?}:{}", name, i))?; - - let word = ln[..split].into(); - let p = usize::from_str(&ln[split + 1..]) - .map_err(|e| format!("error at {:?}:{}: {}", name, i, e))?; - Ok((word, p as f64)) - }) - .collect() -} - -/// Parse bigrams from the `reader` (format: ` \t\n`) -/// -/// The optional `name` argument may be used to provide a source name for error messages. -pub fn parse_bigrams( - reader: R, - name: Option<&str>, -) -> Result, ParseError> { - let name = name.unwrap_or("(unnamed)"); - reader - .lines() - .enumerate() - .map(|(i, ln)| { - let ln = ln?; - let word_split = ln - .find(' ') - .ok_or_else(|| format!("no space found in {:?}:{}", name, i))?; - let score_split = ln[word_split + 1..] - .find('\t') - .ok_or_else(|| format!("no tab found in {:?}:{}", name, i))? - + word_split - + 1; - - let word1 = ln[..word_split].into(); - let word2 = ln[word_split + 1..score_split].into(); - let p = usize::from_str(&ln[score_split + 1..]) - .map_err(|e| format!("error at {:?}:{}: {}", name, i, e))?; - - Ok(((word1, word2), p as f64)) - }) - .collect() -} - /// Iterator that yields `(prefix, suffix)` pairs from `text` struct TextDivider<'a> { text: &'a str, @@ -265,7 +210,7 @@ const DEFAULT_TOTAL: f64 = 1_024_908_267_229.0; const SEGMENT_SIZE: usize = 250; #[cfg(test)] -mod tests { +pub mod tests { #[test] fn test_clean() { assert_eq!(&super::clean("Can't buy me love!"), "cantbuymelove"); diff --git a/src/test_data.rs b/src/test_data.rs new file mode 100644 index 0000000..39f37c3 --- /dev/null +++ b/src/test_data.rs @@ -0,0 +1,49 @@ +#![cfg(feature = "__test_data")] + +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::PathBuf; +use std::str::FromStr; + +use super::Segmenter; + +pub fn segmenter() -> Segmenter { + let dir = PathBuf::from(format!("{}/data", env!("CARGO_MANIFEST_DIR"))); + + let uni_file = dir.join("unigrams.txt"); + let reader = BufReader::new(File::open(&uni_file).unwrap()); + let unigrams = reader.lines().enumerate().map(move |(i, ln)| { + let ln = ln?; + let split = ln + .find('\t') + .ok_or_else(|| format!("no tab found in {:?}:{}", uni_file, i))?; + + let word = ln[..split].into(); + let p = usize::from_str(&ln[split + 1..]) + .map_err(|e| format!("error at {:?}:{}: {}", uni_file, i, e))?; + Ok((word, p as f64)) + }); + + let bi_file = dir.join("bigrams.txt"); + let reader = BufReader::new(File::open(&bi_file).unwrap()); + let bigrams = reader.lines().enumerate().map(move |(i, ln)| { + let ln = ln?; + let word_split = ln + .find(' ') + .ok_or_else(|| format!("no space found in {:?}:{}", bi_file, i))?; + let score_split = ln[word_split + 1..] + .find('\t') + .ok_or_else(|| format!("no tab found in {:?}:{}", bi_file, i))? + + word_split + + 1; + + let word1 = ln[..word_split].into(); + let word2 = ln[word_split + 1..score_split].into(); + let p = usize::from_str(&ln[score_split + 1..]) + .map_err(|e| format!("error at {:?}:{}: {}", bi_file, i, e))?; + + Ok(((word1, word2), p as f64)) + }); + + Segmenter::from_iters(unigrams, bigrams).unwrap() +} diff --git a/tests/basic.rs b/tests/basic.rs index cbb8bf2..160c616 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -1,4 +1,4 @@ -use std::path::PathBuf; +#![cfg(feature = "__test_data")] use once_cell::sync::Lazy; @@ -170,10 +170,4 @@ fn test_segment_12() { ]); } -static SEGMENTER: Lazy = Lazy::new(|| { - Segmenter::from_dir(&PathBuf::from(format!( - "{}/data", - env!("CARGO_MANIFEST_DIR") - ))) - .unwrap() -}); +static SEGMENTER: Lazy = Lazy::new(|| word_segmenters::test_data::segmenter());