diff --git a/src/lib.rs b/src/lib.rs
index 9e19894..ee80f20 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -83,6 +83,15 @@ pub struct Heuristic {
pub keep_pruned: bool,
}
+impl Default for Heuristic {
+ fn default() -> Self {
+ Heuristic {
+ extend_candidates: false,
+ keep_pruned: true,
+ }
+ }
+}
+
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct Hnsw
{
ef_search: usize,
@@ -204,6 +213,7 @@ where
let mut pool = SearchPool::default();
let mut batch = Vec::new();
let mut done = Vec::new();
+ let mut insertion = Search::default();
let max_batch_len = num_cpus::get() * 4;
for (layer, mut range) in ranges {
let num = if layer.0 > 0 { M } else { M * 2 };
@@ -245,7 +255,14 @@ where
search.push(added, &points[pid], &points);
}
- insert(pid, &mut search, &mut zero, &points, &builder.heuristic);
+ insert(
+ pid,
+ &mut insertion,
+ &mut search,
+ &mut zero,
+ &points,
+ &builder.heuristic,
+ );
done.push(pid);
pool.push(search);
}
@@ -334,23 +351,27 @@ where
/// Insert new node in the zero layer
///
-/// `new` contains the `PointId` for the new node; `search` contains the result for searching
-/// potential neighbors for the new node; `layer` contains all the nodes at the current layer;
-/// `points` is a slice of all the points in the index.
+/// * `new`: the `PointId` for the new node
+/// * `insertion`: a `Search` for shrinking a neighbor set (only used with heuristic neighbor selection)
+/// * `search`: the result for searching potential neighbors for the new node
+/// * `layer` contains all the nodes at the current layer
+/// * `points` is a slice of all the points in the index
///
/// Creates the new node, initializing its `nearest` array and updates the nearest neighbors
/// for the new node's neighbors if necessary before appending the new node to the layer.
fn insert(
new: PointId,
+ insertion: &mut Search,
search: &mut Search,
layer: &mut Vec,
points: &[P],
heuristic: &Option,
) {
+ layer.push(ZeroNode::default());
let found = match heuristic {
None => search.select_simple(M * 2),
Some(heuristic) => {
- search.select_heuristic(&layer, M * 2, &points[new], &points, *heuristic)
+ search.select_heuristic(&layer, M * 2, &points[new], points, *heuristic)
}
};
@@ -360,45 +381,57 @@ fn insert(
found.iter().map(|c| c.pid).collect::>().len()
);
- let mut node = ZeroNode::default();
for (i, candidate) in found.iter().enumerate() {
// `candidate` here is the new node's neighbor
let &Candidate { distance, pid } = candidate;
- node.nearest[i] = pid; // Update the new node's `nearest`
+ if let Some(heuristic) = heuristic {
+ insertion.reset();
+ let candidate_point = &points[pid];
+ insertion.push(new, candidate_point, points);
+ for hop in layer.nearest_iter(pid) {
+ insertion.push(hop, candidate_point, points);
+ }
- let old = &points[pid];
- let nearest = &layer[pid.0 as usize].nearest;
+ let found =
+ insertion.select_heuristic(&layer, M * 2, candidate_point, points, *heuristic);
+ for (slot, hop) in layer[pid.0 as usize].nearest.iter_mut().zip(found) {
+ *slot = hop.pid;
+ }
+ layer[new.0 as usize].nearest[i] = pid;
+ } else {
+ // Find the correct index to insert at to keep the neighbor's neighbors sorted
+ let old = &points[pid];
+ let nearest = &layer[pid.0 as usize].nearest;
+ let idx = nearest
+ .binary_search_by(|third| {
+ // `third` here is one of the neighbors of the new node's neighbor.
+ let third = match third {
+ pid if pid.is_valid() => *pid,
+ // if `third` is `None`, our new `node` is always "closer"
+ _ => return Ordering::Greater,
+ };
- // Find the correct index to insert at to keep the neighbor's neighbors sorted
- let idx = nearest
- .binary_search_by(|third| {
- // `third` here is one of the neighbors of the new node's neighbor.
- let third = match third {
- pid if pid.is_valid() => *pid,
- // if `third` is `None`, our new `node` is always "closer"
- _ => return Ordering::Greater,
- };
+ distance.cmp(&old.distance(&points[third.0 as usize]).into())
+ })
+ .unwrap_or_else(|e| e);
- distance.cmp(&old.distance(&points[third.0 as usize]).into())
- })
- .unwrap_or_else(|e| e);
+ // It might be possible for all the neighbor's current neighbors to be closer to our
+ // neighbor than to the new node, in which case we skip insertion of our new node's ID.
+ if idx >= nearest.len() {
+ layer[new.0 as usize].nearest[i] = pid;
+ continue;
+ }
- // It might be possible for all the neighbor's current neighbors to be closer to our
- // neighbor than to the new node, in which case we skip insertion of our new node's ID.
- if idx >= nearest.len() {
- continue;
+ let nearest = &mut layer[pid.0 as usize].nearest;
+ if nearest[idx].is_valid() {
+ let end = (M * 2) - 1;
+ nearest.copy_within(idx..end, idx + 1);
+ }
+
+ nearest[idx] = new;
+ layer[new.0 as usize].nearest[i] = pid;
}
-
- let nearest = &mut layer[pid.0 as usize].nearest;
- if nearest[idx].is_valid() {
- let end = (M * 2) - 1;
- nearest.copy_within(idx..end, idx + 1);
- }
-
- nearest[idx] = new;
}
-
- layer.push(node);
}
#[derive(Default)]
@@ -483,6 +516,7 @@ pub struct Search {
nearest: Vec,
/// Working set for heuristic selection
working: Vec,
+ discarded: Vec,
/// Maximum number of nearest neighbors to retain (`ef` in the paper)
ef: usize,
}
@@ -495,6 +529,7 @@ impl Search {
candidates,
nearest,
working,
+ discarded,
ef: _,
} = self;
@@ -502,6 +537,7 @@ impl Search {
candidates.clear();
nearest.clear();
working.clear();
+ discarded.clear();
}
/// Selection of neighbors for insertion (algorithm 3 from the paper)
@@ -521,24 +557,45 @@ impl Search {
self.working.clear();
// Get input candidates from `self.nearest` and store them in `self.working`.
// `self.candidates` will represent `W` from the paper's algorithm 4 for now.
- for &candidate in &self.nearest {
+ while let Some(Reverse(candidate)) = self.candidates.pop() {
self.working.push(candidate);
- if params.extend_candidates {
- for pid in layer.nearest_iter(candidate.pid).take(num) {
- let other = &points[pid];
- let distance = OrderedFloat::from(point.distance(other));
- self.working.push(Candidate { distance, pid });
+ for hop in layer.nearest_iter(candidate.pid) {
+ if !self.visited.insert(hop) {
+ continue;
}
+
+ let other = &points[hop];
+ let distance = OrderedFloat::from(point.distance(other));
+ let new = Candidate { distance, pid: hop };
+ self.working.push(new);
}
}
self.working.sort_unstable();
self.nearest.clear();
- self.nearest.push(self.working[0]);
+ self.discarded.clear();
+ for candidate in self.working.drain(..) {
+ if self.nearest.len() >= num {
+ break;
+ }
+
+ // Disadvantage candidates which are closer to an existing result point than they
+ // are to the query point, to facilitate bridging between clustered points.
+ let candidate_point = &points[candidate.pid];
+ let nearest = !self.nearest.iter().any(|result| {
+ let distance = OrderedFloat::from(candidate_point.distance(&points[result.pid]));
+ distance < candidate.distance
+ });
+
+ match nearest {
+ true => self.nearest.push(candidate),
+ false => self.discarded.push(candidate),
+ }
+ }
if params.keep_pruned {
// Add discarded connections from `working` (`Wd`) to `self.nearest` (`R`)
- for candidate in self.working.drain(1..) {
+ for candidate in self.discarded.drain(..) {
if self.nearest.len() >= num {
break;
}
@@ -546,6 +603,7 @@ impl Search {
}
}
+ self.nearest.sort_unstable();
&self.nearest[..min(self.nearest.len(), num)]
}
@@ -601,6 +659,7 @@ impl Default for Search {
candidates: BinaryHeap::new(),
nearest: Vec::new(),
working: Vec::new(),
+ discarded: Vec::new(),
ef: 1,
}
}