Use u32 for PointIds to save memory

This commit is contained in:
Dirkjan Ochtman 2020-12-14 16:37:15 +01:00
parent db925603d0
commit 0e0ffe201e
1 changed files with 42 additions and 32 deletions

View File

@ -38,6 +38,7 @@ where
// construction. This allows us to copy higher layers to lower layers as construction
// progresses, while preserving randomness in each point's layer and insertion order.
assert!(points.len() < u32::MAX as usize);
let mut rng = SmallRng::from_entropy();
let mut nodes = (0..points.len())
.map(|i| (LayerId::random(&mut rng), i))
@ -51,7 +52,7 @@ where
let mut new_nodes = Vec::with_capacity(points.len());
let mut out = vec![PointId::invalid(); points.len()];
for (i, &(layer, idx)) in nodes.iter().enumerate() {
let pid = PointId(i);
let pid = PointId(i as u32);
new_points.push(points[idx].clone());
new_nodes.push((layer, pid));
out[idx] = pid;
@ -136,8 +137,6 @@ where
/// The results are returned in the `out` parameter; the number of neighbors to search for
/// is limited by the size of the `out` parameter, and the number of results found is returned
/// in the return value.
///
/// `PointId` values can be initialized with `PointId::invalid()`.
pub fn search(&self, point: &P, out: &mut [PointId], search: &mut Search) -> usize {
if self.points.is_empty() {
return 0;
@ -249,7 +248,7 @@ impl<P> Index<PointId> for Hnsw<P> {
type Output = P;
fn index(&self, index: PointId) -> &Self::Output {
&self.points[index.0]
&self.points[index.0 as usize]
}
}
@ -257,7 +256,7 @@ impl<P: Point> Index<PointId> for Vec<P> {
type Output = P;
fn index(&self, index: PointId) -> &Self::Output {
&self[index.0]
&self[index.0 as usize]
}
}
@ -265,7 +264,7 @@ impl<P: Point> Index<PointId> for [P] {
type Output = P;
fn index(&self, index: PointId) -> &Self::Output {
&self[index.0]
&self[index.0 as usize]
}
}
@ -329,7 +328,7 @@ trait Layer {
}
}
for pid in self.nodes()[candidate.pid.0].nearest_iter().take(num) {
for pid in self.nodes()[candidate.pid.0 as usize].nearest_iter().take(num) {
search.push(pid, point, points);
}
}
@ -356,22 +355,22 @@ trait Layer {
for (i, candidate) in found.iter().take(Self::LINKS).enumerate() {
// `candidate` here is the new node's neighbor
let &Candidate { distance, pid } = candidate;
new_nearest[i] = Some(pid); // Update the new node's `nearest`
new_nearest[i] = pid; // Update the new node's `nearest`
let old = &points[pid];
let nearest = self.nodes()[pid.0].nearest();
let nearest = self.nodes()[pid.0 as usize].nearest();
// Find the correct index to insert at to keep the neighbor's neighbors sorted
let idx = nearest
.binary_search_by(|third| {
// `third` here is one of the neighbors of the new node's neighbor.
let third = match third {
Some(nid) => *nid,
pid if pid.is_valid() => *pid,
// if `third` is `None`, our new `node` is always "closer"
None => return Ordering::Greater,
_ => return Ordering::Greater,
};
let third_distance = OrderedFloat::from(old.distance(&points[third.0]));
let third_distance = OrderedFloat::from(old.distance(&points[third.0 as usize]));
distance.cmp(&third_distance)
})
.unwrap_or_else(|e| e);
@ -382,15 +381,15 @@ trait Layer {
continue;
}
let nearest = self.nodes_mut()[pid.0].nearest_mut();
if nearest[idx].is_none() {
nearest[idx] = Some(new);
let nearest = self.nodes_mut()[pid.0 as usize].nearest_mut();
if !nearest[idx].is_valid() {
nearest[idx] = new;
continue;
}
let end = Self::LINKS - 1;
nearest.copy_within(idx..end, idx + 1);
nearest[idx] = Some(new);
nearest[idx] = new;
}
self.push(node);
@ -409,15 +408,15 @@ struct UpperNode {
/// The nearest neighbors on this layer
///
/// This is always kept in sorted order (near to far).
nearest: [Option<PointId>; M],
nearest: [PointId; M],
}
impl Node for UpperNode {
fn nearest(&self) -> &[Option<PointId>] {
fn nearest(&self) -> &[PointId] {
&self.nearest
}
fn nearest_mut(&mut self) -> &mut [Option<PointId>] {
fn nearest_mut(&mut self) -> &mut [PointId] {
&mut self.nearest
}
@ -434,15 +433,15 @@ struct ZeroNode {
/// The nearest neighbors on this layer
///
/// This is always kept in sorted order (near to far).
nearest: [Option<PointId>; M * 2],
nearest: [PointId; M * 2],
}
impl Node for ZeroNode {
fn nearest(&self) -> &[Option<PointId>] {
fn nearest(&self) -> &[PointId] {
&self.nearest
}
fn nearest_mut(&mut self) -> &mut [Option<PointId>] {
fn nearest_mut(&mut self) -> &mut [PointId] {
&mut self.nearest
}
@ -454,13 +453,13 @@ impl Node for ZeroNode {
}
trait Node: Default {
fn nearest(&self) -> &[Option<PointId>];
fn nearest_mut(&mut self) -> &mut [Option<PointId>];
fn nearest(&self) -> &[PointId];
fn nearest_mut(&mut self) -> &mut [PointId];
fn nearest_iter(&self) -> NearestIter<'_>;
}
struct NearestIter<'a> {
nearest: &'a [Option<PointId>],
nearest: &'a [PointId],
}
impl<'a> Iterator for NearestIter<'a> {
@ -468,11 +467,11 @@ impl<'a> Iterator for NearestIter<'a> {
fn next(&mut self) -> Option<Self::Item> {
let (&first, rest) = self.nearest.split_first()?;
self.nearest = rest;
if first.is_none() {
self.nearest = &[];
if !first.is_valid() {
return None;
}
first
self.nearest = rest;
Some(first)
}
}
@ -530,11 +529,22 @@ struct Candidate {
/// This can be used to index into the `Hnsw` to refer to the `Point` data.
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct PointId(usize);
pub struct PointId(u32);
impl PointId {
pub fn invalid() -> Self {
PointId(usize::MAX)
fn invalid() -> Self {
PointId(u32::MAX)
}
/// Whether this value represents a valid point
pub fn is_valid(self) -> bool {
self.0 != u32::MAX
}
}
impl Default for PointId {
fn default() -> Self {
PointId::invalid()
}
}