Improve heuristic neighbor selection
This commit is contained in:
parent
47f26978f5
commit
6dc83caabe
97
src/lib.rs
97
src/lib.rs
|
@ -83,6 +83,15 @@ pub struct Heuristic {
|
|||
pub keep_pruned: bool,
|
||||
}
|
||||
|
||||
impl Default for Heuristic {
|
||||
fn default() -> Self {
|
||||
Heuristic {
|
||||
extend_candidates: false,
|
||||
keep_pruned: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||
pub struct Hnsw<P> {
|
||||
ef_search: usize,
|
||||
|
@ -204,6 +213,7 @@ where
|
|||
let mut pool = SearchPool::default();
|
||||
let mut batch = Vec::new();
|
||||
let mut done = Vec::new();
|
||||
let mut insertion = Search::default();
|
||||
let max_batch_len = num_cpus::get() * 4;
|
||||
for (layer, mut range) in ranges {
|
||||
let num = if layer.0 > 0 { M } else { M * 2 };
|
||||
|
@ -245,7 +255,14 @@ where
|
|||
search.push(added, &points[pid], &points);
|
||||
}
|
||||
|
||||
insert(pid, &mut search, &mut zero, &points, &builder.heuristic);
|
||||
insert(
|
||||
pid,
|
||||
&mut insertion,
|
||||
&mut search,
|
||||
&mut zero,
|
||||
&points,
|
||||
&builder.heuristic,
|
||||
);
|
||||
done.push(pid);
|
||||
pool.push(search);
|
||||
}
|
||||
|
@ -334,23 +351,27 @@ where
|
|||
|
||||
/// Insert new node in the zero layer
|
||||
///
|
||||
/// `new` contains the `PointId` for the new node; `search` contains the result for searching
|
||||
/// potential neighbors for the new node; `layer` contains all the nodes at the current layer;
|
||||
/// `points` is a slice of all the points in the index.
|
||||
/// * `new`: the `PointId` for the new node
|
||||
/// * `insertion`: a `Search` for shrinking a neighbor set (only used with heuristic neighbor selection)
|
||||
/// * `search`: the result for searching potential neighbors for the new node
|
||||
/// * `layer` contains all the nodes at the current layer
|
||||
/// * `points` is a slice of all the points in the index
|
||||
///
|
||||
/// Creates the new node, initializing its `nearest` array and updates the nearest neighbors
|
||||
/// for the new node's neighbors if necessary before appending the new node to the layer.
|
||||
fn insert<P: Point>(
|
||||
new: PointId,
|
||||
insertion: &mut Search,
|
||||
search: &mut Search,
|
||||
layer: &mut Vec<ZeroNode>,
|
||||
points: &[P],
|
||||
heuristic: &Option<Heuristic>,
|
||||
) {
|
||||
layer.push(ZeroNode::default());
|
||||
let found = match heuristic {
|
||||
None => search.select_simple(M * 2),
|
||||
Some(heuristic) => {
|
||||
search.select_heuristic(&layer, M * 2, &points[new], &points, *heuristic)
|
||||
search.select_heuristic(&layer, M * 2, &points[new], points, *heuristic)
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -360,16 +381,27 @@ fn insert<P: Point>(
|
|||
found.iter().map(|c| c.pid).collect::<HashSet<_>>().len()
|
||||
);
|
||||
|
||||
let mut node = ZeroNode::default();
|
||||
for (i, candidate) in found.iter().enumerate() {
|
||||
// `candidate` here is the new node's neighbor
|
||||
let &Candidate { distance, pid } = candidate;
|
||||
node.nearest[i] = pid; // Update the new node's `nearest`
|
||||
if let Some(heuristic) = heuristic {
|
||||
insertion.reset();
|
||||
let candidate_point = &points[pid];
|
||||
insertion.push(new, candidate_point, points);
|
||||
for hop in layer.nearest_iter(pid) {
|
||||
insertion.push(hop, candidate_point, points);
|
||||
}
|
||||
|
||||
let found =
|
||||
insertion.select_heuristic(&layer, M * 2, candidate_point, points, *heuristic);
|
||||
for (slot, hop) in layer[pid.0 as usize].nearest.iter_mut().zip(found) {
|
||||
*slot = hop.pid;
|
||||
}
|
||||
layer[new.0 as usize].nearest[i] = pid;
|
||||
} else {
|
||||
// Find the correct index to insert at to keep the neighbor's neighbors sorted
|
||||
let old = &points[pid];
|
||||
let nearest = &layer[pid.0 as usize].nearest;
|
||||
|
||||
// Find the correct index to insert at to keep the neighbor's neighbors sorted
|
||||
let idx = nearest
|
||||
.binary_search_by(|third| {
|
||||
// `third` here is one of the neighbors of the new node's neighbor.
|
||||
|
@ -386,6 +418,7 @@ fn insert<P: Point>(
|
|||
// It might be possible for all the neighbor's current neighbors to be closer to our
|
||||
// neighbor than to the new node, in which case we skip insertion of our new node's ID.
|
||||
if idx >= nearest.len() {
|
||||
layer[new.0 as usize].nearest[i] = pid;
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -396,9 +429,9 @@ fn insert<P: Point>(
|
|||
}
|
||||
|
||||
nearest[idx] = new;
|
||||
layer[new.0 as usize].nearest[i] = pid;
|
||||
}
|
||||
}
|
||||
|
||||
layer.push(node);
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
|
@ -483,6 +516,7 @@ pub struct Search {
|
|||
nearest: Vec<Candidate>,
|
||||
/// Working set for heuristic selection
|
||||
working: Vec<Candidate>,
|
||||
discarded: Vec<Candidate>,
|
||||
/// Maximum number of nearest neighbors to retain (`ef` in the paper)
|
||||
ef: usize,
|
||||
}
|
||||
|
@ -495,6 +529,7 @@ impl Search {
|
|||
candidates,
|
||||
nearest,
|
||||
working,
|
||||
discarded,
|
||||
ef: _,
|
||||
} = self;
|
||||
|
||||
|
@ -502,6 +537,7 @@ impl Search {
|
|||
candidates.clear();
|
||||
nearest.clear();
|
||||
working.clear();
|
||||
discarded.clear();
|
||||
}
|
||||
|
||||
/// Selection of neighbors for insertion (algorithm 3 from the paper)
|
||||
|
@ -521,24 +557,45 @@ impl Search {
|
|||
self.working.clear();
|
||||
// Get input candidates from `self.nearest` and store them in `self.working`.
|
||||
// `self.candidates` will represent `W` from the paper's algorithm 4 for now.
|
||||
for &candidate in &self.nearest {
|
||||
while let Some(Reverse(candidate)) = self.candidates.pop() {
|
||||
self.working.push(candidate);
|
||||
if params.extend_candidates {
|
||||
for pid in layer.nearest_iter(candidate.pid).take(num) {
|
||||
let other = &points[pid];
|
||||
let distance = OrderedFloat::from(point.distance(other));
|
||||
self.working.push(Candidate { distance, pid });
|
||||
for hop in layer.nearest_iter(candidate.pid) {
|
||||
if !self.visited.insert(hop) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let other = &points[hop];
|
||||
let distance = OrderedFloat::from(point.distance(other));
|
||||
let new = Candidate { distance, pid: hop };
|
||||
self.working.push(new);
|
||||
}
|
||||
}
|
||||
|
||||
self.working.sort_unstable();
|
||||
self.nearest.clear();
|
||||
self.nearest.push(self.working[0]);
|
||||
self.discarded.clear();
|
||||
for candidate in self.working.drain(..) {
|
||||
if self.nearest.len() >= num {
|
||||
break;
|
||||
}
|
||||
|
||||
// Disadvantage candidates which are closer to an existing result point than they
|
||||
// are to the query point, to facilitate bridging between clustered points.
|
||||
let candidate_point = &points[candidate.pid];
|
||||
let nearest = !self.nearest.iter().any(|result| {
|
||||
let distance = OrderedFloat::from(candidate_point.distance(&points[result.pid]));
|
||||
distance < candidate.distance
|
||||
});
|
||||
|
||||
match nearest {
|
||||
true => self.nearest.push(candidate),
|
||||
false => self.discarded.push(candidate),
|
||||
}
|
||||
}
|
||||
|
||||
if params.keep_pruned {
|
||||
// Add discarded connections from `working` (`Wd`) to `self.nearest` (`R`)
|
||||
for candidate in self.working.drain(1..) {
|
||||
for candidate in self.discarded.drain(..) {
|
||||
if self.nearest.len() >= num {
|
||||
break;
|
||||
}
|
||||
|
@ -546,6 +603,7 @@ impl Search {
|
|||
}
|
||||
}
|
||||
|
||||
self.nearest.sort_unstable();
|
||||
&self.nearest[..min(self.nearest.len(), num)]
|
||||
}
|
||||
|
||||
|
@ -601,6 +659,7 @@ impl Default for Search {
|
|||
candidates: BinaryHeap::new(),
|
||||
nearest: Vec::new(),
|
||||
working: Vec::new(),
|
||||
discarded: Vec::new(),
|
||||
ef: 1,
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue