Use u32 for PointIds to save memory
This commit is contained in:
parent
db925603d0
commit
0e0ffe201e
74
src/lib.rs
74
src/lib.rs
|
@ -38,6 +38,7 @@ where
|
||||||
// construction. This allows us to copy higher layers to lower layers as construction
|
// construction. This allows us to copy higher layers to lower layers as construction
|
||||||
// progresses, while preserving randomness in each point's layer and insertion order.
|
// progresses, while preserving randomness in each point's layer and insertion order.
|
||||||
|
|
||||||
|
assert!(points.len() < u32::MAX as usize);
|
||||||
let mut rng = SmallRng::from_entropy();
|
let mut rng = SmallRng::from_entropy();
|
||||||
let mut nodes = (0..points.len())
|
let mut nodes = (0..points.len())
|
||||||
.map(|i| (LayerId::random(&mut rng), i))
|
.map(|i| (LayerId::random(&mut rng), i))
|
||||||
|
@ -51,7 +52,7 @@ where
|
||||||
let mut new_nodes = Vec::with_capacity(points.len());
|
let mut new_nodes = Vec::with_capacity(points.len());
|
||||||
let mut out = vec![PointId::invalid(); points.len()];
|
let mut out = vec![PointId::invalid(); points.len()];
|
||||||
for (i, &(layer, idx)) in nodes.iter().enumerate() {
|
for (i, &(layer, idx)) in nodes.iter().enumerate() {
|
||||||
let pid = PointId(i);
|
let pid = PointId(i as u32);
|
||||||
new_points.push(points[idx].clone());
|
new_points.push(points[idx].clone());
|
||||||
new_nodes.push((layer, pid));
|
new_nodes.push((layer, pid));
|
||||||
out[idx] = pid;
|
out[idx] = pid;
|
||||||
|
@ -136,8 +137,6 @@ where
|
||||||
/// The results are returned in the `out` parameter; the number of neighbors to search for
|
/// The results are returned in the `out` parameter; the number of neighbors to search for
|
||||||
/// is limited by the size of the `out` parameter, and the number of results found is returned
|
/// is limited by the size of the `out` parameter, and the number of results found is returned
|
||||||
/// in the return value.
|
/// in the return value.
|
||||||
///
|
|
||||||
/// `PointId` values can be initialized with `PointId::invalid()`.
|
|
||||||
pub fn search(&self, point: &P, out: &mut [PointId], search: &mut Search) -> usize {
|
pub fn search(&self, point: &P, out: &mut [PointId], search: &mut Search) -> usize {
|
||||||
if self.points.is_empty() {
|
if self.points.is_empty() {
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -249,7 +248,7 @@ impl<P> Index<PointId> for Hnsw<P> {
|
||||||
type Output = P;
|
type Output = P;
|
||||||
|
|
||||||
fn index(&self, index: PointId) -> &Self::Output {
|
fn index(&self, index: PointId) -> &Self::Output {
|
||||||
&self.points[index.0]
|
&self.points[index.0 as usize]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -257,7 +256,7 @@ impl<P: Point> Index<PointId> for Vec<P> {
|
||||||
type Output = P;
|
type Output = P;
|
||||||
|
|
||||||
fn index(&self, index: PointId) -> &Self::Output {
|
fn index(&self, index: PointId) -> &Self::Output {
|
||||||
&self[index.0]
|
&self[index.0 as usize]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -265,7 +264,7 @@ impl<P: Point> Index<PointId> for [P] {
|
||||||
type Output = P;
|
type Output = P;
|
||||||
|
|
||||||
fn index(&self, index: PointId) -> &Self::Output {
|
fn index(&self, index: PointId) -> &Self::Output {
|
||||||
&self[index.0]
|
&self[index.0 as usize]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -329,7 +328,7 @@ trait Layer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for pid in self.nodes()[candidate.pid.0].nearest_iter().take(num) {
|
for pid in self.nodes()[candidate.pid.0 as usize].nearest_iter().take(num) {
|
||||||
search.push(pid, point, points);
|
search.push(pid, point, points);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -356,22 +355,22 @@ trait Layer {
|
||||||
for (i, candidate) in found.iter().take(Self::LINKS).enumerate() {
|
for (i, candidate) in found.iter().take(Self::LINKS).enumerate() {
|
||||||
// `candidate` here is the new node's neighbor
|
// `candidate` here is the new node's neighbor
|
||||||
let &Candidate { distance, pid } = candidate;
|
let &Candidate { distance, pid } = candidate;
|
||||||
new_nearest[i] = Some(pid); // Update the new node's `nearest`
|
new_nearest[i] = pid; // Update the new node's `nearest`
|
||||||
|
|
||||||
let old = &points[pid];
|
let old = &points[pid];
|
||||||
let nearest = self.nodes()[pid.0].nearest();
|
let nearest = self.nodes()[pid.0 as usize].nearest();
|
||||||
|
|
||||||
// Find the correct index to insert at to keep the neighbor's neighbors sorted
|
// Find the correct index to insert at to keep the neighbor's neighbors sorted
|
||||||
let idx = nearest
|
let idx = nearest
|
||||||
.binary_search_by(|third| {
|
.binary_search_by(|third| {
|
||||||
// `third` here is one of the neighbors of the new node's neighbor.
|
// `third` here is one of the neighbors of the new node's neighbor.
|
||||||
let third = match third {
|
let third = match third {
|
||||||
Some(nid) => *nid,
|
pid if pid.is_valid() => *pid,
|
||||||
// if `third` is `None`, our new `node` is always "closer"
|
// if `third` is `None`, our new `node` is always "closer"
|
||||||
None => return Ordering::Greater,
|
_ => return Ordering::Greater,
|
||||||
};
|
};
|
||||||
|
|
||||||
let third_distance = OrderedFloat::from(old.distance(&points[third.0]));
|
let third_distance = OrderedFloat::from(old.distance(&points[third.0 as usize]));
|
||||||
distance.cmp(&third_distance)
|
distance.cmp(&third_distance)
|
||||||
})
|
})
|
||||||
.unwrap_or_else(|e| e);
|
.unwrap_or_else(|e| e);
|
||||||
|
@ -382,15 +381,15 @@ trait Layer {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let nearest = self.nodes_mut()[pid.0].nearest_mut();
|
let nearest = self.nodes_mut()[pid.0 as usize].nearest_mut();
|
||||||
if nearest[idx].is_none() {
|
if !nearest[idx].is_valid() {
|
||||||
nearest[idx] = Some(new);
|
nearest[idx] = new;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let end = Self::LINKS - 1;
|
let end = Self::LINKS - 1;
|
||||||
nearest.copy_within(idx..end, idx + 1);
|
nearest.copy_within(idx..end, idx + 1);
|
||||||
nearest[idx] = Some(new);
|
nearest[idx] = new;
|
||||||
}
|
}
|
||||||
|
|
||||||
self.push(node);
|
self.push(node);
|
||||||
|
@ -409,15 +408,15 @@ struct UpperNode {
|
||||||
/// The nearest neighbors on this layer
|
/// The nearest neighbors on this layer
|
||||||
///
|
///
|
||||||
/// This is always kept in sorted order (near to far).
|
/// This is always kept in sorted order (near to far).
|
||||||
nearest: [Option<PointId>; M],
|
nearest: [PointId; M],
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Node for UpperNode {
|
impl Node for UpperNode {
|
||||||
fn nearest(&self) -> &[Option<PointId>] {
|
fn nearest(&self) -> &[PointId] {
|
||||||
&self.nearest
|
&self.nearest
|
||||||
}
|
}
|
||||||
|
|
||||||
fn nearest_mut(&mut self) -> &mut [Option<PointId>] {
|
fn nearest_mut(&mut self) -> &mut [PointId] {
|
||||||
&mut self.nearest
|
&mut self.nearest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -434,15 +433,15 @@ struct ZeroNode {
|
||||||
/// The nearest neighbors on this layer
|
/// The nearest neighbors on this layer
|
||||||
///
|
///
|
||||||
/// This is always kept in sorted order (near to far).
|
/// This is always kept in sorted order (near to far).
|
||||||
nearest: [Option<PointId>; M * 2],
|
nearest: [PointId; M * 2],
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Node for ZeroNode {
|
impl Node for ZeroNode {
|
||||||
fn nearest(&self) -> &[Option<PointId>] {
|
fn nearest(&self) -> &[PointId] {
|
||||||
&self.nearest
|
&self.nearest
|
||||||
}
|
}
|
||||||
|
|
||||||
fn nearest_mut(&mut self) -> &mut [Option<PointId>] {
|
fn nearest_mut(&mut self) -> &mut [PointId] {
|
||||||
&mut self.nearest
|
&mut self.nearest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -454,13 +453,13 @@ impl Node for ZeroNode {
|
||||||
}
|
}
|
||||||
|
|
||||||
trait Node: Default {
|
trait Node: Default {
|
||||||
fn nearest(&self) -> &[Option<PointId>];
|
fn nearest(&self) -> &[PointId];
|
||||||
fn nearest_mut(&mut self) -> &mut [Option<PointId>];
|
fn nearest_mut(&mut self) -> &mut [PointId];
|
||||||
fn nearest_iter(&self) -> NearestIter<'_>;
|
fn nearest_iter(&self) -> NearestIter<'_>;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct NearestIter<'a> {
|
struct NearestIter<'a> {
|
||||||
nearest: &'a [Option<PointId>],
|
nearest: &'a [PointId],
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Iterator for NearestIter<'a> {
|
impl<'a> Iterator for NearestIter<'a> {
|
||||||
|
@ -468,11 +467,11 @@ impl<'a> Iterator for NearestIter<'a> {
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
let (&first, rest) = self.nearest.split_first()?;
|
let (&first, rest) = self.nearest.split_first()?;
|
||||||
self.nearest = rest;
|
if !first.is_valid() {
|
||||||
if first.is_none() {
|
return None;
|
||||||
self.nearest = &[];
|
|
||||||
}
|
}
|
||||||
first
|
self.nearest = rest;
|
||||||
|
Some(first)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -530,11 +529,22 @@ struct Candidate {
|
||||||
/// This can be used to index into the `Hnsw` to refer to the `Point` data.
|
/// This can be used to index into the `Hnsw` to refer to the `Point` data.
|
||||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||||
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||||
pub struct PointId(usize);
|
pub struct PointId(u32);
|
||||||
|
|
||||||
impl PointId {
|
impl PointId {
|
||||||
pub fn invalid() -> Self {
|
fn invalid() -> Self {
|
||||||
PointId(usize::MAX)
|
PointId(u32::MAX)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Whether this value represents a valid point
|
||||||
|
pub fn is_valid(self) -> bool {
|
||||||
|
self.0 != u32::MAX
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for PointId {
|
||||||
|
fn default() -> Self {
|
||||||
|
PointId::invalid()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue