diff --git a/src/lib.rs b/src/lib.rs index 49e37bf..357ca07 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize}; mod types; pub use types::PointId; -use types::{Candidate, Layer, LayerId, UpperNode, Visited, ZeroNode}; +use types::{Candidate, Layer, LayerId, UpperNode, Visited, ZeroNode, INVALID}; /// Parameters for building the `Hnsw` pub struct Builder { @@ -178,7 +178,7 @@ where let mut prev_layer = nodes[0].0; let mut new_points = Vec::with_capacity(points.len()); let mut new_nodes = Vec::with_capacity(points.len()); - let mut out = vec![PointId::invalid(); points.len()]; + let mut out = vec![INVALID; points.len()]; for (i, &(layer, idx)) in nodes.iter().enumerate() { if prev_layer != layer { cur_layer = LayerId(cur_layer.0 - 1); diff --git a/src/types.rs b/src/types.rs index bd46af1..0c3af9f 100644 --- a/src/types.rs +++ b/src/types.rs @@ -63,7 +63,7 @@ pub(crate) struct UpperNode([PointId; M]); impl UpperNode { pub(crate) fn from_zero(node: &ZeroNode) -> Self { - let mut nearest = [PointId::invalid(); M]; + let mut nearest = [INVALID; M]; nearest.copy_from_slice(&node.0[..M]); Self(nearest) } @@ -86,8 +86,8 @@ impl ZeroNode { for slot in self.0.iter_mut() { if let Some(pid) = iter.next() { *slot = pid; - } else if *slot != PointId::invalid() { - *slot = PointId::invalid(); + } else if *slot != INVALID { + *slot = INVALID; } else { break; } @@ -116,7 +116,7 @@ impl ZeroNode { impl Default for ZeroNode { fn default() -> ZeroNode { - ZeroNode([PointId::invalid(); M * 2]) + ZeroNode([INVALID; M * 2]) } } @@ -210,10 +210,6 @@ pub(crate) struct Candidate { pub struct PointId(pub(crate) u32); impl PointId { - pub(crate) fn invalid() -> Self { - PointId(u32::MAX) - } - /// Whether this value represents a valid point pub fn is_valid(self) -> bool { self.0 != u32::MAX @@ -222,7 +218,7 @@ impl PointId { impl Default for PointId { fn default() -> Self { - PointId::invalid() + INVALID } } @@ -263,3 +259,5 @@ impl IndexMut for Vec { &mut self[index.0 as usize] } } + +pub(crate) const INVALID: PointId = PointId(u32::MAX);