diff --git a/src/lib.rs b/src/lib.rs
index 1a45967..84b9dcc 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -38,6 +38,7 @@ where
// construction. This allows us to copy higher layers to lower layers as construction
// progresses, while preserving randomness in each point's layer and insertion order.
+ assert!(points.len() < u32::MAX as usize);
let mut rng = SmallRng::from_entropy();
let mut nodes = (0..points.len())
.map(|i| (LayerId::random(&mut rng), i))
@@ -51,7 +52,7 @@ where
let mut new_nodes = Vec::with_capacity(points.len());
let mut out = vec![PointId::invalid(); points.len()];
for (i, &(layer, idx)) in nodes.iter().enumerate() {
- let pid = PointId(i);
+ let pid = PointId(i as u32);
new_points.push(points[idx].clone());
new_nodes.push((layer, pid));
out[idx] = pid;
@@ -136,8 +137,6 @@ where
/// The results are returned in the `out` parameter; the number of neighbors to search for
/// is limited by the size of the `out` parameter, and the number of results found is returned
/// in the return value.
- ///
- /// `PointId` values can be initialized with `PointId::invalid()`.
pub fn search(&self, point: &P, out: &mut [PointId], search: &mut Search) -> usize {
if self.points.is_empty() {
return 0;
@@ -249,7 +248,7 @@ impl
Index for Hnsw {
type Output = P;
fn index(&self, index: PointId) -> &Self::Output {
- &self.points[index.0]
+ &self.points[index.0 as usize]
}
}
@@ -257,7 +256,7 @@ impl Index for Vec {
type Output = P;
fn index(&self, index: PointId) -> &Self::Output {
- &self[index.0]
+ &self[index.0 as usize]
}
}
@@ -265,7 +264,7 @@ impl Index for [P] {
type Output = P;
fn index(&self, index: PointId) -> &Self::Output {
- &self[index.0]
+ &self[index.0 as usize]
}
}
@@ -329,7 +328,7 @@ trait Layer {
}
}
- for pid in self.nodes()[candidate.pid.0].nearest_iter().take(num) {
+ for pid in self.nodes()[candidate.pid.0 as usize].nearest_iter().take(num) {
search.push(pid, point, points);
}
}
@@ -356,22 +355,22 @@ trait Layer {
for (i, candidate) in found.iter().take(Self::LINKS).enumerate() {
// `candidate` here is the new node's neighbor
let &Candidate { distance, pid } = candidate;
- new_nearest[i] = Some(pid); // Update the new node's `nearest`
+ new_nearest[i] = pid; // Update the new node's `nearest`
let old = &points[pid];
- let nearest = self.nodes()[pid.0].nearest();
+ let nearest = self.nodes()[pid.0 as usize].nearest();
// Find the correct index to insert at to keep the neighbor's neighbors sorted
let idx = nearest
.binary_search_by(|third| {
// `third` here is one of the neighbors of the new node's neighbor.
let third = match third {
- Some(nid) => *nid,
+ pid if pid.is_valid() => *pid,
// if `third` is `None`, our new `node` is always "closer"
- None => return Ordering::Greater,
+ _ => return Ordering::Greater,
};
- let third_distance = OrderedFloat::from(old.distance(&points[third.0]));
+ let third_distance = OrderedFloat::from(old.distance(&points[third.0 as usize]));
distance.cmp(&third_distance)
})
.unwrap_or_else(|e| e);
@@ -382,15 +381,15 @@ trait Layer {
continue;
}
- let nearest = self.nodes_mut()[pid.0].nearest_mut();
- if nearest[idx].is_none() {
- nearest[idx] = Some(new);
+ let nearest = self.nodes_mut()[pid.0 as usize].nearest_mut();
+ if !nearest[idx].is_valid() {
+ nearest[idx] = new;
continue;
}
let end = Self::LINKS - 1;
nearest.copy_within(idx..end, idx + 1);
- nearest[idx] = Some(new);
+ nearest[idx] = new;
}
self.push(node);
@@ -409,15 +408,15 @@ struct UpperNode {
/// The nearest neighbors on this layer
///
/// This is always kept in sorted order (near to far).
- nearest: [Option; M],
+ nearest: [PointId; M],
}
impl Node for UpperNode {
- fn nearest(&self) -> &[Option] {
+ fn nearest(&self) -> &[PointId] {
&self.nearest
}
- fn nearest_mut(&mut self) -> &mut [Option] {
+ fn nearest_mut(&mut self) -> &mut [PointId] {
&mut self.nearest
}
@@ -434,15 +433,15 @@ struct ZeroNode {
/// The nearest neighbors on this layer
///
/// This is always kept in sorted order (near to far).
- nearest: [Option; M * 2],
+ nearest: [PointId; M * 2],
}
impl Node for ZeroNode {
- fn nearest(&self) -> &[Option] {
+ fn nearest(&self) -> &[PointId] {
&self.nearest
}
- fn nearest_mut(&mut self) -> &mut [Option] {
+ fn nearest_mut(&mut self) -> &mut [PointId] {
&mut self.nearest
}
@@ -454,13 +453,13 @@ impl Node for ZeroNode {
}
trait Node: Default {
- fn nearest(&self) -> &[Option];
- fn nearest_mut(&mut self) -> &mut [Option];
+ fn nearest(&self) -> &[PointId];
+ fn nearest_mut(&mut self) -> &mut [PointId];
fn nearest_iter(&self) -> NearestIter<'_>;
}
struct NearestIter<'a> {
- nearest: &'a [Option],
+ nearest: &'a [PointId],
}
impl<'a> Iterator for NearestIter<'a> {
@@ -468,11 +467,11 @@ impl<'a> Iterator for NearestIter<'a> {
fn next(&mut self) -> Option {
let (&first, rest) = self.nearest.split_first()?;
- self.nearest = rest;
- if first.is_none() {
- self.nearest = &[];
+ if !first.is_valid() {
+ return None;
}
- first
+ self.nearest = rest;
+ Some(first)
}
}
@@ -530,11 +529,22 @@ struct Candidate {
/// This can be used to index into the `Hnsw` to refer to the `Point` data.
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
-pub struct PointId(usize);
+pub struct PointId(u32);
impl PointId {
- pub fn invalid() -> Self {
- PointId(usize::MAX)
+ fn invalid() -> Self {
+ PointId(u32::MAX)
+ }
+
+ /// Whether this value represents a valid point
+ pub fn is_valid(self) -> bool {
+ self.0 != u32::MAX
+ }
+}
+
+impl Default for PointId {
+ fn default() -> Self {
+ PointId::invalid()
}
}