Move support types into a separate module
This commit is contained in:
parent
6250d198ad
commit
7b84aa8d45
174
src/lib.rs
174
src/lib.rs
|
@ -1,18 +1,20 @@
|
||||||
use std::cmp::{max, min, Ordering, Reverse};
|
use std::cmp::{max, min, Ordering, Reverse};
|
||||||
use std::collections::BinaryHeap;
|
use std::collections::BinaryHeap;
|
||||||
use std::hash::Hash;
|
|
||||||
use std::ops::{Index, IndexMut};
|
|
||||||
|
|
||||||
use ahash::AHashSet as HashSet;
|
use ahash::AHashSet as HashSet;
|
||||||
#[cfg(feature = "indicatif")]
|
#[cfg(feature = "indicatif")]
|
||||||
use indicatif::ProgressBar;
|
use indicatif::ProgressBar;
|
||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
use rand::rngs::SmallRng;
|
use rand::rngs::SmallRng;
|
||||||
use rand::{Rng, SeedableRng};
|
use rand::SeedableRng;
|
||||||
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
|
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
mod types;
|
||||||
|
pub use types::PointId;
|
||||||
|
use types::{Candidate, LayerId, NearestIter, UpperNode, ZeroNode};
|
||||||
|
|
||||||
/// Parameters for building the `Hnsw`
|
/// Parameters for building the `Hnsw`
|
||||||
pub struct Builder {
|
pub struct Builder {
|
||||||
ef_search: Option<usize>,
|
ef_search: Option<usize>,
|
||||||
|
@ -475,22 +477,6 @@ impl SearchPool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Layer for [ZeroNode] {
|
|
||||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
|
||||||
NearestIter {
|
|
||||||
nearest: &self[pid.0 as usize].nearest,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Layer for [UpperNode] {
|
|
||||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
|
||||||
NearestIter {
|
|
||||||
nearest: &self[pid.0 as usize].nearest,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
trait Layer {
|
trait Layer {
|
||||||
/// Search this layer for nodes near the given `point`
|
/// Search this layer for nodes near the given `point`
|
||||||
///
|
///
|
||||||
|
@ -687,160 +673,10 @@ impl Default for Search {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
|
||||||
#[derive(Clone, Copy, Debug, Default)]
|
|
||||||
struct UpperNode {
|
|
||||||
/// The nearest neighbors on this layer
|
|
||||||
///
|
|
||||||
/// This is always kept in sorted order (near to far).
|
|
||||||
nearest: [PointId; M],
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
|
||||||
#[derive(Clone, Copy, Debug)]
|
|
||||||
struct ZeroNode {
|
|
||||||
/// The nearest neighbors on this layer
|
|
||||||
///
|
|
||||||
/// This is always kept in sorted order (near to far).
|
|
||||||
nearest: [PointId; M * 2],
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for ZeroNode {
|
|
||||||
fn default() -> ZeroNode {
|
|
||||||
ZeroNode {
|
|
||||||
nearest: [PointId::invalid(); M * 2],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct NearestIter<'a> {
|
|
||||||
nearest: &'a [PointId],
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> Iterator for NearestIter<'a> {
|
|
||||||
type Item = PointId;
|
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
|
||||||
let (&first, rest) = self.nearest.split_first()?;
|
|
||||||
if !first.is_valid() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
self.nearest = rest;
|
|
||||||
Some(first)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
|
||||||
struct LayerId(usize);
|
|
||||||
|
|
||||||
impl LayerId {
|
|
||||||
fn random(ml: f32, rng: &mut SmallRng) -> Self {
|
|
||||||
let layer = rng.gen::<f32>();
|
|
||||||
LayerId((-(layer.ln() * ml)).floor() as usize)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn descend(&self) -> DescendingLayerIter {
|
|
||||||
DescendingLayerIter { next: Some(self.0) }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_zero(&self) -> bool {
|
|
||||||
self.0 == 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct DescendingLayerIter {
|
|
||||||
next: Option<usize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Iterator for DescendingLayerIter {
|
|
||||||
type Item = LayerId;
|
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
|
||||||
Some(LayerId(match self.next? {
|
|
||||||
0 => {
|
|
||||||
self.next = None;
|
|
||||||
0
|
|
||||||
}
|
|
||||||
next => {
|
|
||||||
self.next = Some(next - 1);
|
|
||||||
next
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Point: Clone + Sync {
|
pub trait Point: Clone + Sync {
|
||||||
fn distance(&self, other: &Self) -> f32;
|
fn distance(&self, other: &Self) -> f32;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
|
|
||||||
struct Candidate {
|
|
||||||
distance: OrderedFloat<f32>,
|
|
||||||
pid: PointId,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// References a `Point` in the `Hnsw`
|
|
||||||
///
|
|
||||||
/// This can be used to index into the `Hnsw` to refer to the `Point` data.
|
|
||||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
|
||||||
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
|
||||||
pub struct PointId(u32);
|
|
||||||
|
|
||||||
impl PointId {
|
|
||||||
fn invalid() -> Self {
|
|
||||||
PointId(u32::MAX)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Whether this value represents a valid point
|
|
||||||
pub fn is_valid(self) -> bool {
|
|
||||||
self.0 != u32::MAX
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for PointId {
|
|
||||||
fn default() -> Self {
|
|
||||||
PointId::invalid()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<P> Index<PointId> for Hnsw<P> {
|
|
||||||
type Output = P;
|
|
||||||
|
|
||||||
fn index(&self, index: PointId) -> &Self::Output {
|
|
||||||
&self.points[index.0 as usize]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<P: Point> Index<PointId> for Vec<P> {
|
|
||||||
type Output = P;
|
|
||||||
|
|
||||||
fn index(&self, index: PointId) -> &Self::Output {
|
|
||||||
&self[index.0 as usize]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<P: Point> Index<PointId> for [P] {
|
|
||||||
type Output = P;
|
|
||||||
|
|
||||||
fn index(&self, index: PointId) -> &Self::Output {
|
|
||||||
&self[index.0 as usize]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Index<PointId> for Vec<ZeroNode> {
|
|
||||||
type Output = ZeroNode;
|
|
||||||
|
|
||||||
fn index(&self, index: PointId) -> &Self::Output {
|
|
||||||
&self[index.0 as usize]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IndexMut<PointId> for Vec<ZeroNode> {
|
|
||||||
fn index_mut(&mut self, index: PointId) -> &mut Self::Output {
|
|
||||||
&mut self[index.0 as usize]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The parameter `M` from the paper
|
/// The parameter `M` from the paper
|
||||||
///
|
///
|
||||||
/// This should become a generic argument to `Hnsw` when possible.
|
/// This should become a generic argument to `Hnsw` when possible.
|
||||||
|
|
|
@ -0,0 +1,176 @@
|
||||||
|
use std::hash::Hash;
|
||||||
|
use std::ops::{Index, IndexMut};
|
||||||
|
|
||||||
|
use ordered_float::OrderedFloat;
|
||||||
|
use rand::rngs::SmallRng;
|
||||||
|
use rand::Rng;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::{Hnsw, Layer, Point, M};
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||||
|
#[derive(Clone, Copy, Debug, Default)]
|
||||||
|
pub(crate) struct UpperNode {
|
||||||
|
/// The nearest neighbors on this layer
|
||||||
|
///
|
||||||
|
/// This is always kept in sorted order (near to far).
|
||||||
|
pub(crate) nearest: [PointId; M],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Layer for [UpperNode] {
|
||||||
|
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
||||||
|
NearestIter {
|
||||||
|
nearest: &self[pid.0 as usize].nearest,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub(crate) struct ZeroNode {
|
||||||
|
/// The nearest neighbors on this layer
|
||||||
|
///
|
||||||
|
/// This is always kept in sorted order (near to far).
|
||||||
|
pub(crate) nearest: [PointId; M * 2],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ZeroNode {
|
||||||
|
fn default() -> ZeroNode {
|
||||||
|
ZeroNode {
|
||||||
|
nearest: [PointId::invalid(); M * 2],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Layer for [ZeroNode] {
|
||||||
|
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
||||||
|
NearestIter {
|
||||||
|
nearest: &self[pid.0 as usize].nearest,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct NearestIter<'a> {
|
||||||
|
nearest: &'a [PointId],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Iterator for NearestIter<'a> {
|
||||||
|
type Item = PointId;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let (&first, rest) = self.nearest.split_first()?;
|
||||||
|
if !first.is_valid() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
self.nearest = rest;
|
||||||
|
Some(first)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||||
|
pub(crate) struct LayerId(pub usize);
|
||||||
|
|
||||||
|
impl LayerId {
|
||||||
|
pub(crate) fn random(ml: f32, rng: &mut SmallRng) -> Self {
|
||||||
|
let layer = rng.gen::<f32>();
|
||||||
|
LayerId((-(layer.ln() * ml)).floor() as usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn descend(&self) -> impl Iterator<Item = LayerId> {
|
||||||
|
DescendingLayerIter { next: Some(self.0) }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn is_zero(&self) -> bool {
|
||||||
|
self.0 == 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct DescendingLayerIter {
|
||||||
|
next: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Iterator for DescendingLayerIter {
|
||||||
|
type Item = LayerId;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
Some(LayerId(match self.next? {
|
||||||
|
0 => {
|
||||||
|
self.next = None;
|
||||||
|
0
|
||||||
|
}
|
||||||
|
next => {
|
||||||
|
self.next = Some(next - 1);
|
||||||
|
next
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
|
||||||
|
pub(crate) struct Candidate {
|
||||||
|
pub(crate) distance: OrderedFloat<f32>,
|
||||||
|
pub(crate) pid: PointId,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// References a `Point` in the `Hnsw`
|
||||||
|
///
|
||||||
|
/// This can be used to index into the `Hnsw` to refer to the `Point` data.
|
||||||
|
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||||
|
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||||
|
pub struct PointId(pub(crate) u32);
|
||||||
|
|
||||||
|
impl PointId {
|
||||||
|
pub(crate) fn invalid() -> Self {
|
||||||
|
PointId(u32::MAX)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Whether this value represents a valid point
|
||||||
|
pub fn is_valid(self) -> bool {
|
||||||
|
self.0 != u32::MAX
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for PointId {
|
||||||
|
fn default() -> Self {
|
||||||
|
PointId::invalid()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<P> Index<PointId> for Hnsw<P> {
|
||||||
|
type Output = P;
|
||||||
|
|
||||||
|
fn index(&self, index: PointId) -> &Self::Output {
|
||||||
|
&self.points[index.0 as usize]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<P: Point> Index<PointId> for Vec<P> {
|
||||||
|
type Output = P;
|
||||||
|
|
||||||
|
fn index(&self, index: PointId) -> &Self::Output {
|
||||||
|
&self[index.0 as usize]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<P: Point> Index<PointId> for [P] {
|
||||||
|
type Output = P;
|
||||||
|
|
||||||
|
fn index(&self, index: PointId) -> &Self::Output {
|
||||||
|
&self[index.0 as usize]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Index<PointId> for Vec<ZeroNode> {
|
||||||
|
type Output = ZeroNode;
|
||||||
|
|
||||||
|
fn index(&self, index: PointId) -> &Self::Output {
|
||||||
|
&self[index.0 as usize]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IndexMut<PointId> for Vec<ZeroNode> {
|
||||||
|
fn index_mut(&mut self, index: PointId) -> &mut Self::Output {
|
||||||
|
&mut self[index.0 as usize]
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue