Use nested `HashMap` for storing both unigram and bigram scores
This commit is contained in:
parent
5810373471
commit
3b3627422b
|
@ -70,7 +70,7 @@ print([word for word in search])
|
|||
use instant_segment::{Search, Segmenter};
|
||||
use std::collections::HashMap;
|
||||
|
||||
let segmenter = Segmenter::from_maps(unigrams, bigrams);
|
||||
let segmenter = Segmenter::new(unigrams, bigrams);
|
||||
let mut search = Search::default();
|
||||
let words = segmenter
|
||||
.segment("instantdomainsearch", &mut search)
|
||||
|
|
|
@ -16,7 +16,6 @@ name = "instant_segment"
|
|||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
ahash = "0.8"
|
||||
bincode = "1.3.2"
|
||||
instant-segment = { version = "0.10", path = "../instant-segment", features = ["with-serde"] }
|
||||
pyo3 = { version = "0.20", features = ["extension-module"] }
|
||||
|
|
|
@ -39,7 +39,7 @@ impl Segmenter {
|
|||
let val = item.get_item(1)?.extract::<f64>()?;
|
||||
Ok((SmartString::from(key), val))
|
||||
})
|
||||
.collect::<Result<HashMap<_, _>, PyErr>>()?;
|
||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||
|
||||
let bigrams = bigrams
|
||||
.map(|item| {
|
||||
|
@ -52,10 +52,10 @@ impl Segmenter {
|
|||
let val = item.get_item(1)?.extract::<f64>()?;
|
||||
Ok(((SmartString::from(first), SmartString::from(second)), val))
|
||||
})
|
||||
.collect::<Result<HashMap<_, _>, PyErr>>()?;
|
||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||
|
||||
Ok(Self {
|
||||
inner: instant_segment::Segmenter::from_maps(unigrams, bigrams),
|
||||
inner: instant_segment::Segmenter::new(unigrams, bigrams),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -148,5 +148,3 @@ impl Search {
|
|||
Some(word)
|
||||
}
|
||||
}
|
||||
|
||||
type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>;
|
||||
|
|
|
@ -2,7 +2,7 @@ use instant_segment::{Search, Segmenter};
|
|||
use std::collections::HashMap;
|
||||
|
||||
fn main() {
|
||||
let mut unigrams = HashMap::default();
|
||||
let mut unigrams = HashMap::new();
|
||||
|
||||
unigrams.insert("choose".into(), 80_000.0);
|
||||
unigrams.insert("chooses".into(), 7_000.0);
|
||||
|
@ -10,12 +10,12 @@ fn main() {
|
|||
unigrams.insert("spain".into(), 20_000.0);
|
||||
unigrams.insert("pain".into(), 90_000.0);
|
||||
|
||||
let mut bigrams = HashMap::default();
|
||||
let mut bigrams = HashMap::new();
|
||||
|
||||
bigrams.insert(("choose".into(), "spain".into()), 7.0);
|
||||
bigrams.insert(("chooses".into(), "pain".into()), 0.0);
|
||||
|
||||
let segmenter = Segmenter::from_maps(unigrams, bigrams);
|
||||
let segmenter = Segmenter::new(unigrams, bigrams);
|
||||
let mut search = Search::default();
|
||||
|
||||
let words = segmenter.segment("choosespain", &mut search).unwrap();
|
||||
|
|
|
@ -13,47 +13,57 @@ pub mod test_data;
|
|||
/// Central data structure used to calculate word probabilities
|
||||
#[cfg_attr(feature = "with-serde", derive(Deserialize, Serialize))]
|
||||
pub struct Segmenter {
|
||||
unigrams: HashMap<String, f64>,
|
||||
bigrams: HashMap<(String, String), f64>,
|
||||
// Maps a word to both its unigram score, as well has a nested HashMap in
|
||||
// which the bigram score can be looked up using the previous word. Scores
|
||||
// are base-10 logarithms of relative word frequencies
|
||||
scores: HashMap<String, (f64, HashMap<String, f64>)>,
|
||||
// Base-10 logarithm of the total count of unigrams
|
||||
uni_total_log10: f64,
|
||||
limit: usize,
|
||||
}
|
||||
|
||||
impl Segmenter {
|
||||
/// Create `Segmenter` from the given iterators
|
||||
/// Create `Segmenter` from the given unigram and bigram counts.
|
||||
///
|
||||
/// Note: the `String` types used in this API are defined in the `smartstring` crate. Any
|
||||
/// `&str` or `String` can be converted into the `String` used here by calling `into()` on it.
|
||||
pub fn from_iters<U, B>(unigrams: U, bigrams: B) -> Self
|
||||
pub fn new<U, B>(unigrams: U, bigrams: B) -> Self
|
||||
where
|
||||
U: Iterator<Item = (String, f64)>,
|
||||
B: Iterator<Item = ((String, String), f64)>,
|
||||
U: IntoIterator<Item = (String, f64)>,
|
||||
B: IntoIterator<Item = ((String, String), f64)>,
|
||||
{
|
||||
Self::from_maps(unigrams.collect(), bigrams.collect())
|
||||
// Initially, `scores` contains the original unigram and bigram counts
|
||||
let mut scores = HashMap::default();
|
||||
let mut uni_total = 0.0;
|
||||
for (word, uni) in unigrams {
|
||||
scores.insert(word, (uni, HashMap::default()));
|
||||
uni_total += uni;
|
||||
}
|
||||
let mut bi_total = 0.0;
|
||||
for ((word1, word2), bi) in bigrams {
|
||||
let Some((_, bi_scores)) = scores.get_mut(&word2) else {
|
||||
// We throw away bigrams for which we do not have a unigram for
|
||||
// the second word. This case shouldn't ever happen on
|
||||
// real-world data, and in fact, it never happens on the word
|
||||
// count lists shipped with this crate.
|
||||
continue;
|
||||
};
|
||||
bi_scores.insert(word1, bi);
|
||||
bi_total += bi;
|
||||
}
|
||||
|
||||
/// Create `Segmenter` from the given hashmaps (using ahash)
|
||||
///
|
||||
/// Note: the `String` types used in this API are defined in the `smartstring` crate. Any
|
||||
/// `&str` or `String` can be converted into the `String` used here by calling `into()` on it.
|
||||
/// The `HashMap` type here refers to `std::collections::HashMap` parametrized with the
|
||||
/// `ahash::RandomState`.
|
||||
pub fn from_maps(
|
||||
mut unigrams: HashMap<String, f64>,
|
||||
mut bigrams: HashMap<(String, String), f64>,
|
||||
) -> Self {
|
||||
let uni_total = unigrams.values().sum::<f64>();
|
||||
let bi_total = bigrams.values().sum::<f64>();
|
||||
for uni in unigrams.values_mut() {
|
||||
// Now convert the counts in `scores` to the values we actually want,
|
||||
// namely logarithms of relative frequencies
|
||||
for (uni, bi_scores) in scores.values_mut() {
|
||||
*uni = (*uni / uni_total).log10();
|
||||
}
|
||||
for bi in bigrams.values_mut() {
|
||||
for bi in bi_scores.values_mut() {
|
||||
*bi = (*bi / bi_total).log10();
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
uni_total_log10: uni_total.log10(),
|
||||
unigrams,
|
||||
bigrams,
|
||||
scores,
|
||||
limit: DEFAULT_LIMIT,
|
||||
}
|
||||
}
|
||||
|
@ -91,24 +101,25 @@ impl Segmenter {
|
|||
}
|
||||
|
||||
fn score(&self, word: &str, previous: Option<&str>) -> f64 {
|
||||
let (uni, bi_scores) = match self.scores.get(word) {
|
||||
Some((uni, bi_scores)) => (uni, bi_scores),
|
||||
// Penalize words not found in the unigrams according
|
||||
// to their length, a crucial heuristic.
|
||||
None => return 1.0 - self.uni_total_log10 - word.len() as f64,
|
||||
};
|
||||
|
||||
if let Some(prev) = previous {
|
||||
if let Some(bi) = self.bigrams.get(&(prev.into(), word.into())) {
|
||||
if let Some(uni) = self.unigrams.get(prev) {
|
||||
if let Some(bi) = bi_scores.get(prev) {
|
||||
if let Some((uni_prev, _)) = self.scores.get(prev) {
|
||||
// Conditional probability of the word given the previous
|
||||
// word. The technical name is "stupid backoff" and it's
|
||||
// not a probability distribution but it works well in practice.
|
||||
return bi - uni;
|
||||
return bi - uni_prev;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match self.unigrams.get(word) {
|
||||
// Probability of the given word
|
||||
Some(uni) => *uni,
|
||||
// Penalize words not found in the unigrams according
|
||||
// to their length, a crucial heuristic.
|
||||
None => 1.0 - self.uni_total_log10 - word.len() as f64,
|
||||
}
|
||||
*uni
|
||||
}
|
||||
|
||||
/// Customize the word length `limit`
|
||||
|
|
|
@ -56,7 +56,7 @@ pub fn segmenter(dir: PathBuf) -> Segmenter {
|
|||
ln.clear();
|
||||
}
|
||||
|
||||
Segmenter::from_maps(unigrams, bigrams)
|
||||
Segmenter::new(unigrams, bigrams)
|
||||
}
|
||||
|
||||
pub fn crate_data_dir() -> PathBuf {
|
||||
|
|
Loading…
Reference in New Issue