Test recall in randomized test

This commit is contained in:
Dirkjan Ochtman 2021-01-07 21:38:53 +01:00
parent 6def318423
commit df24ddc186
1 changed files with 19 additions and 20 deletions

View File

@ -1,3 +1,5 @@
use std::collections::HashSet;
use ordered_float::OrderedFloat;
use rand::rngs::{StdRng, ThreadRng};
use rand::{Rng, SeedableRng};
@ -35,35 +37,32 @@ fn randomized() {
let query = Point(rng.gen(), rng.gen());
println!("query: {:?}", query);
let mut nearest = Vec::with_capacity(256);
for (i, p) in points.iter().enumerate() {
println!("{:2} {:?} ({})", i, p, query.distance(p));
nearest.push((OrderedFloat::from(query.distance(p)), i));
if nearest.len() >= 200 {
nearest.sort_unstable();
nearest.truncate(100);
}
}
let (hnsw, pids) = Hnsw::<Point>::builder().seed(seed).build(&points);
let mut search = Search::default();
let mut results = vec![PointId::default()];
let mut results = vec![PointId::default(); 100];
let found = hnsw.search(&query, &mut results, &mut search);
assert_eq!(found, 1);
assert_eq!(found, 100);
let nearest = points
nearest.sort_unstable();
nearest.truncate(100);
let forced = nearest
.iter()
.enumerate()
.map(|(i, other)| (OrderedFloat::from(query.distance(other)), i))
.min()
.unwrap();
println!(
"nearest (brute force): {:?} @ {:9.7}",
pids[nearest.1],
nearest.0.into_inner()
);
.map(|(_, i)| pids[*i])
.collect::<HashSet<_>>();
let found = results.into_iter().take(found).collect::<HashSet<_>>();
let index = pids.iter().position(|p| p == &results[0]).unwrap();
println!(
"nearest (hnsw): {:?} @ {:9.7}",
results[0],
query.distance(&points[index])
);
assert_eq!(pids[nearest.1], results[0]);
let recall = forced.intersection(&found).count();
println!("{} matched", recall);
assert!(recall > 95);
}
#[derive(Clone, Copy, Debug)]