Implement parallel index construction
This commit is contained in:
parent
e1f6b666c9
commit
9d959fce06
|
@ -8,6 +8,7 @@ edition = "2018"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
ahash = "0.6.1"
|
ahash = "0.6.1"
|
||||||
indicatif = { version = "0.15", optional = true }
|
indicatif = { version = "0.15", optional = true }
|
||||||
|
num_cpus = "1.13"
|
||||||
ordered-float = "2.0"
|
ordered-float = "2.0"
|
||||||
rand = { version = "0.7.3", features = ["small_rng"] }
|
rand = { version = "0.7.3", features = ["small_rng"] }
|
||||||
rayon = "1.5"
|
rayon = "1.5"
|
||||||
|
|
91
src/lib.rs
91
src/lib.rs
|
@ -8,6 +8,7 @@ use indicatif::ProgressBar;
|
||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
use rand::rngs::SmallRng;
|
use rand::rngs::SmallRng;
|
||||||
use rand::{RngCore, SeedableRng};
|
use rand::{RngCore, SeedableRng};
|
||||||
|
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
@ -32,6 +33,11 @@ where
|
||||||
let ef_construction = builder.ef_construction.unwrap_or(100);
|
let ef_construction = builder.ef_construction.unwrap_or(100);
|
||||||
#[cfg(feature = "indicatif")]
|
#[cfg(feature = "indicatif")]
|
||||||
let progress = builder.progress;
|
let progress = builder.progress;
|
||||||
|
#[cfg(feature = "indicatif")]
|
||||||
|
if let Some(bar) = &progress {
|
||||||
|
bar.set_draw_delta(1_000);
|
||||||
|
bar.set_length(points.len() as u64);
|
||||||
|
}
|
||||||
|
|
||||||
if points.is_empty() {
|
if points.is_empty() {
|
||||||
return (
|
return (
|
||||||
|
@ -99,31 +105,51 @@ where
|
||||||
let mut zero = Vec::with_capacity(points.len());
|
let mut zero = Vec::with_capacity(points.len());
|
||||||
zero.push(ZeroNode::default());
|
zero.push(ZeroNode::default());
|
||||||
|
|
||||||
let mut search = Search::default();
|
let mut pool = SearchPool::default();
|
||||||
for (layer, range) in ranges {
|
let mut batch = Vec::new();
|
||||||
|
let mut done = Vec::new();
|
||||||
|
let max_batch_len = num_cpus::get();
|
||||||
|
for (layer, mut range) in ranges {
|
||||||
let num = if layer.0 > 0 { M } else { M * 2 };
|
let num = if layer.0 > 0 { M } else { M * 2 };
|
||||||
for &(_, pid) in &nodes[range] {
|
|
||||||
|
while range.start < range.end {
|
||||||
|
let len = min(range.len(), max_batch_len);
|
||||||
|
batch.clear();
|
||||||
|
batch.extend(
|
||||||
|
nodes[range.start..(range.start + len)]
|
||||||
|
.iter()
|
||||||
|
.map(|&(_, pid)| (pid, pool.pop())),
|
||||||
|
);
|
||||||
|
|
||||||
|
batch.par_iter_mut().for_each(|(pid, search)| {
|
||||||
|
let point = &points[*pid];
|
||||||
|
search.push(PointId(0), point, &points);
|
||||||
|
for cur in top.descend() {
|
||||||
|
search.num = if cur <= layer { ef_construction } else { 1 };
|
||||||
|
zero.search(point, search, &points, num);
|
||||||
|
match cur > layer {
|
||||||
|
true => search.cull(),
|
||||||
|
false => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
done.clear();
|
||||||
|
for (pid, mut search) in batch.drain(..) {
|
||||||
|
for added in done.iter().copied() {
|
||||||
|
search.push(added, &points[pid], &points);
|
||||||
|
}
|
||||||
|
zero.insert_node(pid, &search.nearest, &points);
|
||||||
|
done.push(pid);
|
||||||
|
pool.push(search);
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(feature = "indicatif")]
|
#[cfg(feature = "indicatif")]
|
||||||
if pid.0 % 10_000 == 0 {
|
if let Some(bar) = &progress {
|
||||||
if let Some(bar) = &progress {
|
bar.inc(done.len() as u64);
|
||||||
bar.set_position(pid.0 as u64);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
search.reset();
|
range.start += len;
|
||||||
let point = &points[pid];
|
|
||||||
search.push(PointId(0), &points[pid], &points);
|
|
||||||
|
|
||||||
for cur in top.descend() {
|
|
||||||
search.num = if cur <= layer { ef_construction } else { 1 };
|
|
||||||
zero.search(point, &mut search, &points, num);
|
|
||||||
match cur > layer {
|
|
||||||
true => search.cull(),
|
|
||||||
false => break,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
zero.insert_node(pid, &search.nearest, &points);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For layers above the zero layer, make a copy of the current state of the zero layer
|
// For layers above the zero layer, make a copy of the current state of the zero layer
|
||||||
|
@ -200,6 +226,27 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
struct SearchPool {
|
||||||
|
pool: Vec<Search>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SearchPool {
|
||||||
|
fn pop(&mut self) -> Search {
|
||||||
|
match self.pool.pop() {
|
||||||
|
Some(mut search) => {
|
||||||
|
search.reset();
|
||||||
|
search
|
||||||
|
}
|
||||||
|
None => Search::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn push(&mut self, search: Search) {
|
||||||
|
self.pool.push(search);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Keeps mutable state for searching a point's nearest neighbors
|
/// Keeps mutable state for searching a point's nearest neighbors
|
||||||
///
|
///
|
||||||
/// In particular, this contains most of the state used in algorithm 2. The structure is
|
/// In particular, this contains most of the state used in algorithm 2. The structure is
|
||||||
|
@ -555,7 +602,7 @@ impl Iterator for DescendingLayerIter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Point: Clone {
|
pub trait Point: Clone + Sync {
|
||||||
fn distance(&self, other: &Self) -> f32;
|
fn distance(&self, other: &Self) -> f32;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue