Add a high-level HnswMap type (fixes #7)

This commit is contained in:
Dirkjan Ochtman 2021-05-17 16:52:51 +02:00
parent afd7f928f9
commit b1bd3525a1
4 changed files with 64 additions and 4 deletions

View File

@ -41,7 +41,7 @@ impl Hnsw {
.map(FloatArray::try_from) .map(FloatArray::try_from)
.collect::<Result<Vec<_>, PyErr>>()?; .collect::<Result<Vec<_>, PyErr>>()?;
let (inner, ids) = instant_distance::Builder::from(config).build(points); let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points);
let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner())); let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner()));
Ok((Self { inner }, ids)) Ok((Self { inner }, ids))
} }

View File

@ -15,7 +15,7 @@ fn build_heuristic(bench: &mut Bencher) {
.map(|_| Point(rng.gen(), rng.gen())) .map(|_| Point(rng.gen(), rng.gen()))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
bench.iter(|| Builder::default().seed(seed).build(points.clone())) bench.iter(|| Builder::default().seed(seed).build_hnsw(points.clone()))
} }
/* /*

View File

@ -74,8 +74,13 @@ impl Builder {
self self
} }
/// Build an `HnswMap` with the given sets of points and values
pub fn build<P: Point, V: Clone>(self, points: Vec<P>, values: Vec<V>) -> HnswMap<P, V> {
HnswMap::new(points, values, self)
}
/// Build the `Hnsw` with the given set of points /// Build the `Hnsw` with the given set of points
pub fn build<P: Point>(self, points: Vec<P>) -> (Hnsw<P>, Vec<PointId>) { pub fn build_hnsw<P: Point>(self, points: Vec<P>) -> (Hnsw<P>, Vec<PointId>) {
Hnsw::new(points, self) Hnsw::new(points, self)
} }
@ -122,6 +127,42 @@ impl Default for Heuristic {
} }
} }
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct HnswMap<P, V> {
hnsw: Hnsw<P>,
values: Vec<V>,
}
impl<P, V> HnswMap<P, V>
where
P: Point,
V: Clone,
{
fn new(points: Vec<P>, values: Vec<V>, builder: Builder) -> Self {
let (hnsw, ids) = Hnsw::new(points, builder);
let mut sorted = ids.into_iter().enumerate().collect::<Vec<_>>();
sorted.sort_unstable_by(|a, b| a.1.cmp(&b.1));
let new = sorted
.into_iter()
.map(|(src, _)| values[src].clone())
.collect();
Self { hnsw, values: new }
}
pub fn search<'a>(
&'a self,
point: &P,
search: &'a mut Search,
) -> impl Iterator<Item = (f32, &'a V)> + ExactSizeIterator + 'a {
self.hnsw.search(point, search).map(move |candidate| {
let value = &self.values[candidate.pid.0 as usize];
(candidate.distance.into(), value)
})
}
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct Hnsw<P> { pub struct Hnsw<P> {
ef_search: usize, ef_search: usize,

View File

@ -6,6 +6,25 @@ use rand::{Rng, SeedableRng};
use instant_distance::{Builder, Point as _, Search}; use instant_distance::{Builder, Point as _, Search};
#[test]
fn map() {
let points = (0..16)
.into_iter()
.map(|i| Point(i as f32, i as f32))
.collect::<Vec<_>>();
let values = vec![
"zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten",
"eleven", "twelve", "thirteen", "fourteen", "fifteen",
];
let seed = ThreadRng::default().gen::<u64>();
println!("map (seed = {})", seed);
let map = Builder::default().seed(seed).build(points, values);
let mut search = Search::default();
let _ = map.search(&Point(2.0, 2.0), &mut search);
}
#[test] #[test]
fn random_heuristic() { fn random_heuristic() {
let (seed, recall) = randomized(Builder::default()); let (seed, recall) = randomized(Builder::default());
@ -38,7 +57,7 @@ fn randomized(builder: Builder) -> (u64, usize) {
} }
} }
let (hnsw, pids) = builder.seed(seed).build(points); let (hnsw, pids) = builder.seed(seed).build_hnsw(points);
let mut search = Search::default(); let mut search = Search::default();
let results = hnsw.search(&query, &mut search); let results = hnsw.search(&query, &mut search);
assert!(results.len() >= 100); assert!(results.len() >= 100);