diff --git a/src/lib.rs b/src/lib.rs index 5ecdcd1..1aa4095 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize}; mod types; pub use types::PointId; -use types::{Candidate, LayerId, NearestIter, UpperNode, Visited, ZeroNode}; +use types::{Candidate, Layer, LayerId, UpperNode, Visited, ZeroNode}; /// Parameters for building the `Hnsw` pub struct Builder { @@ -260,11 +260,11 @@ where search.ef = if cur <= layer { ef_construction } else { 1 }; match cur > layer { true => { - layers[cur.0 - 1].search(point, search, &points, num); + search.search(point, &layers[cur.0 - 1], &points, num); search.cull(); } false => { - zero.search(point, search, &points, num); + search.search(point, &zero, &points, num); break; } } @@ -347,8 +347,8 @@ where search.ef = ef; match cur.0 { - 0 => self.zero.search(point, search, &self.points, num), - l => self.layers[l - 1].search(point, search, &self.points, num), + 0 => search.search(point, &self.zero, &self.points, num), + l => search.search(point, &self.layers[l - 1], &self.points, num), } if !cur.is_zero() { @@ -410,7 +410,7 @@ fn insert( insertion.reset(); let candidate_point = &points[pid]; insertion.push(new, candidate_point, points); - for hop in layer.nearest_iter(pid) { + for hop in (&*layer).nearest_iter(pid) { insertion.push(hop, candidate_point, points); } @@ -484,37 +484,6 @@ impl SearchPool { } } -trait Layer { - /// Search this layer for nodes near the given `point` - /// - /// This contains the loops from the paper's algorithm 2. `point` represents `q`, the query - /// element; `search.candidates` contains the enter points `ep`. `points` contains all the - /// points, which is required to calculate distances between two points. - /// - /// The `links` argument represents the number of links from each candidate to consider. This - /// function may be called for a higher layer (with M links per node) or the zero layer (with - /// M * 2 links per node), but for performance reasons we often call this function on the data - /// representation matching the zero layer even when we're referring to a higher layer. In that - /// case, we use `links` to constrain the number of per-candidate links we consider for search. - fn search(&self, point: &P, search: &mut Search, points: &[P], links: usize) { - while let Some(Reverse(candidate)) = search.candidates.pop() { - if let Some(furthest) = search.nearest.last() { - if candidate.distance > furthest.distance { - break; - } - } - - for pid in self.nearest_iter(candidate.pid).take(links) { - search.push(pid, point, points); - } - } - - search.nearest.truncate(search.ef); - } - - fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>; -} - /// Keeps mutable state for searching a point's nearest neighbors /// /// In particular, this contains most of the state used in algorithm 2. The structure is @@ -543,6 +512,33 @@ impl Search { } } + /// Search the given layer for nodes near the given `point` + /// + /// This contains the loops from the paper's algorithm 2. `point` represents `q`, the query + /// element; `search.candidates` contains the enter points `ep`. `points` contains all the + /// points, which is required to calculate distances between two points. + /// + /// The `links` argument represents the number of links from each candidate to consider. This + /// function may be called for a higher layer (with M links per node) or the zero layer (with + /// M * 2 links per node), but for performance reasons we often call this function on the data + /// representation matching the zero layer even when we're referring to a higher layer. In that + /// case, we use `links` to constrain the number of per-candidate links we consider for search. + fn search(&mut self, point: &P, layer: L, points: &[P], links: usize) { + while let Some(Reverse(candidate)) = self.candidates.pop() { + if let Some(furthest) = self.nearest.last() { + if candidate.distance > furthest.distance { + break; + } + } + + for pid in layer.nearest_iter(candidate.pid).take(links) { + self.push(pid, point, points); + } + } + + self.nearest.truncate(self.ef); + } + /// Resets the state to be ready for a new search fn reset(&mut self) { let Search { @@ -568,7 +564,7 @@ impl Search { fn select_heuristic( &mut self, - layer: &[ZeroNode], + layer: &Vec, num: usize, point: &P, points: &[P], diff --git a/src/types.rs b/src/types.rs index aeead7d..668047b 100644 --- a/src/types.rs +++ b/src/types.rs @@ -7,7 +7,7 @@ use rand::Rng; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use crate::{Hnsw, Layer, Point, M}; +use crate::{Hnsw, Point, M}; pub(crate) struct Visited { store: Vec, @@ -66,7 +66,7 @@ pub(crate) struct UpperNode { pub(crate) nearest: [PointId; M], } -impl Layer for [UpperNode] { +impl Layer for &Vec { fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { NearestIter { nearest: &self[pid.0 as usize].nearest, @@ -91,7 +91,7 @@ impl Default for ZeroNode { } } -impl Layer for [ZeroNode] { +impl Layer for &Vec { fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { NearestIter { nearest: &self[pid.0 as usize].nearest, @@ -99,6 +99,10 @@ impl Layer for [ZeroNode] { } } +pub(crate) trait Layer { + fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>; +} + pub(crate) struct NearestIter<'a> { nearest: &'a [PointId], }