diff --git a/src/lib.rs b/src/lib.rs index 9e19894..ee80f20 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -83,6 +83,15 @@ pub struct Heuristic { pub keep_pruned: bool, } +impl Default for Heuristic { + fn default() -> Self { + Heuristic { + extend_candidates: false, + keep_pruned: true, + } + } +} + #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] pub struct Hnsw

{ ef_search: usize, @@ -204,6 +213,7 @@ where let mut pool = SearchPool::default(); let mut batch = Vec::new(); let mut done = Vec::new(); + let mut insertion = Search::default(); let max_batch_len = num_cpus::get() * 4; for (layer, mut range) in ranges { let num = if layer.0 > 0 { M } else { M * 2 }; @@ -245,7 +255,14 @@ where search.push(added, &points[pid], &points); } - insert(pid, &mut search, &mut zero, &points, &builder.heuristic); + insert( + pid, + &mut insertion, + &mut search, + &mut zero, + &points, + &builder.heuristic, + ); done.push(pid); pool.push(search); } @@ -334,23 +351,27 @@ where /// Insert new node in the zero layer /// -/// `new` contains the `PointId` for the new node; `search` contains the result for searching -/// potential neighbors for the new node; `layer` contains all the nodes at the current layer; -/// `points` is a slice of all the points in the index. +/// * `new`: the `PointId` for the new node +/// * `insertion`: a `Search` for shrinking a neighbor set (only used with heuristic neighbor selection) +/// * `search`: the result for searching potential neighbors for the new node +/// * `layer` contains all the nodes at the current layer +/// * `points` is a slice of all the points in the index /// /// Creates the new node, initializing its `nearest` array and updates the nearest neighbors /// for the new node's neighbors if necessary before appending the new node to the layer. fn insert( new: PointId, + insertion: &mut Search, search: &mut Search, layer: &mut Vec, points: &[P], heuristic: &Option, ) { + layer.push(ZeroNode::default()); let found = match heuristic { None => search.select_simple(M * 2), Some(heuristic) => { - search.select_heuristic(&layer, M * 2, &points[new], &points, *heuristic) + search.select_heuristic(&layer, M * 2, &points[new], points, *heuristic) } }; @@ -360,45 +381,57 @@ fn insert( found.iter().map(|c| c.pid).collect::>().len() ); - let mut node = ZeroNode::default(); for (i, candidate) in found.iter().enumerate() { // `candidate` here is the new node's neighbor let &Candidate { distance, pid } = candidate; - node.nearest[i] = pid; // Update the new node's `nearest` + if let Some(heuristic) = heuristic { + insertion.reset(); + let candidate_point = &points[pid]; + insertion.push(new, candidate_point, points); + for hop in layer.nearest_iter(pid) { + insertion.push(hop, candidate_point, points); + } - let old = &points[pid]; - let nearest = &layer[pid.0 as usize].nearest; + let found = + insertion.select_heuristic(&layer, M * 2, candidate_point, points, *heuristic); + for (slot, hop) in layer[pid.0 as usize].nearest.iter_mut().zip(found) { + *slot = hop.pid; + } + layer[new.0 as usize].nearest[i] = pid; + } else { + // Find the correct index to insert at to keep the neighbor's neighbors sorted + let old = &points[pid]; + let nearest = &layer[pid.0 as usize].nearest; + let idx = nearest + .binary_search_by(|third| { + // `third` here is one of the neighbors of the new node's neighbor. + let third = match third { + pid if pid.is_valid() => *pid, + // if `third` is `None`, our new `node` is always "closer" + _ => return Ordering::Greater, + }; - // Find the correct index to insert at to keep the neighbor's neighbors sorted - let idx = nearest - .binary_search_by(|third| { - // `third` here is one of the neighbors of the new node's neighbor. - let third = match third { - pid if pid.is_valid() => *pid, - // if `third` is `None`, our new `node` is always "closer" - _ => return Ordering::Greater, - }; + distance.cmp(&old.distance(&points[third.0 as usize]).into()) + }) + .unwrap_or_else(|e| e); - distance.cmp(&old.distance(&points[third.0 as usize]).into()) - }) - .unwrap_or_else(|e| e); + // It might be possible for all the neighbor's current neighbors to be closer to our + // neighbor than to the new node, in which case we skip insertion of our new node's ID. + if idx >= nearest.len() { + layer[new.0 as usize].nearest[i] = pid; + continue; + } - // It might be possible for all the neighbor's current neighbors to be closer to our - // neighbor than to the new node, in which case we skip insertion of our new node's ID. - if idx >= nearest.len() { - continue; + let nearest = &mut layer[pid.0 as usize].nearest; + if nearest[idx].is_valid() { + let end = (M * 2) - 1; + nearest.copy_within(idx..end, idx + 1); + } + + nearest[idx] = new; + layer[new.0 as usize].nearest[i] = pid; } - - let nearest = &mut layer[pid.0 as usize].nearest; - if nearest[idx].is_valid() { - let end = (M * 2) - 1; - nearest.copy_within(idx..end, idx + 1); - } - - nearest[idx] = new; } - - layer.push(node); } #[derive(Default)] @@ -483,6 +516,7 @@ pub struct Search { nearest: Vec, /// Working set for heuristic selection working: Vec, + discarded: Vec, /// Maximum number of nearest neighbors to retain (`ef` in the paper) ef: usize, } @@ -495,6 +529,7 @@ impl Search { candidates, nearest, working, + discarded, ef: _, } = self; @@ -502,6 +537,7 @@ impl Search { candidates.clear(); nearest.clear(); working.clear(); + discarded.clear(); } /// Selection of neighbors for insertion (algorithm 3 from the paper) @@ -521,24 +557,45 @@ impl Search { self.working.clear(); // Get input candidates from `self.nearest` and store them in `self.working`. // `self.candidates` will represent `W` from the paper's algorithm 4 for now. - for &candidate in &self.nearest { + while let Some(Reverse(candidate)) = self.candidates.pop() { self.working.push(candidate); - if params.extend_candidates { - for pid in layer.nearest_iter(candidate.pid).take(num) { - let other = &points[pid]; - let distance = OrderedFloat::from(point.distance(other)); - self.working.push(Candidate { distance, pid }); + for hop in layer.nearest_iter(candidate.pid) { + if !self.visited.insert(hop) { + continue; } + + let other = &points[hop]; + let distance = OrderedFloat::from(point.distance(other)); + let new = Candidate { distance, pid: hop }; + self.working.push(new); } } self.working.sort_unstable(); self.nearest.clear(); - self.nearest.push(self.working[0]); + self.discarded.clear(); + for candidate in self.working.drain(..) { + if self.nearest.len() >= num { + break; + } + + // Disadvantage candidates which are closer to an existing result point than they + // are to the query point, to facilitate bridging between clustered points. + let candidate_point = &points[candidate.pid]; + let nearest = !self.nearest.iter().any(|result| { + let distance = OrderedFloat::from(candidate_point.distance(&points[result.pid])); + distance < candidate.distance + }); + + match nearest { + true => self.nearest.push(candidate), + false => self.discarded.push(candidate), + } + } if params.keep_pruned { // Add discarded connections from `working` (`Wd`) to `self.nearest` (`R`) - for candidate in self.working.drain(1..) { + for candidate in self.discarded.drain(..) { if self.nearest.len() >= num { break; } @@ -546,6 +603,7 @@ impl Search { } } + self.nearest.sort_unstable(); &self.nearest[..min(self.nearest.len(), num)] } @@ -601,6 +659,7 @@ impl Default for Search { candidates: BinaryHeap::new(), nearest: Vec::new(), working: Vec::new(), + discarded: Vec::new(), ef: 1, } }