This commit is contained in:
Dirkjan Ochtman 2021-05-18 10:21:52 +02:00
parent a2b1b7b726
commit 85d462da96
1 changed files with 166 additions and 134 deletions

View File

@ -1,6 +1,7 @@
use std::cmp::{max, Ordering, Reverse};
use std::collections::BinaryHeap;
use std::collections::HashSet;
use std::ops::Range;
#[cfg(feature = "indicatif")]
use std::sync::atomic::{self, AtomicUsize};
@ -280,75 +281,19 @@ where
.map(|_| RwLock::new(ZeroNode::default()))
.collect::<Vec<_>>();
let pool = SearchPool::new(points.len());
#[cfg(feature = "indicatif")]
let done = AtomicUsize::new(0);
for (layer, range) in ranges {
let num = if layer.is_zero() { M * 2 } else { M };
Construction {
zero: zero.as_slice(),
pool: SearchPool::new(points.len()),
top,
points: &points,
heuristic,
ef_construction,
#[cfg(feature = "indicatif")]
if let Some(bar) = &progress {
bar.set_message(format!("Building index (layer {})", layer.0));
}
let end = range.end;
nodes[range].into_iter().for_each(|(_, pid)| {
let node = zero.as_slice()[*pid].write();
let (mut search, mut insertion) = pool.pop();
let point = &points.as_slice()[*pid];
search.reset();
search.push(PointId(0), point, &points);
for cur in top.descend() {
search.ef = if cur <= layer { ef_construction } else { 1 };
match cur > layer {
true => {
search.search(point, layers[cur.0 - 1].as_slice(), &points, num);
search.cull();
}
false => {
search.search(point, zero.as_slice(), &points, num);
break;
}
}
}
insertion.ef = ef_construction;
insert(
*pid,
node,
&mut insertion,
&mut search,
&zero,
&points,
&heuristic,
);
#[cfg(feature = "indicatif")]
if let Some(bar) = &progress {
let value = done.fetch_add(1, atomic::Ordering::Relaxed);
if value % 1000 == 0 {
bar.set_position(value as u64);
}
}
pool.push((search, insertion));
});
// For layers above the zero layer, make a copy of the current state of the zero layer
// with `nearest` truncated to `M` elements.
if !layer.is_zero() {
let upper = (&zero[..end])
.into_iter()
.map(|zero| UpperNode::from_zero(&zero.read()))
.collect();
layers[layer.0 - 1] = upper;
}
}
#[cfg(feature = "indicatif")]
if let Some(bar) = progress {
bar.finish();
progress,
#[cfg(feature = "indicatif")]
done: AtomicUsize::new(0),
}
.build(nodes, ranges.into_iter(), &mut layers);
(
Self {
@ -407,77 +352,164 @@ where
}
}
/// Insert new node in the zero layer
///
/// * `new`: the `PointId` for the new node
/// * `insertion`: a `Search` for shrinking a neighbor set (only used with heuristic neighbor selection)
/// * `search`: the result for searching potential neighbors for the new node
/// * `layer` contains all the nodes at the current layer
/// * `points` is a slice of all the points in the index
///
/// Creates the new node, initializing its `nearest` array and updates the nearest neighbors
/// for the new node's neighbors if necessary before appending the new node to the layer.
fn insert<P: Point>(
new: PointId,
mut node: parking_lot::RwLockWriteGuard<ZeroNode>,
insertion: &mut Search,
search: &mut Search,
layer: &[RwLock<ZeroNode>],
points: &[P],
heuristic: &Option<Heuristic>,
) {
let found = match heuristic {
None => {
let candidates = search.select_simple();
&candidates[..Ord::min(candidates.len(), M * 2)]
struct Construction<'a, P: Point> {
zero: &'a [RwLock<ZeroNode>],
pool: SearchPool,
top: LayerId,
points: &'a [P],
heuristic: Option<Heuristic>,
ef_construction: usize,
#[cfg(feature = "indicatif")]
progress: Option<ProgressBar>,
#[cfg(feature = "indicatif")]
done: AtomicUsize,
}
impl<'a, P: Point> Construction<'a, P> {
fn build(
&self,
nodes: Vec<(LayerId, PointId)>,
ranges: impl Iterator<Item = (LayerId, Range<usize>)>,
layers: &mut [Vec<UpperNode>],
) {
for (layer, range) in ranges {
#[cfg(feature = "indicatif")]
if let Some(bar) = &self.progress {
bar.set_message(format!("Building index (layer {})", layer.0));
}
let end = range.end;
nodes[range].into_iter().for_each(|(_, pid)| {
let node = self.zero[*pid].write();
self.insert(*pid, node, layer, &layers);
});
// For layers above the zero layer, make a copy of the current state of the zero layer
// with `nearest` truncated to `M` elements.
if !layer.is_zero() {
let upper = (&self.zero[..end])
.into_iter()
.map(|zero| UpperNode::from_zero(&zero.read()))
.collect();
layers[layer.0 - 1] = upper;
}
}
Some(heuristic) => search.select_heuristic(&points[new], layer, points, *heuristic),
};
// Just make sure the candidates are all unique
debug_assert_eq!(
found.len(),
found.iter().map(|c| c.pid).collect::<HashSet<_>>().len()
);
for (i, candidate) in found.iter().enumerate() {
// `candidate` here is the new node's neighbor
let &Candidate { distance, pid } = candidate;
if let Some(heuristic) = heuristic {
let found = insertion.add_neighbor_heuristic(
new,
layer.nearest_iter(pid),
layer,
&points[pid],
points,
*heuristic,
);
layer[pid]
.write()
.rewrite(found.iter().map(|candidate| candidate.pid));
node.set(i, pid);
} else {
// Find the correct index to insert at to keep the neighbor's neighbors sorted
let old = &points[pid];
let idx = layer[pid]
.read()
.binary_search_by(|third| {
// `third` here is one of the neighbors of the new node's neighbor.
let third = match third {
pid if pid.is_valid() => *pid,
// if `third` is `None`, our new `node` is always "closer"
_ => return Ordering::Greater,
};
distance.cmp(&old.distance(&points[third]).into())
})
.unwrap_or_else(|e| e);
layer[pid].write().insert(idx, new);
node.set(i, pid);
#[cfg(feature = "indicatif")]
if let Some(bar) = &self.progress {
bar.finish();
}
}
/// Insert new node in the zero layer
///
/// * `new`: the `PointId` for the new node
/// * `insertion`: a `Search` for shrinking a neighbor set (only used with heuristic neighbor selection)
/// * `search`: the result for searching potential neighbors for the new node
/// * `layer` contains all the nodes at the current layer
/// * `points` is a slice of all the points in the index
///
/// Creates the new node, initializing its `nearest` array and updates the nearest neighbors
/// for the new node's neighbors if necessary before appending the new node to the layer.
#[inline(always)]
fn insert(
&self,
new: PointId,
mut node: parking_lot::RwLockWriteGuard<ZeroNode>,
layer: LayerId,
layers: &[Vec<UpperNode>],
) {
let (mut search, mut insertion) = self.pool.pop();
insertion.ef = self.ef_construction;
let point = &self.points[new];
search.reset();
search.push(PointId(0), point, &self.points);
let num = if layer.is_zero() { M * 2 } else { M };
for cur in self.top.descend() {
search.ef = if cur <= layer {
self.ef_construction
} else {
1
};
match cur > layer {
true => {
search.search(point, layers[cur.0 - 1].as_slice(), &self.points, num);
search.cull();
}
false => {
search.search(point, self.zero, &self.points, num);
break;
}
}
}
let found = match self.heuristic {
None => {
let candidates = search.select_simple();
&candidates[..Ord::min(candidates.len(), M * 2)]
}
Some(heuristic) => {
search.select_heuristic(&self.points[new], self.zero, self.points, heuristic)
}
};
// Just make sure the candidates are all unique
debug_assert_eq!(
found.len(),
found.iter().map(|c| c.pid).collect::<HashSet<_>>().len()
);
for (i, candidate) in found.iter().enumerate() {
// `candidate` here is the new node's neighbor
let &Candidate { distance, pid } = candidate;
if let Some(heuristic) = self.heuristic {
let found = insertion.add_neighbor_heuristic(
new,
self.zero.nearest_iter(pid),
self.zero,
&self.points[pid],
self.points,
heuristic,
);
self.zero[pid]
.write()
.rewrite(found.iter().map(|candidate| candidate.pid));
node.set(i, pid);
} else {
// Find the correct index to insert at to keep the neighbor's neighbors sorted
let old = &self.points[pid];
let idx = self.zero[pid]
.read()
.binary_search_by(|third| {
// `third` here is one of the neighbors of the new node's neighbor.
let third = match third {
pid if pid.is_valid() => *pid,
// if `third` is `None`, our new `node` is always "closer"
_ => return Ordering::Greater,
};
distance.cmp(&old.distance(&self.points[third]).into())
})
.unwrap_or_else(|e| e);
self.zero[pid].write().insert(idx, new);
node.set(i, pid);
}
}
#[cfg(feature = "indicatif")]
if let Some(bar) = &self.progress {
let value = self.done.fetch_add(1, atomic::Ordering::Relaxed);
if value % 1000 == 0 {
bar.set_position(value as u64);
}
}
self.pool.push((search, insertion));
}
}
struct SearchPool {