Use specialized set implementation for faster tracking of visited points
This commit is contained in:
parent
7b84aa8d45
commit
f66bad132e
|
@ -6,7 +6,6 @@ authors = ["Dirkjan Ochtman <dirkjan@ochtman.nl>"]
|
|||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
ahash = "0.6.1"
|
||||
indicatif = { version = "0.15", optional = true }
|
||||
num_cpus = "1.13"
|
||||
ordered-float = "2.0"
|
||||
|
|
|
@ -10,7 +10,7 @@ benchmark_group!(benches, build_heuristic);
|
|||
fn build_heuristic(bench: &mut Bencher) {
|
||||
let seed = ThreadRng::default().gen::<u64>();
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
let points = (0..1024)
|
||||
let points = (0..16384)
|
||||
.into_iter()
|
||||
.map(|_| Point(rng.gen(), rng.gen()))
|
||||
.collect::<Vec<_>>();
|
||||
|
|
27
src/lib.rs
27
src/lib.rs
|
@ -1,7 +1,7 @@
|
|||
use std::cmp::{max, min, Ordering, Reverse};
|
||||
use std::collections::BinaryHeap;
|
||||
use std::collections::HashSet;
|
||||
|
||||
use ahash::AHashSet as HashSet;
|
||||
#[cfg(feature = "indicatif")]
|
||||
use indicatif::ProgressBar;
|
||||
use ordered_float::OrderedFloat;
|
||||
|
@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
|
|||
|
||||
mod types;
|
||||
pub use types::PointId;
|
||||
use types::{Candidate, LayerId, NearestIter, UpperNode, ZeroNode};
|
||||
use types::{Candidate, LayerId, NearestIter, UpperNode, Visited, ZeroNode};
|
||||
|
||||
/// Parameters for building the `Hnsw`
|
||||
pub struct Builder {
|
||||
|
@ -225,10 +225,15 @@ where
|
|||
|
||||
let mut insertion = Search {
|
||||
ef: ef_construction,
|
||||
visited: Visited::with_capacity(points.len()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut pool = SearchPool::default();
|
||||
let mut pool = SearchPool {
|
||||
pool: Vec::new(),
|
||||
len: points.len(),
|
||||
};
|
||||
|
||||
let mut batch = Vec::new();
|
||||
let mut done = Vec::new();
|
||||
let max_batch_len = num_cpus::get() * 4;
|
||||
|
@ -331,6 +336,7 @@ where
|
|||
return 0;
|
||||
}
|
||||
|
||||
search.visited.reserve_capacity(self.points.len());
|
||||
search.reset();
|
||||
search.push(PointId(0), point, &self.points);
|
||||
for cur in LayerId(self.layers.len()).descend() {
|
||||
|
@ -456,9 +462,9 @@ fn insert<P: Point>(
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct SearchPool {
|
||||
pool: Vec<Search>,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
impl SearchPool {
|
||||
|
@ -468,7 +474,7 @@ impl SearchPool {
|
|||
search.reset();
|
||||
search
|
||||
}
|
||||
None => Search::default(),
|
||||
None => Search::new(self.len),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -515,7 +521,7 @@ trait Layer {
|
|||
/// initialized by using `push()` to add the initial enter points.
|
||||
pub struct Search {
|
||||
/// Nodes visited so far (`v` in the paper)
|
||||
visited: HashSet<PointId>,
|
||||
visited: Visited,
|
||||
/// Candidates for further inspection (`C` in the paper)
|
||||
candidates: BinaryHeap<Reverse<Candidate>>,
|
||||
/// Nearest neighbors found so far (`W` in the paper)
|
||||
|
@ -528,6 +534,13 @@ pub struct Search {
|
|||
}
|
||||
|
||||
impl Search {
|
||||
fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
visited: Visited::with_capacity(capacity),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Resets the state to be ready for a new search
|
||||
fn reset(&mut self) {
|
||||
let Search {
|
||||
|
@ -663,7 +676,7 @@ impl Search {
|
|||
impl Default for Search {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
visited: HashSet::new(),
|
||||
visited: Visited::with_capacity(0),
|
||||
candidates: BinaryHeap::new(),
|
||||
nearest: Vec::new(),
|
||||
working: Vec::new(),
|
||||
|
|
48
src/types.rs
48
src/types.rs
|
@ -9,6 +9,54 @@ use serde::{Deserialize, Serialize};
|
|||
|
||||
use crate::{Hnsw, Layer, Point, M};
|
||||
|
||||
pub(crate) struct Visited {
|
||||
store: Vec<u8>,
|
||||
generation: u8,
|
||||
}
|
||||
|
||||
impl Visited {
|
||||
pub(crate) fn with_capacity(capacity: usize) -> Self {
|
||||
Self {
|
||||
store: vec![0; capacity],
|
||||
generation: 1,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn reserve_capacity(&mut self, capacity: usize) {
|
||||
if self.store.len() != capacity {
|
||||
self.store.resize(capacity, self.generation - 1);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn insert(&mut self, pid: PointId) -> bool {
|
||||
let slot = &mut self.store[pid.0 as usize];
|
||||
if *slot != self.generation {
|
||||
*slot = self.generation;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn extend(&mut self, iter: impl Iterator<Item = PointId>) {
|
||||
for pid in iter {
|
||||
self.insert(pid);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn clear(&mut self) {
|
||||
if self.generation < 249 {
|
||||
self.generation += 1;
|
||||
return;
|
||||
}
|
||||
|
||||
let len = self.store.len();
|
||||
self.store.clear();
|
||||
self.store.resize(len, 0);
|
||||
self.generation = 1;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub(crate) struct UpperNode {
|
||||
|
|
Loading…
Reference in New Issue