Use nested `HashMap` for storing both unigram and bigram scores

This commit is contained in:
Michael Partheil 2023-10-03 20:12:20 +02:00 committed by Dirkjan Ochtman
parent 5810373471
commit 3b3627422b
6 changed files with 55 additions and 47 deletions

View File

@ -70,7 +70,7 @@ print([word for word in search])
use instant_segment::{Search, Segmenter}; use instant_segment::{Search, Segmenter};
use std::collections::HashMap; use std::collections::HashMap;
let segmenter = Segmenter::from_maps(unigrams, bigrams); let segmenter = Segmenter::new(unigrams, bigrams);
let mut search = Search::default(); let mut search = Search::default();
let words = segmenter let words = segmenter
.segment("instantdomainsearch", &mut search) .segment("instantdomainsearch", &mut search)

View File

@ -16,7 +16,6 @@ name = "instant_segment"
crate-type = ["cdylib"] crate-type = ["cdylib"]
[dependencies] [dependencies]
ahash = "0.8"
bincode = "1.3.2" bincode = "1.3.2"
instant-segment = { version = "0.10", path = "../instant-segment", features = ["with-serde"] } instant-segment = { version = "0.10", path = "../instant-segment", features = ["with-serde"] }
pyo3 = { version = "0.20", features = ["extension-module"] } pyo3 = { version = "0.20", features = ["extension-module"] }

View File

@ -39,7 +39,7 @@ impl Segmenter {
let val = item.get_item(1)?.extract::<f64>()?; let val = item.get_item(1)?.extract::<f64>()?;
Ok((SmartString::from(key), val)) Ok((SmartString::from(key), val))
}) })
.collect::<Result<HashMap<_, _>, PyErr>>()?; .collect::<Result<Vec<_>, PyErr>>()?;
let bigrams = bigrams let bigrams = bigrams
.map(|item| { .map(|item| {
@ -52,10 +52,10 @@ impl Segmenter {
let val = item.get_item(1)?.extract::<f64>()?; let val = item.get_item(1)?.extract::<f64>()?;
Ok(((SmartString::from(first), SmartString::from(second)), val)) Ok(((SmartString::from(first), SmartString::from(second)), val))
}) })
.collect::<Result<HashMap<_, _>, PyErr>>()?; .collect::<Result<Vec<_>, PyErr>>()?;
Ok(Self { Ok(Self {
inner: instant_segment::Segmenter::from_maps(unigrams, bigrams), inner: instant_segment::Segmenter::new(unigrams, bigrams),
}) })
} }
@ -148,5 +148,3 @@ impl Search {
Some(word) Some(word)
} }
} }
type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>;

View File

@ -2,7 +2,7 @@ use instant_segment::{Search, Segmenter};
use std::collections::HashMap; use std::collections::HashMap;
fn main() { fn main() {
let mut unigrams = HashMap::default(); let mut unigrams = HashMap::new();
unigrams.insert("choose".into(), 80_000.0); unigrams.insert("choose".into(), 80_000.0);
unigrams.insert("chooses".into(), 7_000.0); unigrams.insert("chooses".into(), 7_000.0);
@ -10,12 +10,12 @@ fn main() {
unigrams.insert("spain".into(), 20_000.0); unigrams.insert("spain".into(), 20_000.0);
unigrams.insert("pain".into(), 90_000.0); unigrams.insert("pain".into(), 90_000.0);
let mut bigrams = HashMap::default(); let mut bigrams = HashMap::new();
bigrams.insert(("choose".into(), "spain".into()), 7.0); bigrams.insert(("choose".into(), "spain".into()), 7.0);
bigrams.insert(("chooses".into(), "pain".into()), 0.0); bigrams.insert(("chooses".into(), "pain".into()), 0.0);
let segmenter = Segmenter::from_maps(unigrams, bigrams); let segmenter = Segmenter::new(unigrams, bigrams);
let mut search = Search::default(); let mut search = Search::default();
let words = segmenter.segment("choosespain", &mut search).unwrap(); let words = segmenter.segment("choosespain", &mut search).unwrap();

View File

@ -13,47 +13,57 @@ pub mod test_data;
/// Central data structure used to calculate word probabilities /// Central data structure used to calculate word probabilities
#[cfg_attr(feature = "with-serde", derive(Deserialize, Serialize))] #[cfg_attr(feature = "with-serde", derive(Deserialize, Serialize))]
pub struct Segmenter { pub struct Segmenter {
unigrams: HashMap<String, f64>, // Maps a word to both its unigram score, as well has a nested HashMap in
bigrams: HashMap<(String, String), f64>, // which the bigram score can be looked up using the previous word. Scores
// are base-10 logarithms of relative word frequencies
scores: HashMap<String, (f64, HashMap<String, f64>)>,
// Base-10 logarithm of the total count of unigrams
uni_total_log10: f64, uni_total_log10: f64,
limit: usize, limit: usize,
} }
impl Segmenter { impl Segmenter {
/// Create `Segmenter` from the given iterators /// Create `Segmenter` from the given unigram and bigram counts.
/// ///
/// Note: the `String` types used in this API are defined in the `smartstring` crate. Any /// Note: the `String` types used in this API are defined in the `smartstring` crate. Any
/// `&str` or `String` can be converted into the `String` used here by calling `into()` on it. /// `&str` or `String` can be converted into the `String` used here by calling `into()` on it.
pub fn from_iters<U, B>(unigrams: U, bigrams: B) -> Self pub fn new<U, B>(unigrams: U, bigrams: B) -> Self
where where
U: Iterator<Item = (String, f64)>, U: IntoIterator<Item = (String, f64)>,
B: Iterator<Item = ((String, String), f64)>, B: IntoIterator<Item = ((String, String), f64)>,
{ {
Self::from_maps(unigrams.collect(), bigrams.collect()) // Initially, `scores` contains the original unigram and bigram counts
} let mut scores = HashMap::default();
let mut uni_total = 0.0;
for (word, uni) in unigrams {
scores.insert(word, (uni, HashMap::default()));
uni_total += uni;
}
let mut bi_total = 0.0;
for ((word1, word2), bi) in bigrams {
let Some((_, bi_scores)) = scores.get_mut(&word2) else {
// We throw away bigrams for which we do not have a unigram for
// the second word. This case shouldn't ever happen on
// real-world data, and in fact, it never happens on the word
// count lists shipped with this crate.
continue;
};
bi_scores.insert(word1, bi);
bi_total += bi;
}
/// Create `Segmenter` from the given hashmaps (using ahash) // Now convert the counts in `scores` to the values we actually want,
/// // namely logarithms of relative frequencies
/// Note: the `String` types used in this API are defined in the `smartstring` crate. Any for (uni, bi_scores) in scores.values_mut() {
/// `&str` or `String` can be converted into the `String` used here by calling `into()` on it.
/// The `HashMap` type here refers to `std::collections::HashMap` parametrized with the
/// `ahash::RandomState`.
pub fn from_maps(
mut unigrams: HashMap<String, f64>,
mut bigrams: HashMap<(String, String), f64>,
) -> Self {
let uni_total = unigrams.values().sum::<f64>();
let bi_total = bigrams.values().sum::<f64>();
for uni in unigrams.values_mut() {
*uni = (*uni / uni_total).log10(); *uni = (*uni / uni_total).log10();
for bi in bi_scores.values_mut() {
*bi = (*bi / bi_total).log10();
}
} }
for bi in bigrams.values_mut() {
*bi = (*bi / bi_total).log10();
}
Self { Self {
uni_total_log10: uni_total.log10(), uni_total_log10: uni_total.log10(),
unigrams, scores,
bigrams,
limit: DEFAULT_LIMIT, limit: DEFAULT_LIMIT,
} }
} }
@ -91,24 +101,25 @@ impl Segmenter {
} }
fn score(&self, word: &str, previous: Option<&str>) -> f64 { fn score(&self, word: &str, previous: Option<&str>) -> f64 {
let (uni, bi_scores) = match self.scores.get(word) {
Some((uni, bi_scores)) => (uni, bi_scores),
// Penalize words not found in the unigrams according
// to their length, a crucial heuristic.
None => return 1.0 - self.uni_total_log10 - word.len() as f64,
};
if let Some(prev) = previous { if let Some(prev) = previous {
if let Some(bi) = self.bigrams.get(&(prev.into(), word.into())) { if let Some(bi) = bi_scores.get(prev) {
if let Some(uni) = self.unigrams.get(prev) { if let Some((uni_prev, _)) = self.scores.get(prev) {
// Conditional probability of the word given the previous // Conditional probability of the word given the previous
// word. The technical name is "stupid backoff" and it's // word. The technical name is "stupid backoff" and it's
// not a probability distribution but it works well in practice. // not a probability distribution but it works well in practice.
return bi - uni; return bi - uni_prev;
} }
} }
} }
match self.unigrams.get(word) { *uni
// Probability of the given word
Some(uni) => *uni,
// Penalize words not found in the unigrams according
// to their length, a crucial heuristic.
None => 1.0 - self.uni_total_log10 - word.len() as f64,
}
} }
/// Customize the word length `limit` /// Customize the word length `limit`

View File

@ -56,7 +56,7 @@ pub fn segmenter(dir: PathBuf) -> Segmenter {
ln.clear(); ln.clear();
} }
Segmenter::from_maps(unigrams, bigrams) Segmenter::new(unigrams, bigrams)
} }
pub fn crate_data_dir() -> PathBuf { pub fn crate_data_dir() -> PathBuf {