From f8d1941d7c69047028e1997b4f54243a39ae36c6 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Mon, 14 Dec 2020 15:13:10 +0100 Subject: [PATCH] Restructure in preparation for parallelism --- Cargo.lock | 128 +++++++++++++- Cargo.toml | 1 + src/lib.rs | 508 ++++++++++++++++++++++++++--------------------------- 3 files changed, 372 insertions(+), 265 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1a3cabc..ececa89 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,13 +23,77 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "const_fn" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd51eab21ab4fd6a3bf889e2d0958c0a6e3a61ad04260325e919e652a2a62826" + +[[package]] +name = "crossbeam-channel" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dca26ee1f8d361640700bde38b2c37d8c22b3ce2d360e1fc1c74ea4b0aa7d775" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94af6efb46fef72616855b036a624cf27ba656ffc9be1b9a3c931cfc7749a9a9" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1aaa739f95311c2c7887a76863f500026092fb1dce0161dab577e559ef3569d" +dependencies = [ + "cfg-if 1.0.0", + "const_fn", + "crossbeam-utils", + "lazy_static", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02d96d1e189ef58269ebe5b97953da3274d83a93af647c2ddd6f9dab28cedb8d" +dependencies = [ + "autocfg", + "cfg-if 1.0.0", + "lazy_static", +] + +[[package]] +name = "either" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" + [[package]] name = "getrandom" version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc587bc0ec293155d5bfa6b9891ec18a1e330c234f896ea47fbada4cadbe47e6" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "libc", "wasi", ] @@ -40,11 +104,20 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee8025cf36f917e6a52cce185b7c7177689b838b7ec138364e50cc2277a56cf4" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "libc", "wasi", ] +[[package]] +name = "hermit-abi" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aca5565f760fb5b220e499d72710ed156fdb74e631659e99377d9ebfbd13ae8" +dependencies = [ + "libc", +] + [[package]] name = "hinasmawo" version = "0.1.0" @@ -52,6 +125,7 @@ dependencies = [ "ahash", "ordered-float", "rand", + "rayon", ] [[package]] @@ -66,6 +140,15 @@ version = "0.2.80" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4d58d1b70b004888f764dfbf6a26a3b0342a1632d33968e4a179d8011c760614" +[[package]] +name = "memoffset" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "157b4208e3059a8f9e78d559edc658e13df41410cb3ae03979c83130067fdd87" +dependencies = [ + "autocfg", +] + [[package]] name = "num-traits" version = "0.2.14" @@ -75,6 +158,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "ordered-float" version = "2.0.0" @@ -141,6 +234,37 @@ dependencies = [ "rand_core", ] +[[package]] +name = "rayon" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b0d8e0819fadc20c74ea8373106ead0600e3a67ef1fe8da56e39b9ae7275674" +dependencies = [ + "autocfg", + "crossbeam-deque", + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ab346ac5921dc62ffa9f89b7a773907511cdfa5490c572ae9be1be33e8afa4a" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-utils", + "lazy_static", + "num_cpus", +] + +[[package]] +name = "scopeguard" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" + [[package]] name = "version_check" version = "0.9.2" diff --git a/Cargo.toml b/Cargo.toml index fd4d6d0..9f31fcf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,4 +7,5 @@ edition = "2018" [dependencies] ahash = "0.6.1" rand = { version = "0.7.3", features = ["small_rng"] } +rayon = "1.5" ordered-float = "2.0" diff --git a/src/lib.rs b/src/lib.rs index 2126955..514fad9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ -use std::cmp::{max, Ordering}; -use std::ops::{Index, IndexMut}; +use std::cmp::{max, min, Ordering}; +use std::hash::Hash; +use std::ops::Index; use ahash::AHashSet as HashSet; use ordered_float::OrderedFloat; @@ -7,153 +8,159 @@ use rand::rngs::SmallRng; use rand::{RngCore, SeedableRng}; pub struct Hnsw

{ - ef_construction: usize, + ef_search: usize, points: Vec

, zero: Vec, layers: Vec>, - rng: SmallRng, } impl

Hnsw

where - P: Point, + P: Point + std::fmt::Debug, { - pub fn new(ef_construction: usize) -> Self { - Self { - ef_construction, - points: Vec::new(), - zero: Vec::new(), - layers: Vec::new(), - rng: SmallRng::from_entropy(), - } - } - - /// Insert a point into the `Hnsw`, returning a `PointId` - /// - /// `PointId` implements `Hash`, `Eq` and friends, so it can be linked to some value. - pub fn insert(&mut self, point: P, search: &mut Search) -> PointId { - let layer = self.rng.next_u32() as f32 / u32::MAX as f32; - let layer = LayerId((-(layer.ln() * (M as f32).ln())).floor() as usize); - self.insert_at(point, layer, search) - } - - /// Deterministic implementation of insertion that takes the `layer` as an argument - /// - /// Implements the paper's algorithm 1, although there is a slight difference in that - /// new elements are always inserted from their selected layer, rather than delaying the - /// addition of new layers until after the selection of a particular layer. - fn insert_at(&mut self, point: P, layer: LayerId, search: &mut Search) -> PointId { - let empty = self.points.is_empty(); - let pid = PointId(self.points.len()); - self.points.push(point); - - let top = LayerId(self.layers.len()); - if layer > top { - self.layers.resize_with(layer.0, Default::default); + pub fn new(points: &[P], ef_construction: usize, ef_search: usize) -> (Self, Vec) { + if points.is_empty() { + return ( + Self { + ef_search, + zero: Vec::new(), + points: Vec::new(), + layers: Vec::new(), + }, + Vec::new(), + ); } - search.reset(1, top); - for cur in max(top, layer).descend() { - search.num = if cur <= layer { - self.ef_construction - } else { - 1 - }; + // Give all points a random layer and sort the list of nodes by descending order for + // construction. This allows us to copy higher layers to lower layers as construction + // progresses, while preserving randomness in each point's layer and insertion order. - // If this layer already existed, search it for the 1 nearest neighbor - // (this roughly corresponds to the first loop in the paper's algorithm 1). - if cur <= top { - debug_assert_eq!(search.layer, cur); + let mut rng = SmallRng::from_entropy(); + let mut nodes = (0..points.len()) + .map(|i| (LayerId::random(&mut rng), i)) + .collect::>(); + nodes.sort_unstable_by(|l, r| r.cmp(&l)); - // At the first layer that already existed, insert the first element as an initial - // candidate. Because the zero-th layer always exists, also check if it was empty. - if cur == top && !empty { - search.push(NodeId(0), &self[pid], self); + // Sort the original `points` in layer order. + // TODO: maybe optimize this? https://crates.io/crates/permutation + + let mut new_points = Vec::with_capacity(points.len()); + let mut new_nodes = Vec::with_capacity(points.len()); + let mut out = vec![PointId::invalid(); points.len()]; + for (i, &(layer, idx)) in nodes.iter().enumerate() { + let pid = PointId(i); + new_points.push(points[idx].clone()); + new_nodes.push((layer, pid)); + out[idx] = pid; + } + let (points, nodes) = (new_points, new_nodes); + + // The layer from the first node is our top layer, or the zero layer if we have no nodes. + + let top = match nodes.first() { + Some((top, _)) => *top, + None => LayerId(0), + }; + + // Figure out how many nodes will go on each layer. This helps us allocate memory capacity + // for each layer in advance, and also helps enable batch insertion of points. + + let mut sizes = vec![0; top.0 + 1]; + for (layer, _) in nodes.iter().copied() { + sizes[layer.0] += 1; + } + + let mut start = 0; + let mut ranges = Vec::with_capacity(top.0); + for (i, size) in sizes.into_iter().enumerate().rev() { + // Skip the first point, since we insert the enter point separately + ranges.push((LayerId(i), max(start, 1)..start + size)); + start += size; + } + + // Insert the first point so that we have an enter point to start searches with. + + let mut layers = vec![vec![]; top.0]; + let mut zero = Vec::with_capacity(points.len()); + zero.push(ZeroNode::default()); + + let mut search = Search::default(); + for (layer, range) in ranges { + let num = if layer.0 > 0 { M } else { M * 2 }; + for &(_, pid) in &nodes[range] { + search.reset(); + let point = &points[pid]; + search.push(PointId(0), &points[pid], &points); + + for cur in top.descend() { + search.num = if cur <= layer { ef_construction } else { 1 }; + zero.search(point, &mut search, &points, num); + match cur > layer { + true => search.cull(), + false => break, + } } - self.search_layer(cur, pid, search); - // If we're still above the layer to insert at, we're going to skip the - // insertion code below and continue to the next iteration. Before we do so, - // we update the `Search` so it's ready for the next layer coming up. - if cur > layer { - search.lower(self); - } + zero.insert_node(pid, &search.nearest, &points); } - // If we're above the layer to start inserting links at, skip the rest of this loop. - if cur > layer { - continue; - } - - if cur.is_zero() { - let nid = NodeId(self.zero.len()); - let mut node = ZeroNode { - nearest: Default::default(), - }; - self.link(cur, (nid, &mut node.nearest), &search.nearest); - self.zero.push(node); - } else { - let nid = NodeId(self.layers[cur.0 - 1].len()); - let lower = match cur.0 == 1 { - false => NodeId(self.layers[cur.0 - 2].len()), - true => NodeId(self.zero.len()), - }; - - let mut node = UpperNode { - pid, - lower, - nearest: Default::default(), - }; - - self.link(cur, (nid, &mut node.nearest), &search.nearest); - self.layers[cur.0 - 1].push(node); - } - - if search.layer == cur && !cur.is_zero() { - search.lower(self); + // For layers above the zero layer, make a copy of the current state of the zero layer + // with `nearest` truncated to `M` elements. + if layer.0 > 0 { + let mut upper = Vec::with_capacity(zero.len()); + upper.extend(zero.iter().map(|zero| { + let mut upper = UpperNode::default(); + upper.nearest.copy_from_slice(&zero.nearest[..M]); + upper + })); + layers[layer.0 - 1] = upper; } } - pid + ( + Self { + ef_search, + zero, + points, + layers, + }, + out, + ) } - /// Bidirectionally insert links between newly detected neighbors + /// Search the index for the points nearest to the reference point `point` /// - /// `layer` is the layer we're at; `new` contains the `NodeId` for the new `Node` (which has - /// not yet been added to the layer) and its still-empty list of nearest neighbors; `found` is - /// a slice containing the `Candidate`s found during searching (ordered from near to far). + /// The results are returned in the `out` parameter; the number of neighbors to search for + /// is limited by the size of the `out` parameter, and the number of results found is returned + /// in the return value. /// - /// This just defers to the `Layer`'s `link()` implementation, which specializes on layer type. - fn link(&mut self, layer: LayerId, new: (NodeId, &mut [Option]), found: &[Candidate]) { - match layer.0 { - 0 => self.zero.link(new, found, &self.points), - l => self.layers[l - 1].link(new, found, &self.points), + /// `PointId` values can be initialized with `PointId::invalid()`. + pub fn search(&self, point: &P, out: &mut [PointId], search: &mut Search) -> usize { + if self.points.is_empty() { + return 0; } - } - /// Search the given `layer` for neighbors closed to the point identified by `pid` - /// - /// This implements the outer loop of algorithm 2 from the paper, deferring the state mutation - /// in the inner loop to the `Search::push()` implementation. - fn search_layer(&self, layer: LayerId, pid: PointId, search: &mut Search) { - debug_assert_eq!(search.layer, layer); - let point = &self[pid]; - while let Some(candidate) = search.candidates.pop() { - if let Some(found) = search.nearest.last() { - if candidate.distance > found.distance { - break; - } + search.reset(); + search.push(PointId(0), point, &self.points); + for cur in LayerId(self.layers.len()).descend() { + search.num = if cur.is_zero() { self.ef_search } else { 1 }; + + let num = if cur.0 > 0 { M } else { M * 2 }; + match cur.0 { + 0 => self.zero.search(point, search, &self.points, num), + l => self.layers[l - 1].search(point, search, &self.points, num), } - let iter = match layer.0 { - 0 => self.zero[candidate.nid].nearest_iter(), - l => self.layers[l - 1][candidate.nid].nearest_iter(), - }; - - for nid in iter { - search.push(nid, point, self); + if !cur.is_zero() { + search.cull(); } } + + let found = min(search.nearest.len(), out.len()); + for (i, candidate) in search.nearest.iter().take(found).enumerate() { + out[i] = candidate.pid; + } + found } } @@ -163,42 +170,33 @@ where /// initialized by using `push()` to add the initial enter points. pub struct Search { /// Nodes visited so far (`v` in the paper) - visited: HashSet, + visited: HashSet, /// Candidates for further inspection (`C` in the paper) candidates: Vec, /// Nearest neighbors found so far (`W` in the paper) nearest: Vec, /// Maximum number of nearest neighbors to retain (`ef` in the paper) num: usize, - /// Current layer - layer: LayerId, } impl Search { /// Resets the state to be ready for a new search - fn reset(&mut self, num: usize, layer: LayerId) { + fn reset(&mut self) { self.visited.clear(); self.candidates.clear(); self.nearest.clear(); - self.num = num; - self.layer = layer; } - /// Track node `nid` as a potential new neighbor for the given `point` + /// Track node `pid` as a potential new neighbor for the given `point` /// /// Will immediately return if the node has been considered before. This implements /// the inner loop from the paper's algorithm 2. - fn push(&mut self, nid: NodeId, point: &P, hnsw: &Hnsw

) { - if !self.visited.insert(nid) { + fn push(&mut self, pid: PointId, point: &P, points: &[P]) { + if !self.visited.insert(pid) { return; } - let pid = match self.layer.0 { - 0 => hnsw.zero.pid(nid), - l => hnsw.layers[l - 1].pid(nid), - }; - - let other = &hnsw[pid]; + let other = &points[pid]; let distance = OrderedFloat::from(point.distance(other)); if self.nearest.len() >= self.num { if let Some(found) = self.nearest.last() { @@ -212,7 +210,7 @@ impl Search { self.nearest.pop(); } - let new = Candidate { distance, nid }; + let new = Candidate { distance, pid }; let idx = self.candidates.binary_search(&new).unwrap_or_else(|e| e); self.candidates.insert(idx, new); @@ -222,25 +220,14 @@ impl Search { /// Lower the search to the next lower level /// - /// Resets `visited`, `candidates` to match `nearest`. - /// - /// Panics if called while the `Search` is at level 0. - fn lower(&mut self, hnsw: &Hnsw

) { - debug_assert!(!self.layer.is_zero()); - + /// Re-initialize the `Search`: `nearest`, the output `W` from the last round, now becomes + /// the set of enter points, which we use to initialize both `candidates` and `visited`. + fn cull(&mut self) { self.nearest.truncate(self.num); // Limit size of the set of nearest neighbors - let old = hnsw.layers[self.layer.0 - 1].nodes(); - for cur in self.nearest.iter_mut() { - cur.nid = old[cur.nid].lower; - } - - // Re-initialize the `Search`: `nearest`, the output `W` from the last round, now becomes - // the set of enter points, which we use to initialize both `candidates` and `visited`. - self.layer = self.layer.lower(); self.candidates.clear(); self.candidates.extend(&self.nearest); self.visited.clear(); - self.visited.extend(self.nearest.iter().map(|c| c.nid)); + self.visited.extend(self.nearest.iter().map(|c| c.pid)); } } @@ -250,7 +237,6 @@ impl Default for Search { visited: HashSet::new(), candidates: Vec::new(), nearest: Vec::new(), - layer: LayerId(0), num: 1, } } @@ -264,7 +250,7 @@ impl

Index for Hnsw

{ } } -impl Index for [P] { +impl Index for Vec

{ type Output = P; fn index(&self, index: PointId) -> &Self::Output { @@ -272,46 +258,10 @@ impl Index for [P] { } } -impl Index for Vec { - type Output = UpperNode; +impl Index for [P] { + type Output = P; - fn index(&self, index: NodeId) -> &Self::Output { - &self[index.0] - } -} - -impl IndexMut for Vec { - fn index_mut(&mut self, index: NodeId) -> &mut Self::Output { - &mut self[index.0] - } -} - -impl Index for [UpperNode] { - type Output = UpperNode; - - fn index(&self, index: NodeId) -> &Self::Output { - &self[index.0] - } -} - -impl Index for Vec { - type Output = ZeroNode; - - fn index(&self, index: NodeId) -> &Self::Output { - &self[index.0] - } -} - -impl IndexMut for Vec { - fn index_mut(&mut self, index: NodeId) -> &mut Self::Output { - &mut self[index.0] - } -} - -impl Index for [ZeroNode] { - type Output = ZeroNode; - - fn index(&self, index: NodeId) -> &Self::Output { + fn index(&self, index: PointId) -> &Self::Output { &self[index.0] } } @@ -321,17 +271,17 @@ impl Layer for Vec { type Node = ZeroNode; - fn pid(&self, nid: NodeId) -> PointId { - PointId(nid.0) - } - - fn nodes(&self) -> &[Self::Node] { - self + fn push(&mut self, new: ZeroNode) { + self.push(new); } fn nodes_mut(&mut self) -> &mut [Self::Node] { self } + + fn nodes(&self) -> &[Self::Node] { + self + } } impl Layer for Vec { @@ -339,17 +289,17 @@ impl Layer for Vec { type Node = UpperNode; - fn pid(&self, nid: NodeId) -> PointId { - self.nodes()[nid].pid - } - - fn nodes(&self) -> &[Self::Node] { - self + fn push(&mut self, new: UpperNode) { + self.push(new); } fn nodes_mut(&mut self) -> &mut [Self::Node] { self } + + fn nodes(&self) -> &[Self::Node] { + self + } } trait Layer { @@ -357,41 +307,56 @@ trait Layer { type Node: Node; - fn pid(&self, nid: NodeId) -> PointId; - - fn nodes(&self) -> &[Self::Node]; - - fn nodes_mut(&mut self) -> &mut [Self::Node]; - - /// Bidirectionally insert links between newly detected neighbors + /// Search this layer for nodes near the given `point` /// - /// `new` contains the `NodeId` for the new `Node` (which has not yet been added to the layer) - /// and its still-empty list of nearest neighbors; `found` is a slice containing all + /// This contains the loops from the paper's algorithm 2. `point` represents `q`, the query + /// element; `search.candidates` contains the enter points `ep`. `points` contains all the + /// points, which is required to calculate distances between two points. + /// + /// The `num` argument represents the number of links from each candidate to consider. This + /// function may be called for a higher layer (with M links per node) or the zero layer (with + /// M * 2 links per node), but for performance reasons we often call this function on the data + /// representation matching the zero layer even when we're referring to a higher layer. In that + /// case, we use `num` to constrain the number of per-candidate links we consider for search. + fn search(&self, point: &P, search: &mut Search, points: &[P], num: usize) { + while let Some(candidate) = search.candidates.pop() { + if let Some(found) = search.nearest.last() { + if candidate.distance > found.distance { + break; + } + } + + for pid in self.nodes()[candidate.pid.0].nearest_iter().take(num) { + search.push(pid, point, points); + } + } + } + + /// Insert new node in this layer + /// + /// `new` contains the `PointId` for the new node; `found` is a slice containing all /// `Candidate`s found during searching (ordered from near to far). /// - /// Initializes both the new node's neighbors (in `new.1`) and updates the nearest neighbors + /// Creates the new node, initializing its `nearest` array and updates the nearest neighbors /// for the new node's neighbors if necessary. - fn link( - &mut self, - new: (NodeId, &mut [Option]), - found: &[Candidate], - points: &[P], - ) { + fn insert_node(&mut self, new: PointId, found: &[Candidate], points: &[P]) { + let mut node = Self::Node::default(); + let new_nearest = node.nearest_mut(); + // Just make sure the candidates are all unique debug_assert_eq!( found.len(), - found.iter().map(|c| c.nid).collect::>().len() + found.iter().map(|c| c.pid).collect::>().len() ); // Only use the `Self::LINKS` nearest candidates found for (i, candidate) in found.iter().take(Self::LINKS).enumerate() { // `candidate` here is the new node's neighbor - let &Candidate { distance, nid } = candidate; - new.1[i] = Some(nid); // Update the new node's `nearest` + let &Candidate { distance, pid } = candidate; + new_nearest[i] = Some(pid); // Update the new node's `nearest` - let pid = self.pid(nid); - let old = &points[pid.0]; - let nearest = self.nodes()[nid.0].nearest(); + let old = &points[pid]; + let nearest = self.nodes()[pid.0].nearest(); // Find the correct index to insert at to keep the neighbor's neighbors sorted let idx = nearest @@ -403,8 +368,7 @@ trait Layer { None => return Ordering::Greater, }; - let pid = self.pid(third); - let third_distance = OrderedFloat::from(old.distance(&points[pid.0])); + let third_distance = OrderedFloat::from(old.distance(&points[third.0])); distance.cmp(&third_distance) }) .unwrap_or_else(|e| e); @@ -415,39 +379,41 @@ trait Layer { continue; } - let nearest = self.nodes_mut()[nid.0].nearest_mut(); + let nearest = self.nodes_mut()[pid.0].nearest_mut(); if nearest[idx].is_none() { - nearest[idx] = Some(new.0); + nearest[idx] = Some(new); continue; } let end = Self::LINKS - 1; nearest.copy_within(idx..end, idx + 1); - nearest[idx] = Some(new.0); + nearest[idx] = Some(new); } + + self.push(node); } + + fn push(&mut self, new: Self::Node); + + fn nodes_mut(&mut self) -> &mut [Self::Node]; + + fn nodes(&self) -> &[Self::Node]; } -#[derive(Debug)] +#[derive(Clone, Copy, Debug, Default)] struct UpperNode { - /// This node's point - pid: PointId, - /// The point's node on the next level down - /// - /// This is only used when lowering the search. - lower: NodeId, /// The nearest neighbors on this layer /// /// This is always kept in sorted order (near to far). - nearest: [Option; M], + nearest: [Option; M], } impl Node for UpperNode { - fn nearest(&self) -> &[Option] { + fn nearest(&self) -> &[Option] { &self.nearest } - fn nearest_mut(&mut self) -> &mut [Option] { + fn nearest_mut(&mut self) -> &mut [Option] { &mut self.nearest } @@ -458,20 +424,20 @@ impl Node for UpperNode { } } -#[derive(Debug)] +#[derive(Clone, Copy, Debug, Default)] struct ZeroNode { /// The nearest neighbors on this layer /// /// This is always kept in sorted order (near to far). - nearest: [Option; M * 2], + nearest: [Option; M * 2], } impl Node for ZeroNode { - fn nearest(&self) -> &[Option] { + fn nearest(&self) -> &[Option] { &self.nearest } - fn nearest_mut(&mut self) -> &mut [Option] { + fn nearest_mut(&mut self) -> &mut [Option] { &mut self.nearest } @@ -482,18 +448,18 @@ impl Node for ZeroNode { } } -trait Node { - fn nearest(&self) -> &[Option]; - fn nearest_mut(&mut self) -> &mut [Option]; +trait Node: Default { + fn nearest(&self) -> &[Option]; + fn nearest_mut(&mut self) -> &mut [Option]; fn nearest_iter(&self) -> NearestIter<'_>; } struct NearestIter<'a> { - nearest: &'a [Option], + nearest: &'a [Option], } impl<'a> Iterator for NearestIter<'a> { - type Item = NodeId; + type Item = PointId; fn next(&mut self) -> Option { let (&first, rest) = self.nearest.split_first()?; @@ -509,11 +475,9 @@ impl<'a> Iterator for NearestIter<'a> { struct LayerId(usize); impl LayerId { - /// Return a `LayerId` for the layer one lower - /// - /// Panics when called for `LayerId(0)`. - fn lower(&self) -> LayerId { - LayerId(self.0 - 1) + fn random(rng: &mut SmallRng) -> Self { + let layer = rng.next_u32() as f32 / u32::MAX as f32; + LayerId((-(layer.ln() * (M as f32).ln())).floor() as usize) } fn descend(&self) -> DescendingLayerIter { @@ -546,14 +510,14 @@ impl Iterator for DescendingLayerIter { } } -pub trait Point { +pub trait Point: Clone { fn distance(&self, other: &Self) -> f32; } #[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)] struct Candidate { distance: OrderedFloat, - nid: NodeId, + pid: PointId, } /// References a node in a particular layer (usually the same layer) @@ -566,6 +530,12 @@ struct NodeId(usize); #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct PointId(usize); +impl PointId { + pub fn invalid() -> Self { + PointId(usize::MAX) + } +} + /// The parameter `M` from the paper /// /// This should become a generic argument to `Hnsw` when possible. @@ -576,15 +546,27 @@ mod tests { use super::*; #[test] - fn test_insertion() { + fn basic() { + let (hnsw, pids) = Hnsw::new( + &[ + Point(0.1, 0.4), + Point(-0.324, 0.543), + Point(0.87, -0.33), + Point(0.452, 0.932), + ], + 100, + 100, + ); + let mut search = Search::default(); - let mut hnsw = Hnsw::new(100); - hnsw.insert(Point(0.1, 0.4), &mut search); - hnsw.insert(Point(-0.324, 0.543), &mut search); - hnsw.insert(Point(0.87, -0.33), &mut search); - hnsw.insert(Point(0.452, 0.932), &mut search); + let mut results = vec![PointId::invalid()]; + let p = Point(0.1, 0.35); + let found = hnsw.search(&p, &mut results, &mut search); + assert_eq!(found, 1); + assert_eq!(&results, &[pids[0]]); } + #[derive(Clone, Copy, Debug)] struct Point(f32, f32); impl super::Point for Point {