Implement parallel index construction
This commit is contained in:
parent
e1f6b666c9
commit
9d959fce06
|
@ -8,6 +8,7 @@ edition = "2018"
|
|||
[dependencies]
|
||||
ahash = "0.6.1"
|
||||
indicatif = { version = "0.15", optional = true }
|
||||
num_cpus = "1.13"
|
||||
ordered-float = "2.0"
|
||||
rand = { version = "0.7.3", features = ["small_rng"] }
|
||||
rayon = "1.5"
|
||||
|
|
91
src/lib.rs
91
src/lib.rs
|
@ -8,6 +8,7 @@ use indicatif::ProgressBar;
|
|||
use ordered_float::OrderedFloat;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{RngCore, SeedableRng};
|
||||
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
@ -32,6 +33,11 @@ where
|
|||
let ef_construction = builder.ef_construction.unwrap_or(100);
|
||||
#[cfg(feature = "indicatif")]
|
||||
let progress = builder.progress;
|
||||
#[cfg(feature = "indicatif")]
|
||||
if let Some(bar) = &progress {
|
||||
bar.set_draw_delta(1_000);
|
||||
bar.set_length(points.len() as u64);
|
||||
}
|
||||
|
||||
if points.is_empty() {
|
||||
return (
|
||||
|
@ -99,31 +105,51 @@ where
|
|||
let mut zero = Vec::with_capacity(points.len());
|
||||
zero.push(ZeroNode::default());
|
||||
|
||||
let mut search = Search::default();
|
||||
for (layer, range) in ranges {
|
||||
let mut pool = SearchPool::default();
|
||||
let mut batch = Vec::new();
|
||||
let mut done = Vec::new();
|
||||
let max_batch_len = num_cpus::get();
|
||||
for (layer, mut range) in ranges {
|
||||
let num = if layer.0 > 0 { M } else { M * 2 };
|
||||
for &(_, pid) in &nodes[range] {
|
||||
|
||||
while range.start < range.end {
|
||||
let len = min(range.len(), max_batch_len);
|
||||
batch.clear();
|
||||
batch.extend(
|
||||
nodes[range.start..(range.start + len)]
|
||||
.iter()
|
||||
.map(|&(_, pid)| (pid, pool.pop())),
|
||||
);
|
||||
|
||||
batch.par_iter_mut().for_each(|(pid, search)| {
|
||||
let point = &points[*pid];
|
||||
search.push(PointId(0), point, &points);
|
||||
for cur in top.descend() {
|
||||
search.num = if cur <= layer { ef_construction } else { 1 };
|
||||
zero.search(point, search, &points, num);
|
||||
match cur > layer {
|
||||
true => search.cull(),
|
||||
false => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
done.clear();
|
||||
for (pid, mut search) in batch.drain(..) {
|
||||
for added in done.iter().copied() {
|
||||
search.push(added, &points[pid], &points);
|
||||
}
|
||||
zero.insert_node(pid, &search.nearest, &points);
|
||||
done.push(pid);
|
||||
pool.push(search);
|
||||
}
|
||||
|
||||
#[cfg(feature = "indicatif")]
|
||||
if pid.0 % 10_000 == 0 {
|
||||
if let Some(bar) = &progress {
|
||||
bar.set_position(pid.0 as u64);
|
||||
}
|
||||
if let Some(bar) = &progress {
|
||||
bar.inc(done.len() as u64);
|
||||
}
|
||||
|
||||
search.reset();
|
||||
let point = &points[pid];
|
||||
search.push(PointId(0), &points[pid], &points);
|
||||
|
||||
for cur in top.descend() {
|
||||
search.num = if cur <= layer { ef_construction } else { 1 };
|
||||
zero.search(point, &mut search, &points, num);
|
||||
match cur > layer {
|
||||
true => search.cull(),
|
||||
false => break,
|
||||
}
|
||||
}
|
||||
|
||||
zero.insert_node(pid, &search.nearest, &points);
|
||||
range.start += len;
|
||||
}
|
||||
|
||||
// For layers above the zero layer, make a copy of the current state of the zero layer
|
||||
|
@ -200,6 +226,27 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct SearchPool {
|
||||
pool: Vec<Search>,
|
||||
}
|
||||
|
||||
impl SearchPool {
|
||||
fn pop(&mut self) -> Search {
|
||||
match self.pool.pop() {
|
||||
Some(mut search) => {
|
||||
search.reset();
|
||||
search
|
||||
}
|
||||
None => Search::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn push(&mut self, search: Search) {
|
||||
self.pool.push(search);
|
||||
}
|
||||
}
|
||||
|
||||
/// Keeps mutable state for searching a point's nearest neighbors
|
||||
///
|
||||
/// In particular, this contains most of the state used in algorithm 2. The structure is
|
||||
|
@ -555,7 +602,7 @@ impl Iterator for DescendingLayerIter {
|
|||
}
|
||||
}
|
||||
|
||||
pub trait Point: Clone {
|
||||
pub trait Point: Clone + Sync {
|
||||
fn distance(&self, other: &Self) -> f32;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue