Test both heuristic and simple selection

This commit is contained in:
Dirkjan Ochtman 2021-01-19 11:00:29 +01:00
parent 6db0f151ec
commit cdc1242ba9
2 changed files with 18 additions and 15 deletions

View File

@ -840,4 +840,4 @@ impl IndexMut<PointId> for Vec<ZeroNode> {
/// The parameter `M` from the paper /// The parameter `M` from the paper
/// ///
/// This should become a generic argument to `Hnsw` when possible. /// This should become a generic argument to `Hnsw` when possible.
const M: usize = 12; const M: usize = 6;

View File

@ -4,7 +4,7 @@ use ordered_float::OrderedFloat;
use rand::rngs::{StdRng, ThreadRng}; use rand::rngs::{StdRng, ThreadRng};
use rand::{Rng, SeedableRng}; use rand::{Rng, SeedableRng};
use instant_distance::{Hnsw, Point as _, PointId, Search}; use instant_distance::{Builder, Heuristic, Hnsw, Point as _, PointId, Search};
#[test] #[test]
fn basic() { fn basic() {
@ -24,10 +24,21 @@ fn basic() {
} }
#[test] #[test]
fn randomized() { fn random_heuristic() {
let seed = ThreadRng::default().gen::<u64>(); let (seed, recall) = randomized(Builder::default().select_heuristic(Heuristic::default()));
println!("seed {}", seed); println!("heuristic (seed = {}) recall = {}", seed, recall);
assert!(recall > 98, "expected at least 98, got {}", recall);
}
#[test]
fn random_simple() {
let (seed, recall) = randomized(Builder::default());
println!("simple (seed = {}) recall = {}", seed, recall);
assert!(recall > 90, "expected at least 90, got {}", recall);
}
fn randomized(builder: Builder) -> (u64, usize) {
let seed = ThreadRng::default().gen::<u64>();
let mut rng = StdRng::seed_from_u64(seed); let mut rng = StdRng::seed_from_u64(seed);
let points = (0..1024) let points = (0..1024)
.into_iter() .into_iter()
@ -35,8 +46,6 @@ fn randomized() {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let query = Point(rng.gen(), rng.gen()); let query = Point(rng.gen(), rng.gen());
println!("query: {:?}", query);
let mut nearest = Vec::with_capacity(256); let mut nearest = Vec::with_capacity(256);
for (i, p) in points.iter().enumerate() { for (i, p) in points.iter().enumerate() {
nearest.push((OrderedFloat::from(query.distance(p)), i)); nearest.push((OrderedFloat::from(query.distance(p)), i));
@ -46,10 +55,7 @@ fn randomized() {
} }
} }
let (hnsw, pids) = Hnsw::<Point>::builder() let (hnsw, pids) = builder.seed(seed).build(&points);
.seed(seed)
.select_heuristic(Default::default())
.build(&points);
let mut search = Search::default(); let mut search = Search::default();
let mut results = vec![PointId::default(); 100]; let mut results = vec![PointId::default(); 100];
let found = hnsw.search(&query, &mut results, &mut search); let found = hnsw.search(&query, &mut results, &mut search);
@ -62,10 +68,7 @@ fn randomized() {
.map(|(_, i)| pids[*i]) .map(|(_, i)| pids[*i])
.collect::<HashSet<_>>(); .collect::<HashSet<_>>();
let found = results.into_iter().take(found).collect::<HashSet<_>>(); let found = results.into_iter().take(found).collect::<HashSet<_>>();
(seed, forced.intersection(&found).count())
let recall = forced.intersection(&found).count();
println!("{} matched", recall);
assert!(recall > 90);
} }
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]