diff --git a/src/lib.rs b/src/lib.rs index fe9b9fe..3f0afdd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -139,7 +139,7 @@ where for added in done.iter().copied() { search.push(added, &points[pid], &points); } - zero.insert_node(pid, &search.nearest, &points); + insert(&mut zero, pid, &search.nearest, &points); done.push(pid); pool.push(search); } @@ -226,6 +226,65 @@ where } } +/// Insert new node in the zero layer +/// +/// `new` contains the `PointId` for the new node; `found` is a slice containing all +/// `Candidate`s found during searching (ordered from near to far). +/// +/// Creates the new node, initializing its `nearest` array and updates the nearest neighbors +/// for the new node's neighbors if necessary before appending the new node to the layer. +fn insert(layer: &mut Vec, new: PointId, found: &[Candidate], points: &[P]) { + let mut node = ZeroNode::default(); + + // Just make sure the candidates are all unique + debug_assert_eq!( + found.len(), + found.iter().map(|c| c.pid).collect::>().len() + ); + + // Only use the `Self::LINKS` nearest candidates found + for (i, candidate) in found.iter().take(M * 2).enumerate() { + // `candidate` here is the new node's neighbor + let &Candidate { distance, pid } = candidate; + node.nearest[i] = pid; // Update the new node's `nearest` + + let old = &points[pid]; + let nearest = &layer[pid.0 as usize].nearest; + + // Find the correct index to insert at to keep the neighbor's neighbors sorted + let idx = nearest + .binary_search_by(|third| { + // `third` here is one of the neighbors of the new node's neighbor. + let third = match third { + pid if pid.is_valid() => *pid, + // if `third` is `None`, our new `node` is always "closer" + _ => return Ordering::Greater, + }; + + distance.cmp(&old.distance(&points[third.0 as usize]).into()) + }) + .unwrap_or_else(|e| e); + + // It might be possible for all the neighbor's current neighbors to be closer to our + // neighbor than to the new node, in which case we skip insertion of our new node's ID. + if idx >= nearest.len() { + continue; + } + + let nearest = &mut layer[pid.0 as usize].nearest; + if !nearest[idx].is_valid() { + nearest[idx] = new; + continue; + } + + let end = (M * 2) - 1; + nearest.copy_within(idx..end, idx + 1); + nearest[idx] = new; + } + + layer.push(node); +} + #[derive(Default)] struct SearchPool { pool: Vec, @@ -371,46 +430,22 @@ impl Builder { } impl Layer for Vec { - const LINKS: usize = M * 2; - - type Node = ZeroNode; - - fn push(&mut self, new: ZeroNode) { - self.push(new); - } - - fn nearest_mut(&mut self, pid: PointId) -> &mut [PointId] { - &mut self[pid.0 as usize].nearest - } - - fn nearest(&self, pid: PointId) -> &[PointId] { - &self[pid.0 as usize].nearest + fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { + NearestIter { + nearest: &self[pid.0 as usize].nearest, + } } } impl Layer for Vec { - const LINKS: usize = M; - - type Node = UpperNode; - - fn push(&mut self, new: UpperNode) { - self.push(new); - } - - fn nearest_mut(&mut self, pid: PointId) -> &mut [PointId] { - &mut self[pid.0 as usize].nearest - } - - fn nearest(&self, pid: PointId) -> &[PointId] { - &self[pid.0 as usize].nearest + fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { + NearestIter { + nearest: &self[pid.0 as usize].nearest, + } } } trait Layer { - const LINKS: usize; - - type Node: Node; - /// Search this layer for nodes near the given `point` /// /// This contains the loops from the paper's algorithm 2. `point` represents `q`, the query @@ -439,77 +474,7 @@ trait Layer { search.nearest.truncate(search.ef); } - /// Insert new node in this layer - /// - /// `new` contains the `PointId` for the new node; `found` is a slice containing all - /// `Candidate`s found during searching (ordered from near to far). - /// - /// Creates the new node, initializing its `nearest` array and updates the nearest neighbors - /// for the new node's neighbors if necessary. - fn insert_node(&mut self, new: PointId, found: &[Candidate], points: &[P]) { - let mut node = Self::Node::default(); - let new_nearest = node.nearest_mut(); - - // Just make sure the candidates are all unique - debug_assert_eq!( - found.len(), - found.iter().map(|c| c.pid).collect::>().len() - ); - - // Only use the `Self::LINKS` nearest candidates found - for (i, candidate) in found.iter().take(Self::LINKS).enumerate() { - // `candidate` here is the new node's neighbor - let &Candidate { distance, pid } = candidate; - new_nearest[i] = pid; // Update the new node's `nearest` - - let old = &points[pid]; - let nearest = self.nearest(pid); - - // Find the correct index to insert at to keep the neighbor's neighbors sorted - let idx = nearest - .binary_search_by(|third| { - // `third` here is one of the neighbors of the new node's neighbor. - let third = match third { - pid if pid.is_valid() => *pid, - // if `third` is `None`, our new `node` is always "closer" - _ => return Ordering::Greater, - }; - - distance.cmp(&old.distance(&points[third.0 as usize]).into()) - }) - .unwrap_or_else(|e| e); - - // It might be possible for all the neighbor's current neighbors to be closer to our - // neighbor than to the new node, in which case we skip insertion of our new node's ID. - if idx >= nearest.len() { - continue; - } - - let nearest = self.nearest_mut(pid); - if !nearest[idx].is_valid() { - nearest[idx] = new; - continue; - } - - let end = Self::LINKS - 1; - nearest.copy_within(idx..end, idx + 1); - nearest[idx] = new; - } - - self.push(node); - } - - fn push(&mut self, new: Self::Node); - - fn nearest_mut(&mut self, pid: PointId) -> &mut [PointId]; - - fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { - NearestIter { - nearest: self.nearest(pid), - } - } - - fn nearest(&self, pid: PointId) -> &[PointId]; + fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>; } #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] @@ -660,6 +625,14 @@ impl Index for [P] { } } +impl Index for [ZeroNode] { + type Output = ZeroNode; + + fn index(&self, index: PointId) -> &Self::Output { + &self[index.0 as usize] + } +} + /// The parameter `M` from the paper /// /// This should become a generic argument to `Hnsw` when possible.