Use nested `HashMap` for storing both unigram and bigram scores
This commit is contained in:
parent
5810373471
commit
3b3627422b
|
@ -70,7 +70,7 @@ print([word for word in search])
|
||||||
use instant_segment::{Search, Segmenter};
|
use instant_segment::{Search, Segmenter};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
let segmenter = Segmenter::from_maps(unigrams, bigrams);
|
let segmenter = Segmenter::new(unigrams, bigrams);
|
||||||
let mut search = Search::default();
|
let mut search = Search::default();
|
||||||
let words = segmenter
|
let words = segmenter
|
||||||
.segment("instantdomainsearch", &mut search)
|
.segment("instantdomainsearch", &mut search)
|
||||||
|
|
|
@ -16,7 +16,6 @@ name = "instant_segment"
|
||||||
crate-type = ["cdylib"]
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
ahash = "0.8"
|
|
||||||
bincode = "1.3.2"
|
bincode = "1.3.2"
|
||||||
instant-segment = { version = "0.10", path = "../instant-segment", features = ["with-serde"] }
|
instant-segment = { version = "0.10", path = "../instant-segment", features = ["with-serde"] }
|
||||||
pyo3 = { version = "0.20", features = ["extension-module"] }
|
pyo3 = { version = "0.20", features = ["extension-module"] }
|
||||||
|
|
|
@ -39,7 +39,7 @@ impl Segmenter {
|
||||||
let val = item.get_item(1)?.extract::<f64>()?;
|
let val = item.get_item(1)?.extract::<f64>()?;
|
||||||
Ok((SmartString::from(key), val))
|
Ok((SmartString::from(key), val))
|
||||||
})
|
})
|
||||||
.collect::<Result<HashMap<_, _>, PyErr>>()?;
|
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||||
|
|
||||||
let bigrams = bigrams
|
let bigrams = bigrams
|
||||||
.map(|item| {
|
.map(|item| {
|
||||||
|
@ -52,10 +52,10 @@ impl Segmenter {
|
||||||
let val = item.get_item(1)?.extract::<f64>()?;
|
let val = item.get_item(1)?.extract::<f64>()?;
|
||||||
Ok(((SmartString::from(first), SmartString::from(second)), val))
|
Ok(((SmartString::from(first), SmartString::from(second)), val))
|
||||||
})
|
})
|
||||||
.collect::<Result<HashMap<_, _>, PyErr>>()?;
|
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
inner: instant_segment::Segmenter::from_maps(unigrams, bigrams),
|
inner: instant_segment::Segmenter::new(unigrams, bigrams),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -148,5 +148,3 @@ impl Search {
|
||||||
Some(word)
|
Some(word)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>;
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ use instant_segment::{Search, Segmenter};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let mut unigrams = HashMap::default();
|
let mut unigrams = HashMap::new();
|
||||||
|
|
||||||
unigrams.insert("choose".into(), 80_000.0);
|
unigrams.insert("choose".into(), 80_000.0);
|
||||||
unigrams.insert("chooses".into(), 7_000.0);
|
unigrams.insert("chooses".into(), 7_000.0);
|
||||||
|
@ -10,12 +10,12 @@ fn main() {
|
||||||
unigrams.insert("spain".into(), 20_000.0);
|
unigrams.insert("spain".into(), 20_000.0);
|
||||||
unigrams.insert("pain".into(), 90_000.0);
|
unigrams.insert("pain".into(), 90_000.0);
|
||||||
|
|
||||||
let mut bigrams = HashMap::default();
|
let mut bigrams = HashMap::new();
|
||||||
|
|
||||||
bigrams.insert(("choose".into(), "spain".into()), 7.0);
|
bigrams.insert(("choose".into(), "spain".into()), 7.0);
|
||||||
bigrams.insert(("chooses".into(), "pain".into()), 0.0);
|
bigrams.insert(("chooses".into(), "pain".into()), 0.0);
|
||||||
|
|
||||||
let segmenter = Segmenter::from_maps(unigrams, bigrams);
|
let segmenter = Segmenter::new(unigrams, bigrams);
|
||||||
let mut search = Search::default();
|
let mut search = Search::default();
|
||||||
|
|
||||||
let words = segmenter.segment("choosespain", &mut search).unwrap();
|
let words = segmenter.segment("choosespain", &mut search).unwrap();
|
||||||
|
|
|
@ -13,47 +13,57 @@ pub mod test_data;
|
||||||
/// Central data structure used to calculate word probabilities
|
/// Central data structure used to calculate word probabilities
|
||||||
#[cfg_attr(feature = "with-serde", derive(Deserialize, Serialize))]
|
#[cfg_attr(feature = "with-serde", derive(Deserialize, Serialize))]
|
||||||
pub struct Segmenter {
|
pub struct Segmenter {
|
||||||
unigrams: HashMap<String, f64>,
|
// Maps a word to both its unigram score, as well has a nested HashMap in
|
||||||
bigrams: HashMap<(String, String), f64>,
|
// which the bigram score can be looked up using the previous word. Scores
|
||||||
|
// are base-10 logarithms of relative word frequencies
|
||||||
|
scores: HashMap<String, (f64, HashMap<String, f64>)>,
|
||||||
|
// Base-10 logarithm of the total count of unigrams
|
||||||
uni_total_log10: f64,
|
uni_total_log10: f64,
|
||||||
limit: usize,
|
limit: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Segmenter {
|
impl Segmenter {
|
||||||
/// Create `Segmenter` from the given iterators
|
/// Create `Segmenter` from the given unigram and bigram counts.
|
||||||
///
|
///
|
||||||
/// Note: the `String` types used in this API are defined in the `smartstring` crate. Any
|
/// Note: the `String` types used in this API are defined in the `smartstring` crate. Any
|
||||||
/// `&str` or `String` can be converted into the `String` used here by calling `into()` on it.
|
/// `&str` or `String` can be converted into the `String` used here by calling `into()` on it.
|
||||||
pub fn from_iters<U, B>(unigrams: U, bigrams: B) -> Self
|
pub fn new<U, B>(unigrams: U, bigrams: B) -> Self
|
||||||
where
|
where
|
||||||
U: Iterator<Item = (String, f64)>,
|
U: IntoIterator<Item = (String, f64)>,
|
||||||
B: Iterator<Item = ((String, String), f64)>,
|
B: IntoIterator<Item = ((String, String), f64)>,
|
||||||
{
|
{
|
||||||
Self::from_maps(unigrams.collect(), bigrams.collect())
|
// Initially, `scores` contains the original unigram and bigram counts
|
||||||
|
let mut scores = HashMap::default();
|
||||||
|
let mut uni_total = 0.0;
|
||||||
|
for (word, uni) in unigrams {
|
||||||
|
scores.insert(word, (uni, HashMap::default()));
|
||||||
|
uni_total += uni;
|
||||||
|
}
|
||||||
|
let mut bi_total = 0.0;
|
||||||
|
for ((word1, word2), bi) in bigrams {
|
||||||
|
let Some((_, bi_scores)) = scores.get_mut(&word2) else {
|
||||||
|
// We throw away bigrams for which we do not have a unigram for
|
||||||
|
// the second word. This case shouldn't ever happen on
|
||||||
|
// real-world data, and in fact, it never happens on the word
|
||||||
|
// count lists shipped with this crate.
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
bi_scores.insert(word1, bi);
|
||||||
|
bi_total += bi;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create `Segmenter` from the given hashmaps (using ahash)
|
// Now convert the counts in `scores` to the values we actually want,
|
||||||
///
|
// namely logarithms of relative frequencies
|
||||||
/// Note: the `String` types used in this API are defined in the `smartstring` crate. Any
|
for (uni, bi_scores) in scores.values_mut() {
|
||||||
/// `&str` or `String` can be converted into the `String` used here by calling `into()` on it.
|
|
||||||
/// The `HashMap` type here refers to `std::collections::HashMap` parametrized with the
|
|
||||||
/// `ahash::RandomState`.
|
|
||||||
pub fn from_maps(
|
|
||||||
mut unigrams: HashMap<String, f64>,
|
|
||||||
mut bigrams: HashMap<(String, String), f64>,
|
|
||||||
) -> Self {
|
|
||||||
let uni_total = unigrams.values().sum::<f64>();
|
|
||||||
let bi_total = bigrams.values().sum::<f64>();
|
|
||||||
for uni in unigrams.values_mut() {
|
|
||||||
*uni = (*uni / uni_total).log10();
|
*uni = (*uni / uni_total).log10();
|
||||||
}
|
for bi in bi_scores.values_mut() {
|
||||||
for bi in bigrams.values_mut() {
|
|
||||||
*bi = (*bi / bi_total).log10();
|
*bi = (*bi / bi_total).log10();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
uni_total_log10: uni_total.log10(),
|
uni_total_log10: uni_total.log10(),
|
||||||
unigrams,
|
scores,
|
||||||
bigrams,
|
|
||||||
limit: DEFAULT_LIMIT,
|
limit: DEFAULT_LIMIT,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -91,24 +101,25 @@ impl Segmenter {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn score(&self, word: &str, previous: Option<&str>) -> f64 {
|
fn score(&self, word: &str, previous: Option<&str>) -> f64 {
|
||||||
|
let (uni, bi_scores) = match self.scores.get(word) {
|
||||||
|
Some((uni, bi_scores)) => (uni, bi_scores),
|
||||||
|
// Penalize words not found in the unigrams according
|
||||||
|
// to their length, a crucial heuristic.
|
||||||
|
None => return 1.0 - self.uni_total_log10 - word.len() as f64,
|
||||||
|
};
|
||||||
|
|
||||||
if let Some(prev) = previous {
|
if let Some(prev) = previous {
|
||||||
if let Some(bi) = self.bigrams.get(&(prev.into(), word.into())) {
|
if let Some(bi) = bi_scores.get(prev) {
|
||||||
if let Some(uni) = self.unigrams.get(prev) {
|
if let Some((uni_prev, _)) = self.scores.get(prev) {
|
||||||
// Conditional probability of the word given the previous
|
// Conditional probability of the word given the previous
|
||||||
// word. The technical name is "stupid backoff" and it's
|
// word. The technical name is "stupid backoff" and it's
|
||||||
// not a probability distribution but it works well in practice.
|
// not a probability distribution but it works well in practice.
|
||||||
return bi - uni;
|
return bi - uni_prev;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
match self.unigrams.get(word) {
|
*uni
|
||||||
// Probability of the given word
|
|
||||||
Some(uni) => *uni,
|
|
||||||
// Penalize words not found in the unigrams according
|
|
||||||
// to their length, a crucial heuristic.
|
|
||||||
None => 1.0 - self.uni_total_log10 - word.len() as f64,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Customize the word length `limit`
|
/// Customize the word length `limit`
|
||||||
|
|
|
@ -56,7 +56,7 @@ pub fn segmenter(dir: PathBuf) -> Segmenter {
|
||||||
ln.clear();
|
ln.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
Segmenter::from_maps(unigrams, bigrams)
|
Segmenter::new(unigrams, bigrams)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn crate_data_dir() -> PathBuf {
|
pub fn crate_data_dir() -> PathBuf {
|
||||||
|
|
Loading…
Reference in New Issue