Move support types into a separate module
This commit is contained in:
parent
6250d198ad
commit
7b84aa8d45
174
src/lib.rs
174
src/lib.rs
|
@ -1,18 +1,20 @@
|
|||
use std::cmp::{max, min, Ordering, Reverse};
|
||||
use std::collections::BinaryHeap;
|
||||
use std::hash::Hash;
|
||||
use std::ops::{Index, IndexMut};
|
||||
|
||||
use ahash::AHashSet as HashSet;
|
||||
#[cfg(feature = "indicatif")]
|
||||
use indicatif::ProgressBar;
|
||||
use ordered_float::OrderedFloat;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand::SeedableRng;
|
||||
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
mod types;
|
||||
pub use types::PointId;
|
||||
use types::{Candidate, LayerId, NearestIter, UpperNode, ZeroNode};
|
||||
|
||||
/// Parameters for building the `Hnsw`
|
||||
pub struct Builder {
|
||||
ef_search: Option<usize>,
|
||||
|
@ -475,22 +477,6 @@ impl SearchPool {
|
|||
}
|
||||
}
|
||||
|
||||
impl Layer for [ZeroNode] {
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
||||
NearestIter {
|
||||
nearest: &self[pid.0 as usize].nearest,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Layer for [UpperNode] {
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
||||
NearestIter {
|
||||
nearest: &self[pid.0 as usize].nearest,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait Layer {
|
||||
/// Search this layer for nodes near the given `point`
|
||||
///
|
||||
|
@ -687,160 +673,10 @@ impl Default for Search {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
struct UpperNode {
|
||||
/// The nearest neighbors on this layer
|
||||
///
|
||||
/// This is always kept in sorted order (near to far).
|
||||
nearest: [PointId; M],
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct ZeroNode {
|
||||
/// The nearest neighbors on this layer
|
||||
///
|
||||
/// This is always kept in sorted order (near to far).
|
||||
nearest: [PointId; M * 2],
|
||||
}
|
||||
|
||||
impl Default for ZeroNode {
|
||||
fn default() -> ZeroNode {
|
||||
ZeroNode {
|
||||
nearest: [PointId::invalid(); M * 2],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct NearestIter<'a> {
|
||||
nearest: &'a [PointId],
|
||||
}
|
||||
|
||||
impl<'a> Iterator for NearestIter<'a> {
|
||||
type Item = PointId;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let (&first, rest) = self.nearest.split_first()?;
|
||||
if !first.is_valid() {
|
||||
return None;
|
||||
}
|
||||
self.nearest = rest;
|
||||
Some(first)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
struct LayerId(usize);
|
||||
|
||||
impl LayerId {
|
||||
fn random(ml: f32, rng: &mut SmallRng) -> Self {
|
||||
let layer = rng.gen::<f32>();
|
||||
LayerId((-(layer.ln() * ml)).floor() as usize)
|
||||
}
|
||||
|
||||
fn descend(&self) -> DescendingLayerIter {
|
||||
DescendingLayerIter { next: Some(self.0) }
|
||||
}
|
||||
|
||||
fn is_zero(&self) -> bool {
|
||||
self.0 == 0
|
||||
}
|
||||
}
|
||||
|
||||
struct DescendingLayerIter {
|
||||
next: Option<usize>,
|
||||
}
|
||||
|
||||
impl Iterator for DescendingLayerIter {
|
||||
type Item = LayerId;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
Some(LayerId(match self.next? {
|
||||
0 => {
|
||||
self.next = None;
|
||||
0
|
||||
}
|
||||
next => {
|
||||
self.next = Some(next - 1);
|
||||
next
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Point: Clone + Sync {
|
||||
fn distance(&self, other: &Self) -> f32;
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
|
||||
struct Candidate {
|
||||
distance: OrderedFloat<f32>,
|
||||
pid: PointId,
|
||||
}
|
||||
|
||||
/// References a `Point` in the `Hnsw`
|
||||
///
|
||||
/// This can be used to index into the `Hnsw` to refer to the `Point` data.
|
||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
pub struct PointId(u32);
|
||||
|
||||
impl PointId {
|
||||
fn invalid() -> Self {
|
||||
PointId(u32::MAX)
|
||||
}
|
||||
|
||||
/// Whether this value represents a valid point
|
||||
pub fn is_valid(self) -> bool {
|
||||
self.0 != u32::MAX
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PointId {
|
||||
fn default() -> Self {
|
||||
PointId::invalid()
|
||||
}
|
||||
}
|
||||
|
||||
impl<P> Index<PointId> for Hnsw<P> {
|
||||
type Output = P;
|
||||
|
||||
fn index(&self, index: PointId) -> &Self::Output {
|
||||
&self.points[index.0 as usize]
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: Point> Index<PointId> for Vec<P> {
|
||||
type Output = P;
|
||||
|
||||
fn index(&self, index: PointId) -> &Self::Output {
|
||||
&self[index.0 as usize]
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: Point> Index<PointId> for [P] {
|
||||
type Output = P;
|
||||
|
||||
fn index(&self, index: PointId) -> &Self::Output {
|
||||
&self[index.0 as usize]
|
||||
}
|
||||
}
|
||||
|
||||
impl Index<PointId> for Vec<ZeroNode> {
|
||||
type Output = ZeroNode;
|
||||
|
||||
fn index(&self, index: PointId) -> &Self::Output {
|
||||
&self[index.0 as usize]
|
||||
}
|
||||
}
|
||||
|
||||
impl IndexMut<PointId> for Vec<ZeroNode> {
|
||||
fn index_mut(&mut self, index: PointId) -> &mut Self::Output {
|
||||
&mut self[index.0 as usize]
|
||||
}
|
||||
}
|
||||
|
||||
/// The parameter `M` from the paper
|
||||
///
|
||||
/// This should become a generic argument to `Hnsw` when possible.
|
||||
|
|
|
@ -0,0 +1,176 @@
|
|||
use std::hash::Hash;
|
||||
use std::ops::{Index, IndexMut};
|
||||
|
||||
use ordered_float::OrderedFloat;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{Hnsw, Layer, Point, M};
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub(crate) struct UpperNode {
|
||||
/// The nearest neighbors on this layer
|
||||
///
|
||||
/// This is always kept in sorted order (near to far).
|
||||
pub(crate) nearest: [PointId; M],
|
||||
}
|
||||
|
||||
impl Layer for [UpperNode] {
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
||||
NearestIter {
|
||||
nearest: &self[pid.0 as usize].nearest,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub(crate) struct ZeroNode {
|
||||
/// The nearest neighbors on this layer
|
||||
///
|
||||
/// This is always kept in sorted order (near to far).
|
||||
pub(crate) nearest: [PointId; M * 2],
|
||||
}
|
||||
|
||||
impl Default for ZeroNode {
|
||||
fn default() -> ZeroNode {
|
||||
ZeroNode {
|
||||
nearest: [PointId::invalid(); M * 2],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Layer for [ZeroNode] {
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
||||
NearestIter {
|
||||
nearest: &self[pid.0 as usize].nearest,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct NearestIter<'a> {
|
||||
nearest: &'a [PointId],
|
||||
}
|
||||
|
||||
impl<'a> Iterator for NearestIter<'a> {
|
||||
type Item = PointId;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let (&first, rest) = self.nearest.split_first()?;
|
||||
if !first.is_valid() {
|
||||
return None;
|
||||
}
|
||||
self.nearest = rest;
|
||||
Some(first)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
pub(crate) struct LayerId(pub usize);
|
||||
|
||||
impl LayerId {
|
||||
pub(crate) fn random(ml: f32, rng: &mut SmallRng) -> Self {
|
||||
let layer = rng.gen::<f32>();
|
||||
LayerId((-(layer.ln() * ml)).floor() as usize)
|
||||
}
|
||||
|
||||
pub(crate) fn descend(&self) -> impl Iterator<Item = LayerId> {
|
||||
DescendingLayerIter { next: Some(self.0) }
|
||||
}
|
||||
|
||||
pub(crate) fn is_zero(&self) -> bool {
|
||||
self.0 == 0
|
||||
}
|
||||
}
|
||||
|
||||
struct DescendingLayerIter {
|
||||
next: Option<usize>,
|
||||
}
|
||||
|
||||
impl Iterator for DescendingLayerIter {
|
||||
type Item = LayerId;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
Some(LayerId(match self.next? {
|
||||
0 => {
|
||||
self.next = None;
|
||||
0
|
||||
}
|
||||
next => {
|
||||
self.next = Some(next - 1);
|
||||
next
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
|
||||
pub(crate) struct Candidate {
|
||||
pub(crate) distance: OrderedFloat<f32>,
|
||||
pub(crate) pid: PointId,
|
||||
}
|
||||
|
||||
/// References a `Point` in the `Hnsw`
|
||||
///
|
||||
/// This can be used to index into the `Hnsw` to refer to the `Point` data.
|
||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
pub struct PointId(pub(crate) u32);
|
||||
|
||||
impl PointId {
|
||||
pub(crate) fn invalid() -> Self {
|
||||
PointId(u32::MAX)
|
||||
}
|
||||
|
||||
/// Whether this value represents a valid point
|
||||
pub fn is_valid(self) -> bool {
|
||||
self.0 != u32::MAX
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PointId {
|
||||
fn default() -> Self {
|
||||
PointId::invalid()
|
||||
}
|
||||
}
|
||||
|
||||
impl<P> Index<PointId> for Hnsw<P> {
|
||||
type Output = P;
|
||||
|
||||
fn index(&self, index: PointId) -> &Self::Output {
|
||||
&self.points[index.0 as usize]
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: Point> Index<PointId> for Vec<P> {
|
||||
type Output = P;
|
||||
|
||||
fn index(&self, index: PointId) -> &Self::Output {
|
||||
&self[index.0 as usize]
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: Point> Index<PointId> for [P] {
|
||||
type Output = P;
|
||||
|
||||
fn index(&self, index: PointId) -> &Self::Output {
|
||||
&self[index.0 as usize]
|
||||
}
|
||||
}
|
||||
|
||||
impl Index<PointId> for Vec<ZeroNode> {
|
||||
type Output = ZeroNode;
|
||||
|
||||
fn index(&self, index: PointId) -> &Self::Output {
|
||||
&self[index.0 as usize]
|
||||
}
|
||||
}
|
||||
|
||||
impl IndexMut<PointId> for Vec<ZeroNode> {
|
||||
fn index_mut(&mut self, index: PointId) -> &mut Self::Output {
|
||||
&mut self[index.0 as usize]
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue