Simplify heuristic selection code

This commit is contained in:
Dirkjan Ochtman 2021-01-12 13:31:51 +01:00
parent d265ef7c29
commit 581b3a17e2
1 changed files with 18 additions and 39 deletions

View File

@ -1,5 +1,5 @@
use std::cmp::{max, min, Ordering, Reverse}; use std::cmp::{max, min, Ordering, Reverse};
use std::collections::{BinaryHeap, VecDeque}; use std::collections::BinaryHeap;
use std::hash::Hash; use std::hash::Hash;
use std::ops::Index; use std::ops::Index;
@ -476,6 +476,8 @@ pub struct Search {
candidates: BinaryHeap<Reverse<Candidate>>, candidates: BinaryHeap<Reverse<Candidate>>,
/// Nearest neighbors found so far (`W` in the paper) /// Nearest neighbors found so far (`W` in the paper)
nearest: Vec<Candidate>, nearest: Vec<Candidate>,
/// Working set for heuristic selection
working: Vec<Candidate>,
/// Maximum number of nearest neighbors to retain (`ef` in the paper) /// Maximum number of nearest neighbors to retain (`ef` in the paper)
ef: usize, ef: usize,
} }
@ -487,12 +489,14 @@ impl Search {
visited, visited,
candidates, candidates,
nearest, nearest,
working,
ef: _, ef: _,
} = self; } = self;
visited.clear(); visited.clear();
candidates.clear(); candidates.clear();
nearest.clear(); nearest.clear();
working.clear();
} }
/// Selection of neighbors for insertion (algorithm 3 from the paper) /// Selection of neighbors for insertion (algorithm 3 from the paper)
@ -509,61 +513,35 @@ impl Search {
points: &[P], points: &[P],
params: Heuristic, params: Heuristic,
) -> &[Candidate] { ) -> &[Candidate] {
// Get input candidates from `self.nearest` and store them in `self.candidates`. self.working.clear();
// Get input candidates from `self.nearest` and store them in `self.working`.
// `self.candidates` will represent `W` from the paper's algorithm 4 for now. // `self.candidates` will represent `W` from the paper's algorithm 4 for now.
self.candidates.clear(); for &candidate in &self.nearest {
self.nearest.sort_unstable(); self.working.push(candidate);
for &candidate in self.nearest.iter() { if params.extend_candidates {
self.candidates.push(Reverse(candidate));
}
// Clear `self.nearest`. This now represents the result set (`R`).
self.nearest.clear();
let mut working = VecDeque::new();
if params.extend_candidates {
// Because we can't extend `self.candidates` while we iterate over it, we use
// `working` to accumulate candidates' neighbors in.
for Reverse(candidate) in &self.candidates {
for pid in layer.nearest_iter(candidate.pid).take(num) { for pid in layer.nearest_iter(candidate.pid).take(num) {
let other = &points[pid]; let other = &points[pid];
let distance = OrderedFloat::from(point.distance(other)); let distance = OrderedFloat::from(point.distance(other));
working.push_back(Candidate { distance, pid }); self.working.push(Candidate { distance, pid });
} }
} }
// Once we have all the extended candidates, push them onto `self.candidates`.
// Because `self.candidates` is a `BinaryHeap`, it remains in sorted order.
// After this loop, `working` is empty again, so we can reuse it.
for candidate in working.drain(..) {
self.candidates.push(Reverse(candidate));
}
} }
// Take candidates from `self.candidates` (`W`) and compare them to the re self.working.sort_unstable();
while let Some(Reverse(candidate)) = self.candidates.pop() { self.nearest.clear();
if self.nearest.len() >= num { self.nearest.push(self.working[0]);
break;
}
match self.nearest.binary_search(&candidate) {
Err(0) => self.nearest.insert(0, candidate),
Err(_) => working.push_back(candidate),
Ok(_) => unreachable!(),
}
}
// `working` was filled by pushing back from `candidates`, so it must be in sorted order.
if params.keep_pruned { if params.keep_pruned {
while let Some(candidate) = working.pop_front() { // Add discarded connections from `working` (`Wd`) to `self.nearest` (`R`)
for candidate in self.working.drain(1..) {
if self.nearest.len() >= num { if self.nearest.len() >= num {
break; break;
} }
self.nearest.push(candidate); self.nearest.push(candidate);
} }
} }
self.select_simple(num) &self.nearest[..min(self.nearest.len(), num)]
} }
/// Track node `pid` as a potential new neighbor for the given `point` /// Track node `pid` as a potential new neighbor for the given `point`
@ -617,6 +595,7 @@ impl Default for Search {
visited: HashSet::new(), visited: HashSet::new(),
candidates: BinaryHeap::new(), candidates: BinaryHeap::new(),
nearest: Vec::new(), nearest: Vec::new(),
working: Vec::new(),
ef: 1, ef: 1,
} }
} }