From f66bad132ee1234332c838e21806d9bd2b3f69bc Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Wed, 20 Jan 2021 15:04:23 +0100 Subject: [PATCH] Use specialized set implementation for faster tracking of visited points --- Cargo.toml | 1 - benches/all.rs | 2 +- src/lib.rs | 27 ++++++++++++++++++++------- src/types.rs | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 69 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c95e82e..441d9bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,6 @@ authors = ["Dirkjan Ochtman "] edition = "2018" [dependencies] -ahash = "0.6.1" indicatif = { version = "0.15", optional = true } num_cpus = "1.13" ordered-float = "2.0" diff --git a/benches/all.rs b/benches/all.rs index df88014..717a433 100644 --- a/benches/all.rs +++ b/benches/all.rs @@ -10,7 +10,7 @@ benchmark_group!(benches, build_heuristic); fn build_heuristic(bench: &mut Bencher) { let seed = ThreadRng::default().gen::(); let mut rng = StdRng::seed_from_u64(seed); - let points = (0..1024) + let points = (0..16384) .into_iter() .map(|_| Point(rng.gen(), rng.gen())) .collect::>(); diff --git a/src/lib.rs b/src/lib.rs index 2397095..d3ab441 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,7 @@ use std::cmp::{max, min, Ordering, Reverse}; use std::collections::BinaryHeap; +use std::collections::HashSet; -use ahash::AHashSet as HashSet; #[cfg(feature = "indicatif")] use indicatif::ProgressBar; use ordered_float::OrderedFloat; @@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize}; mod types; pub use types::PointId; -use types::{Candidate, LayerId, NearestIter, UpperNode, ZeroNode}; +use types::{Candidate, LayerId, NearestIter, UpperNode, Visited, ZeroNode}; /// Parameters for building the `Hnsw` pub struct Builder { @@ -225,10 +225,15 @@ where let mut insertion = Search { ef: ef_construction, + visited: Visited::with_capacity(points.len()), ..Default::default() }; - let mut pool = SearchPool::default(); + let mut pool = SearchPool { + pool: Vec::new(), + len: points.len(), + }; + let mut batch = Vec::new(); let mut done = Vec::new(); let max_batch_len = num_cpus::get() * 4; @@ -331,6 +336,7 @@ where return 0; } + search.visited.reserve_capacity(self.points.len()); search.reset(); search.push(PointId(0), point, &self.points); for cur in LayerId(self.layers.len()).descend() { @@ -456,9 +462,9 @@ fn insert( } } -#[derive(Default)] struct SearchPool { pool: Vec, + len: usize, } impl SearchPool { @@ -468,7 +474,7 @@ impl SearchPool { search.reset(); search } - None => Search::default(), + None => Search::new(self.len), } } @@ -515,7 +521,7 @@ trait Layer { /// initialized by using `push()` to add the initial enter points. pub struct Search { /// Nodes visited so far (`v` in the paper) - visited: HashSet, + visited: Visited, /// Candidates for further inspection (`C` in the paper) candidates: BinaryHeap>, /// Nearest neighbors found so far (`W` in the paper) @@ -528,6 +534,13 @@ pub struct Search { } impl Search { + fn new(capacity: usize) -> Self { + Self { + visited: Visited::with_capacity(capacity), + ..Default::default() + } + } + /// Resets the state to be ready for a new search fn reset(&mut self) { let Search { @@ -663,7 +676,7 @@ impl Search { impl Default for Search { fn default() -> Self { Self { - visited: HashSet::new(), + visited: Visited::with_capacity(0), candidates: BinaryHeap::new(), nearest: Vec::new(), working: Vec::new(), diff --git a/src/types.rs b/src/types.rs index ab42581..aeead7d 100644 --- a/src/types.rs +++ b/src/types.rs @@ -9,6 +9,54 @@ use serde::{Deserialize, Serialize}; use crate::{Hnsw, Layer, Point, M}; +pub(crate) struct Visited { + store: Vec, + generation: u8, +} + +impl Visited { + pub(crate) fn with_capacity(capacity: usize) -> Self { + Self { + store: vec![0; capacity], + generation: 1, + } + } + + pub(crate) fn reserve_capacity(&mut self, capacity: usize) { + if self.store.len() != capacity { + self.store.resize(capacity, self.generation - 1); + } + } + + pub(crate) fn insert(&mut self, pid: PointId) -> bool { + let slot = &mut self.store[pid.0 as usize]; + if *slot != self.generation { + *slot = self.generation; + true + } else { + false + } + } + + pub(crate) fn extend(&mut self, iter: impl Iterator) { + for pid in iter { + self.insert(pid); + } + } + + pub(crate) fn clear(&mut self) { + if self.generation < 249 { + self.generation += 1; + return; + } + + let len = self.store.len(); + self.store.clear(); + self.store.resize(len, 0); + self.generation = 1; + } +} + #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] #[derive(Clone, Copy, Debug, Default)] pub(crate) struct UpperNode {