Improve heuristic neighbor selection

This commit is contained in:
Dirkjan Ochtman 2021-01-14 17:01:31 +01:00
parent 47f26978f5
commit 6dc83caabe
1 changed files with 102 additions and 43 deletions

View File

@ -83,6 +83,15 @@ pub struct Heuristic {
pub keep_pruned: bool, pub keep_pruned: bool,
} }
impl Default for Heuristic {
fn default() -> Self {
Heuristic {
extend_candidates: false,
keep_pruned: true,
}
}
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct Hnsw<P> { pub struct Hnsw<P> {
ef_search: usize, ef_search: usize,
@ -204,6 +213,7 @@ where
let mut pool = SearchPool::default(); let mut pool = SearchPool::default();
let mut batch = Vec::new(); let mut batch = Vec::new();
let mut done = Vec::new(); let mut done = Vec::new();
let mut insertion = Search::default();
let max_batch_len = num_cpus::get() * 4; let max_batch_len = num_cpus::get() * 4;
for (layer, mut range) in ranges { for (layer, mut range) in ranges {
let num = if layer.0 > 0 { M } else { M * 2 }; let num = if layer.0 > 0 { M } else { M * 2 };
@ -245,7 +255,14 @@ where
search.push(added, &points[pid], &points); search.push(added, &points[pid], &points);
} }
insert(pid, &mut search, &mut zero, &points, &builder.heuristic); insert(
pid,
&mut insertion,
&mut search,
&mut zero,
&points,
&builder.heuristic,
);
done.push(pid); done.push(pid);
pool.push(search); pool.push(search);
} }
@ -334,23 +351,27 @@ where
/// Insert new node in the zero layer /// Insert new node in the zero layer
/// ///
/// `new` contains the `PointId` for the new node; `search` contains the result for searching /// * `new`: the `PointId` for the new node
/// potential neighbors for the new node; `layer` contains all the nodes at the current layer; /// * `insertion`: a `Search` for shrinking a neighbor set (only used with heuristic neighbor selection)
/// `points` is a slice of all the points in the index. /// * `search`: the result for searching potential neighbors for the new node
/// * `layer` contains all the nodes at the current layer
/// * `points` is a slice of all the points in the index
/// ///
/// Creates the new node, initializing its `nearest` array and updates the nearest neighbors /// Creates the new node, initializing its `nearest` array and updates the nearest neighbors
/// for the new node's neighbors if necessary before appending the new node to the layer. /// for the new node's neighbors if necessary before appending the new node to the layer.
fn insert<P: Point>( fn insert<P: Point>(
new: PointId, new: PointId,
insertion: &mut Search,
search: &mut Search, search: &mut Search,
layer: &mut Vec<ZeroNode>, layer: &mut Vec<ZeroNode>,
points: &[P], points: &[P],
heuristic: &Option<Heuristic>, heuristic: &Option<Heuristic>,
) { ) {
layer.push(ZeroNode::default());
let found = match heuristic { let found = match heuristic {
None => search.select_simple(M * 2), None => search.select_simple(M * 2),
Some(heuristic) => { Some(heuristic) => {
search.select_heuristic(&layer, M * 2, &points[new], &points, *heuristic) search.select_heuristic(&layer, M * 2, &points[new], points, *heuristic)
} }
}; };
@ -360,45 +381,57 @@ fn insert<P: Point>(
found.iter().map(|c| c.pid).collect::<HashSet<_>>().len() found.iter().map(|c| c.pid).collect::<HashSet<_>>().len()
); );
let mut node = ZeroNode::default();
for (i, candidate) in found.iter().enumerate() { for (i, candidate) in found.iter().enumerate() {
// `candidate` here is the new node's neighbor // `candidate` here is the new node's neighbor
let &Candidate { distance, pid } = candidate; let &Candidate { distance, pid } = candidate;
node.nearest[i] = pid; // Update the new node's `nearest` if let Some(heuristic) = heuristic {
insertion.reset();
let candidate_point = &points[pid];
insertion.push(new, candidate_point, points);
for hop in layer.nearest_iter(pid) {
insertion.push(hop, candidate_point, points);
}
let old = &points[pid]; let found =
let nearest = &layer[pid.0 as usize].nearest; insertion.select_heuristic(&layer, M * 2, candidate_point, points, *heuristic);
for (slot, hop) in layer[pid.0 as usize].nearest.iter_mut().zip(found) {
*slot = hop.pid;
}
layer[new.0 as usize].nearest[i] = pid;
} else {
// Find the correct index to insert at to keep the neighbor's neighbors sorted
let old = &points[pid];
let nearest = &layer[pid.0 as usize].nearest;
let idx = nearest
.binary_search_by(|third| {
// `third` here is one of the neighbors of the new node's neighbor.
let third = match third {
pid if pid.is_valid() => *pid,
// if `third` is `None`, our new `node` is always "closer"
_ => return Ordering::Greater,
};
// Find the correct index to insert at to keep the neighbor's neighbors sorted distance.cmp(&old.distance(&points[third.0 as usize]).into())
let idx = nearest })
.binary_search_by(|third| { .unwrap_or_else(|e| e);
// `third` here is one of the neighbors of the new node's neighbor.
let third = match third {
pid if pid.is_valid() => *pid,
// if `third` is `None`, our new `node` is always "closer"
_ => return Ordering::Greater,
};
distance.cmp(&old.distance(&points[third.0 as usize]).into()) // It might be possible for all the neighbor's current neighbors to be closer to our
}) // neighbor than to the new node, in which case we skip insertion of our new node's ID.
.unwrap_or_else(|e| e); if idx >= nearest.len() {
layer[new.0 as usize].nearest[i] = pid;
continue;
}
// It might be possible for all the neighbor's current neighbors to be closer to our let nearest = &mut layer[pid.0 as usize].nearest;
// neighbor than to the new node, in which case we skip insertion of our new node's ID. if nearest[idx].is_valid() {
if idx >= nearest.len() { let end = (M * 2) - 1;
continue; nearest.copy_within(idx..end, idx + 1);
}
nearest[idx] = new;
layer[new.0 as usize].nearest[i] = pid;
} }
let nearest = &mut layer[pid.0 as usize].nearest;
if nearest[idx].is_valid() {
let end = (M * 2) - 1;
nearest.copy_within(idx..end, idx + 1);
}
nearest[idx] = new;
} }
layer.push(node);
} }
#[derive(Default)] #[derive(Default)]
@ -483,6 +516,7 @@ pub struct Search {
nearest: Vec<Candidate>, nearest: Vec<Candidate>,
/// Working set for heuristic selection /// Working set for heuristic selection
working: Vec<Candidate>, working: Vec<Candidate>,
discarded: Vec<Candidate>,
/// Maximum number of nearest neighbors to retain (`ef` in the paper) /// Maximum number of nearest neighbors to retain (`ef` in the paper)
ef: usize, ef: usize,
} }
@ -495,6 +529,7 @@ impl Search {
candidates, candidates,
nearest, nearest,
working, working,
discarded,
ef: _, ef: _,
} = self; } = self;
@ -502,6 +537,7 @@ impl Search {
candidates.clear(); candidates.clear();
nearest.clear(); nearest.clear();
working.clear(); working.clear();
discarded.clear();
} }
/// Selection of neighbors for insertion (algorithm 3 from the paper) /// Selection of neighbors for insertion (algorithm 3 from the paper)
@ -521,24 +557,45 @@ impl Search {
self.working.clear(); self.working.clear();
// Get input candidates from `self.nearest` and store them in `self.working`. // Get input candidates from `self.nearest` and store them in `self.working`.
// `self.candidates` will represent `W` from the paper's algorithm 4 for now. // `self.candidates` will represent `W` from the paper's algorithm 4 for now.
for &candidate in &self.nearest { while let Some(Reverse(candidate)) = self.candidates.pop() {
self.working.push(candidate); self.working.push(candidate);
if params.extend_candidates { for hop in layer.nearest_iter(candidate.pid) {
for pid in layer.nearest_iter(candidate.pid).take(num) { if !self.visited.insert(hop) {
let other = &points[pid]; continue;
let distance = OrderedFloat::from(point.distance(other));
self.working.push(Candidate { distance, pid });
} }
let other = &points[hop];
let distance = OrderedFloat::from(point.distance(other));
let new = Candidate { distance, pid: hop };
self.working.push(new);
} }
} }
self.working.sort_unstable(); self.working.sort_unstable();
self.nearest.clear(); self.nearest.clear();
self.nearest.push(self.working[0]); self.discarded.clear();
for candidate in self.working.drain(..) {
if self.nearest.len() >= num {
break;
}
// Disadvantage candidates which are closer to an existing result point than they
// are to the query point, to facilitate bridging between clustered points.
let candidate_point = &points[candidate.pid];
let nearest = !self.nearest.iter().any(|result| {
let distance = OrderedFloat::from(candidate_point.distance(&points[result.pid]));
distance < candidate.distance
});
match nearest {
true => self.nearest.push(candidate),
false => self.discarded.push(candidate),
}
}
if params.keep_pruned { if params.keep_pruned {
// Add discarded connections from `working` (`Wd`) to `self.nearest` (`R`) // Add discarded connections from `working` (`Wd`) to `self.nearest` (`R`)
for candidate in self.working.drain(1..) { for candidate in self.discarded.drain(..) {
if self.nearest.len() >= num { if self.nearest.len() >= num {
break; break;
} }
@ -546,6 +603,7 @@ impl Search {
} }
} }
self.nearest.sort_unstable();
&self.nearest[..min(self.nearest.len(), num)] &self.nearest[..min(self.nearest.len(), num)]
} }
@ -601,6 +659,7 @@ impl Default for Search {
candidates: BinaryHeap::new(), candidates: BinaryHeap::new(),
nearest: Vec::new(), nearest: Vec::new(),
working: Vec::new(), working: Vec::new(),
discarded: Vec::new(),
ef: 1, ef: 1,
} }
} }