Implement parallel index construction

This commit is contained in:
Dirkjan Ochtman 2020-12-16 13:37:39 +01:00
parent e1f6b666c9
commit 9d959fce06
2 changed files with 70 additions and 22 deletions

View File

@ -8,6 +8,7 @@ edition = "2018"
[dependencies] [dependencies]
ahash = "0.6.1" ahash = "0.6.1"
indicatif = { version = "0.15", optional = true } indicatif = { version = "0.15", optional = true }
num_cpus = "1.13"
ordered-float = "2.0" ordered-float = "2.0"
rand = { version = "0.7.3", features = ["small_rng"] } rand = { version = "0.7.3", features = ["small_rng"] }
rayon = "1.5" rayon = "1.5"

View File

@ -8,6 +8,7 @@ use indicatif::ProgressBar;
use ordered_float::OrderedFloat; use ordered_float::OrderedFloat;
use rand::rngs::SmallRng; use rand::rngs::SmallRng;
use rand::{RngCore, SeedableRng}; use rand::{RngCore, SeedableRng};
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -32,6 +33,11 @@ where
let ef_construction = builder.ef_construction.unwrap_or(100); let ef_construction = builder.ef_construction.unwrap_or(100);
#[cfg(feature = "indicatif")] #[cfg(feature = "indicatif")]
let progress = builder.progress; let progress = builder.progress;
#[cfg(feature = "indicatif")]
if let Some(bar) = &progress {
bar.set_draw_delta(1_000);
bar.set_length(points.len() as u64);
}
if points.is_empty() { if points.is_empty() {
return ( return (
@ -99,31 +105,51 @@ where
let mut zero = Vec::with_capacity(points.len()); let mut zero = Vec::with_capacity(points.len());
zero.push(ZeroNode::default()); zero.push(ZeroNode::default());
let mut search = Search::default(); let mut pool = SearchPool::default();
for (layer, range) in ranges { let mut batch = Vec::new();
let mut done = Vec::new();
let max_batch_len = num_cpus::get();
for (layer, mut range) in ranges {
let num = if layer.0 > 0 { M } else { M * 2 }; let num = if layer.0 > 0 { M } else { M * 2 };
for &(_, pid) in &nodes[range] {
while range.start < range.end {
let len = min(range.len(), max_batch_len);
batch.clear();
batch.extend(
nodes[range.start..(range.start + len)]
.iter()
.map(|&(_, pid)| (pid, pool.pop())),
);
batch.par_iter_mut().for_each(|(pid, search)| {
let point = &points[*pid];
search.push(PointId(0), point, &points);
for cur in top.descend() {
search.num = if cur <= layer { ef_construction } else { 1 };
zero.search(point, search, &points, num);
match cur > layer {
true => search.cull(),
false => break,
}
}
});
done.clear();
for (pid, mut search) in batch.drain(..) {
for added in done.iter().copied() {
search.push(added, &points[pid], &points);
}
zero.insert_node(pid, &search.nearest, &points);
done.push(pid);
pool.push(search);
}
#[cfg(feature = "indicatif")] #[cfg(feature = "indicatif")]
if pid.0 % 10_000 == 0 { if let Some(bar) = &progress {
if let Some(bar) = &progress { bar.inc(done.len() as u64);
bar.set_position(pid.0 as u64);
}
} }
search.reset(); range.start += len;
let point = &points[pid];
search.push(PointId(0), &points[pid], &points);
for cur in top.descend() {
search.num = if cur <= layer { ef_construction } else { 1 };
zero.search(point, &mut search, &points, num);
match cur > layer {
true => search.cull(),
false => break,
}
}
zero.insert_node(pid, &search.nearest, &points);
} }
// For layers above the zero layer, make a copy of the current state of the zero layer // For layers above the zero layer, make a copy of the current state of the zero layer
@ -200,6 +226,27 @@ where
} }
} }
#[derive(Default)]
struct SearchPool {
pool: Vec<Search>,
}
impl SearchPool {
fn pop(&mut self) -> Search {
match self.pool.pop() {
Some(mut search) => {
search.reset();
search
}
None => Search::default(),
}
}
fn push(&mut self, search: Search) {
self.pool.push(search);
}
}
/// Keeps mutable state for searching a point's nearest neighbors /// Keeps mutable state for searching a point's nearest neighbors
/// ///
/// In particular, this contains most of the state used in algorithm 2. The structure is /// In particular, this contains most of the state used in algorithm 2. The structure is
@ -555,7 +602,7 @@ impl Iterator for DescendingLayerIter {
} }
} }
pub trait Point: Clone { pub trait Point: Clone + Sync {
fn distance(&self, other: &Self) -> f32; fn distance(&self, other: &Self) -> f32;
} }