Use specialized set implementation for faster tracking of visited points

This commit is contained in:
Dirkjan Ochtman 2021-01-20 15:04:23 +01:00
parent 7b84aa8d45
commit f66bad132e
4 changed files with 69 additions and 9 deletions

View File

@ -6,7 +6,6 @@ authors = ["Dirkjan Ochtman <dirkjan@ochtman.nl>"]
edition = "2018" edition = "2018"
[dependencies] [dependencies]
ahash = "0.6.1"
indicatif = { version = "0.15", optional = true } indicatif = { version = "0.15", optional = true }
num_cpus = "1.13" num_cpus = "1.13"
ordered-float = "2.0" ordered-float = "2.0"

View File

@ -10,7 +10,7 @@ benchmark_group!(benches, build_heuristic);
fn build_heuristic(bench: &mut Bencher) { fn build_heuristic(bench: &mut Bencher) {
let seed = ThreadRng::default().gen::<u64>(); let seed = ThreadRng::default().gen::<u64>();
let mut rng = StdRng::seed_from_u64(seed); let mut rng = StdRng::seed_from_u64(seed);
let points = (0..1024) let points = (0..16384)
.into_iter() .into_iter()
.map(|_| Point(rng.gen(), rng.gen())) .map(|_| Point(rng.gen(), rng.gen()))
.collect::<Vec<_>>(); .collect::<Vec<_>>();

View File

@ -1,7 +1,7 @@
use std::cmp::{max, min, Ordering, Reverse}; use std::cmp::{max, min, Ordering, Reverse};
use std::collections::BinaryHeap; use std::collections::BinaryHeap;
use std::collections::HashSet;
use ahash::AHashSet as HashSet;
#[cfg(feature = "indicatif")] #[cfg(feature = "indicatif")]
use indicatif::ProgressBar; use indicatif::ProgressBar;
use ordered_float::OrderedFloat; use ordered_float::OrderedFloat;
@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
mod types; mod types;
pub use types::PointId; pub use types::PointId;
use types::{Candidate, LayerId, NearestIter, UpperNode, ZeroNode}; use types::{Candidate, LayerId, NearestIter, UpperNode, Visited, ZeroNode};
/// Parameters for building the `Hnsw` /// Parameters for building the `Hnsw`
pub struct Builder { pub struct Builder {
@ -225,10 +225,15 @@ where
let mut insertion = Search { let mut insertion = Search {
ef: ef_construction, ef: ef_construction,
visited: Visited::with_capacity(points.len()),
..Default::default() ..Default::default()
}; };
let mut pool = SearchPool::default(); let mut pool = SearchPool {
pool: Vec::new(),
len: points.len(),
};
let mut batch = Vec::new(); let mut batch = Vec::new();
let mut done = Vec::new(); let mut done = Vec::new();
let max_batch_len = num_cpus::get() * 4; let max_batch_len = num_cpus::get() * 4;
@ -331,6 +336,7 @@ where
return 0; return 0;
} }
search.visited.reserve_capacity(self.points.len());
search.reset(); search.reset();
search.push(PointId(0), point, &self.points); search.push(PointId(0), point, &self.points);
for cur in LayerId(self.layers.len()).descend() { for cur in LayerId(self.layers.len()).descend() {
@ -456,9 +462,9 @@ fn insert<P: Point>(
} }
} }
#[derive(Default)]
struct SearchPool { struct SearchPool {
pool: Vec<Search>, pool: Vec<Search>,
len: usize,
} }
impl SearchPool { impl SearchPool {
@ -468,7 +474,7 @@ impl SearchPool {
search.reset(); search.reset();
search search
} }
None => Search::default(), None => Search::new(self.len),
} }
} }
@ -515,7 +521,7 @@ trait Layer {
/// initialized by using `push()` to add the initial enter points. /// initialized by using `push()` to add the initial enter points.
pub struct Search { pub struct Search {
/// Nodes visited so far (`v` in the paper) /// Nodes visited so far (`v` in the paper)
visited: HashSet<PointId>, visited: Visited,
/// Candidates for further inspection (`C` in the paper) /// Candidates for further inspection (`C` in the paper)
candidates: BinaryHeap<Reverse<Candidate>>, candidates: BinaryHeap<Reverse<Candidate>>,
/// Nearest neighbors found so far (`W` in the paper) /// Nearest neighbors found so far (`W` in the paper)
@ -528,6 +534,13 @@ pub struct Search {
} }
impl Search { impl Search {
fn new(capacity: usize) -> Self {
Self {
visited: Visited::with_capacity(capacity),
..Default::default()
}
}
/// Resets the state to be ready for a new search /// Resets the state to be ready for a new search
fn reset(&mut self) { fn reset(&mut self) {
let Search { let Search {
@ -663,7 +676,7 @@ impl Search {
impl Default for Search { impl Default for Search {
fn default() -> Self { fn default() -> Self {
Self { Self {
visited: HashSet::new(), visited: Visited::with_capacity(0),
candidates: BinaryHeap::new(), candidates: BinaryHeap::new(),
nearest: Vec::new(), nearest: Vec::new(),
working: Vec::new(), working: Vec::new(),

View File

@ -9,6 +9,54 @@ use serde::{Deserialize, Serialize};
use crate::{Hnsw, Layer, Point, M}; use crate::{Hnsw, Layer, Point, M};
pub(crate) struct Visited {
store: Vec<u8>,
generation: u8,
}
impl Visited {
pub(crate) fn with_capacity(capacity: usize) -> Self {
Self {
store: vec![0; capacity],
generation: 1,
}
}
pub(crate) fn reserve_capacity(&mut self, capacity: usize) {
if self.store.len() != capacity {
self.store.resize(capacity, self.generation - 1);
}
}
pub(crate) fn insert(&mut self, pid: PointId) -> bool {
let slot = &mut self.store[pid.0 as usize];
if *slot != self.generation {
*slot = self.generation;
true
} else {
false
}
}
pub(crate) fn extend(&mut self, iter: impl Iterator<Item = PointId>) {
for pid in iter {
self.insert(pid);
}
}
pub(crate) fn clear(&mut self) {
if self.generation < 249 {
self.generation += 1;
return;
}
let len = self.store.len();
self.store.clear();
self.store.resize(len, 0);
self.generation = 1;
}
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Clone, Copy, Debug, Default)] #[derive(Clone, Copy, Debug, Default)]
pub(crate) struct UpperNode { pub(crate) struct UpperNode {