Add support for selection heuristic

This commit is contained in:
Dirkjan Ochtman 2021-01-11 21:05:17 +01:00
parent 09dd8d886b
commit 7bfca28d85
2 changed files with 95 additions and 4 deletions

View File

@ -1,5 +1,5 @@
use std::cmp::{max, min, Ordering, Reverse};
use std::collections::BinaryHeap;
use std::collections::{BinaryHeap, VecDeque};
use std::hash::Hash;
use std::ops::Index;
@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize};
pub struct Builder {
ef_search: Option<usize>,
ef_construction: Option<usize>,
heuristic: Option<Heuristic>,
ml: Option<f32>,
seed: Option<u64>,
#[cfg(feature = "indicatif")]
@ -43,6 +44,11 @@ impl Builder {
self
}
pub fn select_heuristic(mut self, params: Heuristic) -> Self {
self.heuristic = Some(params);
self
}
/// Set the `mL` parameter from the paper
///
/// If the `mL` parameter is not already set, it defaults to `ln(M)`.
@ -72,6 +78,12 @@ impl Builder {
}
}
#[derive(Copy, Clone, Debug)]
pub struct Heuristic {
pub extend_candidates: bool,
pub keep_pruned: bool,
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct Hnsw<P> {
ef_search: usize,
@ -234,7 +246,14 @@ where
search.push(added, &points[pid], &points);
}
insert(&mut zero, pid, search.select_simple(num), &points);
let candidates = match builder.heuristic {
None => search.select_simple(num),
Some(heuristic) => {
search.select_heuristic(&zero, num, &points[pid], &points, heuristic)
}
};
insert(&mut zero, pid, candidates, &points);
done.push(pid);
pool.push(search);
}
@ -482,6 +501,71 @@ impl Search {
&self.nearest[..min(self.nearest.len(), num)]
}
fn select_heuristic<P: Point>(
&mut self,
layer: &Vec<ZeroNode>,
num: usize,
point: &P,
points: &[P],
params: Heuristic,
) -> &[Candidate] {
// Get input candidates from `self.nearest` and store them in `self.candidates`.
// `self.candidates` will represent `W` from the paper's algorithm 4 for now.
self.candidates.clear();
self.nearest.sort_unstable();
for &candidate in self.nearest.iter() {
self.candidates.push(Reverse(candidate));
}
// Clear `self.nearest`. This now represents the result set (`R`).
self.nearest.clear();
let mut working = VecDeque::new();
if params.extend_candidates {
// Because we can't extend `self.candidates` while we iterate over it, we use
// `working` to accumulate candidates' neighbors in.
for Reverse(candidate) in &self.candidates {
for pid in layer.nearest_iter(candidate.pid).take(num) {
let other = &points[pid];
let distance = OrderedFloat::from(point.distance(other));
working.push_back(Candidate { distance, pid });
}
}
// Once we have all the extended candidates, push them onto `self.candidates`.
// Because `self.candidates` is a `BinaryHeap`, it remains in sorted order.
// After this loop, `working` is empty again, so we can reuse it.
for candidate in working.drain(..) {
self.candidates.push(Reverse(candidate));
}
}
// Take candidates from `self.candidates` (`W`) and compare them to the re
while let Some(Reverse(candidate)) = self.candidates.pop() {
if self.nearest.len() >= num {
break;
}
match self.nearest.binary_search(&candidate) {
Err(0) => self.nearest.insert(0, candidate),
Err(_) => working.push_back(candidate),
Ok(_) => unreachable!(),
}
}
// `working` was filled by pushing back from `candidates`, so it must be in sorted order.
if params.keep_pruned {
while let Some(candidate) = working.pop_front() {
if self.nearest.len() >= num {
break;
}
self.nearest.push(candidate);
}
}
self.select_simple(num)
}
/// Track node `pid` as a potential new neighbor for the given `point`
///
/// Will immediately return if the node has been considered before. This implements

View File

@ -4,7 +4,7 @@ use ordered_float::OrderedFloat;
use rand::rngs::{StdRng, ThreadRng};
use rand::{Rng, SeedableRng};
use instant_distance::{Hnsw, Point as _, PointId, Search};
use instant_distance::{Heuristic, Hnsw, Point as _, PointId, Search};
#[test]
fn basic() {
@ -46,7 +46,14 @@ fn randomized() {
}
}
let (hnsw, pids) = Hnsw::<Point>::builder().seed(seed).build(&points);
let (hnsw, pids) = Hnsw::<Point>::builder()
.seed(seed)
.select_heuristic(Heuristic {
extend_candidates: false,
keep_pruned: true,
})
.build(&points);
let mut search = Search::default();
let mut results = vec![PointId::default(); 100];
let found = hnsw.search(&query, &mut results, &mut search);