Improve heuristic neighbor selection
This commit is contained in:
parent
47f26978f5
commit
6dc83caabe
145
src/lib.rs
145
src/lib.rs
|
@ -83,6 +83,15 @@ pub struct Heuristic {
|
||||||
pub keep_pruned: bool,
|
pub keep_pruned: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for Heuristic {
|
||||||
|
fn default() -> Self {
|
||||||
|
Heuristic {
|
||||||
|
extend_candidates: false,
|
||||||
|
keep_pruned: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||||
pub struct Hnsw<P> {
|
pub struct Hnsw<P> {
|
||||||
ef_search: usize,
|
ef_search: usize,
|
||||||
|
@ -204,6 +213,7 @@ where
|
||||||
let mut pool = SearchPool::default();
|
let mut pool = SearchPool::default();
|
||||||
let mut batch = Vec::new();
|
let mut batch = Vec::new();
|
||||||
let mut done = Vec::new();
|
let mut done = Vec::new();
|
||||||
|
let mut insertion = Search::default();
|
||||||
let max_batch_len = num_cpus::get() * 4;
|
let max_batch_len = num_cpus::get() * 4;
|
||||||
for (layer, mut range) in ranges {
|
for (layer, mut range) in ranges {
|
||||||
let num = if layer.0 > 0 { M } else { M * 2 };
|
let num = if layer.0 > 0 { M } else { M * 2 };
|
||||||
|
@ -245,7 +255,14 @@ where
|
||||||
search.push(added, &points[pid], &points);
|
search.push(added, &points[pid], &points);
|
||||||
}
|
}
|
||||||
|
|
||||||
insert(pid, &mut search, &mut zero, &points, &builder.heuristic);
|
insert(
|
||||||
|
pid,
|
||||||
|
&mut insertion,
|
||||||
|
&mut search,
|
||||||
|
&mut zero,
|
||||||
|
&points,
|
||||||
|
&builder.heuristic,
|
||||||
|
);
|
||||||
done.push(pid);
|
done.push(pid);
|
||||||
pool.push(search);
|
pool.push(search);
|
||||||
}
|
}
|
||||||
|
@ -334,23 +351,27 @@ where
|
||||||
|
|
||||||
/// Insert new node in the zero layer
|
/// Insert new node in the zero layer
|
||||||
///
|
///
|
||||||
/// `new` contains the `PointId` for the new node; `search` contains the result for searching
|
/// * `new`: the `PointId` for the new node
|
||||||
/// potential neighbors for the new node; `layer` contains all the nodes at the current layer;
|
/// * `insertion`: a `Search` for shrinking a neighbor set (only used with heuristic neighbor selection)
|
||||||
/// `points` is a slice of all the points in the index.
|
/// * `search`: the result for searching potential neighbors for the new node
|
||||||
|
/// * `layer` contains all the nodes at the current layer
|
||||||
|
/// * `points` is a slice of all the points in the index
|
||||||
///
|
///
|
||||||
/// Creates the new node, initializing its `nearest` array and updates the nearest neighbors
|
/// Creates the new node, initializing its `nearest` array and updates the nearest neighbors
|
||||||
/// for the new node's neighbors if necessary before appending the new node to the layer.
|
/// for the new node's neighbors if necessary before appending the new node to the layer.
|
||||||
fn insert<P: Point>(
|
fn insert<P: Point>(
|
||||||
new: PointId,
|
new: PointId,
|
||||||
|
insertion: &mut Search,
|
||||||
search: &mut Search,
|
search: &mut Search,
|
||||||
layer: &mut Vec<ZeroNode>,
|
layer: &mut Vec<ZeroNode>,
|
||||||
points: &[P],
|
points: &[P],
|
||||||
heuristic: &Option<Heuristic>,
|
heuristic: &Option<Heuristic>,
|
||||||
) {
|
) {
|
||||||
|
layer.push(ZeroNode::default());
|
||||||
let found = match heuristic {
|
let found = match heuristic {
|
||||||
None => search.select_simple(M * 2),
|
None => search.select_simple(M * 2),
|
||||||
Some(heuristic) => {
|
Some(heuristic) => {
|
||||||
search.select_heuristic(&layer, M * 2, &points[new], &points, *heuristic)
|
search.select_heuristic(&layer, M * 2, &points[new], points, *heuristic)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -360,45 +381,57 @@ fn insert<P: Point>(
|
||||||
found.iter().map(|c| c.pid).collect::<HashSet<_>>().len()
|
found.iter().map(|c| c.pid).collect::<HashSet<_>>().len()
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut node = ZeroNode::default();
|
|
||||||
for (i, candidate) in found.iter().enumerate() {
|
for (i, candidate) in found.iter().enumerate() {
|
||||||
// `candidate` here is the new node's neighbor
|
// `candidate` here is the new node's neighbor
|
||||||
let &Candidate { distance, pid } = candidate;
|
let &Candidate { distance, pid } = candidate;
|
||||||
node.nearest[i] = pid; // Update the new node's `nearest`
|
if let Some(heuristic) = heuristic {
|
||||||
|
insertion.reset();
|
||||||
|
let candidate_point = &points[pid];
|
||||||
|
insertion.push(new, candidate_point, points);
|
||||||
|
for hop in layer.nearest_iter(pid) {
|
||||||
|
insertion.push(hop, candidate_point, points);
|
||||||
|
}
|
||||||
|
|
||||||
let old = &points[pid];
|
let found =
|
||||||
let nearest = &layer[pid.0 as usize].nearest;
|
insertion.select_heuristic(&layer, M * 2, candidate_point, points, *heuristic);
|
||||||
|
for (slot, hop) in layer[pid.0 as usize].nearest.iter_mut().zip(found) {
|
||||||
|
*slot = hop.pid;
|
||||||
|
}
|
||||||
|
layer[new.0 as usize].nearest[i] = pid;
|
||||||
|
} else {
|
||||||
|
// Find the correct index to insert at to keep the neighbor's neighbors sorted
|
||||||
|
let old = &points[pid];
|
||||||
|
let nearest = &layer[pid.0 as usize].nearest;
|
||||||
|
let idx = nearest
|
||||||
|
.binary_search_by(|third| {
|
||||||
|
// `third` here is one of the neighbors of the new node's neighbor.
|
||||||
|
let third = match third {
|
||||||
|
pid if pid.is_valid() => *pid,
|
||||||
|
// if `third` is `None`, our new `node` is always "closer"
|
||||||
|
_ => return Ordering::Greater,
|
||||||
|
};
|
||||||
|
|
||||||
// Find the correct index to insert at to keep the neighbor's neighbors sorted
|
distance.cmp(&old.distance(&points[third.0 as usize]).into())
|
||||||
let idx = nearest
|
})
|
||||||
.binary_search_by(|third| {
|
.unwrap_or_else(|e| e);
|
||||||
// `third` here is one of the neighbors of the new node's neighbor.
|
|
||||||
let third = match third {
|
|
||||||
pid if pid.is_valid() => *pid,
|
|
||||||
// if `third` is `None`, our new `node` is always "closer"
|
|
||||||
_ => return Ordering::Greater,
|
|
||||||
};
|
|
||||||
|
|
||||||
distance.cmp(&old.distance(&points[third.0 as usize]).into())
|
// It might be possible for all the neighbor's current neighbors to be closer to our
|
||||||
})
|
// neighbor than to the new node, in which case we skip insertion of our new node's ID.
|
||||||
.unwrap_or_else(|e| e);
|
if idx >= nearest.len() {
|
||||||
|
layer[new.0 as usize].nearest[i] = pid;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// It might be possible for all the neighbor's current neighbors to be closer to our
|
let nearest = &mut layer[pid.0 as usize].nearest;
|
||||||
// neighbor than to the new node, in which case we skip insertion of our new node's ID.
|
if nearest[idx].is_valid() {
|
||||||
if idx >= nearest.len() {
|
let end = (M * 2) - 1;
|
||||||
continue;
|
nearest.copy_within(idx..end, idx + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
nearest[idx] = new;
|
||||||
|
layer[new.0 as usize].nearest[i] = pid;
|
||||||
}
|
}
|
||||||
|
|
||||||
let nearest = &mut layer[pid.0 as usize].nearest;
|
|
||||||
if nearest[idx].is_valid() {
|
|
||||||
let end = (M * 2) - 1;
|
|
||||||
nearest.copy_within(idx..end, idx + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
nearest[idx] = new;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
layer.push(node);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
|
@ -483,6 +516,7 @@ pub struct Search {
|
||||||
nearest: Vec<Candidate>,
|
nearest: Vec<Candidate>,
|
||||||
/// Working set for heuristic selection
|
/// Working set for heuristic selection
|
||||||
working: Vec<Candidate>,
|
working: Vec<Candidate>,
|
||||||
|
discarded: Vec<Candidate>,
|
||||||
/// Maximum number of nearest neighbors to retain (`ef` in the paper)
|
/// Maximum number of nearest neighbors to retain (`ef` in the paper)
|
||||||
ef: usize,
|
ef: usize,
|
||||||
}
|
}
|
||||||
|
@ -495,6 +529,7 @@ impl Search {
|
||||||
candidates,
|
candidates,
|
||||||
nearest,
|
nearest,
|
||||||
working,
|
working,
|
||||||
|
discarded,
|
||||||
ef: _,
|
ef: _,
|
||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
|
@ -502,6 +537,7 @@ impl Search {
|
||||||
candidates.clear();
|
candidates.clear();
|
||||||
nearest.clear();
|
nearest.clear();
|
||||||
working.clear();
|
working.clear();
|
||||||
|
discarded.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Selection of neighbors for insertion (algorithm 3 from the paper)
|
/// Selection of neighbors for insertion (algorithm 3 from the paper)
|
||||||
|
@ -521,24 +557,45 @@ impl Search {
|
||||||
self.working.clear();
|
self.working.clear();
|
||||||
// Get input candidates from `self.nearest` and store them in `self.working`.
|
// Get input candidates from `self.nearest` and store them in `self.working`.
|
||||||
// `self.candidates` will represent `W` from the paper's algorithm 4 for now.
|
// `self.candidates` will represent `W` from the paper's algorithm 4 for now.
|
||||||
for &candidate in &self.nearest {
|
while let Some(Reverse(candidate)) = self.candidates.pop() {
|
||||||
self.working.push(candidate);
|
self.working.push(candidate);
|
||||||
if params.extend_candidates {
|
for hop in layer.nearest_iter(candidate.pid) {
|
||||||
for pid in layer.nearest_iter(candidate.pid).take(num) {
|
if !self.visited.insert(hop) {
|
||||||
let other = &points[pid];
|
continue;
|
||||||
let distance = OrderedFloat::from(point.distance(other));
|
|
||||||
self.working.push(Candidate { distance, pid });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let other = &points[hop];
|
||||||
|
let distance = OrderedFloat::from(point.distance(other));
|
||||||
|
let new = Candidate { distance, pid: hop };
|
||||||
|
self.working.push(new);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
self.working.sort_unstable();
|
self.working.sort_unstable();
|
||||||
self.nearest.clear();
|
self.nearest.clear();
|
||||||
self.nearest.push(self.working[0]);
|
self.discarded.clear();
|
||||||
|
for candidate in self.working.drain(..) {
|
||||||
|
if self.nearest.len() >= num {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disadvantage candidates which are closer to an existing result point than they
|
||||||
|
// are to the query point, to facilitate bridging between clustered points.
|
||||||
|
let candidate_point = &points[candidate.pid];
|
||||||
|
let nearest = !self.nearest.iter().any(|result| {
|
||||||
|
let distance = OrderedFloat::from(candidate_point.distance(&points[result.pid]));
|
||||||
|
distance < candidate.distance
|
||||||
|
});
|
||||||
|
|
||||||
|
match nearest {
|
||||||
|
true => self.nearest.push(candidate),
|
||||||
|
false => self.discarded.push(candidate),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if params.keep_pruned {
|
if params.keep_pruned {
|
||||||
// Add discarded connections from `working` (`Wd`) to `self.nearest` (`R`)
|
// Add discarded connections from `working` (`Wd`) to `self.nearest` (`R`)
|
||||||
for candidate in self.working.drain(1..) {
|
for candidate in self.discarded.drain(..) {
|
||||||
if self.nearest.len() >= num {
|
if self.nearest.len() >= num {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -546,6 +603,7 @@ impl Search {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.nearest.sort_unstable();
|
||||||
&self.nearest[..min(self.nearest.len(), num)]
|
&self.nearest[..min(self.nearest.len(), num)]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -601,6 +659,7 @@ impl Default for Search {
|
||||||
candidates: BinaryHeap::new(),
|
candidates: BinaryHeap::new(),
|
||||||
nearest: Vec::new(),
|
nearest: Vec::new(),
|
||||||
working: Vec::new(),
|
working: Vec::new(),
|
||||||
|
discarded: Vec::new(),
|
||||||
ef: 1,
|
ef: 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue