Initial version

This commit is contained in:
Dirkjan Ochtman 2020-05-26 19:59:05 +02:00
commit 38f9747c92
9 changed files with 798830 additions and 0 deletions

62
.github/workflows/rust.yml vendored Normal file
View File

@ -0,0 +1,62 @@
name: CI
on:
push:
branches: ['master']
pull_request:
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
rust: [stable, beta]
exclude:
- os: macos-latest
rust: beta
- os: windows-latest
rust: beta
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v1
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: ${{ matrix.rust }}
override: true
- uses: actions-rs/cargo@v1
with:
command: build
args: --workspace --all-targets
- uses: actions-rs/cargo@v1
with:
command: test
args: --workspace
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
components: rustfmt, clippy
- uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check
- uses: actions-rs/cargo@v1
if: always()
with:
command: clippy
args: --workspace --all-targets -- -D warnings
audit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: EmbarkStudios/cargo-deny-action@v0

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
/target
Cargo.lock

12
Cargo.toml Normal file
View File

@ -0,0 +1,12 @@
[package]
name = "word-segmenters"
version = "0.1.0"
authors = ["Dirkjan Ochtman <dirkjan@ochtman.nl>"]
edition = "2018"
license = "Apache-2.0"
[dependencies]
err-derive = "0.2.4"
[dev-dependencies]
once_cell = "1.4"

286358
data/bigrams.txt Normal file

File diff suppressed because it is too large Load Diff

333213
data/unigrams.txt Normal file

File diff suppressed because it is too large Load Diff

178758
data/words.txt Normal file

File diff suppressed because it is too large Load Diff

4
deny.toml Normal file
View File

@ -0,0 +1,4 @@
[licenses]
allow-osi-fsf-free = "either"
copyleft = "deny"
private = { ignore = true }

245
src/lib.rs Normal file
View File

@ -0,0 +1,245 @@
use std::{
collections::HashMap,
fs::File,
io::{self, BufRead, BufReader},
num::ParseIntError,
ops::Range,
path::Path,
str::FromStr,
};
use err_derive::Error;
pub struct Segmenter {
unigrams: HashMap<String, f64>,
bigrams: HashMap<(String, String), f64>,
total: f64,
limit: usize,
}
impl Segmenter {
/// Create `Segmenter` from files in the given directory
///
/// Reads from `unigrams.txt` and `bigrams.txt` in `dir`.
pub fn from_dir(dir: &Path) -> Result<Self, ParseError> {
let uni_file = dir.join("unigrams.txt");
let bi_file = dir.join("bigrams.txt");
Ok(Self {
unigrams: parse_unigrams(BufReader::new(File::open(&uni_file)?), uni_file.to_str())?,
bigrams: parse_bigrams(BufReader::new(File::open(&bi_file)?), bi_file.to_str())?,
limit: DEFAULT_LIMIT,
total: DEFAULT_TOTAL,
})
}
/// Returns a list of words that is the best segmentation of `text`
pub fn segment(&self, text: &str) -> Vec<String> {
let clean = clean(text);
let mut words = vec![];
let mut memo = HashMap::new();
let (mut start, mut end) = (0, 0);
loop {
end = clean.len().min(end + SEGMENT_SIZE);
let prefix = &clean[start..end];
let window_words = self.search(&prefix, "<s>", &mut memo).1;
for word in &window_words[..window_words.len().saturating_sub(5)] {
start += word.len();
words.push(word.into());
}
if end == clean.len() {
break;
}
}
let mut window_words = self.search(&clean[start..], "<s>", &mut memo).1;
words.append(&mut window_words);
words
}
/// Score `word` in the context of `previous` word
fn search(&self, text: &str, previous: &str, memo: &mut MemoMap) -> (f64, Vec<String>) {
if text.is_empty() {
return (0.0, vec![]);
}
let mut best = (f64::MIN, vec![]);
for (prefix, suffix) in TextDivider::new(text, self.limit) {
let prefix_score = self.score(prefix, Some(previous)).log10();
let pair = (suffix.to_owned(), prefix.to_owned());
let (suffix_score, suffix_words) = match memo.get(&pair) {
Some((score, words)) => (*score, words.clone()),
None => {
let (suffix_score, suffix_words) = self.search(&suffix, prefix, memo);
memo.insert(pair, (suffix_score, suffix_words.clone()));
(suffix_score, suffix_words)
}
};
let score = prefix_score + suffix_score;
if score > best.0 {
best.0 = score;
best.1.clear();
best.1.push(prefix.to_owned());
best.1.extend(suffix_words);
}
}
best
}
fn score(&self, word: &str, previous: Option<&str>) -> f64 {
match previous {
None => match self.unigrams.get(word) {
// Probabibility of the given word
Some(p) => p / self.total,
// Penalize words not found in the unigrams according
// to their length, a crucial heuristic.
None => 10.0 / (self.total * 10.0f64.powf(word.len() as f64)),
},
Some(prev) => match (
self.bigrams.get(&(prev.into(), word.into())),
self.unigrams.get(prev),
) {
// Conditional probability of the word given the previous
// word. The technical name is "stupid backoff" and it's
// not a probability distribution but it works well in practice.
(Some(pb), Some(_)) => pb / self.total / self.score(prev, None),
// Fall back to using the unigram probability
_ => self.score(word, None),
},
}
}
/// Customize the word length `limit
pub fn set_limit(&mut self, limit: usize) {
self.limit = limit;
}
/// Customize the relative score by setting the `total`
pub fn set_total(&mut self, total: f64) {
self.total = total;
}
}
/// Parse unigrams from the `reader` (format: `<word>\t<int>\n`)
///
/// The optional `name` argument may be used to provide a source name for error messages.
pub fn parse_unigrams<R: BufRead>(
reader: R,
name: Option<&str>,
) -> Result<HashMap<String, f64>, ParseError> {
let name = name.unwrap_or("(unnamed)");
reader
.lines()
.enumerate()
.map(|(i, ln)| {
let ln = ln?;
let split = ln
.find('\t')
.ok_or_else(|| ParseError::String(format!("no tab found in {:?}:{}", name, i)))?;
let word = ln[..split].to_owned();
let p = usize::from_str(&ln[split + 1..])
.map_err(|e| ParseError::String(format!("error at {:?}:{}: {}", name, i, e)))?;
Ok((word, p as f64))
})
.collect()
}
/// Parse bigrams from the `reader` (format: `<word-1> <word-2>\t<int>\n`)
///
/// The optional `name` argument may be used to provide a source name for error messages.
pub fn parse_bigrams<R: BufRead>(
reader: R,
name: Option<&str>,
) -> Result<HashMap<(String, String), f64>, ParseError> {
let name = name.unwrap_or("(unnamed)");
reader
.lines()
.enumerate()
.map(|(i, ln)| {
let ln = ln?;
let word_split = ln
.find(' ')
.ok_or_else(|| ParseError::String(format!("no space found in {:?}:{}", name, i)))?;
let score_split = ln[word_split + 1..]
.find('\t')
.ok_or_else(|| ParseError::String(format!("no tab found in {:?}:{}", name, i)))?
+ word_split
+ 1;
let word1 = ln[..word_split].to_owned();
let word2 = ln[word_split + 1..score_split].to_owned();
let p = usize::from_str(&ln[score_split + 1..])
.map_err(|e| ParseError::String(format!("error at {:?}:{}: {}", name, i, e)))?;
Ok(((word1, word2), p as f64))
})
.collect()
}
/// Iterator that yields `(prefix, suffix)` pairs from `text`
struct TextDivider<'a> {
text: &'a str,
split: Range<usize>,
}
impl<'a> TextDivider<'a> {
fn new(text: &'a str, limit: usize) -> Self {
TextDivider {
text,
split: 1..(text.len().min(limit) + 1),
}
}
}
impl<'a> Iterator for TextDivider<'a> {
type Item = (&'a str, &'a str);
fn next(&mut self) -> Option<Self::Item> {
self.split
.next()
.map(|split| (&self.text[..split], &self.text[split..]))
}
}
/// Return `text` lower-cased with non-alphanumeric characters removed
fn clean(s: &str) -> String {
s.chars()
.filter_map(|c| {
if c.is_ascii_alphanumeric() {
Some(c.to_ascii_lowercase())
} else {
None
}
})
.collect()
}
#[derive(Debug, Error)]
pub enum ParseError {
#[error(display = "I/O error: {}", _0)]
Io(#[source] io::Error),
#[error(display = "integer parsing error: {}", _0)]
ParseInt(#[source] ParseIntError),
#[error(display = "{}", _0)]
String(String),
}
type MemoMap = HashMap<(String, String), (f64, Vec<String>)>;
const DEFAULT_LIMIT: usize = 24;
const DEFAULT_TOTAL: f64 = 1_024_908_267_229.0;
const SEGMENT_SIZE: usize = 250;
#[cfg(test)]
mod tests {
#[test]
fn test_clean() {
assert_eq!(super::clean("Can't buy me love!"), "cantbuymelove");
}
}

176
tests/basic.rs Normal file
View File

@ -0,0 +1,176 @@
use std::path::PathBuf;
use once_cell::sync::Lazy;
use word_segmenters::Segmenter;
macro_rules! assert_segments {
($list:expr) => {
assert_eq!(SEGMENTER.segment(&$list.join("")), $list);
};
}
#[test]
fn test_segment_0() {
assert_segments!(&["choose", "spain"]);
}
#[test]
fn test_segment_1() {
assert_segments!(&["this", "is", "a", "test"]);
}
#[test]
fn test_segment_2() {
assert_segments!(&[
"when",
"in",
"the",
"course",
"of",
"human",
"events",
"it",
"becomes",
"necessary",
]);
}
#[test]
fn test_segment_3() {
assert_segments!(&["who", "represents"]);
}
#[test]
fn test_segment_4() {
assert_segments!(&["experts", "exchange"]);
}
#[test]
fn test_segment_5() {
assert_segments!(&["speed", "of", "art"]);
}
#[test]
fn test_segment_6() {
assert_segments!(&["now", "is", "the", "time", "for", "all", "good"]);
}
#[test]
fn test_segment_7() {
assert_segments!(&["it", "is", "a", "truth", "universally", "acknowledged"]);
}
#[test]
fn test_segment_8() {
assert_segments!(&[
"it", "was", "a", "bright", "cold", "day", "in", "april", "and", "the", "clocks", "were",
"striking", "thirteen",
]);
}
#[test]
fn test_segment_9() {
assert_segments!(&[
"it",
"was",
"the",
"best",
"of",
"times",
"it",
"was",
"the",
"worst",
"of",
"times",
"it",
"was",
"the",
"age",
"of",
"wisdom",
"it",
"was",
"the",
"age",
"of",
"foolishness",
]);
}
#[test]
fn test_segment_10() {
assert_segments!(&[
"as",
"gregor",
"samsa",
"awoke",
"one",
"morning",
"from",
"uneasy",
"dreams",
"he",
"found",
"himself",
"transformed",
"in",
"his",
"bed",
"into",
"a",
"gigantic",
"insect",
]);
}
#[test]
fn test_segment_11() {
assert_segments!(vec![
"in", "a", "hole", "in", "the", "ground", "there", "lived", "a", "hobbit", "not", "a",
"nasty", "dirty", "wet", "hole", "filled", "with", "the", "ends", "of", "worms", "and",
"an", "oozy", "smell", "nor", "yet", "a", "dry", "bare", "sandy", "hole", "with",
"nothing", "in", "it", "to", "sit", "down", "on", "or", "to", "eat", "it", "was", "a",
"hobbit", "hole", "and", "that", "means", "comfort"
]);
}
#[test]
fn test_segment_12() {
assert_segments!(&[
"far",
"out",
"in",
"the",
"uncharted",
"backwaters",
"of",
"the",
"unfashionable",
"end",
"of",
"the",
"western",
"spiral",
"arm",
"of",
"the",
"galaxy",
"lies",
"a",
"small",
"un",
"regarded",
"yellow",
"sun",
]);
}
static SEGMENTER: Lazy<Segmenter> = Lazy::new(|| {
Segmenter::from_dir(&PathBuf::from(format!(
"{}/data",
env!("CARGO_MANIFEST_DIR")
)))
.unwrap()
});