Re-order some code
This commit is contained in:
parent
e6d200954e
commit
f01ed4a4a0
210
src/lib.rs
210
src/lib.rs
|
@ -13,6 +13,65 @@ use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
|
|||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Parameters for building the `Hnsw`
|
||||
#[derive(Default)]
|
||||
pub struct Builder {
|
||||
ef_search: Option<usize>,
|
||||
ef_construction: Option<usize>,
|
||||
ml: Option<f32>,
|
||||
seed: Option<u64>,
|
||||
#[cfg(feature = "indicatif")]
|
||||
progress: Option<ProgressBar>,
|
||||
}
|
||||
|
||||
impl Builder {
|
||||
/// Set the `efConstruction` parameter from the paper
|
||||
pub fn ef_construction(mut self, ef_construction: usize) -> Self {
|
||||
self.ef_construction = Some(ef_construction);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the `ef` parameter from the paper
|
||||
///
|
||||
/// If the `efConstruction` parameter is not already set, it will be set
|
||||
/// to the same value as `ef` by default.
|
||||
pub fn ef(mut self, ef: usize) -> Self {
|
||||
self.ef_search = Some(ef);
|
||||
if self.ef_construction.is_none() {
|
||||
self.ef_construction = Some(ef);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the `mL` parameter from the paper
|
||||
///
|
||||
/// If the `mL` parameter is not already set, it defaults to `ln(M)`.
|
||||
pub fn ml(mut self, ml: f32) -> Self {
|
||||
self.ml = Some(ml);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the seed value for the random number generator used to generate a layer for each point
|
||||
///
|
||||
/// If this value is left unset, a seed is generated from entropy (via `getrandom()`).
|
||||
pub fn seed(mut self, seed: u64) -> Self {
|
||||
self.seed = Some(seed);
|
||||
self
|
||||
}
|
||||
|
||||
/// A `ProgressBar` to track `Hnsw` construction progress
|
||||
#[cfg(feature = "indicatif")]
|
||||
pub fn progress(mut self, bar: ProgressBar) -> Self {
|
||||
self.progress = Some(bar);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the `Hnsw` with the given set of points
|
||||
pub fn build<P: Point>(self, points: &[P]) -> (Hnsw<P>, Vec<PointId>) {
|
||||
Hnsw::new(points, self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||
pub struct Hnsw<P> {
|
||||
ef_search: usize,
|
||||
|
@ -341,6 +400,52 @@ impl SearchPool {
|
|||
}
|
||||
}
|
||||
|
||||
impl Layer for Vec<ZeroNode> {
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
||||
NearestIter {
|
||||
nearest: &self[pid.0 as usize].nearest,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Layer for Vec<UpperNode> {
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
||||
NearestIter {
|
||||
nearest: &self[pid.0 as usize].nearest,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait Layer {
|
||||
/// Search this layer for nodes near the given `point`
|
||||
///
|
||||
/// This contains the loops from the paper's algorithm 2. `point` represents `q`, the query
|
||||
/// element; `search.candidates` contains the enter points `ep`. `points` contains all the
|
||||
/// points, which is required to calculate distances between two points.
|
||||
///
|
||||
/// The `links` argument represents the number of links from each candidate to consider. This
|
||||
/// function may be called for a higher layer (with M links per node) or the zero layer (with
|
||||
/// M * 2 links per node), but for performance reasons we often call this function on the data
|
||||
/// representation matching the zero layer even when we're referring to a higher layer. In that
|
||||
/// case, we use `links` to constrain the number of per-candidate links we consider for search.
|
||||
fn search<P: Point>(&self, point: &P, search: &mut Search, points: &[P], links: usize) {
|
||||
while let Some(Reverse(candidate)) = search.candidates.pop() {
|
||||
if candidate.distance > search.furthest {
|
||||
break;
|
||||
}
|
||||
|
||||
for pid in self.nearest_iter(candidate.pid).take(links) {
|
||||
search.push(pid, point, points);
|
||||
}
|
||||
}
|
||||
|
||||
search.nearest.sort_unstable();
|
||||
search.nearest.truncate(search.ef);
|
||||
}
|
||||
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>;
|
||||
}
|
||||
|
||||
/// Keeps mutable state for searching a point's nearest neighbors
|
||||
///
|
||||
/// In particular, this contains most of the state used in algorithm 2. The structure is
|
||||
|
@ -435,111 +540,6 @@ impl Default for Search {
|
|||
}
|
||||
}
|
||||
|
||||
/// Parameters for building the `Hnsw`
|
||||
#[derive(Default)]
|
||||
pub struct Builder {
|
||||
ef_search: Option<usize>,
|
||||
ef_construction: Option<usize>,
|
||||
ml: Option<f32>,
|
||||
seed: Option<u64>,
|
||||
#[cfg(feature = "indicatif")]
|
||||
progress: Option<ProgressBar>,
|
||||
}
|
||||
|
||||
impl Builder {
|
||||
/// Set the `efConstruction` parameter from the paper
|
||||
pub fn ef_construction(mut self, ef_construction: usize) -> Self {
|
||||
self.ef_construction = Some(ef_construction);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the `ef` parameter from the paper
|
||||
///
|
||||
/// If the `efConstruction` parameter is not already set, it will be set
|
||||
/// to the same value as `ef` by default.
|
||||
pub fn ef(mut self, ef: usize) -> Self {
|
||||
self.ef_search = Some(ef);
|
||||
if self.ef_construction.is_none() {
|
||||
self.ef_construction = Some(ef);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the `mL` parameter from the paper
|
||||
///
|
||||
/// If the `mL` parameter is not already set, it defaults to `ln(M)`.
|
||||
pub fn ml(mut self, ml: f32) -> Self {
|
||||
self.ml = Some(ml);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the seed value for the random number generator used to generate a layer for each point
|
||||
///
|
||||
/// If this value is left unset, a seed is generated from entropy (via `getrandom()`).
|
||||
pub fn seed(mut self, seed: u64) -> Self {
|
||||
self.seed = Some(seed);
|
||||
self
|
||||
}
|
||||
|
||||
/// A `ProgressBar` to track `Hnsw` construction progress
|
||||
#[cfg(feature = "indicatif")]
|
||||
pub fn progress(mut self, bar: ProgressBar) -> Self {
|
||||
self.progress = Some(bar);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the `Hnsw` with the given set of points
|
||||
pub fn build<P: Point>(self, points: &[P]) -> (Hnsw<P>, Vec<PointId>) {
|
||||
Hnsw::new(points, self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Layer for Vec<ZeroNode> {
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
||||
NearestIter {
|
||||
nearest: &self[pid.0 as usize].nearest,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Layer for Vec<UpperNode> {
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
||||
NearestIter {
|
||||
nearest: &self[pid.0 as usize].nearest,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait Layer {
|
||||
/// Search this layer for nodes near the given `point`
|
||||
///
|
||||
/// This contains the loops from the paper's algorithm 2. `point` represents `q`, the query
|
||||
/// element; `search.candidates` contains the enter points `ep`. `points` contains all the
|
||||
/// points, which is required to calculate distances between two points.
|
||||
///
|
||||
/// The `links` argument represents the number of links from each candidate to consider. This
|
||||
/// function may be called for a higher layer (with M links per node) or the zero layer (with
|
||||
/// M * 2 links per node), but for performance reasons we often call this function on the data
|
||||
/// representation matching the zero layer even when we're referring to a higher layer. In that
|
||||
/// case, we use `links` to constrain the number of per-candidate links we consider for search.
|
||||
fn search<P: Point>(&self, point: &P, search: &mut Search, points: &[P], links: usize) {
|
||||
while let Some(Reverse(candidate)) = search.candidates.pop() {
|
||||
if candidate.distance > search.furthest {
|
||||
break;
|
||||
}
|
||||
|
||||
for pid in self.nearest_iter(candidate.pid).take(links) {
|
||||
search.push(pid, point, points);
|
||||
}
|
||||
}
|
||||
|
||||
search.nearest.sort_unstable();
|
||||
search.nearest.truncate(search.ef);
|
||||
}
|
||||
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>;
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
struct UpperNode {
|
||||
|
|
Loading…
Reference in New Issue