diff --git a/src/lib.rs b/src/lib.rs index 8d9bd2f..2397095 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,18 +1,20 @@ use std::cmp::{max, min, Ordering, Reverse}; use std::collections::BinaryHeap; -use std::hash::Hash; -use std::ops::{Index, IndexMut}; use ahash::AHashSet as HashSet; #[cfg(feature = "indicatif")] use indicatif::ProgressBar; use ordered_float::OrderedFloat; use rand::rngs::SmallRng; -use rand::{Rng, SeedableRng}; +use rand::SeedableRng; use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +mod types; +pub use types::PointId; +use types::{Candidate, LayerId, NearestIter, UpperNode, ZeroNode}; + /// Parameters for building the `Hnsw` pub struct Builder { ef_search: Option, @@ -475,22 +477,6 @@ impl SearchPool { } } -impl Layer for [ZeroNode] { - fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { - NearestIter { - nearest: &self[pid.0 as usize].nearest, - } - } -} - -impl Layer for [UpperNode] { - fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { - NearestIter { - nearest: &self[pid.0 as usize].nearest, - } - } -} - trait Layer { /// Search this layer for nodes near the given `point` /// @@ -687,160 +673,10 @@ impl Default for Search { } } -#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] -#[derive(Clone, Copy, Debug, Default)] -struct UpperNode { - /// The nearest neighbors on this layer - /// - /// This is always kept in sorted order (near to far). - nearest: [PointId; M], -} - -#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] -#[derive(Clone, Copy, Debug)] -struct ZeroNode { - /// The nearest neighbors on this layer - /// - /// This is always kept in sorted order (near to far). - nearest: [PointId; M * 2], -} - -impl Default for ZeroNode { - fn default() -> ZeroNode { - ZeroNode { - nearest: [PointId::invalid(); M * 2], - } - } -} - -struct NearestIter<'a> { - nearest: &'a [PointId], -} - -impl<'a> Iterator for NearestIter<'a> { - type Item = PointId; - - fn next(&mut self) -> Option { - let (&first, rest) = self.nearest.split_first()?; - if !first.is_valid() { - return None; - } - self.nearest = rest; - Some(first) - } -} - -#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -struct LayerId(usize); - -impl LayerId { - fn random(ml: f32, rng: &mut SmallRng) -> Self { - let layer = rng.gen::(); - LayerId((-(layer.ln() * ml)).floor() as usize) - } - - fn descend(&self) -> DescendingLayerIter { - DescendingLayerIter { next: Some(self.0) } - } - - fn is_zero(&self) -> bool { - self.0 == 0 - } -} - -struct DescendingLayerIter { - next: Option, -} - -impl Iterator for DescendingLayerIter { - type Item = LayerId; - - fn next(&mut self) -> Option { - Some(LayerId(match self.next? { - 0 => { - self.next = None; - 0 - } - next => { - self.next = Some(next - 1); - next - } - })) - } -} - pub trait Point: Clone + Sync { fn distance(&self, other: &Self) -> f32; } -#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)] -struct Candidate { - distance: OrderedFloat, - pid: PointId, -} - -/// References a `Point` in the `Hnsw` -/// -/// This can be used to index into the `Hnsw` to refer to the `Point` data. -#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] -#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub struct PointId(u32); - -impl PointId { - fn invalid() -> Self { - PointId(u32::MAX) - } - - /// Whether this value represents a valid point - pub fn is_valid(self) -> bool { - self.0 != u32::MAX - } -} - -impl Default for PointId { - fn default() -> Self { - PointId::invalid() - } -} - -impl

Index for Hnsw

{ - type Output = P; - - fn index(&self, index: PointId) -> &Self::Output { - &self.points[index.0 as usize] - } -} - -impl Index for Vec

{ - type Output = P; - - fn index(&self, index: PointId) -> &Self::Output { - &self[index.0 as usize] - } -} - -impl Index for [P] { - type Output = P; - - fn index(&self, index: PointId) -> &Self::Output { - &self[index.0 as usize] - } -} - -impl Index for Vec { - type Output = ZeroNode; - - fn index(&self, index: PointId) -> &Self::Output { - &self[index.0 as usize] - } -} - -impl IndexMut for Vec { - fn index_mut(&mut self, index: PointId) -> &mut Self::Output { - &mut self[index.0 as usize] - } -} - /// The parameter `M` from the paper /// /// This should become a generic argument to `Hnsw` when possible. diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 0000000..ab42581 --- /dev/null +++ b/src/types.rs @@ -0,0 +1,176 @@ +use std::hash::Hash; +use std::ops::{Index, IndexMut}; + +use ordered_float::OrderedFloat; +use rand::rngs::SmallRng; +use rand::Rng; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::{Hnsw, Layer, Point, M}; + +#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] +#[derive(Clone, Copy, Debug, Default)] +pub(crate) struct UpperNode { + /// The nearest neighbors on this layer + /// + /// This is always kept in sorted order (near to far). + pub(crate) nearest: [PointId; M], +} + +impl Layer for [UpperNode] { + fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { + NearestIter { + nearest: &self[pid.0 as usize].nearest, + } + } +} + +#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] +#[derive(Clone, Copy, Debug)] +pub(crate) struct ZeroNode { + /// The nearest neighbors on this layer + /// + /// This is always kept in sorted order (near to far). + pub(crate) nearest: [PointId; M * 2], +} + +impl Default for ZeroNode { + fn default() -> ZeroNode { + ZeroNode { + nearest: [PointId::invalid(); M * 2], + } + } +} + +impl Layer for [ZeroNode] { + fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { + NearestIter { + nearest: &self[pid.0 as usize].nearest, + } + } +} + +pub(crate) struct NearestIter<'a> { + nearest: &'a [PointId], +} + +impl<'a> Iterator for NearestIter<'a> { + type Item = PointId; + + fn next(&mut self) -> Option { + let (&first, rest) = self.nearest.split_first()?; + if !first.is_valid() { + return None; + } + self.nearest = rest; + Some(first) + } +} + +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub(crate) struct LayerId(pub usize); + +impl LayerId { + pub(crate) fn random(ml: f32, rng: &mut SmallRng) -> Self { + let layer = rng.gen::(); + LayerId((-(layer.ln() * ml)).floor() as usize) + } + + pub(crate) fn descend(&self) -> impl Iterator { + DescendingLayerIter { next: Some(self.0) } + } + + pub(crate) fn is_zero(&self) -> bool { + self.0 == 0 + } +} + +struct DescendingLayerIter { + next: Option, +} + +impl Iterator for DescendingLayerIter { + type Item = LayerId; + + fn next(&mut self) -> Option { + Some(LayerId(match self.next? { + 0 => { + self.next = None; + 0 + } + next => { + self.next = Some(next - 1); + next + } + })) + } +} + +#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)] +pub(crate) struct Candidate { + pub(crate) distance: OrderedFloat, + pub(crate) pid: PointId, +} + +/// References a `Point` in the `Hnsw` +/// +/// This can be used to index into the `Hnsw` to refer to the `Point` data. +#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub struct PointId(pub(crate) u32); + +impl PointId { + pub(crate) fn invalid() -> Self { + PointId(u32::MAX) + } + + /// Whether this value represents a valid point + pub fn is_valid(self) -> bool { + self.0 != u32::MAX + } +} + +impl Default for PointId { + fn default() -> Self { + PointId::invalid() + } +} + +impl

Index for Hnsw

{ + type Output = P; + + fn index(&self, index: PointId) -> &Self::Output { + &self.points[index.0 as usize] + } +} + +impl Index for Vec

{ + type Output = P; + + fn index(&self, index: PointId) -> &Self::Output { + &self[index.0 as usize] + } +} + +impl Index for [P] { + type Output = P; + + fn index(&self, index: PointId) -> &Self::Output { + &self[index.0 as usize] + } +} + +impl Index for Vec { + type Output = ZeroNode; + + fn index(&self, index: PointId) -> &Self::Output { + &self[index.0 as usize] + } +} + +impl IndexMut for Vec { + fn index_mut(&mut self, index: PointId) -> &mut Self::Output { + &mut self[index.0 as usize] + } +}