Re-order some code

This commit is contained in:
Dirkjan Ochtman 2021-01-07 20:58:35 +01:00
parent e6d200954e
commit f01ed4a4a0
1 changed files with 105 additions and 105 deletions

View File

@ -13,6 +13,65 @@ use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
/// Parameters for building the `Hnsw`
#[derive(Default)]
pub struct Builder {
ef_search: Option<usize>,
ef_construction: Option<usize>,
ml: Option<f32>,
seed: Option<u64>,
#[cfg(feature = "indicatif")]
progress: Option<ProgressBar>,
}
impl Builder {
/// Set the `efConstruction` parameter from the paper
pub fn ef_construction(mut self, ef_construction: usize) -> Self {
self.ef_construction = Some(ef_construction);
self
}
/// Set the `ef` parameter from the paper
///
/// If the `efConstruction` parameter is not already set, it will be set
/// to the same value as `ef` by default.
pub fn ef(mut self, ef: usize) -> Self {
self.ef_search = Some(ef);
if self.ef_construction.is_none() {
self.ef_construction = Some(ef);
}
self
}
/// Set the `mL` parameter from the paper
///
/// If the `mL` parameter is not already set, it defaults to `ln(M)`.
pub fn ml(mut self, ml: f32) -> Self {
self.ml = Some(ml);
self
}
/// Set the seed value for the random number generator used to generate a layer for each point
///
/// If this value is left unset, a seed is generated from entropy (via `getrandom()`).
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
/// A `ProgressBar` to track `Hnsw` construction progress
#[cfg(feature = "indicatif")]
pub fn progress(mut self, bar: ProgressBar) -> Self {
self.progress = Some(bar);
self
}
/// Build the `Hnsw` with the given set of points
pub fn build<P: Point>(self, points: &[P]) -> (Hnsw<P>, Vec<PointId>) {
Hnsw::new(points, self)
}
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct Hnsw<P> {
ef_search: usize,
@ -341,6 +400,52 @@ impl SearchPool {
}
}
impl Layer for Vec<ZeroNode> {
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
NearestIter {
nearest: &self[pid.0 as usize].nearest,
}
}
}
impl Layer for Vec<UpperNode> {
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
NearestIter {
nearest: &self[pid.0 as usize].nearest,
}
}
}
trait Layer {
/// Search this layer for nodes near the given `point`
///
/// This contains the loops from the paper's algorithm 2. `point` represents `q`, the query
/// element; `search.candidates` contains the enter points `ep`. `points` contains all the
/// points, which is required to calculate distances between two points.
///
/// The `links` argument represents the number of links from each candidate to consider. This
/// function may be called for a higher layer (with M links per node) or the zero layer (with
/// M * 2 links per node), but for performance reasons we often call this function on the data
/// representation matching the zero layer even when we're referring to a higher layer. In that
/// case, we use `links` to constrain the number of per-candidate links we consider for search.
fn search<P: Point>(&self, point: &P, search: &mut Search, points: &[P], links: usize) {
while let Some(Reverse(candidate)) = search.candidates.pop() {
if candidate.distance > search.furthest {
break;
}
for pid in self.nearest_iter(candidate.pid).take(links) {
search.push(pid, point, points);
}
}
search.nearest.sort_unstable();
search.nearest.truncate(search.ef);
}
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>;
}
/// Keeps mutable state for searching a point's nearest neighbors
///
/// In particular, this contains most of the state used in algorithm 2. The structure is
@ -435,111 +540,6 @@ impl Default for Search {
}
}
/// Parameters for building the `Hnsw`
#[derive(Default)]
pub struct Builder {
ef_search: Option<usize>,
ef_construction: Option<usize>,
ml: Option<f32>,
seed: Option<u64>,
#[cfg(feature = "indicatif")]
progress: Option<ProgressBar>,
}
impl Builder {
/// Set the `efConstruction` parameter from the paper
pub fn ef_construction(mut self, ef_construction: usize) -> Self {
self.ef_construction = Some(ef_construction);
self
}
/// Set the `ef` parameter from the paper
///
/// If the `efConstruction` parameter is not already set, it will be set
/// to the same value as `ef` by default.
pub fn ef(mut self, ef: usize) -> Self {
self.ef_search = Some(ef);
if self.ef_construction.is_none() {
self.ef_construction = Some(ef);
}
self
}
/// Set the `mL` parameter from the paper
///
/// If the `mL` parameter is not already set, it defaults to `ln(M)`.
pub fn ml(mut self, ml: f32) -> Self {
self.ml = Some(ml);
self
}
/// Set the seed value for the random number generator used to generate a layer for each point
///
/// If this value is left unset, a seed is generated from entropy (via `getrandom()`).
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
/// A `ProgressBar` to track `Hnsw` construction progress
#[cfg(feature = "indicatif")]
pub fn progress(mut self, bar: ProgressBar) -> Self {
self.progress = Some(bar);
self
}
/// Build the `Hnsw` with the given set of points
pub fn build<P: Point>(self, points: &[P]) -> (Hnsw<P>, Vec<PointId>) {
Hnsw::new(points, self)
}
}
impl Layer for Vec<ZeroNode> {
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
NearestIter {
nearest: &self[pid.0 as usize].nearest,
}
}
}
impl Layer for Vec<UpperNode> {
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
NearestIter {
nearest: &self[pid.0 as usize].nearest,
}
}
}
trait Layer {
/// Search this layer for nodes near the given `point`
///
/// This contains the loops from the paper's algorithm 2. `point` represents `q`, the query
/// element; `search.candidates` contains the enter points `ep`. `points` contains all the
/// points, which is required to calculate distances between two points.
///
/// The `links` argument represents the number of links from each candidate to consider. This
/// function may be called for a higher layer (with M links per node) or the zero layer (with
/// M * 2 links per node), but for performance reasons we often call this function on the data
/// representation matching the zero layer even when we're referring to a higher layer. In that
/// case, we use `links` to constrain the number of per-candidate links we consider for search.
fn search<P: Point>(&self, point: &P, search: &mut Search, points: &[P], links: usize) {
while let Some(Reverse(candidate)) = search.candidates.pop() {
if candidate.distance > search.furthest {
break;
}
for pid in self.nearest_iter(candidate.pid).take(links) {
search.push(pid, point, points);
}
}
search.nearest.sort_unstable();
search.nearest.truncate(search.ef);
}
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>;
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Clone, Copy, Debug, Default)]
struct UpperNode {