Use a constant for invalid PointId

This commit is contained in:
Dirkjan Ochtman 2021-01-21 10:40:36 +01:00
parent 96b69e5d4b
commit f388fd0a46
2 changed files with 9 additions and 11 deletions

View File

@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
mod types;
pub use types::PointId;
use types::{Candidate, Layer, LayerId, UpperNode, Visited, ZeroNode};
use types::{Candidate, Layer, LayerId, UpperNode, Visited, ZeroNode, INVALID};
/// Parameters for building the `Hnsw`
pub struct Builder {
@ -178,7 +178,7 @@ where
let mut prev_layer = nodes[0].0;
let mut new_points = Vec::with_capacity(points.len());
let mut new_nodes = Vec::with_capacity(points.len());
let mut out = vec![PointId::invalid(); points.len()];
let mut out = vec![INVALID; points.len()];
for (i, &(layer, idx)) in nodes.iter().enumerate() {
if prev_layer != layer {
cur_layer = LayerId(cur_layer.0 - 1);

View File

@ -63,7 +63,7 @@ pub(crate) struct UpperNode([PointId; M]);
impl UpperNode {
pub(crate) fn from_zero(node: &ZeroNode) -> Self {
let mut nearest = [PointId::invalid(); M];
let mut nearest = [INVALID; M];
nearest.copy_from_slice(&node.0[..M]);
Self(nearest)
}
@ -86,8 +86,8 @@ impl ZeroNode {
for slot in self.0.iter_mut() {
if let Some(pid) = iter.next() {
*slot = pid;
} else if *slot != PointId::invalid() {
*slot = PointId::invalid();
} else if *slot != INVALID {
*slot = INVALID;
} else {
break;
}
@ -116,7 +116,7 @@ impl ZeroNode {
impl Default for ZeroNode {
fn default() -> ZeroNode {
ZeroNode([PointId::invalid(); M * 2])
ZeroNode([INVALID; M * 2])
}
}
@ -210,10 +210,6 @@ pub(crate) struct Candidate {
pub struct PointId(pub(crate) u32);
impl PointId {
pub(crate) fn invalid() -> Self {
PointId(u32::MAX)
}
/// Whether this value represents a valid point
pub fn is_valid(self) -> bool {
self.0 != u32::MAX
@ -222,7 +218,7 @@ impl PointId {
impl Default for PointId {
fn default() -> Self {
PointId::invalid()
INVALID
}
}
@ -263,3 +259,5 @@ impl IndexMut<PointId> for Vec<ZeroNode> {
&mut self[index.0 as usize]
}
}
pub(crate) const INVALID: PointId = PointId(u32::MAX);