Move search() method into Search
This commit is contained in:
parent
5cf83543db
commit
d80a2e3f67
72
src/lib.rs
72
src/lib.rs
|
@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
|
|||
|
||||
mod types;
|
||||
pub use types::PointId;
|
||||
use types::{Candidate, LayerId, NearestIter, UpperNode, Visited, ZeroNode};
|
||||
use types::{Candidate, Layer, LayerId, UpperNode, Visited, ZeroNode};
|
||||
|
||||
/// Parameters for building the `Hnsw`
|
||||
pub struct Builder {
|
||||
|
@ -260,11 +260,11 @@ where
|
|||
search.ef = if cur <= layer { ef_construction } else { 1 };
|
||||
match cur > layer {
|
||||
true => {
|
||||
layers[cur.0 - 1].search(point, search, &points, num);
|
||||
search.search(point, &layers[cur.0 - 1], &points, num);
|
||||
search.cull();
|
||||
}
|
||||
false => {
|
||||
zero.search(point, search, &points, num);
|
||||
search.search(point, &zero, &points, num);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -347,8 +347,8 @@ where
|
|||
|
||||
search.ef = ef;
|
||||
match cur.0 {
|
||||
0 => self.zero.search(point, search, &self.points, num),
|
||||
l => self.layers[l - 1].search(point, search, &self.points, num),
|
||||
0 => search.search(point, &self.zero, &self.points, num),
|
||||
l => search.search(point, &self.layers[l - 1], &self.points, num),
|
||||
}
|
||||
|
||||
if !cur.is_zero() {
|
||||
|
@ -410,7 +410,7 @@ fn insert<P: Point>(
|
|||
insertion.reset();
|
||||
let candidate_point = &points[pid];
|
||||
insertion.push(new, candidate_point, points);
|
||||
for hop in layer.nearest_iter(pid) {
|
||||
for hop in (&*layer).nearest_iter(pid) {
|
||||
insertion.push(hop, candidate_point, points);
|
||||
}
|
||||
|
||||
|
@ -484,37 +484,6 @@ impl SearchPool {
|
|||
}
|
||||
}
|
||||
|
||||
trait Layer {
|
||||
/// Search this layer for nodes near the given `point`
|
||||
///
|
||||
/// This contains the loops from the paper's algorithm 2. `point` represents `q`, the query
|
||||
/// element; `search.candidates` contains the enter points `ep`. `points` contains all the
|
||||
/// points, which is required to calculate distances between two points.
|
||||
///
|
||||
/// The `links` argument represents the number of links from each candidate to consider. This
|
||||
/// function may be called for a higher layer (with M links per node) or the zero layer (with
|
||||
/// M * 2 links per node), but for performance reasons we often call this function on the data
|
||||
/// representation matching the zero layer even when we're referring to a higher layer. In that
|
||||
/// case, we use `links` to constrain the number of per-candidate links we consider for search.
|
||||
fn search<P: Point>(&self, point: &P, search: &mut Search, points: &[P], links: usize) {
|
||||
while let Some(Reverse(candidate)) = search.candidates.pop() {
|
||||
if let Some(furthest) = search.nearest.last() {
|
||||
if candidate.distance > furthest.distance {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for pid in self.nearest_iter(candidate.pid).take(links) {
|
||||
search.push(pid, point, points);
|
||||
}
|
||||
}
|
||||
|
||||
search.nearest.truncate(search.ef);
|
||||
}
|
||||
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>;
|
||||
}
|
||||
|
||||
/// Keeps mutable state for searching a point's nearest neighbors
|
||||
///
|
||||
/// In particular, this contains most of the state used in algorithm 2. The structure is
|
||||
|
@ -543,6 +512,33 @@ impl Search {
|
|||
}
|
||||
}
|
||||
|
||||
/// Search the given layer for nodes near the given `point`
|
||||
///
|
||||
/// This contains the loops from the paper's algorithm 2. `point` represents `q`, the query
|
||||
/// element; `search.candidates` contains the enter points `ep`. `points` contains all the
|
||||
/// points, which is required to calculate distances between two points.
|
||||
///
|
||||
/// The `links` argument represents the number of links from each candidate to consider. This
|
||||
/// function may be called for a higher layer (with M links per node) or the zero layer (with
|
||||
/// M * 2 links per node), but for performance reasons we often call this function on the data
|
||||
/// representation matching the zero layer even when we're referring to a higher layer. In that
|
||||
/// case, we use `links` to constrain the number of per-candidate links we consider for search.
|
||||
fn search<L: Layer, P: Point>(&mut self, point: &P, layer: L, points: &[P], links: usize) {
|
||||
while let Some(Reverse(candidate)) = self.candidates.pop() {
|
||||
if let Some(furthest) = self.nearest.last() {
|
||||
if candidate.distance > furthest.distance {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for pid in layer.nearest_iter(candidate.pid).take(links) {
|
||||
self.push(pid, point, points);
|
||||
}
|
||||
}
|
||||
|
||||
self.nearest.truncate(self.ef);
|
||||
}
|
||||
|
||||
/// Resets the state to be ready for a new search
|
||||
fn reset(&mut self) {
|
||||
let Search {
|
||||
|
@ -568,7 +564,7 @@ impl Search {
|
|||
|
||||
fn select_heuristic<P: Point>(
|
||||
&mut self,
|
||||
layer: &[ZeroNode],
|
||||
layer: &Vec<ZeroNode>,
|
||||
num: usize,
|
||||
point: &P,
|
||||
points: &[P],
|
||||
|
|
10
src/types.rs
10
src/types.rs
|
@ -7,7 +7,7 @@ use rand::Rng;
|
|||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{Hnsw, Layer, Point, M};
|
||||
use crate::{Hnsw, Point, M};
|
||||
|
||||
pub(crate) struct Visited {
|
||||
store: Vec<u8>,
|
||||
|
@ -66,7 +66,7 @@ pub(crate) struct UpperNode {
|
|||
pub(crate) nearest: [PointId; M],
|
||||
}
|
||||
|
||||
impl Layer for [UpperNode] {
|
||||
impl Layer for &Vec<UpperNode> {
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
||||
NearestIter {
|
||||
nearest: &self[pid.0 as usize].nearest,
|
||||
|
@ -91,7 +91,7 @@ impl Default for ZeroNode {
|
|||
}
|
||||
}
|
||||
|
||||
impl Layer for [ZeroNode] {
|
||||
impl Layer for &Vec<ZeroNode> {
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
||||
NearestIter {
|
||||
nearest: &self[pid.0 as usize].nearest,
|
||||
|
@ -99,6 +99,10 @@ impl Layer for [ZeroNode] {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) trait Layer {
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>;
|
||||
}
|
||||
|
||||
pub(crate) struct NearestIter<'a> {
|
||||
nearest: &'a [PointId],
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue