Restructure in preparation for parallelism
This commit is contained in:
parent
98a673fea2
commit
f8d1941d7c
|
@ -23,13 +23,77 @@ version = "0.1.10"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
|
checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cfg-if"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "const_fn"
|
||||||
|
version = "0.4.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cd51eab21ab4fd6a3bf889e2d0958c0a6e3a61ad04260325e919e652a2a62826"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-channel"
|
||||||
|
version = "0.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "dca26ee1f8d361640700bde38b2c37d8c22b3ce2d360e1fc1c74ea4b0aa7d775"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if 1.0.0",
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-deque"
|
||||||
|
version = "0.8.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "94af6efb46fef72616855b036a624cf27ba656ffc9be1b9a3c931cfc7749a9a9"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if 1.0.0",
|
||||||
|
"crossbeam-epoch",
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-epoch"
|
||||||
|
version = "0.9.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a1aaa739f95311c2c7887a76863f500026092fb1dce0161dab577e559ef3569d"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if 1.0.0",
|
||||||
|
"const_fn",
|
||||||
|
"crossbeam-utils",
|
||||||
|
"lazy_static",
|
||||||
|
"memoffset",
|
||||||
|
"scopeguard",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-utils"
|
||||||
|
version = "0.8.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "02d96d1e189ef58269ebe5b97953da3274d83a93af647c2ddd6f9dab28cedb8d"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"cfg-if 1.0.0",
|
||||||
|
"lazy_static",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "either"
|
||||||
|
version = "1.6.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "getrandom"
|
name = "getrandom"
|
||||||
version = "0.1.15"
|
version = "0.1.15"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "fc587bc0ec293155d5bfa6b9891ec18a1e330c234f896ea47fbada4cadbe47e6"
|
checksum = "fc587bc0ec293155d5bfa6b9891ec18a1e330c234f896ea47fbada4cadbe47e6"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if 0.1.10",
|
||||||
"libc",
|
"libc",
|
||||||
"wasi",
|
"wasi",
|
||||||
]
|
]
|
||||||
|
@ -40,11 +104,20 @@ version = "0.2.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ee8025cf36f917e6a52cce185b7c7177689b838b7ec138364e50cc2277a56cf4"
|
checksum = "ee8025cf36f917e6a52cce185b7c7177689b838b7ec138364e50cc2277a56cf4"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if 0.1.10",
|
||||||
"libc",
|
"libc",
|
||||||
"wasi",
|
"wasi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hermit-abi"
|
||||||
|
version = "0.1.17"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5aca5565f760fb5b220e499d72710ed156fdb74e631659e99377d9ebfbd13ae8"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hinasmawo"
|
name = "hinasmawo"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
@ -52,6 +125,7 @@ dependencies = [
|
||||||
"ahash",
|
"ahash",
|
||||||
"ordered-float",
|
"ordered-float",
|
||||||
"rand",
|
"rand",
|
||||||
|
"rayon",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -66,6 +140,15 @@ version = "0.2.80"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4d58d1b70b004888f764dfbf6a26a3b0342a1632d33968e4a179d8011c760614"
|
checksum = "4d58d1b70b004888f764dfbf6a26a3b0342a1632d33968e4a179d8011c760614"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "memoffset"
|
||||||
|
version = "0.6.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "157b4208e3059a8f9e78d559edc658e13df41410cb3ae03979c83130067fdd87"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num-traits"
|
name = "num-traits"
|
||||||
version = "0.2.14"
|
version = "0.2.14"
|
||||||
|
@ -75,6 +158,16 @@ dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num_cpus"
|
||||||
|
version = "1.13.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
|
||||||
|
dependencies = [
|
||||||
|
"hermit-abi",
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ordered-float"
|
name = "ordered-float"
|
||||||
version = "2.0.0"
|
version = "2.0.0"
|
||||||
|
@ -141,6 +234,37 @@ dependencies = [
|
||||||
"rand_core",
|
"rand_core",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rayon"
|
||||||
|
version = "1.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8b0d8e0819fadc20c74ea8373106ead0600e3a67ef1fe8da56e39b9ae7275674"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"crossbeam-deque",
|
||||||
|
"either",
|
||||||
|
"rayon-core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rayon-core"
|
||||||
|
version = "1.9.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9ab346ac5921dc62ffa9f89b7a773907511cdfa5490c572ae9be1be33e8afa4a"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-channel",
|
||||||
|
"crossbeam-deque",
|
||||||
|
"crossbeam-utils",
|
||||||
|
"lazy_static",
|
||||||
|
"num_cpus",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "scopeguard"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "version_check"
|
name = "version_check"
|
||||||
version = "0.9.2"
|
version = "0.9.2"
|
||||||
|
|
|
@ -7,4 +7,5 @@ edition = "2018"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
ahash = "0.6.1"
|
ahash = "0.6.1"
|
||||||
rand = { version = "0.7.3", features = ["small_rng"] }
|
rand = { version = "0.7.3", features = ["small_rng"] }
|
||||||
|
rayon = "1.5"
|
||||||
ordered-float = "2.0"
|
ordered-float = "2.0"
|
||||||
|
|
508
src/lib.rs
508
src/lib.rs
|
@ -1,5 +1,6 @@
|
||||||
use std::cmp::{max, Ordering};
|
use std::cmp::{max, min, Ordering};
|
||||||
use std::ops::{Index, IndexMut};
|
use std::hash::Hash;
|
||||||
|
use std::ops::Index;
|
||||||
|
|
||||||
use ahash::AHashSet as HashSet;
|
use ahash::AHashSet as HashSet;
|
||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
|
@ -7,153 +8,159 @@ use rand::rngs::SmallRng;
|
||||||
use rand::{RngCore, SeedableRng};
|
use rand::{RngCore, SeedableRng};
|
||||||
|
|
||||||
pub struct Hnsw<P> {
|
pub struct Hnsw<P> {
|
||||||
ef_construction: usize,
|
ef_search: usize,
|
||||||
points: Vec<P>,
|
points: Vec<P>,
|
||||||
zero: Vec<ZeroNode>,
|
zero: Vec<ZeroNode>,
|
||||||
layers: Vec<Vec<UpperNode>>,
|
layers: Vec<Vec<UpperNode>>,
|
||||||
rng: SmallRng,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<P> Hnsw<P>
|
impl<P> Hnsw<P>
|
||||||
where
|
where
|
||||||
P: Point,
|
P: Point + std::fmt::Debug,
|
||||||
{
|
{
|
||||||
pub fn new(ef_construction: usize) -> Self {
|
pub fn new(points: &[P], ef_construction: usize, ef_search: usize) -> (Self, Vec<PointId>) {
|
||||||
Self {
|
if points.is_empty() {
|
||||||
ef_construction,
|
return (
|
||||||
points: Vec::new(),
|
Self {
|
||||||
zero: Vec::new(),
|
ef_search,
|
||||||
layers: Vec::new(),
|
zero: Vec::new(),
|
||||||
rng: SmallRng::from_entropy(),
|
points: Vec::new(),
|
||||||
}
|
layers: Vec::new(),
|
||||||
}
|
},
|
||||||
|
Vec::new(),
|
||||||
/// Insert a point into the `Hnsw`, returning a `PointId`
|
);
|
||||||
///
|
|
||||||
/// `PointId` implements `Hash`, `Eq` and friends, so it can be linked to some value.
|
|
||||||
pub fn insert(&mut self, point: P, search: &mut Search) -> PointId {
|
|
||||||
let layer = self.rng.next_u32() as f32 / u32::MAX as f32;
|
|
||||||
let layer = LayerId((-(layer.ln() * (M as f32).ln())).floor() as usize);
|
|
||||||
self.insert_at(point, layer, search)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Deterministic implementation of insertion that takes the `layer` as an argument
|
|
||||||
///
|
|
||||||
/// Implements the paper's algorithm 1, although there is a slight difference in that
|
|
||||||
/// new elements are always inserted from their selected layer, rather than delaying the
|
|
||||||
/// addition of new layers until after the selection of a particular layer.
|
|
||||||
fn insert_at(&mut self, point: P, layer: LayerId, search: &mut Search) -> PointId {
|
|
||||||
let empty = self.points.is_empty();
|
|
||||||
let pid = PointId(self.points.len());
|
|
||||||
self.points.push(point);
|
|
||||||
|
|
||||||
let top = LayerId(self.layers.len());
|
|
||||||
if layer > top {
|
|
||||||
self.layers.resize_with(layer.0, Default::default);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
search.reset(1, top);
|
// Give all points a random layer and sort the list of nodes by descending order for
|
||||||
for cur in max(top, layer).descend() {
|
// construction. This allows us to copy higher layers to lower layers as construction
|
||||||
search.num = if cur <= layer {
|
// progresses, while preserving randomness in each point's layer and insertion order.
|
||||||
self.ef_construction
|
|
||||||
} else {
|
|
||||||
1
|
|
||||||
};
|
|
||||||
|
|
||||||
// If this layer already existed, search it for the 1 nearest neighbor
|
let mut rng = SmallRng::from_entropy();
|
||||||
// (this roughly corresponds to the first loop in the paper's algorithm 1).
|
let mut nodes = (0..points.len())
|
||||||
if cur <= top {
|
.map(|i| (LayerId::random(&mut rng), i))
|
||||||
debug_assert_eq!(search.layer, cur);
|
.collect::<Vec<_>>();
|
||||||
|
nodes.sort_unstable_by(|l, r| r.cmp(&l));
|
||||||
|
|
||||||
// At the first layer that already existed, insert the first element as an initial
|
// Sort the original `points` in layer order.
|
||||||
// candidate. Because the zero-th layer always exists, also check if it was empty.
|
// TODO: maybe optimize this? https://crates.io/crates/permutation
|
||||||
if cur == top && !empty {
|
|
||||||
search.push(NodeId(0), &self[pid], self);
|
let mut new_points = Vec::with_capacity(points.len());
|
||||||
|
let mut new_nodes = Vec::with_capacity(points.len());
|
||||||
|
let mut out = vec![PointId::invalid(); points.len()];
|
||||||
|
for (i, &(layer, idx)) in nodes.iter().enumerate() {
|
||||||
|
let pid = PointId(i);
|
||||||
|
new_points.push(points[idx].clone());
|
||||||
|
new_nodes.push((layer, pid));
|
||||||
|
out[idx] = pid;
|
||||||
|
}
|
||||||
|
let (points, nodes) = (new_points, new_nodes);
|
||||||
|
|
||||||
|
// The layer from the first node is our top layer, or the zero layer if we have no nodes.
|
||||||
|
|
||||||
|
let top = match nodes.first() {
|
||||||
|
Some((top, _)) => *top,
|
||||||
|
None => LayerId(0),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Figure out how many nodes will go on each layer. This helps us allocate memory capacity
|
||||||
|
// for each layer in advance, and also helps enable batch insertion of points.
|
||||||
|
|
||||||
|
let mut sizes = vec![0; top.0 + 1];
|
||||||
|
for (layer, _) in nodes.iter().copied() {
|
||||||
|
sizes[layer.0] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut start = 0;
|
||||||
|
let mut ranges = Vec::with_capacity(top.0);
|
||||||
|
for (i, size) in sizes.into_iter().enumerate().rev() {
|
||||||
|
// Skip the first point, since we insert the enter point separately
|
||||||
|
ranges.push((LayerId(i), max(start, 1)..start + size));
|
||||||
|
start += size;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert the first point so that we have an enter point to start searches with.
|
||||||
|
|
||||||
|
let mut layers = vec![vec![]; top.0];
|
||||||
|
let mut zero = Vec::with_capacity(points.len());
|
||||||
|
zero.push(ZeroNode::default());
|
||||||
|
|
||||||
|
let mut search = Search::default();
|
||||||
|
for (layer, range) in ranges {
|
||||||
|
let num = if layer.0 > 0 { M } else { M * 2 };
|
||||||
|
for &(_, pid) in &nodes[range] {
|
||||||
|
search.reset();
|
||||||
|
let point = &points[pid];
|
||||||
|
search.push(PointId(0), &points[pid], &points);
|
||||||
|
|
||||||
|
for cur in top.descend() {
|
||||||
|
search.num = if cur <= layer { ef_construction } else { 1 };
|
||||||
|
zero.search(point, &mut search, &points, num);
|
||||||
|
match cur > layer {
|
||||||
|
true => search.cull(),
|
||||||
|
false => break,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
self.search_layer(cur, pid, search);
|
zero.insert_node(pid, &search.nearest, &points);
|
||||||
// If we're still above the layer to insert at, we're going to skip the
|
|
||||||
// insertion code below and continue to the next iteration. Before we do so,
|
|
||||||
// we update the `Search` so it's ready for the next layer coming up.
|
|
||||||
if cur > layer {
|
|
||||||
search.lower(self);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we're above the layer to start inserting links at, skip the rest of this loop.
|
// For layers above the zero layer, make a copy of the current state of the zero layer
|
||||||
if cur > layer {
|
// with `nearest` truncated to `M` elements.
|
||||||
continue;
|
if layer.0 > 0 {
|
||||||
}
|
let mut upper = Vec::with_capacity(zero.len());
|
||||||
|
upper.extend(zero.iter().map(|zero| {
|
||||||
if cur.is_zero() {
|
let mut upper = UpperNode::default();
|
||||||
let nid = NodeId(self.zero.len());
|
upper.nearest.copy_from_slice(&zero.nearest[..M]);
|
||||||
let mut node = ZeroNode {
|
upper
|
||||||
nearest: Default::default(),
|
}));
|
||||||
};
|
layers[layer.0 - 1] = upper;
|
||||||
self.link(cur, (nid, &mut node.nearest), &search.nearest);
|
|
||||||
self.zero.push(node);
|
|
||||||
} else {
|
|
||||||
let nid = NodeId(self.layers[cur.0 - 1].len());
|
|
||||||
let lower = match cur.0 == 1 {
|
|
||||||
false => NodeId(self.layers[cur.0 - 2].len()),
|
|
||||||
true => NodeId(self.zero.len()),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut node = UpperNode {
|
|
||||||
pid,
|
|
||||||
lower,
|
|
||||||
nearest: Default::default(),
|
|
||||||
};
|
|
||||||
|
|
||||||
self.link(cur, (nid, &mut node.nearest), &search.nearest);
|
|
||||||
self.layers[cur.0 - 1].push(node);
|
|
||||||
}
|
|
||||||
|
|
||||||
if search.layer == cur && !cur.is_zero() {
|
|
||||||
search.lower(self);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pid
|
(
|
||||||
|
Self {
|
||||||
|
ef_search,
|
||||||
|
zero,
|
||||||
|
points,
|
||||||
|
layers,
|
||||||
|
},
|
||||||
|
out,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Bidirectionally insert links between newly detected neighbors
|
/// Search the index for the points nearest to the reference point `point`
|
||||||
///
|
///
|
||||||
/// `layer` is the layer we're at; `new` contains the `NodeId` for the new `Node` (which has
|
/// The results are returned in the `out` parameter; the number of neighbors to search for
|
||||||
/// not yet been added to the layer) and its still-empty list of nearest neighbors; `found` is
|
/// is limited by the size of the `out` parameter, and the number of results found is returned
|
||||||
/// a slice containing the `Candidate`s found during searching (ordered from near to far).
|
/// in the return value.
|
||||||
///
|
///
|
||||||
/// This just defers to the `Layer`'s `link()` implementation, which specializes on layer type.
|
/// `PointId` values can be initialized with `PointId::invalid()`.
|
||||||
fn link(&mut self, layer: LayerId, new: (NodeId, &mut [Option<NodeId>]), found: &[Candidate]) {
|
pub fn search(&self, point: &P, out: &mut [PointId], search: &mut Search) -> usize {
|
||||||
match layer.0 {
|
if self.points.is_empty() {
|
||||||
0 => self.zero.link(new, found, &self.points),
|
return 0;
|
||||||
l => self.layers[l - 1].link(new, found, &self.points),
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/// Search the given `layer` for neighbors closed to the point identified by `pid`
|
search.reset();
|
||||||
///
|
search.push(PointId(0), point, &self.points);
|
||||||
/// This implements the outer loop of algorithm 2 from the paper, deferring the state mutation
|
for cur in LayerId(self.layers.len()).descend() {
|
||||||
/// in the inner loop to the `Search::push()` implementation.
|
search.num = if cur.is_zero() { self.ef_search } else { 1 };
|
||||||
fn search_layer(&self, layer: LayerId, pid: PointId, search: &mut Search) {
|
|
||||||
debug_assert_eq!(search.layer, layer);
|
let num = if cur.0 > 0 { M } else { M * 2 };
|
||||||
let point = &self[pid];
|
match cur.0 {
|
||||||
while let Some(candidate) = search.candidates.pop() {
|
0 => self.zero.search(point, search, &self.points, num),
|
||||||
if let Some(found) = search.nearest.last() {
|
l => self.layers[l - 1].search(point, search, &self.points, num),
|
||||||
if candidate.distance > found.distance {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let iter = match layer.0 {
|
if !cur.is_zero() {
|
||||||
0 => self.zero[candidate.nid].nearest_iter(),
|
search.cull();
|
||||||
l => self.layers[l - 1][candidate.nid].nearest_iter(),
|
|
||||||
};
|
|
||||||
|
|
||||||
for nid in iter {
|
|
||||||
search.push(nid, point, self);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let found = min(search.nearest.len(), out.len());
|
||||||
|
for (i, candidate) in search.nearest.iter().take(found).enumerate() {
|
||||||
|
out[i] = candidate.pid;
|
||||||
|
}
|
||||||
|
found
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,42 +170,33 @@ where
|
||||||
/// initialized by using `push()` to add the initial enter points.
|
/// initialized by using `push()` to add the initial enter points.
|
||||||
pub struct Search {
|
pub struct Search {
|
||||||
/// Nodes visited so far (`v` in the paper)
|
/// Nodes visited so far (`v` in the paper)
|
||||||
visited: HashSet<NodeId>,
|
visited: HashSet<PointId>,
|
||||||
/// Candidates for further inspection (`C` in the paper)
|
/// Candidates for further inspection (`C` in the paper)
|
||||||
candidates: Vec<Candidate>,
|
candidates: Vec<Candidate>,
|
||||||
/// Nearest neighbors found so far (`W` in the paper)
|
/// Nearest neighbors found so far (`W` in the paper)
|
||||||
nearest: Vec<Candidate>,
|
nearest: Vec<Candidate>,
|
||||||
/// Maximum number of nearest neighbors to retain (`ef` in the paper)
|
/// Maximum number of nearest neighbors to retain (`ef` in the paper)
|
||||||
num: usize,
|
num: usize,
|
||||||
/// Current layer
|
|
||||||
layer: LayerId,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Search {
|
impl Search {
|
||||||
/// Resets the state to be ready for a new search
|
/// Resets the state to be ready for a new search
|
||||||
fn reset(&mut self, num: usize, layer: LayerId) {
|
fn reset(&mut self) {
|
||||||
self.visited.clear();
|
self.visited.clear();
|
||||||
self.candidates.clear();
|
self.candidates.clear();
|
||||||
self.nearest.clear();
|
self.nearest.clear();
|
||||||
self.num = num;
|
|
||||||
self.layer = layer;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Track node `nid` as a potential new neighbor for the given `point`
|
/// Track node `pid` as a potential new neighbor for the given `point`
|
||||||
///
|
///
|
||||||
/// Will immediately return if the node has been considered before. This implements
|
/// Will immediately return if the node has been considered before. This implements
|
||||||
/// the inner loop from the paper's algorithm 2.
|
/// the inner loop from the paper's algorithm 2.
|
||||||
fn push<P: Point>(&mut self, nid: NodeId, point: &P, hnsw: &Hnsw<P>) {
|
fn push<P: Point>(&mut self, pid: PointId, point: &P, points: &[P]) {
|
||||||
if !self.visited.insert(nid) {
|
if !self.visited.insert(pid) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let pid = match self.layer.0 {
|
let other = &points[pid];
|
||||||
0 => hnsw.zero.pid(nid),
|
|
||||||
l => hnsw.layers[l - 1].pid(nid),
|
|
||||||
};
|
|
||||||
|
|
||||||
let other = &hnsw[pid];
|
|
||||||
let distance = OrderedFloat::from(point.distance(other));
|
let distance = OrderedFloat::from(point.distance(other));
|
||||||
if self.nearest.len() >= self.num {
|
if self.nearest.len() >= self.num {
|
||||||
if let Some(found) = self.nearest.last() {
|
if let Some(found) = self.nearest.last() {
|
||||||
|
@ -212,7 +210,7 @@ impl Search {
|
||||||
self.nearest.pop();
|
self.nearest.pop();
|
||||||
}
|
}
|
||||||
|
|
||||||
let new = Candidate { distance, nid };
|
let new = Candidate { distance, pid };
|
||||||
let idx = self.candidates.binary_search(&new).unwrap_or_else(|e| e);
|
let idx = self.candidates.binary_search(&new).unwrap_or_else(|e| e);
|
||||||
self.candidates.insert(idx, new);
|
self.candidates.insert(idx, new);
|
||||||
|
|
||||||
|
@ -222,25 +220,14 @@ impl Search {
|
||||||
|
|
||||||
/// Lower the search to the next lower level
|
/// Lower the search to the next lower level
|
||||||
///
|
///
|
||||||
/// Resets `visited`, `candidates` to match `nearest`.
|
/// Re-initialize the `Search`: `nearest`, the output `W` from the last round, now becomes
|
||||||
///
|
/// the set of enter points, which we use to initialize both `candidates` and `visited`.
|
||||||
/// Panics if called while the `Search` is at level 0.
|
fn cull(&mut self) {
|
||||||
fn lower<P: Point>(&mut self, hnsw: &Hnsw<P>) {
|
|
||||||
debug_assert!(!self.layer.is_zero());
|
|
||||||
|
|
||||||
self.nearest.truncate(self.num); // Limit size of the set of nearest neighbors
|
self.nearest.truncate(self.num); // Limit size of the set of nearest neighbors
|
||||||
let old = hnsw.layers[self.layer.0 - 1].nodes();
|
|
||||||
for cur in self.nearest.iter_mut() {
|
|
||||||
cur.nid = old[cur.nid].lower;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Re-initialize the `Search`: `nearest`, the output `W` from the last round, now becomes
|
|
||||||
// the set of enter points, which we use to initialize both `candidates` and `visited`.
|
|
||||||
self.layer = self.layer.lower();
|
|
||||||
self.candidates.clear();
|
self.candidates.clear();
|
||||||
self.candidates.extend(&self.nearest);
|
self.candidates.extend(&self.nearest);
|
||||||
self.visited.clear();
|
self.visited.clear();
|
||||||
self.visited.extend(self.nearest.iter().map(|c| c.nid));
|
self.visited.extend(self.nearest.iter().map(|c| c.pid));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -250,7 +237,6 @@ impl Default for Search {
|
||||||
visited: HashSet::new(),
|
visited: HashSet::new(),
|
||||||
candidates: Vec::new(),
|
candidates: Vec::new(),
|
||||||
nearest: Vec::new(),
|
nearest: Vec::new(),
|
||||||
layer: LayerId(0),
|
|
||||||
num: 1,
|
num: 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -264,7 +250,7 @@ impl<P> Index<PointId> for Hnsw<P> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<P: Point> Index<PointId> for [P] {
|
impl<P: Point> Index<PointId> for Vec<P> {
|
||||||
type Output = P;
|
type Output = P;
|
||||||
|
|
||||||
fn index(&self, index: PointId) -> &Self::Output {
|
fn index(&self, index: PointId) -> &Self::Output {
|
||||||
|
@ -272,46 +258,10 @@ impl<P: Point> Index<PointId> for [P] {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Index<NodeId> for Vec<UpperNode> {
|
impl<P: Point> Index<PointId> for [P] {
|
||||||
type Output = UpperNode;
|
type Output = P;
|
||||||
|
|
||||||
fn index(&self, index: NodeId) -> &Self::Output {
|
fn index(&self, index: PointId) -> &Self::Output {
|
||||||
&self[index.0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IndexMut<NodeId> for Vec<UpperNode> {
|
|
||||||
fn index_mut(&mut self, index: NodeId) -> &mut Self::Output {
|
|
||||||
&mut self[index.0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Index<NodeId> for [UpperNode] {
|
|
||||||
type Output = UpperNode;
|
|
||||||
|
|
||||||
fn index(&self, index: NodeId) -> &Self::Output {
|
|
||||||
&self[index.0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Index<NodeId> for Vec<ZeroNode> {
|
|
||||||
type Output = ZeroNode;
|
|
||||||
|
|
||||||
fn index(&self, index: NodeId) -> &Self::Output {
|
|
||||||
&self[index.0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IndexMut<NodeId> for Vec<ZeroNode> {
|
|
||||||
fn index_mut(&mut self, index: NodeId) -> &mut Self::Output {
|
|
||||||
&mut self[index.0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Index<NodeId> for [ZeroNode] {
|
|
||||||
type Output = ZeroNode;
|
|
||||||
|
|
||||||
fn index(&self, index: NodeId) -> &Self::Output {
|
|
||||||
&self[index.0]
|
&self[index.0]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -321,17 +271,17 @@ impl Layer for Vec<ZeroNode> {
|
||||||
|
|
||||||
type Node = ZeroNode;
|
type Node = ZeroNode;
|
||||||
|
|
||||||
fn pid(&self, nid: NodeId) -> PointId {
|
fn push(&mut self, new: ZeroNode) {
|
||||||
PointId(nid.0)
|
self.push(new);
|
||||||
}
|
|
||||||
|
|
||||||
fn nodes(&self) -> &[Self::Node] {
|
|
||||||
self
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn nodes_mut(&mut self) -> &mut [Self::Node] {
|
fn nodes_mut(&mut self) -> &mut [Self::Node] {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn nodes(&self) -> &[Self::Node] {
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Layer for Vec<UpperNode> {
|
impl Layer for Vec<UpperNode> {
|
||||||
|
@ -339,17 +289,17 @@ impl Layer for Vec<UpperNode> {
|
||||||
|
|
||||||
type Node = UpperNode;
|
type Node = UpperNode;
|
||||||
|
|
||||||
fn pid(&self, nid: NodeId) -> PointId {
|
fn push(&mut self, new: UpperNode) {
|
||||||
self.nodes()[nid].pid
|
self.push(new);
|
||||||
}
|
|
||||||
|
|
||||||
fn nodes(&self) -> &[Self::Node] {
|
|
||||||
self
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn nodes_mut(&mut self) -> &mut [Self::Node] {
|
fn nodes_mut(&mut self) -> &mut [Self::Node] {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn nodes(&self) -> &[Self::Node] {
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
trait Layer {
|
trait Layer {
|
||||||
|
@ -357,41 +307,56 @@ trait Layer {
|
||||||
|
|
||||||
type Node: Node;
|
type Node: Node;
|
||||||
|
|
||||||
fn pid(&self, nid: NodeId) -> PointId;
|
/// Search this layer for nodes near the given `point`
|
||||||
|
|
||||||
fn nodes(&self) -> &[Self::Node];
|
|
||||||
|
|
||||||
fn nodes_mut(&mut self) -> &mut [Self::Node];
|
|
||||||
|
|
||||||
/// Bidirectionally insert links between newly detected neighbors
|
|
||||||
///
|
///
|
||||||
/// `new` contains the `NodeId` for the new `Node` (which has not yet been added to the layer)
|
/// This contains the loops from the paper's algorithm 2. `point` represents `q`, the query
|
||||||
/// and its still-empty list of nearest neighbors; `found` is a slice containing all
|
/// element; `search.candidates` contains the enter points `ep`. `points` contains all the
|
||||||
|
/// points, which is required to calculate distances between two points.
|
||||||
|
///
|
||||||
|
/// The `num` argument represents the number of links from each candidate to consider. This
|
||||||
|
/// function may be called for a higher layer (with M links per node) or the zero layer (with
|
||||||
|
/// M * 2 links per node), but for performance reasons we often call this function on the data
|
||||||
|
/// representation matching the zero layer even when we're referring to a higher layer. In that
|
||||||
|
/// case, we use `num` to constrain the number of per-candidate links we consider for search.
|
||||||
|
fn search<P: Point>(&self, point: &P, search: &mut Search, points: &[P], num: usize) {
|
||||||
|
while let Some(candidate) = search.candidates.pop() {
|
||||||
|
if let Some(found) = search.nearest.last() {
|
||||||
|
if candidate.distance > found.distance {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for pid in self.nodes()[candidate.pid.0].nearest_iter().take(num) {
|
||||||
|
search.push(pid, point, points);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert new node in this layer
|
||||||
|
///
|
||||||
|
/// `new` contains the `PointId` for the new node; `found` is a slice containing all
|
||||||
/// `Candidate`s found during searching (ordered from near to far).
|
/// `Candidate`s found during searching (ordered from near to far).
|
||||||
///
|
///
|
||||||
/// Initializes both the new node's neighbors (in `new.1`) and updates the nearest neighbors
|
/// Creates the new node, initializing its `nearest` array and updates the nearest neighbors
|
||||||
/// for the new node's neighbors if necessary.
|
/// for the new node's neighbors if necessary.
|
||||||
fn link<P: Point>(
|
fn insert_node<P: Point>(&mut self, new: PointId, found: &[Candidate], points: &[P]) {
|
||||||
&mut self,
|
let mut node = Self::Node::default();
|
||||||
new: (NodeId, &mut [Option<NodeId>]),
|
let new_nearest = node.nearest_mut();
|
||||||
found: &[Candidate],
|
|
||||||
points: &[P],
|
|
||||||
) {
|
|
||||||
// Just make sure the candidates are all unique
|
// Just make sure the candidates are all unique
|
||||||
debug_assert_eq!(
|
debug_assert_eq!(
|
||||||
found.len(),
|
found.len(),
|
||||||
found.iter().map(|c| c.nid).collect::<HashSet<_>>().len()
|
found.iter().map(|c| c.pid).collect::<HashSet<_>>().len()
|
||||||
);
|
);
|
||||||
|
|
||||||
// Only use the `Self::LINKS` nearest candidates found
|
// Only use the `Self::LINKS` nearest candidates found
|
||||||
for (i, candidate) in found.iter().take(Self::LINKS).enumerate() {
|
for (i, candidate) in found.iter().take(Self::LINKS).enumerate() {
|
||||||
// `candidate` here is the new node's neighbor
|
// `candidate` here is the new node's neighbor
|
||||||
let &Candidate { distance, nid } = candidate;
|
let &Candidate { distance, pid } = candidate;
|
||||||
new.1[i] = Some(nid); // Update the new node's `nearest`
|
new_nearest[i] = Some(pid); // Update the new node's `nearest`
|
||||||
|
|
||||||
let pid = self.pid(nid);
|
let old = &points[pid];
|
||||||
let old = &points[pid.0];
|
let nearest = self.nodes()[pid.0].nearest();
|
||||||
let nearest = self.nodes()[nid.0].nearest();
|
|
||||||
|
|
||||||
// Find the correct index to insert at to keep the neighbor's neighbors sorted
|
// Find the correct index to insert at to keep the neighbor's neighbors sorted
|
||||||
let idx = nearest
|
let idx = nearest
|
||||||
|
@ -403,8 +368,7 @@ trait Layer {
|
||||||
None => return Ordering::Greater,
|
None => return Ordering::Greater,
|
||||||
};
|
};
|
||||||
|
|
||||||
let pid = self.pid(third);
|
let third_distance = OrderedFloat::from(old.distance(&points[third.0]));
|
||||||
let third_distance = OrderedFloat::from(old.distance(&points[pid.0]));
|
|
||||||
distance.cmp(&third_distance)
|
distance.cmp(&third_distance)
|
||||||
})
|
})
|
||||||
.unwrap_or_else(|e| e);
|
.unwrap_or_else(|e| e);
|
||||||
|
@ -415,39 +379,41 @@ trait Layer {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let nearest = self.nodes_mut()[nid.0].nearest_mut();
|
let nearest = self.nodes_mut()[pid.0].nearest_mut();
|
||||||
if nearest[idx].is_none() {
|
if nearest[idx].is_none() {
|
||||||
nearest[idx] = Some(new.0);
|
nearest[idx] = Some(new);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let end = Self::LINKS - 1;
|
let end = Self::LINKS - 1;
|
||||||
nearest.copy_within(idx..end, idx + 1);
|
nearest.copy_within(idx..end, idx + 1);
|
||||||
nearest[idx] = Some(new.0);
|
nearest[idx] = Some(new);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.push(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn push(&mut self, new: Self::Node);
|
||||||
|
|
||||||
|
fn nodes_mut(&mut self) -> &mut [Self::Node];
|
||||||
|
|
||||||
|
fn nodes(&self) -> &[Self::Node];
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Copy, Debug, Default)]
|
||||||
struct UpperNode {
|
struct UpperNode {
|
||||||
/// This node's point
|
|
||||||
pid: PointId,
|
|
||||||
/// The point's node on the next level down
|
|
||||||
///
|
|
||||||
/// This is only used when lowering the search.
|
|
||||||
lower: NodeId,
|
|
||||||
/// The nearest neighbors on this layer
|
/// The nearest neighbors on this layer
|
||||||
///
|
///
|
||||||
/// This is always kept in sorted order (near to far).
|
/// This is always kept in sorted order (near to far).
|
||||||
nearest: [Option<NodeId>; M],
|
nearest: [Option<PointId>; M],
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Node for UpperNode {
|
impl Node for UpperNode {
|
||||||
fn nearest(&self) -> &[Option<NodeId>] {
|
fn nearest(&self) -> &[Option<PointId>] {
|
||||||
&self.nearest
|
&self.nearest
|
||||||
}
|
}
|
||||||
|
|
||||||
fn nearest_mut(&mut self) -> &mut [Option<NodeId>] {
|
fn nearest_mut(&mut self) -> &mut [Option<PointId>] {
|
||||||
&mut self.nearest
|
&mut self.nearest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -458,20 +424,20 @@ impl Node for UpperNode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Copy, Debug, Default)]
|
||||||
struct ZeroNode {
|
struct ZeroNode {
|
||||||
/// The nearest neighbors on this layer
|
/// The nearest neighbors on this layer
|
||||||
///
|
///
|
||||||
/// This is always kept in sorted order (near to far).
|
/// This is always kept in sorted order (near to far).
|
||||||
nearest: [Option<NodeId>; M * 2],
|
nearest: [Option<PointId>; M * 2],
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Node for ZeroNode {
|
impl Node for ZeroNode {
|
||||||
fn nearest(&self) -> &[Option<NodeId>] {
|
fn nearest(&self) -> &[Option<PointId>] {
|
||||||
&self.nearest
|
&self.nearest
|
||||||
}
|
}
|
||||||
|
|
||||||
fn nearest_mut(&mut self) -> &mut [Option<NodeId>] {
|
fn nearest_mut(&mut self) -> &mut [Option<PointId>] {
|
||||||
&mut self.nearest
|
&mut self.nearest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -482,18 +448,18 @@ impl Node for ZeroNode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
trait Node {
|
trait Node: Default {
|
||||||
fn nearest(&self) -> &[Option<NodeId>];
|
fn nearest(&self) -> &[Option<PointId>];
|
||||||
fn nearest_mut(&mut self) -> &mut [Option<NodeId>];
|
fn nearest_mut(&mut self) -> &mut [Option<PointId>];
|
||||||
fn nearest_iter(&self) -> NearestIter<'_>;
|
fn nearest_iter(&self) -> NearestIter<'_>;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct NearestIter<'a> {
|
struct NearestIter<'a> {
|
||||||
nearest: &'a [Option<NodeId>],
|
nearest: &'a [Option<PointId>],
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Iterator for NearestIter<'a> {
|
impl<'a> Iterator for NearestIter<'a> {
|
||||||
type Item = NodeId;
|
type Item = PointId;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
let (&first, rest) = self.nearest.split_first()?;
|
let (&first, rest) = self.nearest.split_first()?;
|
||||||
|
@ -509,11 +475,9 @@ impl<'a> Iterator for NearestIter<'a> {
|
||||||
struct LayerId(usize);
|
struct LayerId(usize);
|
||||||
|
|
||||||
impl LayerId {
|
impl LayerId {
|
||||||
/// Return a `LayerId` for the layer one lower
|
fn random(rng: &mut SmallRng) -> Self {
|
||||||
///
|
let layer = rng.next_u32() as f32 / u32::MAX as f32;
|
||||||
/// Panics when called for `LayerId(0)`.
|
LayerId((-(layer.ln() * (M as f32).ln())).floor() as usize)
|
||||||
fn lower(&self) -> LayerId {
|
|
||||||
LayerId(self.0 - 1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn descend(&self) -> DescendingLayerIter {
|
fn descend(&self) -> DescendingLayerIter {
|
||||||
|
@ -546,14 +510,14 @@ impl Iterator for DescendingLayerIter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Point {
|
pub trait Point: Clone {
|
||||||
fn distance(&self, other: &Self) -> f32;
|
fn distance(&self, other: &Self) -> f32;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
|
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
|
||||||
struct Candidate {
|
struct Candidate {
|
||||||
distance: OrderedFloat<f32>,
|
distance: OrderedFloat<f32>,
|
||||||
nid: NodeId,
|
pid: PointId,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// References a node in a particular layer (usually the same layer)
|
/// References a node in a particular layer (usually the same layer)
|
||||||
|
@ -566,6 +530,12 @@ struct NodeId(usize);
|
||||||
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||||
pub struct PointId(usize);
|
pub struct PointId(usize);
|
||||||
|
|
||||||
|
impl PointId {
|
||||||
|
pub fn invalid() -> Self {
|
||||||
|
PointId(usize::MAX)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// The parameter `M` from the paper
|
/// The parameter `M` from the paper
|
||||||
///
|
///
|
||||||
/// This should become a generic argument to `Hnsw` when possible.
|
/// This should become a generic argument to `Hnsw` when possible.
|
||||||
|
@ -576,15 +546,27 @@ mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_insertion() {
|
fn basic() {
|
||||||
|
let (hnsw, pids) = Hnsw::new(
|
||||||
|
&[
|
||||||
|
Point(0.1, 0.4),
|
||||||
|
Point(-0.324, 0.543),
|
||||||
|
Point(0.87, -0.33),
|
||||||
|
Point(0.452, 0.932),
|
||||||
|
],
|
||||||
|
100,
|
||||||
|
100,
|
||||||
|
);
|
||||||
|
|
||||||
let mut search = Search::default();
|
let mut search = Search::default();
|
||||||
let mut hnsw = Hnsw::new(100);
|
let mut results = vec![PointId::invalid()];
|
||||||
hnsw.insert(Point(0.1, 0.4), &mut search);
|
let p = Point(0.1, 0.35);
|
||||||
hnsw.insert(Point(-0.324, 0.543), &mut search);
|
let found = hnsw.search(&p, &mut results, &mut search);
|
||||||
hnsw.insert(Point(0.87, -0.33), &mut search);
|
assert_eq!(found, 1);
|
||||||
hnsw.insert(Point(0.452, 0.932), &mut search);
|
assert_eq!(&results, &[pids[0]]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
struct Point(f32, f32);
|
struct Point(f32, f32);
|
||||||
|
|
||||||
impl super::Point for Point {
|
impl super::Point for Point {
|
||||||
|
|
Loading…
Reference in New Issue