Add a high-level HnswMap type (fixes #7)
This commit is contained in:
parent
afd7f928f9
commit
b1bd3525a1
|
@ -41,7 +41,7 @@ impl Hnsw {
|
||||||
.map(FloatArray::try_from)
|
.map(FloatArray::try_from)
|
||||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||||
|
|
||||||
let (inner, ids) = instant_distance::Builder::from(config).build(points);
|
let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points);
|
||||||
let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner()));
|
let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner()));
|
||||||
Ok((Self { inner }, ids))
|
Ok((Self { inner }, ids))
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ fn build_heuristic(bench: &mut Bencher) {
|
||||||
.map(|_| Point(rng.gen(), rng.gen()))
|
.map(|_| Point(rng.gen(), rng.gen()))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
bench.iter(|| Builder::default().seed(seed).build(points.clone()))
|
bench.iter(|| Builder::default().seed(seed).build_hnsw(points.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -74,8 +74,13 @@ impl Builder {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build an `HnswMap` with the given sets of points and values
|
||||||
|
pub fn build<P: Point, V: Clone>(self, points: Vec<P>, values: Vec<V>) -> HnswMap<P, V> {
|
||||||
|
HnswMap::new(points, values, self)
|
||||||
|
}
|
||||||
|
|
||||||
/// Build the `Hnsw` with the given set of points
|
/// Build the `Hnsw` with the given set of points
|
||||||
pub fn build<P: Point>(self, points: Vec<P>) -> (Hnsw<P>, Vec<PointId>) {
|
pub fn build_hnsw<P: Point>(self, points: Vec<P>) -> (Hnsw<P>, Vec<PointId>) {
|
||||||
Hnsw::new(points, self)
|
Hnsw::new(points, self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,6 +127,42 @@ impl Default for Heuristic {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||||
|
pub struct HnswMap<P, V> {
|
||||||
|
hnsw: Hnsw<P>,
|
||||||
|
values: Vec<V>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<P, V> HnswMap<P, V>
|
||||||
|
where
|
||||||
|
P: Point,
|
||||||
|
V: Clone,
|
||||||
|
{
|
||||||
|
fn new(points: Vec<P>, values: Vec<V>, builder: Builder) -> Self {
|
||||||
|
let (hnsw, ids) = Hnsw::new(points, builder);
|
||||||
|
|
||||||
|
let mut sorted = ids.into_iter().enumerate().collect::<Vec<_>>();
|
||||||
|
sorted.sort_unstable_by(|a, b| a.1.cmp(&b.1));
|
||||||
|
let new = sorted
|
||||||
|
.into_iter()
|
||||||
|
.map(|(src, _)| values[src].clone())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Self { hnsw, values: new }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn search<'a>(
|
||||||
|
&'a self,
|
||||||
|
point: &P,
|
||||||
|
search: &'a mut Search,
|
||||||
|
) -> impl Iterator<Item = (f32, &'a V)> + ExactSizeIterator + 'a {
|
||||||
|
self.hnsw.search(point, search).map(move |candidate| {
|
||||||
|
let value = &self.values[candidate.pid.0 as usize];
|
||||||
|
(candidate.distance.into(), value)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||||
pub struct Hnsw<P> {
|
pub struct Hnsw<P> {
|
||||||
ef_search: usize,
|
ef_search: usize,
|
||||||
|
|
|
@ -6,6 +6,25 @@ use rand::{Rng, SeedableRng};
|
||||||
|
|
||||||
use instant_distance::{Builder, Point as _, Search};
|
use instant_distance::{Builder, Point as _, Search};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn map() {
|
||||||
|
let points = (0..16)
|
||||||
|
.into_iter()
|
||||||
|
.map(|i| Point(i as f32, i as f32))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let values = vec![
|
||||||
|
"zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten",
|
||||||
|
"eleven", "twelve", "thirteen", "fourteen", "fifteen",
|
||||||
|
];
|
||||||
|
|
||||||
|
let seed = ThreadRng::default().gen::<u64>();
|
||||||
|
println!("map (seed = {})", seed);
|
||||||
|
let map = Builder::default().seed(seed).build(points, values);
|
||||||
|
let mut search = Search::default();
|
||||||
|
|
||||||
|
let _ = map.search(&Point(2.0, 2.0), &mut search);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn random_heuristic() {
|
fn random_heuristic() {
|
||||||
let (seed, recall) = randomized(Builder::default());
|
let (seed, recall) = randomized(Builder::default());
|
||||||
|
@ -38,7 +57,7 @@ fn randomized(builder: Builder) -> (u64, usize) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let (hnsw, pids) = builder.seed(seed).build(points);
|
let (hnsw, pids) = builder.seed(seed).build_hnsw(points);
|
||||||
let mut search = Search::default();
|
let mut search = Search::default();
|
||||||
let results = hnsw.search(&query, &mut search);
|
let results = hnsw.search(&query, &mut search);
|
||||||
assert!(results.len() >= 100);
|
assert!(results.len() >= 100);
|
||||||
|
|
Loading…
Reference in New Issue