Improve parallelization with fine-grained locks
This commit is contained in:
parent
f388fd0a46
commit
58bb0f315a
|
@ -9,6 +9,7 @@ edition = "2018"
|
|||
indicatif = { version = "0.15", optional = true }
|
||||
num_cpus = "1.13"
|
||||
ordered-float = "2.0"
|
||||
parking_lot = "0.11"
|
||||
rand = { version = "0.8", features = ["small_rng"] }
|
||||
rayon = "1.5"
|
||||
serde = { version = "1.0.118", features = ["derive"], optional = true }
|
||||
|
|
165
src/lib.rs
165
src/lib.rs
|
@ -1,13 +1,16 @@
|
|||
use std::cmp::{max, min, Ordering, Reverse};
|
||||
use std::cmp::{max, Ordering, Reverse};
|
||||
use std::collections::BinaryHeap;
|
||||
use std::collections::HashSet;
|
||||
#[cfg(feature = "indicatif")]
|
||||
use std::sync::atomic::{self, AtomicUsize};
|
||||
|
||||
#[cfg(feature = "indicatif")]
|
||||
use indicatif::ProgressBar;
|
||||
use ordered_float::OrderedFloat;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::SeedableRng;
|
||||
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
|
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator};
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
@ -125,6 +128,7 @@ where
|
|||
let ef_search = builder.ef_search.unwrap_or(100);
|
||||
let ef_construction = builder.ef_construction.unwrap_or(100);
|
||||
let ml = builder.ml.unwrap_or_else(|| (M as f32).ln());
|
||||
let heuristic = builder.heuristic;
|
||||
let mut rng = match builder.seed {
|
||||
Some(seed) => SmallRng::seed_from_u64(seed),
|
||||
None => SmallRng::from_entropy(),
|
||||
|
@ -136,6 +140,7 @@ where
|
|||
if let Some(bar) = &progress {
|
||||
bar.set_draw_delta(1_000);
|
||||
bar.set_length(points.len() as u64);
|
||||
bar.set_message("Build index (preparation)");
|
||||
}
|
||||
|
||||
if points.is_empty() {
|
||||
|
@ -220,88 +225,67 @@ where
|
|||
// Insert the first point so that we have an enter point to start searches with.
|
||||
|
||||
let mut layers = vec![vec![]; top.0];
|
||||
let mut zero = Vec::with_capacity(points.len());
|
||||
zero.push(ZeroNode::default());
|
||||
let zero = points
|
||||
.iter()
|
||||
.map(|_| RwLock::new(ZeroNode::default()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut insertion = Search {
|
||||
ef: ef_construction,
|
||||
visited: Visited::with_capacity(points.len()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut pool = SearchPool {
|
||||
pool: Vec::new(),
|
||||
len: points.len(),
|
||||
};
|
||||
|
||||
let mut batch = Vec::new();
|
||||
let mut done = Vec::new();
|
||||
let max_batch_len = num_cpus::get() * 4;
|
||||
for (layer, mut range) in ranges {
|
||||
let pool = SearchPool::new(points.len());
|
||||
#[cfg(feature = "indicatif")]
|
||||
let done = AtomicUsize::new(0);
|
||||
for (layer, range) in ranges {
|
||||
let num = if layer.0 > 0 { M } else { M * 2 };
|
||||
#[cfg(feature = "indicatif")]
|
||||
if let Some(bar) = &progress {
|
||||
bar.set_message(&format!("Building index (layer {})", layer.0));
|
||||
}
|
||||
|
||||
while range.start < range.end {
|
||||
let len = min(range.len(), max_batch_len);
|
||||
batch.clear();
|
||||
batch.extend(
|
||||
nodes[range.start..(range.start + len)]
|
||||
.iter()
|
||||
.map(|&(_, pid)| (pid, pool.pop())),
|
||||
);
|
||||
nodes[range].into_par_iter().for_each(|(_, pid)| {
|
||||
let (mut search, mut insertion) = pool.pop();
|
||||
let point = &points.as_slice()[*pid];
|
||||
search.reset();
|
||||
search.push(PointId(0), point, &points);
|
||||
|
||||
batch.par_iter_mut().for_each(|(pid, search)| {
|
||||
let point = &points[*pid];
|
||||
search.push(PointId(0), point, &points);
|
||||
for cur in top.descend() {
|
||||
search.ef = if cur <= layer { ef_construction } else { 1 };
|
||||
match cur > layer {
|
||||
true => {
|
||||
search.search(point, layers[cur.0 - 1].as_slice(), &points, num);
|
||||
search.cull();
|
||||
}
|
||||
false => {
|
||||
search.search(point, zero.as_slice(), &points, num);
|
||||
break;
|
||||
}
|
||||
for cur in top.descend() {
|
||||
search.ef = if cur <= layer { ef_construction } else { 1 };
|
||||
match cur > layer {
|
||||
true => {
|
||||
search.search(point, layers[cur.0 - 1].as_slice(), &points, num);
|
||||
search.cull();
|
||||
}
|
||||
false => {
|
||||
search.search(point, zero.as_slice(), &points, num);
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
done.clear();
|
||||
for (pid, mut search) in batch.drain(..) {
|
||||
for added in done.iter().copied() {
|
||||
search.push(added, &points[pid], &points);
|
||||
}
|
||||
|
||||
insert(
|
||||
pid,
|
||||
&mut insertion,
|
||||
&mut search,
|
||||
&mut zero,
|
||||
&points,
|
||||
&builder.heuristic,
|
||||
);
|
||||
done.push(pid);
|
||||
pool.push(search);
|
||||
}
|
||||
|
||||
insertion.ef = ef_construction;
|
||||
insert(
|
||||
*pid,
|
||||
&mut insertion,
|
||||
&mut search,
|
||||
&zero,
|
||||
&points,
|
||||
&heuristic,
|
||||
);
|
||||
|
||||
#[cfg(feature = "indicatif")]
|
||||
if let Some(bar) = &progress {
|
||||
bar.inc(done.len() as u64);
|
||||
let value = done.fetch_add(1, atomic::Ordering::Relaxed);
|
||||
if value % 1000 == 0 {
|
||||
bar.set_position(value as u64);
|
||||
}
|
||||
}
|
||||
|
||||
range.start += len;
|
||||
}
|
||||
pool.push((search, insertion));
|
||||
});
|
||||
|
||||
// For layers above the zero layer, make a copy of the current state of the zero layer
|
||||
// with `nearest` truncated to `M` elements.
|
||||
if layer.0 > 0 {
|
||||
let mut upper = Vec::with_capacity(zero.len());
|
||||
upper.extend(zero.iter().map(UpperNode::from_zero));
|
||||
upper.extend(zero.iter().map(|zero| UpperNode::from_zero(&zero.read())));
|
||||
layers[layer.0 - 1] = upper;
|
||||
}
|
||||
}
|
||||
|
@ -314,7 +298,7 @@ where
|
|||
(
|
||||
Self {
|
||||
ef_search,
|
||||
zero,
|
||||
zero: zero.into_iter().map(|node| node.into_inner()).collect(),
|
||||
points,
|
||||
layers,
|
||||
},
|
||||
|
@ -382,14 +366,14 @@ fn insert<P: Point>(
|
|||
new: PointId,
|
||||
insertion: &mut Search,
|
||||
search: &mut Search,
|
||||
layer: &mut Vec<ZeroNode>,
|
||||
layer: &[RwLock<ZeroNode>],
|
||||
points: &[P],
|
||||
heuristic: &Option<Heuristic>,
|
||||
) {
|
||||
layer.push(ZeroNode::default());
|
||||
let mut node = layer[new].write();
|
||||
let found = match heuristic {
|
||||
None => search.select_simple(M * 2),
|
||||
Some(heuristic) => search.select_heuristic(&points[new], &layer, points, *heuristic),
|
||||
Some(heuristic) => search.select_heuristic(&points[new], layer, points, *heuristic),
|
||||
};
|
||||
|
||||
// Just make sure the candidates are all unique
|
||||
|
@ -404,19 +388,22 @@ fn insert<P: Point>(
|
|||
if let Some(heuristic) = heuristic {
|
||||
let found = insertion.add_neighbor_heuristic(
|
||||
new,
|
||||
layer.as_slice().nearest_iter(pid),
|
||||
layer.nearest_iter(pid),
|
||||
layer,
|
||||
&points[pid],
|
||||
points,
|
||||
*heuristic,
|
||||
);
|
||||
|
||||
layer[pid].rewrite(found.iter().map(|candidate| candidate.pid));
|
||||
layer[new].set(i, pid);
|
||||
layer[pid]
|
||||
.write()
|
||||
.rewrite(found.iter().map(|candidate| candidate.pid));
|
||||
node.set(i, pid);
|
||||
} else {
|
||||
// Find the correct index to insert at to keep the neighbor's neighbors sorted
|
||||
let old = &points[pid];
|
||||
let idx = layer[pid]
|
||||
.read()
|
||||
.binary_search_by(|third| {
|
||||
// `third` here is one of the neighbors of the new node's neighbor.
|
||||
let third = match third {
|
||||
|
@ -429,30 +416,34 @@ fn insert<P: Point>(
|
|||
})
|
||||
.unwrap_or_else(|e| e);
|
||||
|
||||
layer[pid].insert(idx, new);
|
||||
layer[new].set(i, pid);
|
||||
layer[pid].write().insert(idx, new);
|
||||
node.set(i, pid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct SearchPool {
|
||||
pool: Vec<Search>,
|
||||
pool: Mutex<Vec<(Search, Search)>>,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
impl SearchPool {
|
||||
fn pop(&mut self) -> Search {
|
||||
match self.pool.pop() {
|
||||
Some(mut search) => {
|
||||
search.reset();
|
||||
search
|
||||
}
|
||||
None => Search::new(self.len),
|
||||
fn new(len: usize) -> Self {
|
||||
Self {
|
||||
pool: Mutex::new(Vec::new()),
|
||||
len,
|
||||
}
|
||||
}
|
||||
|
||||
fn push(&mut self, search: Search) {
|
||||
self.pool.push(search);
|
||||
fn pop(&self) -> (Search, Search) {
|
||||
match self.pool.lock().pop() {
|
||||
Some(res) => res,
|
||||
None => (Search::new(self.len), Search::new(self.len)),
|
||||
}
|
||||
}
|
||||
|
||||
fn push(&self, item: (Search, Search)) {
|
||||
self.pool.lock().push(item);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -516,11 +507,11 @@ impl Search {
|
|||
}
|
||||
}
|
||||
|
||||
fn add_neighbor_heuristic<P: Point>(
|
||||
fn add_neighbor_heuristic<L: Layer, P: Point>(
|
||||
&mut self,
|
||||
new: PointId,
|
||||
current: impl Iterator<Item = PointId>,
|
||||
layer: &[ZeroNode],
|
||||
layer: L,
|
||||
point: &P,
|
||||
points: &[P],
|
||||
params: Heuristic,
|
||||
|
@ -530,16 +521,16 @@ impl Search {
|
|||
for pid in current {
|
||||
self.push(pid, point, points);
|
||||
}
|
||||
self.select_heuristic(point, &layer, points, params)
|
||||
self.select_heuristic(point, layer, points, params)
|
||||
}
|
||||
|
||||
/// Heuristically sort and truncate neighbors in `self.nearest`
|
||||
///
|
||||
/// Invariant: `self.nearest` must be in sorted (nearest first) order.
|
||||
fn select_heuristic<P: Point>(
|
||||
fn select_heuristic<L: Layer, P: Point>(
|
||||
&mut self,
|
||||
point: &P,
|
||||
layer: &[ZeroNode],
|
||||
layer: L,
|
||||
points: &[P],
|
||||
params: Heuristic,
|
||||
) -> &[Candidate] {
|
||||
|
|
88
src/types.rs
88
src/types.rs
|
@ -1,7 +1,8 @@
|
|||
use std::hash::Hash;
|
||||
use std::ops::{Deref, Index, IndexMut};
|
||||
use std::ops::{Deref, Index};
|
||||
|
||||
use ordered_float::OrderedFloat;
|
||||
use parking_lot::{MappedRwLockReadGuard, RwLock, RwLockReadGuard};
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
|
@ -69,11 +70,11 @@ impl UpperNode {
|
|||
}
|
||||
}
|
||||
|
||||
impl Layer for &[UpperNode] {
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
||||
NearestIter {
|
||||
nearest: &self[pid.0 as usize].0,
|
||||
}
|
||||
impl<'a> Layer for &'a [UpperNode] {
|
||||
type Slice = &'a [PointId];
|
||||
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<Self::Slice> {
|
||||
NearestIter::new(&self[pid.0 as usize].0)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -128,32 +129,63 @@ impl Deref for ZeroNode {
|
|||
}
|
||||
}
|
||||
|
||||
impl Layer for &[ZeroNode] {
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
||||
NearestIter {
|
||||
nearest: &self[pid.0 as usize].0,
|
||||
}
|
||||
impl<'a> Layer for &'a [ZeroNode] {
|
||||
type Slice = &'a [PointId];
|
||||
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<Self::Slice> {
|
||||
NearestIter::new(&self[pid.0 as usize])
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Layer for &'a [RwLock<ZeroNode>] {
|
||||
type Slice = MappedRwLockReadGuard<'a, [PointId]>;
|
||||
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<Self::Slice> {
|
||||
NearestIter::new(RwLockReadGuard::map(
|
||||
self[pid.0 as usize].read(),
|
||||
Deref::deref,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait Layer {
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>;
|
||||
type Slice: Deref<Target = [PointId]>;
|
||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<Self::Slice>;
|
||||
}
|
||||
|
||||
pub(crate) struct NearestIter<'a> {
|
||||
nearest: &'a [PointId],
|
||||
pub(crate) struct NearestIter<T> {
|
||||
node: T,
|
||||
cur: usize,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for NearestIter<'a> {
|
||||
impl<T> NearestIter<T>
|
||||
where
|
||||
T: Deref<Target = [PointId]>,
|
||||
{
|
||||
fn new(node: T) -> Self {
|
||||
Self { node, cur: 0 }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Iterator for NearestIter<T>
|
||||
where
|
||||
T: Deref<Target = [PointId]>,
|
||||
{
|
||||
type Item = PointId;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let (&first, rest) = self.nearest.split_first()?;
|
||||
if !first.is_valid() {
|
||||
if self.cur >= self.node.len() {
|
||||
return None;
|
||||
}
|
||||
self.nearest = rest;
|
||||
Some(first)
|
||||
|
||||
let item = self.node[self.cur];
|
||||
if !item.is_valid() {
|
||||
self.cur = self.node.len();
|
||||
return None;
|
||||
}
|
||||
|
||||
self.cur += 1;
|
||||
Some(item)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -230,14 +262,6 @@ impl<P> Index<PointId> for Hnsw<P> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<P: Point> Index<PointId> for Vec<P> {
|
||||
type Output = P;
|
||||
|
||||
fn index(&self, index: PointId) -> &Self::Output {
|
||||
&self[index.0 as usize]
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: Point> Index<PointId> for [P] {
|
||||
type Output = P;
|
||||
|
||||
|
@ -246,18 +270,12 @@ impl<P: Point> Index<PointId> for [P] {
|
|||
}
|
||||
}
|
||||
|
||||
impl Index<PointId> for Vec<ZeroNode> {
|
||||
type Output = ZeroNode;
|
||||
impl Index<PointId> for [RwLock<ZeroNode>] {
|
||||
type Output = RwLock<ZeroNode>;
|
||||
|
||||
fn index(&self, index: PointId) -> &Self::Output {
|
||||
&self[index.0 as usize]
|
||||
}
|
||||
}
|
||||
|
||||
impl IndexMut<PointId> for Vec<ZeroNode> {
|
||||
fn index_mut(&mut self, index: PointId) -> &mut Self::Output {
|
||||
&mut self[index.0 as usize]
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) const INVALID: PointId = PointId(u32::MAX);
|
||||
|
|
Loading…
Reference in New Issue