Move support types into a separate module

This commit is contained in:
Dirkjan Ochtman 2021-01-20 13:43:53 +01:00
parent 6250d198ad
commit 7b84aa8d45
2 changed files with 181 additions and 169 deletions

View File

@ -1,18 +1,20 @@
use std::cmp::{max, min, Ordering, Reverse};
use std::collections::BinaryHeap;
use std::hash::Hash;
use std::ops::{Index, IndexMut};
use ahash::AHashSet as HashSet;
#[cfg(feature = "indicatif")]
use indicatif::ProgressBar;
use ordered_float::OrderedFloat;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use rand::SeedableRng;
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
mod types;
pub use types::PointId;
use types::{Candidate, LayerId, NearestIter, UpperNode, ZeroNode};
/// Parameters for building the `Hnsw`
pub struct Builder {
ef_search: Option<usize>,
@ -475,22 +477,6 @@ impl SearchPool {
}
}
impl Layer for [ZeroNode] {
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
NearestIter {
nearest: &self[pid.0 as usize].nearest,
}
}
}
impl Layer for [UpperNode] {
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
NearestIter {
nearest: &self[pid.0 as usize].nearest,
}
}
}
trait Layer {
/// Search this layer for nodes near the given `point`
///
@ -687,160 +673,10 @@ impl Default for Search {
}
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Clone, Copy, Debug, Default)]
struct UpperNode {
/// The nearest neighbors on this layer
///
/// This is always kept in sorted order (near to far).
nearest: [PointId; M],
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Clone, Copy, Debug)]
struct ZeroNode {
/// The nearest neighbors on this layer
///
/// This is always kept in sorted order (near to far).
nearest: [PointId; M * 2],
}
impl Default for ZeroNode {
fn default() -> ZeroNode {
ZeroNode {
nearest: [PointId::invalid(); M * 2],
}
}
}
struct NearestIter<'a> {
nearest: &'a [PointId],
}
impl<'a> Iterator for NearestIter<'a> {
type Item = PointId;
fn next(&mut self) -> Option<Self::Item> {
let (&first, rest) = self.nearest.split_first()?;
if !first.is_valid() {
return None;
}
self.nearest = rest;
Some(first)
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
struct LayerId(usize);
impl LayerId {
fn random(ml: f32, rng: &mut SmallRng) -> Self {
let layer = rng.gen::<f32>();
LayerId((-(layer.ln() * ml)).floor() as usize)
}
fn descend(&self) -> DescendingLayerIter {
DescendingLayerIter { next: Some(self.0) }
}
fn is_zero(&self) -> bool {
self.0 == 0
}
}
struct DescendingLayerIter {
next: Option<usize>,
}
impl Iterator for DescendingLayerIter {
type Item = LayerId;
fn next(&mut self) -> Option<Self::Item> {
Some(LayerId(match self.next? {
0 => {
self.next = None;
0
}
next => {
self.next = Some(next - 1);
next
}
}))
}
}
pub trait Point: Clone + Sync {
fn distance(&self, other: &Self) -> f32;
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
struct Candidate {
distance: OrderedFloat<f32>,
pid: PointId,
}
/// References a `Point` in the `Hnsw`
///
/// This can be used to index into the `Hnsw` to refer to the `Point` data.
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct PointId(u32);
impl PointId {
fn invalid() -> Self {
PointId(u32::MAX)
}
/// Whether this value represents a valid point
pub fn is_valid(self) -> bool {
self.0 != u32::MAX
}
}
impl Default for PointId {
fn default() -> Self {
PointId::invalid()
}
}
impl<P> Index<PointId> for Hnsw<P> {
type Output = P;
fn index(&self, index: PointId) -> &Self::Output {
&self.points[index.0 as usize]
}
}
impl<P: Point> Index<PointId> for Vec<P> {
type Output = P;
fn index(&self, index: PointId) -> &Self::Output {
&self[index.0 as usize]
}
}
impl<P: Point> Index<PointId> for [P] {
type Output = P;
fn index(&self, index: PointId) -> &Self::Output {
&self[index.0 as usize]
}
}
impl Index<PointId> for Vec<ZeroNode> {
type Output = ZeroNode;
fn index(&self, index: PointId) -> &Self::Output {
&self[index.0 as usize]
}
}
impl IndexMut<PointId> for Vec<ZeroNode> {
fn index_mut(&mut self, index: PointId) -> &mut Self::Output {
&mut self[index.0 as usize]
}
}
/// The parameter `M` from the paper
///
/// This should become a generic argument to `Hnsw` when possible.

176
src/types.rs Normal file
View File

@ -0,0 +1,176 @@
use std::hash::Hash;
use std::ops::{Index, IndexMut};
use ordered_float::OrderedFloat;
use rand::rngs::SmallRng;
use rand::Rng;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::{Hnsw, Layer, Point, M};
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Clone, Copy, Debug, Default)]
pub(crate) struct UpperNode {
/// The nearest neighbors on this layer
///
/// This is always kept in sorted order (near to far).
pub(crate) nearest: [PointId; M],
}
impl Layer for [UpperNode] {
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
NearestIter {
nearest: &self[pid.0 as usize].nearest,
}
}
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Clone, Copy, Debug)]
pub(crate) struct ZeroNode {
/// The nearest neighbors on this layer
///
/// This is always kept in sorted order (near to far).
pub(crate) nearest: [PointId; M * 2],
}
impl Default for ZeroNode {
fn default() -> ZeroNode {
ZeroNode {
nearest: [PointId::invalid(); M * 2],
}
}
}
impl Layer for [ZeroNode] {
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
NearestIter {
nearest: &self[pid.0 as usize].nearest,
}
}
}
pub(crate) struct NearestIter<'a> {
nearest: &'a [PointId],
}
impl<'a> Iterator for NearestIter<'a> {
type Item = PointId;
fn next(&mut self) -> Option<Self::Item> {
let (&first, rest) = self.nearest.split_first()?;
if !first.is_valid() {
return None;
}
self.nearest = rest;
Some(first)
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub(crate) struct LayerId(pub usize);
impl LayerId {
pub(crate) fn random(ml: f32, rng: &mut SmallRng) -> Self {
let layer = rng.gen::<f32>();
LayerId((-(layer.ln() * ml)).floor() as usize)
}
pub(crate) fn descend(&self) -> impl Iterator<Item = LayerId> {
DescendingLayerIter { next: Some(self.0) }
}
pub(crate) fn is_zero(&self) -> bool {
self.0 == 0
}
}
struct DescendingLayerIter {
next: Option<usize>,
}
impl Iterator for DescendingLayerIter {
type Item = LayerId;
fn next(&mut self) -> Option<Self::Item> {
Some(LayerId(match self.next? {
0 => {
self.next = None;
0
}
next => {
self.next = Some(next - 1);
next
}
}))
}
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub(crate) struct Candidate {
pub(crate) distance: OrderedFloat<f32>,
pub(crate) pid: PointId,
}
/// References a `Point` in the `Hnsw`
///
/// This can be used to index into the `Hnsw` to refer to the `Point` data.
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct PointId(pub(crate) u32);
impl PointId {
pub(crate) fn invalid() -> Self {
PointId(u32::MAX)
}
/// Whether this value represents a valid point
pub fn is_valid(self) -> bool {
self.0 != u32::MAX
}
}
impl Default for PointId {
fn default() -> Self {
PointId::invalid()
}
}
impl<P> Index<PointId> for Hnsw<P> {
type Output = P;
fn index(&self, index: PointId) -> &Self::Output {
&self.points[index.0 as usize]
}
}
impl<P: Point> Index<PointId> for Vec<P> {
type Output = P;
fn index(&self, index: PointId) -> &Self::Output {
&self[index.0 as usize]
}
}
impl<P: Point> Index<PointId> for [P] {
type Output = P;
fn index(&self, index: PointId) -> &Self::Output {
&self[index.0 as usize]
}
}
impl Index<PointId> for Vec<ZeroNode> {
type Output = ZeroNode;
fn index(&self, index: PointId) -> &Self::Output {
&self[index.0 as usize]
}
}
impl IndexMut<PointId> for Vec<ZeroNode> {
fn index_mut(&mut self, index: PointId) -> &mut Self::Output {
&mut self[index.0 as usize]
}
}