diff --git a/src/lib.rs b/src/lib.rs index 1a45967..84b9dcc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,6 +38,7 @@ where // construction. This allows us to copy higher layers to lower layers as construction // progresses, while preserving randomness in each point's layer and insertion order. + assert!(points.len() < u32::MAX as usize); let mut rng = SmallRng::from_entropy(); let mut nodes = (0..points.len()) .map(|i| (LayerId::random(&mut rng), i)) @@ -51,7 +52,7 @@ where let mut new_nodes = Vec::with_capacity(points.len()); let mut out = vec![PointId::invalid(); points.len()]; for (i, &(layer, idx)) in nodes.iter().enumerate() { - let pid = PointId(i); + let pid = PointId(i as u32); new_points.push(points[idx].clone()); new_nodes.push((layer, pid)); out[idx] = pid; @@ -136,8 +137,6 @@ where /// The results are returned in the `out` parameter; the number of neighbors to search for /// is limited by the size of the `out` parameter, and the number of results found is returned /// in the return value. - /// - /// `PointId` values can be initialized with `PointId::invalid()`. pub fn search(&self, point: &P, out: &mut [PointId], search: &mut Search) -> usize { if self.points.is_empty() { return 0; @@ -249,7 +248,7 @@ impl

Index for Hnsw

{ type Output = P; fn index(&self, index: PointId) -> &Self::Output { - &self.points[index.0] + &self.points[index.0 as usize] } } @@ -257,7 +256,7 @@ impl Index for Vec

{ type Output = P; fn index(&self, index: PointId) -> &Self::Output { - &self[index.0] + &self[index.0 as usize] } } @@ -265,7 +264,7 @@ impl Index for [P] { type Output = P; fn index(&self, index: PointId) -> &Self::Output { - &self[index.0] + &self[index.0 as usize] } } @@ -329,7 +328,7 @@ trait Layer { } } - for pid in self.nodes()[candidate.pid.0].nearest_iter().take(num) { + for pid in self.nodes()[candidate.pid.0 as usize].nearest_iter().take(num) { search.push(pid, point, points); } } @@ -356,22 +355,22 @@ trait Layer { for (i, candidate) in found.iter().take(Self::LINKS).enumerate() { // `candidate` here is the new node's neighbor let &Candidate { distance, pid } = candidate; - new_nearest[i] = Some(pid); // Update the new node's `nearest` + new_nearest[i] = pid; // Update the new node's `nearest` let old = &points[pid]; - let nearest = self.nodes()[pid.0].nearest(); + let nearest = self.nodes()[pid.0 as usize].nearest(); // Find the correct index to insert at to keep the neighbor's neighbors sorted let idx = nearest .binary_search_by(|third| { // `third` here is one of the neighbors of the new node's neighbor. let third = match third { - Some(nid) => *nid, + pid if pid.is_valid() => *pid, // if `third` is `None`, our new `node` is always "closer" - None => return Ordering::Greater, + _ => return Ordering::Greater, }; - let third_distance = OrderedFloat::from(old.distance(&points[third.0])); + let third_distance = OrderedFloat::from(old.distance(&points[third.0 as usize])); distance.cmp(&third_distance) }) .unwrap_or_else(|e| e); @@ -382,15 +381,15 @@ trait Layer { continue; } - let nearest = self.nodes_mut()[pid.0].nearest_mut(); - if nearest[idx].is_none() { - nearest[idx] = Some(new); + let nearest = self.nodes_mut()[pid.0 as usize].nearest_mut(); + if !nearest[idx].is_valid() { + nearest[idx] = new; continue; } let end = Self::LINKS - 1; nearest.copy_within(idx..end, idx + 1); - nearest[idx] = Some(new); + nearest[idx] = new; } self.push(node); @@ -409,15 +408,15 @@ struct UpperNode { /// The nearest neighbors on this layer /// /// This is always kept in sorted order (near to far). - nearest: [Option; M], + nearest: [PointId; M], } impl Node for UpperNode { - fn nearest(&self) -> &[Option] { + fn nearest(&self) -> &[PointId] { &self.nearest } - fn nearest_mut(&mut self) -> &mut [Option] { + fn nearest_mut(&mut self) -> &mut [PointId] { &mut self.nearest } @@ -434,15 +433,15 @@ struct ZeroNode { /// The nearest neighbors on this layer /// /// This is always kept in sorted order (near to far). - nearest: [Option; M * 2], + nearest: [PointId; M * 2], } impl Node for ZeroNode { - fn nearest(&self) -> &[Option] { + fn nearest(&self) -> &[PointId] { &self.nearest } - fn nearest_mut(&mut self) -> &mut [Option] { + fn nearest_mut(&mut self) -> &mut [PointId] { &mut self.nearest } @@ -454,13 +453,13 @@ impl Node for ZeroNode { } trait Node: Default { - fn nearest(&self) -> &[Option]; - fn nearest_mut(&mut self) -> &mut [Option]; + fn nearest(&self) -> &[PointId]; + fn nearest_mut(&mut self) -> &mut [PointId]; fn nearest_iter(&self) -> NearestIter<'_>; } struct NearestIter<'a> { - nearest: &'a [Option], + nearest: &'a [PointId], } impl<'a> Iterator for NearestIter<'a> { @@ -468,11 +467,11 @@ impl<'a> Iterator for NearestIter<'a> { fn next(&mut self) -> Option { let (&first, rest) = self.nearest.split_first()?; - self.nearest = rest; - if first.is_none() { - self.nearest = &[]; + if !first.is_valid() { + return None; } - first + self.nearest = rest; + Some(first) } } @@ -530,11 +529,22 @@ struct Candidate { /// This can be used to index into the `Hnsw` to refer to the `Point` data. #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub struct PointId(usize); +pub struct PointId(u32); impl PointId { - pub fn invalid() -> Self { - PointId(usize::MAX) + fn invalid() -> Self { + PointId(u32::MAX) + } + + /// Whether this value represents a valid point + pub fn is_valid(self) -> bool { + self.0 != u32::MAX + } +} + +impl Default for PointId { + fn default() -> Self { + PointId::invalid() } }