From d7c33e9e8f894e0ed93a556029620c2719ed1b76 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Thu, 7 Jan 2021 15:59:41 +0100 Subject: [PATCH] Keep track of candidate order using BinaryHeap --- src/lib.rs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a39b7d8..5e26c62 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ use std::cmp::{max, min, Ordering, Reverse}; +use std::collections::BinaryHeap; use std::hash::Hash; use std::ops::Index; @@ -316,7 +317,7 @@ pub struct Search { /// Nodes visited so far (`v` in the paper) visited: HashSet, /// Candidates for further inspection (`C` in the paper) - candidates: Vec, + candidates: BinaryHeap>, /// Nearest neighbors found so far (`W` in the paper) nearest: Vec, /// Maximum number of nearest neighbors to retain (`ef` in the paper) @@ -364,7 +365,7 @@ impl Search { } let new = Candidate { distance, pid }; - self.candidates.push(new); + self.candidates.push(Reverse(new)); self.nearest.push(new); self.furthest = max(self.furthest, distance); } @@ -379,8 +380,12 @@ impl Search { fn cull(&mut self) { self.nearest.truncate(self.ef); // Limit size of the set of nearest neighbors self.furthest = self.nearest.last().unwrap().distance; + self.candidates.clear(); - self.candidates.extend(&self.nearest); + for &candidate in self.nearest.iter() { + self.candidates.push(Reverse(candidate)); + } + self.visited.clear(); self.visited.extend(self.nearest.iter().map(|c| c.pid)); } @@ -390,7 +395,7 @@ impl Default for Search { fn default() -> Self { Self { visited: HashSet::new(), - candidates: Vec::new(), + candidates: BinaryHeap::new(), nearest: Vec::new(), ef: 1, furthest: OrderedFloat::from(f32::INFINITY), @@ -486,9 +491,9 @@ trait Layer { /// representation matching the zero layer even when we're referring to a higher layer. In that /// case, we use `links` to constrain the number of per-candidate links we consider for search. fn search(&self, point: &P, search: &mut Search, points: &[P], links: usize) { - while let Some(candidate) = search.candidates.pop() { + while let Some(Reverse(candidate)) = search.candidates.pop() { if candidate.distance > search.furthest { - continue; + break; } for pid in self.nearest_iter(candidate.pid).take(links) {