Add support for selection heuristic
This commit is contained in:
parent
09dd8d886b
commit
7bfca28d85
88
src/lib.rs
88
src/lib.rs
|
@ -1,5 +1,5 @@
|
|||
use std::cmp::{max, min, Ordering, Reverse};
|
||||
use std::collections::BinaryHeap;
|
||||
use std::collections::{BinaryHeap, VecDeque};
|
||||
use std::hash::Hash;
|
||||
use std::ops::Index;
|
||||
|
||||
|
@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize};
|
|||
pub struct Builder {
|
||||
ef_search: Option<usize>,
|
||||
ef_construction: Option<usize>,
|
||||
heuristic: Option<Heuristic>,
|
||||
ml: Option<f32>,
|
||||
seed: Option<u64>,
|
||||
#[cfg(feature = "indicatif")]
|
||||
|
@ -43,6 +44,11 @@ impl Builder {
|
|||
self
|
||||
}
|
||||
|
||||
pub fn select_heuristic(mut self, params: Heuristic) -> Self {
|
||||
self.heuristic = Some(params);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the `mL` parameter from the paper
|
||||
///
|
||||
/// If the `mL` parameter is not already set, it defaults to `ln(M)`.
|
||||
|
@ -72,6 +78,12 @@ impl Builder {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Heuristic {
|
||||
pub extend_candidates: bool,
|
||||
pub keep_pruned: bool,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||
pub struct Hnsw<P> {
|
||||
ef_search: usize,
|
||||
|
@ -234,7 +246,14 @@ where
|
|||
search.push(added, &points[pid], &points);
|
||||
}
|
||||
|
||||
insert(&mut zero, pid, search.select_simple(num), &points);
|
||||
let candidates = match builder.heuristic {
|
||||
None => search.select_simple(num),
|
||||
Some(heuristic) => {
|
||||
search.select_heuristic(&zero, num, &points[pid], &points, heuristic)
|
||||
}
|
||||
};
|
||||
|
||||
insert(&mut zero, pid, candidates, &points);
|
||||
done.push(pid);
|
||||
pool.push(search);
|
||||
}
|
||||
|
@ -482,6 +501,71 @@ impl Search {
|
|||
&self.nearest[..min(self.nearest.len(), num)]
|
||||
}
|
||||
|
||||
fn select_heuristic<P: Point>(
|
||||
&mut self,
|
||||
layer: &Vec<ZeroNode>,
|
||||
num: usize,
|
||||
point: &P,
|
||||
points: &[P],
|
||||
params: Heuristic,
|
||||
) -> &[Candidate] {
|
||||
// Get input candidates from `self.nearest` and store them in `self.candidates`.
|
||||
// `self.candidates` will represent `W` from the paper's algorithm 4 for now.
|
||||
self.candidates.clear();
|
||||
self.nearest.sort_unstable();
|
||||
for &candidate in self.nearest.iter() {
|
||||
self.candidates.push(Reverse(candidate));
|
||||
}
|
||||
// Clear `self.nearest`. This now represents the result set (`R`).
|
||||
self.nearest.clear();
|
||||
|
||||
let mut working = VecDeque::new();
|
||||
if params.extend_candidates {
|
||||
// Because we can't extend `self.candidates` while we iterate over it, we use
|
||||
// `working` to accumulate candidates' neighbors in.
|
||||
for Reverse(candidate) in &self.candidates {
|
||||
for pid in layer.nearest_iter(candidate.pid).take(num) {
|
||||
let other = &points[pid];
|
||||
let distance = OrderedFloat::from(point.distance(other));
|
||||
working.push_back(Candidate { distance, pid });
|
||||
}
|
||||
}
|
||||
|
||||
// Once we have all the extended candidates, push them onto `self.candidates`.
|
||||
// Because `self.candidates` is a `BinaryHeap`, it remains in sorted order.
|
||||
// After this loop, `working` is empty again, so we can reuse it.
|
||||
for candidate in working.drain(..) {
|
||||
self.candidates.push(Reverse(candidate));
|
||||
}
|
||||
}
|
||||
|
||||
// Take candidates from `self.candidates` (`W`) and compare them to the re
|
||||
while let Some(Reverse(candidate)) = self.candidates.pop() {
|
||||
if self.nearest.len() >= num {
|
||||
break;
|
||||
}
|
||||
|
||||
match self.nearest.binary_search(&candidate) {
|
||||
Err(0) => self.nearest.insert(0, candidate),
|
||||
Err(_) => working.push_back(candidate),
|
||||
Ok(_) => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
// `working` was filled by pushing back from `candidates`, so it must be in sorted order.
|
||||
if params.keep_pruned {
|
||||
while let Some(candidate) = working.pop_front() {
|
||||
if self.nearest.len() >= num {
|
||||
break;
|
||||
}
|
||||
|
||||
self.nearest.push(candidate);
|
||||
}
|
||||
}
|
||||
|
||||
self.select_simple(num)
|
||||
}
|
||||
|
||||
/// Track node `pid` as a potential new neighbor for the given `point`
|
||||
///
|
||||
/// Will immediately return if the node has been considered before. This implements
|
||||
|
|
11
tests/all.rs
11
tests/all.rs
|
@ -4,7 +4,7 @@ use ordered_float::OrderedFloat;
|
|||
use rand::rngs::{StdRng, ThreadRng};
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use instant_distance::{Hnsw, Point as _, PointId, Search};
|
||||
use instant_distance::{Heuristic, Hnsw, Point as _, PointId, Search};
|
||||
|
||||
#[test]
|
||||
fn basic() {
|
||||
|
@ -46,7 +46,14 @@ fn randomized() {
|
|||
}
|
||||
}
|
||||
|
||||
let (hnsw, pids) = Hnsw::<Point>::builder().seed(seed).build(&points);
|
||||
let (hnsw, pids) = Hnsw::<Point>::builder()
|
||||
.seed(seed)
|
||||
.select_heuristic(Heuristic {
|
||||
extend_candidates: false,
|
||||
keep_pruned: true,
|
||||
})
|
||||
.build(&points);
|
||||
|
||||
let mut search = Search::default();
|
||||
let mut results = vec![PointId::default(); 100];
|
||||
let found = hnsw.search(&query, &mut results, &mut search);
|
||||
|
|
Loading…
Reference in New Issue