Add a high-level HnswMap type (fixes #7)
This commit is contained in:
parent
afd7f928f9
commit
b1bd3525a1
|
@ -41,7 +41,7 @@ impl Hnsw {
|
|||
.map(FloatArray::try_from)
|
||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||
|
||||
let (inner, ids) = instant_distance::Builder::from(config).build(points);
|
||||
let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points);
|
||||
let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner()));
|
||||
Ok((Self { inner }, ids))
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ fn build_heuristic(bench: &mut Bencher) {
|
|||
.map(|_| Point(rng.gen(), rng.gen()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
bench.iter(|| Builder::default().seed(seed).build(points.clone()))
|
||||
bench.iter(|| Builder::default().seed(seed).build_hnsw(points.clone()))
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
|
@ -74,8 +74,13 @@ impl Builder {
|
|||
self
|
||||
}
|
||||
|
||||
/// Build an `HnswMap` with the given sets of points and values
|
||||
pub fn build<P: Point, V: Clone>(self, points: Vec<P>, values: Vec<V>) -> HnswMap<P, V> {
|
||||
HnswMap::new(points, values, self)
|
||||
}
|
||||
|
||||
/// Build the `Hnsw` with the given set of points
|
||||
pub fn build<P: Point>(self, points: Vec<P>) -> (Hnsw<P>, Vec<PointId>) {
|
||||
pub fn build_hnsw<P: Point>(self, points: Vec<P>) -> (Hnsw<P>, Vec<PointId>) {
|
||||
Hnsw::new(points, self)
|
||||
}
|
||||
|
||||
|
@ -122,6 +127,42 @@ impl Default for Heuristic {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||
pub struct HnswMap<P, V> {
|
||||
hnsw: Hnsw<P>,
|
||||
values: Vec<V>,
|
||||
}
|
||||
|
||||
impl<P, V> HnswMap<P, V>
|
||||
where
|
||||
P: Point,
|
||||
V: Clone,
|
||||
{
|
||||
fn new(points: Vec<P>, values: Vec<V>, builder: Builder) -> Self {
|
||||
let (hnsw, ids) = Hnsw::new(points, builder);
|
||||
|
||||
let mut sorted = ids.into_iter().enumerate().collect::<Vec<_>>();
|
||||
sorted.sort_unstable_by(|a, b| a.1.cmp(&b.1));
|
||||
let new = sorted
|
||||
.into_iter()
|
||||
.map(|(src, _)| values[src].clone())
|
||||
.collect();
|
||||
|
||||
Self { hnsw, values: new }
|
||||
}
|
||||
|
||||
pub fn search<'a>(
|
||||
&'a self,
|
||||
point: &P,
|
||||
search: &'a mut Search,
|
||||
) -> impl Iterator<Item = (f32, &'a V)> + ExactSizeIterator + 'a {
|
||||
self.hnsw.search(point, search).map(move |candidate| {
|
||||
let value = &self.values[candidate.pid.0 as usize];
|
||||
(candidate.distance.into(), value)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||
pub struct Hnsw<P> {
|
||||
ef_search: usize,
|
||||
|
|
|
@ -6,6 +6,25 @@ use rand::{Rng, SeedableRng};
|
|||
|
||||
use instant_distance::{Builder, Point as _, Search};
|
||||
|
||||
#[test]
|
||||
fn map() {
|
||||
let points = (0..16)
|
||||
.into_iter()
|
||||
.map(|i| Point(i as f32, i as f32))
|
||||
.collect::<Vec<_>>();
|
||||
let values = vec![
|
||||
"zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten",
|
||||
"eleven", "twelve", "thirteen", "fourteen", "fifteen",
|
||||
];
|
||||
|
||||
let seed = ThreadRng::default().gen::<u64>();
|
||||
println!("map (seed = {})", seed);
|
||||
let map = Builder::default().seed(seed).build(points, values);
|
||||
let mut search = Search::default();
|
||||
|
||||
let _ = map.search(&Point(2.0, 2.0), &mut search);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn random_heuristic() {
|
||||
let (seed, recall) = randomized(Builder::default());
|
||||
|
@ -38,7 +57,7 @@ fn randomized(builder: Builder) -> (u64, usize) {
|
|||
}
|
||||
}
|
||||
|
||||
let (hnsw, pids) = builder.seed(seed).build(points);
|
||||
let (hnsw, pids) = builder.seed(seed).build_hnsw(points);
|
||||
let mut search = Search::default();
|
||||
let results = hnsw.search(&query, &mut search);
|
||||
assert!(results.len() >= 100);
|
||||
|
|
Loading…
Reference in New Issue