From 3b3627422b0e67f703f16d9a8a8662d7f83ea2dd Mon Sep 17 00:00:00 2001 From: Michael Partheil Date: Tue, 3 Oct 2023 20:12:20 +0200 Subject: [PATCH] Use nested `HashMap` for storing both unigram and bigram scores --- README.md | 2 +- instant-segment-py/Cargo.toml | 1 - instant-segment-py/src/lib.rs | 8 +-- instant-segment/examples/contrived.rs | 6 +- instant-segment/src/lib.rs | 83 +++++++++++++++------------ instant-segment/src/test_data.rs | 2 +- 6 files changed, 55 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index cadfd6b..9c89f6f 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ print([word for word in search]) use instant_segment::{Search, Segmenter}; use std::collections::HashMap; -let segmenter = Segmenter::from_maps(unigrams, bigrams); +let segmenter = Segmenter::new(unigrams, bigrams); let mut search = Search::default(); let words = segmenter .segment("instantdomainsearch", &mut search) diff --git a/instant-segment-py/Cargo.toml b/instant-segment-py/Cargo.toml index 8de362e..36064ff 100644 --- a/instant-segment-py/Cargo.toml +++ b/instant-segment-py/Cargo.toml @@ -16,7 +16,6 @@ name = "instant_segment" crate-type = ["cdylib"] [dependencies] -ahash = "0.8" bincode = "1.3.2" instant-segment = { version = "0.10", path = "../instant-segment", features = ["with-serde"] } pyo3 = { version = "0.20", features = ["extension-module"] } diff --git a/instant-segment-py/src/lib.rs b/instant-segment-py/src/lib.rs index d29633b..e517736 100644 --- a/instant-segment-py/src/lib.rs +++ b/instant-segment-py/src/lib.rs @@ -39,7 +39,7 @@ impl Segmenter { let val = item.get_item(1)?.extract::()?; Ok((SmartString::from(key), val)) }) - .collect::, PyErr>>()?; + .collect::, PyErr>>()?; let bigrams = bigrams .map(|item| { @@ -52,10 +52,10 @@ impl Segmenter { let val = item.get_item(1)?.extract::()?; Ok(((SmartString::from(first), SmartString::from(second)), val)) }) - .collect::, PyErr>>()?; + .collect::, PyErr>>()?; Ok(Self { - inner: instant_segment::Segmenter::from_maps(unigrams, bigrams), + inner: instant_segment::Segmenter::new(unigrams, bigrams), }) } @@ -148,5 +148,3 @@ impl Search { Some(word) } } - -type HashMap = std::collections::HashMap; diff --git a/instant-segment/examples/contrived.rs b/instant-segment/examples/contrived.rs index c3cb2ab..e5850ad 100644 --- a/instant-segment/examples/contrived.rs +++ b/instant-segment/examples/contrived.rs @@ -2,7 +2,7 @@ use instant_segment::{Search, Segmenter}; use std::collections::HashMap; fn main() { - let mut unigrams = HashMap::default(); + let mut unigrams = HashMap::new(); unigrams.insert("choose".into(), 80_000.0); unigrams.insert("chooses".into(), 7_000.0); @@ -10,12 +10,12 @@ fn main() { unigrams.insert("spain".into(), 20_000.0); unigrams.insert("pain".into(), 90_000.0); - let mut bigrams = HashMap::default(); + let mut bigrams = HashMap::new(); bigrams.insert(("choose".into(), "spain".into()), 7.0); bigrams.insert(("chooses".into(), "pain".into()), 0.0); - let segmenter = Segmenter::from_maps(unigrams, bigrams); + let segmenter = Segmenter::new(unigrams, bigrams); let mut search = Search::default(); let words = segmenter.segment("choosespain", &mut search).unwrap(); diff --git a/instant-segment/src/lib.rs b/instant-segment/src/lib.rs index 010b3a5..d5052d3 100644 --- a/instant-segment/src/lib.rs +++ b/instant-segment/src/lib.rs @@ -13,47 +13,57 @@ pub mod test_data; /// Central data structure used to calculate word probabilities #[cfg_attr(feature = "with-serde", derive(Deserialize, Serialize))] pub struct Segmenter { - unigrams: HashMap, - bigrams: HashMap<(String, String), f64>, + // Maps a word to both its unigram score, as well has a nested HashMap in + // which the bigram score can be looked up using the previous word. Scores + // are base-10 logarithms of relative word frequencies + scores: HashMap)>, + // Base-10 logarithm of the total count of unigrams uni_total_log10: f64, limit: usize, } impl Segmenter { - /// Create `Segmenter` from the given iterators + /// Create `Segmenter` from the given unigram and bigram counts. /// /// Note: the `String` types used in this API are defined in the `smartstring` crate. Any /// `&str` or `String` can be converted into the `String` used here by calling `into()` on it. - pub fn from_iters(unigrams: U, bigrams: B) -> Self + pub fn new(unigrams: U, bigrams: B) -> Self where - U: Iterator, - B: Iterator, + U: IntoIterator, + B: IntoIterator, { - Self::from_maps(unigrams.collect(), bigrams.collect()) - } + // Initially, `scores` contains the original unigram and bigram counts + let mut scores = HashMap::default(); + let mut uni_total = 0.0; + for (word, uni) in unigrams { + scores.insert(word, (uni, HashMap::default())); + uni_total += uni; + } + let mut bi_total = 0.0; + for ((word1, word2), bi) in bigrams { + let Some((_, bi_scores)) = scores.get_mut(&word2) else { + // We throw away bigrams for which we do not have a unigram for + // the second word. This case shouldn't ever happen on + // real-world data, and in fact, it never happens on the word + // count lists shipped with this crate. + continue; + }; + bi_scores.insert(word1, bi); + bi_total += bi; + } - /// Create `Segmenter` from the given hashmaps (using ahash) - /// - /// Note: the `String` types used in this API are defined in the `smartstring` crate. Any - /// `&str` or `String` can be converted into the `String` used here by calling `into()` on it. - /// The `HashMap` type here refers to `std::collections::HashMap` parametrized with the - /// `ahash::RandomState`. - pub fn from_maps( - mut unigrams: HashMap, - mut bigrams: HashMap<(String, String), f64>, - ) -> Self { - let uni_total = unigrams.values().sum::(); - let bi_total = bigrams.values().sum::(); - for uni in unigrams.values_mut() { + // Now convert the counts in `scores` to the values we actually want, + // namely logarithms of relative frequencies + for (uni, bi_scores) in scores.values_mut() { *uni = (*uni / uni_total).log10(); + for bi in bi_scores.values_mut() { + *bi = (*bi / bi_total).log10(); + } } - for bi in bigrams.values_mut() { - *bi = (*bi / bi_total).log10(); - } + Self { uni_total_log10: uni_total.log10(), - unigrams, - bigrams, + scores, limit: DEFAULT_LIMIT, } } @@ -91,24 +101,25 @@ impl Segmenter { } fn score(&self, word: &str, previous: Option<&str>) -> f64 { + let (uni, bi_scores) = match self.scores.get(word) { + Some((uni, bi_scores)) => (uni, bi_scores), + // Penalize words not found in the unigrams according + // to their length, a crucial heuristic. + None => return 1.0 - self.uni_total_log10 - word.len() as f64, + }; + if let Some(prev) = previous { - if let Some(bi) = self.bigrams.get(&(prev.into(), word.into())) { - if let Some(uni) = self.unigrams.get(prev) { + if let Some(bi) = bi_scores.get(prev) { + if let Some((uni_prev, _)) = self.scores.get(prev) { // Conditional probability of the word given the previous // word. The technical name is "stupid backoff" and it's // not a probability distribution but it works well in practice. - return bi - uni; + return bi - uni_prev; } } } - match self.unigrams.get(word) { - // Probability of the given word - Some(uni) => *uni, - // Penalize words not found in the unigrams according - // to their length, a crucial heuristic. - None => 1.0 - self.uni_total_log10 - word.len() as f64, - } + *uni } /// Customize the word length `limit` diff --git a/instant-segment/src/test_data.rs b/instant-segment/src/test_data.rs index 36fd5dd..74ff57b 100644 --- a/instant-segment/src/test_data.rs +++ b/instant-segment/src/test_data.rs @@ -56,7 +56,7 @@ pub fn segmenter(dir: PathBuf) -> Segmenter { ln.clear(); } - Segmenter::from_maps(unigrams, bigrams) + Segmenter::new(unigrams, bigrams) } pub fn crate_data_dir() -> PathBuf {