Move search() method into Search
This commit is contained in:
parent
5cf83543db
commit
d80a2e3f67
72
src/lib.rs
72
src/lib.rs
|
@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
mod types;
|
mod types;
|
||||||
pub use types::PointId;
|
pub use types::PointId;
|
||||||
use types::{Candidate, LayerId, NearestIter, UpperNode, Visited, ZeroNode};
|
use types::{Candidate, Layer, LayerId, UpperNode, Visited, ZeroNode};
|
||||||
|
|
||||||
/// Parameters for building the `Hnsw`
|
/// Parameters for building the `Hnsw`
|
||||||
pub struct Builder {
|
pub struct Builder {
|
||||||
|
@ -260,11 +260,11 @@ where
|
||||||
search.ef = if cur <= layer { ef_construction } else { 1 };
|
search.ef = if cur <= layer { ef_construction } else { 1 };
|
||||||
match cur > layer {
|
match cur > layer {
|
||||||
true => {
|
true => {
|
||||||
layers[cur.0 - 1].search(point, search, &points, num);
|
search.search(point, &layers[cur.0 - 1], &points, num);
|
||||||
search.cull();
|
search.cull();
|
||||||
}
|
}
|
||||||
false => {
|
false => {
|
||||||
zero.search(point, search, &points, num);
|
search.search(point, &zero, &points, num);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -347,8 +347,8 @@ where
|
||||||
|
|
||||||
search.ef = ef;
|
search.ef = ef;
|
||||||
match cur.0 {
|
match cur.0 {
|
||||||
0 => self.zero.search(point, search, &self.points, num),
|
0 => search.search(point, &self.zero, &self.points, num),
|
||||||
l => self.layers[l - 1].search(point, search, &self.points, num),
|
l => search.search(point, &self.layers[l - 1], &self.points, num),
|
||||||
}
|
}
|
||||||
|
|
||||||
if !cur.is_zero() {
|
if !cur.is_zero() {
|
||||||
|
@ -410,7 +410,7 @@ fn insert<P: Point>(
|
||||||
insertion.reset();
|
insertion.reset();
|
||||||
let candidate_point = &points[pid];
|
let candidate_point = &points[pid];
|
||||||
insertion.push(new, candidate_point, points);
|
insertion.push(new, candidate_point, points);
|
||||||
for hop in layer.nearest_iter(pid) {
|
for hop in (&*layer).nearest_iter(pid) {
|
||||||
insertion.push(hop, candidate_point, points);
|
insertion.push(hop, candidate_point, points);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -484,37 +484,6 @@ impl SearchPool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
trait Layer {
|
|
||||||
/// Search this layer for nodes near the given `point`
|
|
||||||
///
|
|
||||||
/// This contains the loops from the paper's algorithm 2. `point` represents `q`, the query
|
|
||||||
/// element; `search.candidates` contains the enter points `ep`. `points` contains all the
|
|
||||||
/// points, which is required to calculate distances between two points.
|
|
||||||
///
|
|
||||||
/// The `links` argument represents the number of links from each candidate to consider. This
|
|
||||||
/// function may be called for a higher layer (with M links per node) or the zero layer (with
|
|
||||||
/// M * 2 links per node), but for performance reasons we often call this function on the data
|
|
||||||
/// representation matching the zero layer even when we're referring to a higher layer. In that
|
|
||||||
/// case, we use `links` to constrain the number of per-candidate links we consider for search.
|
|
||||||
fn search<P: Point>(&self, point: &P, search: &mut Search, points: &[P], links: usize) {
|
|
||||||
while let Some(Reverse(candidate)) = search.candidates.pop() {
|
|
||||||
if let Some(furthest) = search.nearest.last() {
|
|
||||||
if candidate.distance > furthest.distance {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for pid in self.nearest_iter(candidate.pid).take(links) {
|
|
||||||
search.push(pid, point, points);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
search.nearest.truncate(search.ef);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Keeps mutable state for searching a point's nearest neighbors
|
/// Keeps mutable state for searching a point's nearest neighbors
|
||||||
///
|
///
|
||||||
/// In particular, this contains most of the state used in algorithm 2. The structure is
|
/// In particular, this contains most of the state used in algorithm 2. The structure is
|
||||||
|
@ -543,6 +512,33 @@ impl Search {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Search the given layer for nodes near the given `point`
|
||||||
|
///
|
||||||
|
/// This contains the loops from the paper's algorithm 2. `point` represents `q`, the query
|
||||||
|
/// element; `search.candidates` contains the enter points `ep`. `points` contains all the
|
||||||
|
/// points, which is required to calculate distances between two points.
|
||||||
|
///
|
||||||
|
/// The `links` argument represents the number of links from each candidate to consider. This
|
||||||
|
/// function may be called for a higher layer (with M links per node) or the zero layer (with
|
||||||
|
/// M * 2 links per node), but for performance reasons we often call this function on the data
|
||||||
|
/// representation matching the zero layer even when we're referring to a higher layer. In that
|
||||||
|
/// case, we use `links` to constrain the number of per-candidate links we consider for search.
|
||||||
|
fn search<L: Layer, P: Point>(&mut self, point: &P, layer: L, points: &[P], links: usize) {
|
||||||
|
while let Some(Reverse(candidate)) = self.candidates.pop() {
|
||||||
|
if let Some(furthest) = self.nearest.last() {
|
||||||
|
if candidate.distance > furthest.distance {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for pid in layer.nearest_iter(candidate.pid).take(links) {
|
||||||
|
self.push(pid, point, points);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.nearest.truncate(self.ef);
|
||||||
|
}
|
||||||
|
|
||||||
/// Resets the state to be ready for a new search
|
/// Resets the state to be ready for a new search
|
||||||
fn reset(&mut self) {
|
fn reset(&mut self) {
|
||||||
let Search {
|
let Search {
|
||||||
|
@ -568,7 +564,7 @@ impl Search {
|
||||||
|
|
||||||
fn select_heuristic<P: Point>(
|
fn select_heuristic<P: Point>(
|
||||||
&mut self,
|
&mut self,
|
||||||
layer: &[ZeroNode],
|
layer: &Vec<ZeroNode>,
|
||||||
num: usize,
|
num: usize,
|
||||||
point: &P,
|
point: &P,
|
||||||
points: &[P],
|
points: &[P],
|
||||||
|
|
10
src/types.rs
10
src/types.rs
|
@ -7,7 +7,7 @@ use rand::Rng;
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::{Hnsw, Layer, Point, M};
|
use crate::{Hnsw, Point, M};
|
||||||
|
|
||||||
pub(crate) struct Visited {
|
pub(crate) struct Visited {
|
||||||
store: Vec<u8>,
|
store: Vec<u8>,
|
||||||
|
@ -66,7 +66,7 @@ pub(crate) struct UpperNode {
|
||||||
pub(crate) nearest: [PointId; M],
|
pub(crate) nearest: [PointId; M],
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Layer for [UpperNode] {
|
impl Layer for &Vec<UpperNode> {
|
||||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
||||||
NearestIter {
|
NearestIter {
|
||||||
nearest: &self[pid.0 as usize].nearest,
|
nearest: &self[pid.0 as usize].nearest,
|
||||||
|
@ -91,7 +91,7 @@ impl Default for ZeroNode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Layer for [ZeroNode] {
|
impl Layer for &Vec<ZeroNode> {
|
||||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
||||||
NearestIter {
|
NearestIter {
|
||||||
nearest: &self[pid.0 as usize].nearest,
|
nearest: &self[pid.0 as usize].nearest,
|
||||||
|
@ -99,6 +99,10 @@ impl Layer for [ZeroNode] {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) trait Layer {
|
||||||
|
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>;
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) struct NearestIter<'a> {
|
pub(crate) struct NearestIter<'a> {
|
||||||
nearest: &'a [PointId],
|
nearest: &'a [PointId],
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue