Test both heuristic and simple selection
This commit is contained in:
parent
6db0f151ec
commit
cdc1242ba9
|
@ -840,4 +840,4 @@ impl IndexMut<PointId> for Vec<ZeroNode> {
|
|||
/// The parameter `M` from the paper
|
||||
///
|
||||
/// This should become a generic argument to `Hnsw` when possible.
|
||||
const M: usize = 12;
|
||||
const M: usize = 6;
|
||||
|
|
31
tests/all.rs
31
tests/all.rs
|
@ -4,7 +4,7 @@ use ordered_float::OrderedFloat;
|
|||
use rand::rngs::{StdRng, ThreadRng};
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use instant_distance::{Hnsw, Point as _, PointId, Search};
|
||||
use instant_distance::{Builder, Heuristic, Hnsw, Point as _, PointId, Search};
|
||||
|
||||
#[test]
|
||||
fn basic() {
|
||||
|
@ -24,10 +24,21 @@ fn basic() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn randomized() {
|
||||
let seed = ThreadRng::default().gen::<u64>();
|
||||
println!("seed {}", seed);
|
||||
fn random_heuristic() {
|
||||
let (seed, recall) = randomized(Builder::default().select_heuristic(Heuristic::default()));
|
||||
println!("heuristic (seed = {}) recall = {}", seed, recall);
|
||||
assert!(recall > 98, "expected at least 98, got {}", recall);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn random_simple() {
|
||||
let (seed, recall) = randomized(Builder::default());
|
||||
println!("simple (seed = {}) recall = {}", seed, recall);
|
||||
assert!(recall > 90, "expected at least 90, got {}", recall);
|
||||
}
|
||||
|
||||
fn randomized(builder: Builder) -> (u64, usize) {
|
||||
let seed = ThreadRng::default().gen::<u64>();
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
let points = (0..1024)
|
||||
.into_iter()
|
||||
|
@ -35,8 +46,6 @@ fn randomized() {
|
|||
.collect::<Vec<_>>();
|
||||
|
||||
let query = Point(rng.gen(), rng.gen());
|
||||
println!("query: {:?}", query);
|
||||
|
||||
let mut nearest = Vec::with_capacity(256);
|
||||
for (i, p) in points.iter().enumerate() {
|
||||
nearest.push((OrderedFloat::from(query.distance(p)), i));
|
||||
|
@ -46,10 +55,7 @@ fn randomized() {
|
|||
}
|
||||
}
|
||||
|
||||
let (hnsw, pids) = Hnsw::<Point>::builder()
|
||||
.seed(seed)
|
||||
.select_heuristic(Default::default())
|
||||
.build(&points);
|
||||
let (hnsw, pids) = builder.seed(seed).build(&points);
|
||||
let mut search = Search::default();
|
||||
let mut results = vec![PointId::default(); 100];
|
||||
let found = hnsw.search(&query, &mut results, &mut search);
|
||||
|
@ -62,10 +68,7 @@ fn randomized() {
|
|||
.map(|(_, i)| pids[*i])
|
||||
.collect::<HashSet<_>>();
|
||||
let found = results.into_iter().take(found).collect::<HashSet<_>>();
|
||||
|
||||
let recall = forced.intersection(&found).count();
|
||||
println!("{} matched", recall);
|
||||
assert!(recall > 90);
|
||||
(seed, forced.intersection(&found).count())
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
|
|
Loading…
Reference in New Issue