Improve heuristic neighbor selection

This commit is contained in:
Dirkjan Ochtman 2021-01-14 17:01:31 +01:00
parent 47f26978f5
commit 6dc83caabe
1 changed files with 102 additions and 43 deletions

View File

@ -83,6 +83,15 @@ pub struct Heuristic {
pub keep_pruned: bool,
}
impl Default for Heuristic {
fn default() -> Self {
Heuristic {
extend_candidates: false,
keep_pruned: true,
}
}
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct Hnsw<P> {
ef_search: usize,
@ -204,6 +213,7 @@ where
let mut pool = SearchPool::default();
let mut batch = Vec::new();
let mut done = Vec::new();
let mut insertion = Search::default();
let max_batch_len = num_cpus::get() * 4;
for (layer, mut range) in ranges {
let num = if layer.0 > 0 { M } else { M * 2 };
@ -245,7 +255,14 @@ where
search.push(added, &points[pid], &points);
}
insert(pid, &mut search, &mut zero, &points, &builder.heuristic);
insert(
pid,
&mut insertion,
&mut search,
&mut zero,
&points,
&builder.heuristic,
);
done.push(pid);
pool.push(search);
}
@ -334,23 +351,27 @@ where
/// Insert new node in the zero layer
///
/// `new` contains the `PointId` for the new node; `search` contains the result for searching
/// potential neighbors for the new node; `layer` contains all the nodes at the current layer;
/// `points` is a slice of all the points in the index.
/// * `new`: the `PointId` for the new node
/// * `insertion`: a `Search` for shrinking a neighbor set (only used with heuristic neighbor selection)
/// * `search`: the result for searching potential neighbors for the new node
/// * `layer` contains all the nodes at the current layer
/// * `points` is a slice of all the points in the index
///
/// Creates the new node, initializing its `nearest` array and updates the nearest neighbors
/// for the new node's neighbors if necessary before appending the new node to the layer.
fn insert<P: Point>(
new: PointId,
insertion: &mut Search,
search: &mut Search,
layer: &mut Vec<ZeroNode>,
points: &[P],
heuristic: &Option<Heuristic>,
) {
layer.push(ZeroNode::default());
let found = match heuristic {
None => search.select_simple(M * 2),
Some(heuristic) => {
search.select_heuristic(&layer, M * 2, &points[new], &points, *heuristic)
search.select_heuristic(&layer, M * 2, &points[new], points, *heuristic)
}
};
@ -360,45 +381,57 @@ fn insert<P: Point>(
found.iter().map(|c| c.pid).collect::<HashSet<_>>().len()
);
let mut node = ZeroNode::default();
for (i, candidate) in found.iter().enumerate() {
// `candidate` here is the new node's neighbor
let &Candidate { distance, pid } = candidate;
node.nearest[i] = pid; // Update the new node's `nearest`
if let Some(heuristic) = heuristic {
insertion.reset();
let candidate_point = &points[pid];
insertion.push(new, candidate_point, points);
for hop in layer.nearest_iter(pid) {
insertion.push(hop, candidate_point, points);
}
let old = &points[pid];
let nearest = &layer[pid.0 as usize].nearest;
let found =
insertion.select_heuristic(&layer, M * 2, candidate_point, points, *heuristic);
for (slot, hop) in layer[pid.0 as usize].nearest.iter_mut().zip(found) {
*slot = hop.pid;
}
layer[new.0 as usize].nearest[i] = pid;
} else {
// Find the correct index to insert at to keep the neighbor's neighbors sorted
let old = &points[pid];
let nearest = &layer[pid.0 as usize].nearest;
let idx = nearest
.binary_search_by(|third| {
// `third` here is one of the neighbors of the new node's neighbor.
let third = match third {
pid if pid.is_valid() => *pid,
// if `third` is `None`, our new `node` is always "closer"
_ => return Ordering::Greater,
};
// Find the correct index to insert at to keep the neighbor's neighbors sorted
let idx = nearest
.binary_search_by(|third| {
// `third` here is one of the neighbors of the new node's neighbor.
let third = match third {
pid if pid.is_valid() => *pid,
// if `third` is `None`, our new `node` is always "closer"
_ => return Ordering::Greater,
};
distance.cmp(&old.distance(&points[third.0 as usize]).into())
})
.unwrap_or_else(|e| e);
distance.cmp(&old.distance(&points[third.0 as usize]).into())
})
.unwrap_or_else(|e| e);
// It might be possible for all the neighbor's current neighbors to be closer to our
// neighbor than to the new node, in which case we skip insertion of our new node's ID.
if idx >= nearest.len() {
layer[new.0 as usize].nearest[i] = pid;
continue;
}
// It might be possible for all the neighbor's current neighbors to be closer to our
// neighbor than to the new node, in which case we skip insertion of our new node's ID.
if idx >= nearest.len() {
continue;
let nearest = &mut layer[pid.0 as usize].nearest;
if nearest[idx].is_valid() {
let end = (M * 2) - 1;
nearest.copy_within(idx..end, idx + 1);
}
nearest[idx] = new;
layer[new.0 as usize].nearest[i] = pid;
}
let nearest = &mut layer[pid.0 as usize].nearest;
if nearest[idx].is_valid() {
let end = (M * 2) - 1;
nearest.copy_within(idx..end, idx + 1);
}
nearest[idx] = new;
}
layer.push(node);
}
#[derive(Default)]
@ -483,6 +516,7 @@ pub struct Search {
nearest: Vec<Candidate>,
/// Working set for heuristic selection
working: Vec<Candidate>,
discarded: Vec<Candidate>,
/// Maximum number of nearest neighbors to retain (`ef` in the paper)
ef: usize,
}
@ -495,6 +529,7 @@ impl Search {
candidates,
nearest,
working,
discarded,
ef: _,
} = self;
@ -502,6 +537,7 @@ impl Search {
candidates.clear();
nearest.clear();
working.clear();
discarded.clear();
}
/// Selection of neighbors for insertion (algorithm 3 from the paper)
@ -521,24 +557,45 @@ impl Search {
self.working.clear();
// Get input candidates from `self.nearest` and store them in `self.working`.
// `self.candidates` will represent `W` from the paper's algorithm 4 for now.
for &candidate in &self.nearest {
while let Some(Reverse(candidate)) = self.candidates.pop() {
self.working.push(candidate);
if params.extend_candidates {
for pid in layer.nearest_iter(candidate.pid).take(num) {
let other = &points[pid];
let distance = OrderedFloat::from(point.distance(other));
self.working.push(Candidate { distance, pid });
for hop in layer.nearest_iter(candidate.pid) {
if !self.visited.insert(hop) {
continue;
}
let other = &points[hop];
let distance = OrderedFloat::from(point.distance(other));
let new = Candidate { distance, pid: hop };
self.working.push(new);
}
}
self.working.sort_unstable();
self.nearest.clear();
self.nearest.push(self.working[0]);
self.discarded.clear();
for candidate in self.working.drain(..) {
if self.nearest.len() >= num {
break;
}
// Disadvantage candidates which are closer to an existing result point than they
// are to the query point, to facilitate bridging between clustered points.
let candidate_point = &points[candidate.pid];
let nearest = !self.nearest.iter().any(|result| {
let distance = OrderedFloat::from(candidate_point.distance(&points[result.pid]));
distance < candidate.distance
});
match nearest {
true => self.nearest.push(candidate),
false => self.discarded.push(candidate),
}
}
if params.keep_pruned {
// Add discarded connections from `working` (`Wd`) to `self.nearest` (`R`)
for candidate in self.working.drain(1..) {
for candidate in self.discarded.drain(..) {
if self.nearest.len() >= num {
break;
}
@ -546,6 +603,7 @@ impl Search {
}
}
self.nearest.sort_unstable();
&self.nearest[..min(self.nearest.len(), num)]
}
@ -601,6 +659,7 @@ impl Default for Search {
candidates: BinaryHeap::new(),
nearest: Vec::new(),
working: Vec::new(),
discarded: Vec::new(),
ef: 1,
}
}