Move search() method into Search

This commit is contained in:
Dirkjan Ochtman 2021-01-20 16:42:12 +01:00
parent 5cf83543db
commit d80a2e3f67
2 changed files with 41 additions and 41 deletions

View File

@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
mod types; mod types;
pub use types::PointId; pub use types::PointId;
use types::{Candidate, LayerId, NearestIter, UpperNode, Visited, ZeroNode}; use types::{Candidate, Layer, LayerId, UpperNode, Visited, ZeroNode};
/// Parameters for building the `Hnsw` /// Parameters for building the `Hnsw`
pub struct Builder { pub struct Builder {
@ -260,11 +260,11 @@ where
search.ef = if cur <= layer { ef_construction } else { 1 }; search.ef = if cur <= layer { ef_construction } else { 1 };
match cur > layer { match cur > layer {
true => { true => {
layers[cur.0 - 1].search(point, search, &points, num); search.search(point, &layers[cur.0 - 1], &points, num);
search.cull(); search.cull();
} }
false => { false => {
zero.search(point, search, &points, num); search.search(point, &zero, &points, num);
break; break;
} }
} }
@ -347,8 +347,8 @@ where
search.ef = ef; search.ef = ef;
match cur.0 { match cur.0 {
0 => self.zero.search(point, search, &self.points, num), 0 => search.search(point, &self.zero, &self.points, num),
l => self.layers[l - 1].search(point, search, &self.points, num), l => search.search(point, &self.layers[l - 1], &self.points, num),
} }
if !cur.is_zero() { if !cur.is_zero() {
@ -410,7 +410,7 @@ fn insert<P: Point>(
insertion.reset(); insertion.reset();
let candidate_point = &points[pid]; let candidate_point = &points[pid];
insertion.push(new, candidate_point, points); insertion.push(new, candidate_point, points);
for hop in layer.nearest_iter(pid) { for hop in (&*layer).nearest_iter(pid) {
insertion.push(hop, candidate_point, points); insertion.push(hop, candidate_point, points);
} }
@ -484,37 +484,6 @@ impl SearchPool {
} }
} }
trait Layer {
/// Search this layer for nodes near the given `point`
///
/// This contains the loops from the paper's algorithm 2. `point` represents `q`, the query
/// element; `search.candidates` contains the enter points `ep`. `points` contains all the
/// points, which is required to calculate distances between two points.
///
/// The `links` argument represents the number of links from each candidate to consider. This
/// function may be called for a higher layer (with M links per node) or the zero layer (with
/// M * 2 links per node), but for performance reasons we often call this function on the data
/// representation matching the zero layer even when we're referring to a higher layer. In that
/// case, we use `links` to constrain the number of per-candidate links we consider for search.
fn search<P: Point>(&self, point: &P, search: &mut Search, points: &[P], links: usize) {
while let Some(Reverse(candidate)) = search.candidates.pop() {
if let Some(furthest) = search.nearest.last() {
if candidate.distance > furthest.distance {
break;
}
}
for pid in self.nearest_iter(candidate.pid).take(links) {
search.push(pid, point, points);
}
}
search.nearest.truncate(search.ef);
}
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>;
}
/// Keeps mutable state for searching a point's nearest neighbors /// Keeps mutable state for searching a point's nearest neighbors
/// ///
/// In particular, this contains most of the state used in algorithm 2. The structure is /// In particular, this contains most of the state used in algorithm 2. The structure is
@ -543,6 +512,33 @@ impl Search {
} }
} }
/// Search the given layer for nodes near the given `point`
///
/// This contains the loops from the paper's algorithm 2. `point` represents `q`, the query
/// element; `search.candidates` contains the enter points `ep`. `points` contains all the
/// points, which is required to calculate distances between two points.
///
/// The `links` argument represents the number of links from each candidate to consider. This
/// function may be called for a higher layer (with M links per node) or the zero layer (with
/// M * 2 links per node), but for performance reasons we often call this function on the data
/// representation matching the zero layer even when we're referring to a higher layer. In that
/// case, we use `links` to constrain the number of per-candidate links we consider for search.
fn search<L: Layer, P: Point>(&mut self, point: &P, layer: L, points: &[P], links: usize) {
while let Some(Reverse(candidate)) = self.candidates.pop() {
if let Some(furthest) = self.nearest.last() {
if candidate.distance > furthest.distance {
break;
}
}
for pid in layer.nearest_iter(candidate.pid).take(links) {
self.push(pid, point, points);
}
}
self.nearest.truncate(self.ef);
}
/// Resets the state to be ready for a new search /// Resets the state to be ready for a new search
fn reset(&mut self) { fn reset(&mut self) {
let Search { let Search {
@ -568,7 +564,7 @@ impl Search {
fn select_heuristic<P: Point>( fn select_heuristic<P: Point>(
&mut self, &mut self,
layer: &[ZeroNode], layer: &Vec<ZeroNode>,
num: usize, num: usize,
point: &P, point: &P,
points: &[P], points: &[P],

View File

@ -7,7 +7,7 @@ use rand::Rng;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{Hnsw, Layer, Point, M}; use crate::{Hnsw, Point, M};
pub(crate) struct Visited { pub(crate) struct Visited {
store: Vec<u8>, store: Vec<u8>,
@ -66,7 +66,7 @@ pub(crate) struct UpperNode {
pub(crate) nearest: [PointId; M], pub(crate) nearest: [PointId; M],
} }
impl Layer for [UpperNode] { impl Layer for &Vec<UpperNode> {
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
NearestIter { NearestIter {
nearest: &self[pid.0 as usize].nearest, nearest: &self[pid.0 as usize].nearest,
@ -91,7 +91,7 @@ impl Default for ZeroNode {
} }
} }
impl Layer for [ZeroNode] { impl Layer for &Vec<ZeroNode> {
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
NearestIter { NearestIter {
nearest: &self[pid.0 as usize].nearest, nearest: &self[pid.0 as usize].nearest,
@ -99,6 +99,10 @@ impl Layer for [ZeroNode] {
} }
} }
pub(crate) trait Layer {
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>;
}
pub(crate) struct NearestIter<'a> { pub(crate) struct NearestIter<'a> {
nearest: &'a [PointId], nearest: &'a [PointId],
} }