Keep track of candidate order using BinaryHeap

This commit is contained in:
Dirkjan Ochtman 2021-01-07 15:59:41 +01:00
parent 459f5e4d65
commit d7c33e9e8f
1 changed files with 11 additions and 6 deletions

View File

@ -1,4 +1,5 @@
use std::cmp::{max, min, Ordering, Reverse};
use std::collections::BinaryHeap;
use std::hash::Hash;
use std::ops::Index;
@ -316,7 +317,7 @@ pub struct Search {
/// Nodes visited so far (`v` in the paper)
visited: HashSet<PointId>,
/// Candidates for further inspection (`C` in the paper)
candidates: Vec<Candidate>,
candidates: BinaryHeap<Reverse<Candidate>>,
/// Nearest neighbors found so far (`W` in the paper)
nearest: Vec<Candidate>,
/// Maximum number of nearest neighbors to retain (`ef` in the paper)
@ -364,7 +365,7 @@ impl Search {
}
let new = Candidate { distance, pid };
self.candidates.push(new);
self.candidates.push(Reverse(new));
self.nearest.push(new);
self.furthest = max(self.furthest, distance);
}
@ -379,8 +380,12 @@ impl Search {
fn cull(&mut self) {
self.nearest.truncate(self.ef); // Limit size of the set of nearest neighbors
self.furthest = self.nearest.last().unwrap().distance;
self.candidates.clear();
self.candidates.extend(&self.nearest);
for &candidate in self.nearest.iter() {
self.candidates.push(Reverse(candidate));
}
self.visited.clear();
self.visited.extend(self.nearest.iter().map(|c| c.pid));
}
@ -390,7 +395,7 @@ impl Default for Search {
fn default() -> Self {
Self {
visited: HashSet::new(),
candidates: Vec::new(),
candidates: BinaryHeap::new(),
nearest: Vec::new(),
ef: 1,
furthest: OrderedFloat::from(f32::INFINITY),
@ -486,9 +491,9 @@ trait Layer {
/// representation matching the zero layer even when we're referring to a higher layer. In that
/// case, we use `links` to constrain the number of per-candidate links we consider for search.
fn search<P: Point>(&self, point: &P, search: &mut Search, points: &[P], links: usize) {
while let Some(candidate) = search.candidates.pop() {
while let Some(Reverse(candidate)) = search.candidates.pop() {
if candidate.distance > search.furthest {
continue;
break;
}
for pid in self.nearest_iter(candidate.pid).take(links) {