Simplify selection sizes

This commit is contained in:
Dirkjan Ochtman 2021-01-21 10:20:52 +01:00
parent f627769667
commit 96b69e5d4b
1 changed files with 11 additions and 14 deletions

View File

@ -352,12 +352,11 @@ where
} }
} }
let nearest = search.select_simple(); let nearest = search.select_simple(out.len());
let found = min(nearest.len(), out.len()); for (i, candidate) in nearest.iter().enumerate() {
for (i, candidate) in nearest.iter().take(found).enumerate() {
out[i] = candidate.pid; out[i] = candidate.pid;
} }
found nearest.len()
} }
/// Iterate over the keys and values in this index /// Iterate over the keys and values in this index
@ -389,8 +388,8 @@ fn insert<P: Point>(
) { ) {
layer.push(ZeroNode::default()); layer.push(ZeroNode::default());
let found = match heuristic { let found = match heuristic {
None => search.select_simple(), None => search.select_simple(M * 2),
Some(heuristic) => search.select_heuristic(&points[new], &layer, M * 2, points, *heuristic), Some(heuristic) => search.select_heuristic(&points[new], &layer, points, *heuristic),
}; };
// Just make sure the candidates are all unique // Just make sure the candidates are all unique
@ -399,7 +398,7 @@ fn insert<P: Point>(
found.iter().map(|c| c.pid).collect::<HashSet<_>>().len() found.iter().map(|c| c.pid).collect::<HashSet<_>>().len()
); );
for (i, candidate) in found.iter().take(M * 2).enumerate() { for (i, candidate) in found.iter().enumerate() {
// `candidate` here is the new node's neighbor // `candidate` here is the new node's neighbor
let &Candidate { distance, pid } = candidate; let &Candidate { distance, pid } = candidate;
if let Some(heuristic) = heuristic { if let Some(heuristic) = heuristic {
@ -407,7 +406,6 @@ fn insert<P: Point>(
new, new,
layer.as_slice().nearest_iter(pid), layer.as_slice().nearest_iter(pid),
layer, layer,
M * 2,
&points[pid], &points[pid],
points, points,
*heuristic, *heuristic,
@ -523,7 +521,6 @@ impl Search {
new: PointId, new: PointId,
current: impl Iterator<Item = PointId>, current: impl Iterator<Item = PointId>,
layer: &[ZeroNode], layer: &[ZeroNode],
num: usize,
point: &P, point: &P,
points: &[P], points: &[P],
params: Heuristic, params: Heuristic,
@ -533,7 +530,7 @@ impl Search {
for pid in current { for pid in current {
self.push(pid, point, points); self.push(pid, point, points);
} }
self.select_heuristic(point, &layer, num, points, params) self.select_heuristic(point, &layer, points, params)
} }
/// Heuristically sort and truncate neighbors in `self.nearest` /// Heuristically sort and truncate neighbors in `self.nearest`
@ -543,7 +540,6 @@ impl Search {
&mut self, &mut self,
point: &P, point: &P,
layer: &[ZeroNode], layer: &[ZeroNode],
num: usize,
points: &[P], points: &[P],
params: Heuristic, params: Heuristic,
) -> &[Candidate] { ) -> &[Candidate] {
@ -573,7 +569,7 @@ impl Search {
self.nearest.clear(); self.nearest.clear();
self.discarded.clear(); self.discarded.clear();
for candidate in self.working.drain(..) { for candidate in self.working.drain(..) {
if self.nearest.len() >= num { if self.nearest.len() >= M * 2 {
break; break;
} }
@ -594,7 +590,7 @@ impl Search {
if params.keep_pruned { if params.keep_pruned {
// Add discarded connections from `working` (`Wd`) to `self.nearest` (`R`) // Add discarded connections from `working` (`Wd`) to `self.nearest` (`R`)
for candidate in self.discarded.drain(..) { for candidate in self.discarded.drain(..) {
if self.nearest.len() >= num { if self.nearest.len() >= M * 2 {
break; break;
} }
self.nearest.push(candidate); self.nearest.push(candidate);
@ -662,7 +658,8 @@ impl Search {
} }
/// Selection of neighbors for insertion (algorithm 3 from the paper) /// Selection of neighbors for insertion (algorithm 3 from the paper)
fn select_simple(&self) -> &[Candidate] { fn select_simple(&mut self, num: usize) -> &[Candidate] {
self.nearest.truncate(num);
&self.nearest &self.nearest
} }
} }