Use specialized set implementation for faster tracking of visited points
This commit is contained in:
parent
7b84aa8d45
commit
f66bad132e
|
@ -6,7 +6,6 @@ authors = ["Dirkjan Ochtman <dirkjan@ochtman.nl>"]
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
ahash = "0.6.1"
|
|
||||||
indicatif = { version = "0.15", optional = true }
|
indicatif = { version = "0.15", optional = true }
|
||||||
num_cpus = "1.13"
|
num_cpus = "1.13"
|
||||||
ordered-float = "2.0"
|
ordered-float = "2.0"
|
||||||
|
|
|
@ -10,7 +10,7 @@ benchmark_group!(benches, build_heuristic);
|
||||||
fn build_heuristic(bench: &mut Bencher) {
|
fn build_heuristic(bench: &mut Bencher) {
|
||||||
let seed = ThreadRng::default().gen::<u64>();
|
let seed = ThreadRng::default().gen::<u64>();
|
||||||
let mut rng = StdRng::seed_from_u64(seed);
|
let mut rng = StdRng::seed_from_u64(seed);
|
||||||
let points = (0..1024)
|
let points = (0..16384)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|_| Point(rng.gen(), rng.gen()))
|
.map(|_| Point(rng.gen(), rng.gen()))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
27
src/lib.rs
27
src/lib.rs
|
@ -1,7 +1,7 @@
|
||||||
use std::cmp::{max, min, Ordering, Reverse};
|
use std::cmp::{max, min, Ordering, Reverse};
|
||||||
use std::collections::BinaryHeap;
|
use std::collections::BinaryHeap;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use ahash::AHashSet as HashSet;
|
|
||||||
#[cfg(feature = "indicatif")]
|
#[cfg(feature = "indicatif")]
|
||||||
use indicatif::ProgressBar;
|
use indicatif::ProgressBar;
|
||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
|
@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
mod types;
|
mod types;
|
||||||
pub use types::PointId;
|
pub use types::PointId;
|
||||||
use types::{Candidate, LayerId, NearestIter, UpperNode, ZeroNode};
|
use types::{Candidate, LayerId, NearestIter, UpperNode, Visited, ZeroNode};
|
||||||
|
|
||||||
/// Parameters for building the `Hnsw`
|
/// Parameters for building the `Hnsw`
|
||||||
pub struct Builder {
|
pub struct Builder {
|
||||||
|
@ -225,10 +225,15 @@ where
|
||||||
|
|
||||||
let mut insertion = Search {
|
let mut insertion = Search {
|
||||||
ef: ef_construction,
|
ef: ef_construction,
|
||||||
|
visited: Visited::with_capacity(points.len()),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut pool = SearchPool::default();
|
let mut pool = SearchPool {
|
||||||
|
pool: Vec::new(),
|
||||||
|
len: points.len(),
|
||||||
|
};
|
||||||
|
|
||||||
let mut batch = Vec::new();
|
let mut batch = Vec::new();
|
||||||
let mut done = Vec::new();
|
let mut done = Vec::new();
|
||||||
let max_batch_len = num_cpus::get() * 4;
|
let max_batch_len = num_cpus::get() * 4;
|
||||||
|
@ -331,6 +336,7 @@ where
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
search.visited.reserve_capacity(self.points.len());
|
||||||
search.reset();
|
search.reset();
|
||||||
search.push(PointId(0), point, &self.points);
|
search.push(PointId(0), point, &self.points);
|
||||||
for cur in LayerId(self.layers.len()).descend() {
|
for cur in LayerId(self.layers.len()).descend() {
|
||||||
|
@ -456,9 +462,9 @@ fn insert<P: Point>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
struct SearchPool {
|
struct SearchPool {
|
||||||
pool: Vec<Search>,
|
pool: Vec<Search>,
|
||||||
|
len: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SearchPool {
|
impl SearchPool {
|
||||||
|
@ -468,7 +474,7 @@ impl SearchPool {
|
||||||
search.reset();
|
search.reset();
|
||||||
search
|
search
|
||||||
}
|
}
|
||||||
None => Search::default(),
|
None => Search::new(self.len),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -515,7 +521,7 @@ trait Layer {
|
||||||
/// initialized by using `push()` to add the initial enter points.
|
/// initialized by using `push()` to add the initial enter points.
|
||||||
pub struct Search {
|
pub struct Search {
|
||||||
/// Nodes visited so far (`v` in the paper)
|
/// Nodes visited so far (`v` in the paper)
|
||||||
visited: HashSet<PointId>,
|
visited: Visited,
|
||||||
/// Candidates for further inspection (`C` in the paper)
|
/// Candidates for further inspection (`C` in the paper)
|
||||||
candidates: BinaryHeap<Reverse<Candidate>>,
|
candidates: BinaryHeap<Reverse<Candidate>>,
|
||||||
/// Nearest neighbors found so far (`W` in the paper)
|
/// Nearest neighbors found so far (`W` in the paper)
|
||||||
|
@ -528,6 +534,13 @@ pub struct Search {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Search {
|
impl Search {
|
||||||
|
fn new(capacity: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
visited: Visited::with_capacity(capacity),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Resets the state to be ready for a new search
|
/// Resets the state to be ready for a new search
|
||||||
fn reset(&mut self) {
|
fn reset(&mut self) {
|
||||||
let Search {
|
let Search {
|
||||||
|
@ -663,7 +676,7 @@ impl Search {
|
||||||
impl Default for Search {
|
impl Default for Search {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
visited: HashSet::new(),
|
visited: Visited::with_capacity(0),
|
||||||
candidates: BinaryHeap::new(),
|
candidates: BinaryHeap::new(),
|
||||||
nearest: Vec::new(),
|
nearest: Vec::new(),
|
||||||
working: Vec::new(),
|
working: Vec::new(),
|
||||||
|
|
48
src/types.rs
48
src/types.rs
|
@ -9,6 +9,54 @@ use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::{Hnsw, Layer, Point, M};
|
use crate::{Hnsw, Layer, Point, M};
|
||||||
|
|
||||||
|
pub(crate) struct Visited {
|
||||||
|
store: Vec<u8>,
|
||||||
|
generation: u8,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Visited {
|
||||||
|
pub(crate) fn with_capacity(capacity: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
store: vec![0; capacity],
|
||||||
|
generation: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn reserve_capacity(&mut self, capacity: usize) {
|
||||||
|
if self.store.len() != capacity {
|
||||||
|
self.store.resize(capacity, self.generation - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn insert(&mut self, pid: PointId) -> bool {
|
||||||
|
let slot = &mut self.store[pid.0 as usize];
|
||||||
|
if *slot != self.generation {
|
||||||
|
*slot = self.generation;
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn extend(&mut self, iter: impl Iterator<Item = PointId>) {
|
||||||
|
for pid in iter {
|
||||||
|
self.insert(pid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn clear(&mut self) {
|
||||||
|
if self.generation < 249 {
|
||||||
|
self.generation += 1;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let len = self.store.len();
|
||||||
|
self.store.clear();
|
||||||
|
self.store.resize(len, 0);
|
||||||
|
self.generation = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||||
#[derive(Clone, Copy, Debug, Default)]
|
#[derive(Clone, Copy, Debug, Default)]
|
||||||
pub(crate) struct UpperNode {
|
pub(crate) struct UpperNode {
|
||||||
|
|
Loading…
Reference in New Issue