From f01ed4a4a0356d15bf02fb5730ea1a2e1962f0e0 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Thu, 7 Jan 2021 20:58:35 +0100 Subject: [PATCH] Re-order some code --- src/lib.rs | 210 ++++++++++++++++++++++++++--------------------------- 1 file changed, 105 insertions(+), 105 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 97e1d46..636987a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,65 @@ use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +/// Parameters for building the `Hnsw` +#[derive(Default)] +pub struct Builder { + ef_search: Option, + ef_construction: Option, + ml: Option, + seed: Option, + #[cfg(feature = "indicatif")] + progress: Option, +} + +impl Builder { + /// Set the `efConstruction` parameter from the paper + pub fn ef_construction(mut self, ef_construction: usize) -> Self { + self.ef_construction = Some(ef_construction); + self + } + + /// Set the `ef` parameter from the paper + /// + /// If the `efConstruction` parameter is not already set, it will be set + /// to the same value as `ef` by default. + pub fn ef(mut self, ef: usize) -> Self { + self.ef_search = Some(ef); + if self.ef_construction.is_none() { + self.ef_construction = Some(ef); + } + self + } + + /// Set the `mL` parameter from the paper + /// + /// If the `mL` parameter is not already set, it defaults to `ln(M)`. + pub fn ml(mut self, ml: f32) -> Self { + self.ml = Some(ml); + self + } + + /// Set the seed value for the random number generator used to generate a layer for each point + /// + /// If this value is left unset, a seed is generated from entropy (via `getrandom()`). + pub fn seed(mut self, seed: u64) -> Self { + self.seed = Some(seed); + self + } + + /// A `ProgressBar` to track `Hnsw` construction progress + #[cfg(feature = "indicatif")] + pub fn progress(mut self, bar: ProgressBar) -> Self { + self.progress = Some(bar); + self + } + + /// Build the `Hnsw` with the given set of points + pub fn build(self, points: &[P]) -> (Hnsw

, Vec) { + Hnsw::new(points, self) + } +} + #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] pub struct Hnsw

{ ef_search: usize, @@ -341,6 +400,52 @@ impl SearchPool { } } +impl Layer for Vec { + fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { + NearestIter { + nearest: &self[pid.0 as usize].nearest, + } + } +} + +impl Layer for Vec { + fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { + NearestIter { + nearest: &self[pid.0 as usize].nearest, + } + } +} + +trait Layer { + /// Search this layer for nodes near the given `point` + /// + /// This contains the loops from the paper's algorithm 2. `point` represents `q`, the query + /// element; `search.candidates` contains the enter points `ep`. `points` contains all the + /// points, which is required to calculate distances between two points. + /// + /// The `links` argument represents the number of links from each candidate to consider. This + /// function may be called for a higher layer (with M links per node) or the zero layer (with + /// M * 2 links per node), but for performance reasons we often call this function on the data + /// representation matching the zero layer even when we're referring to a higher layer. In that + /// case, we use `links` to constrain the number of per-candidate links we consider for search. + fn search(&self, point: &P, search: &mut Search, points: &[P], links: usize) { + while let Some(Reverse(candidate)) = search.candidates.pop() { + if candidate.distance > search.furthest { + break; + } + + for pid in self.nearest_iter(candidate.pid).take(links) { + search.push(pid, point, points); + } + } + + search.nearest.sort_unstable(); + search.nearest.truncate(search.ef); + } + + fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>; +} + /// Keeps mutable state for searching a point's nearest neighbors /// /// In particular, this contains most of the state used in algorithm 2. The structure is @@ -435,111 +540,6 @@ impl Default for Search { } } -/// Parameters for building the `Hnsw` -#[derive(Default)] -pub struct Builder { - ef_search: Option, - ef_construction: Option, - ml: Option, - seed: Option, - #[cfg(feature = "indicatif")] - progress: Option, -} - -impl Builder { - /// Set the `efConstruction` parameter from the paper - pub fn ef_construction(mut self, ef_construction: usize) -> Self { - self.ef_construction = Some(ef_construction); - self - } - - /// Set the `ef` parameter from the paper - /// - /// If the `efConstruction` parameter is not already set, it will be set - /// to the same value as `ef` by default. - pub fn ef(mut self, ef: usize) -> Self { - self.ef_search = Some(ef); - if self.ef_construction.is_none() { - self.ef_construction = Some(ef); - } - self - } - - /// Set the `mL` parameter from the paper - /// - /// If the `mL` parameter is not already set, it defaults to `ln(M)`. - pub fn ml(mut self, ml: f32) -> Self { - self.ml = Some(ml); - self - } - - /// Set the seed value for the random number generator used to generate a layer for each point - /// - /// If this value is left unset, a seed is generated from entropy (via `getrandom()`). - pub fn seed(mut self, seed: u64) -> Self { - self.seed = Some(seed); - self - } - - /// A `ProgressBar` to track `Hnsw` construction progress - #[cfg(feature = "indicatif")] - pub fn progress(mut self, bar: ProgressBar) -> Self { - self.progress = Some(bar); - self - } - - /// Build the `Hnsw` with the given set of points - pub fn build(self, points: &[P]) -> (Hnsw

, Vec) { - Hnsw::new(points, self) - } -} - -impl Layer for Vec { - fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { - NearestIter { - nearest: &self[pid.0 as usize].nearest, - } - } -} - -impl Layer for Vec { - fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { - NearestIter { - nearest: &self[pid.0 as usize].nearest, - } - } -} - -trait Layer { - /// Search this layer for nodes near the given `point` - /// - /// This contains the loops from the paper's algorithm 2. `point` represents `q`, the query - /// element; `search.candidates` contains the enter points `ep`. `points` contains all the - /// points, which is required to calculate distances between two points. - /// - /// The `links` argument represents the number of links from each candidate to consider. This - /// function may be called for a higher layer (with M links per node) or the zero layer (with - /// M * 2 links per node), but for performance reasons we often call this function on the data - /// representation matching the zero layer even when we're referring to a higher layer. In that - /// case, we use `links` to constrain the number of per-candidate links we consider for search. - fn search(&self, point: &P, search: &mut Search, points: &[P], links: usize) { - while let Some(Reverse(candidate)) = search.candidates.pop() { - if candidate.distance > search.furthest { - break; - } - - for pid in self.nearest_iter(candidate.pid).take(links) { - search.push(pid, point, points); - } - } - - search.nearest.sort_unstable(); - search.nearest.truncate(search.ef); - } - - fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>; -} - #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] #[derive(Clone, Copy, Debug, Default)] struct UpperNode {