mirror of
https://github.com/instant-labs/instant-distance.git
synced 2025-02-17 21:32:02 +00:00
Introduce Metric trait
This commit is contained in:
parent
11b5626775
commit
d5f7b47b15
@ -1,7 +1,7 @@
|
||||
use bencher::{benchmark_group, benchmark_main, Bencher};
|
||||
|
||||
use instant_distance::{Builder, Point, Search};
|
||||
use instant_distance_py::FloatArray;
|
||||
use instant_distance::{Builder, Metric, Search};
|
||||
use instant_distance_py::{EuclidMetric, FloatArray};
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
benchmark_main!(benches);
|
||||
@ -12,7 +12,7 @@ fn distance(bench: &mut Bencher) {
|
||||
let point_a = FloatArray([rng.gen(); 300]);
|
||||
let point_b = FloatArray([rng.gen(); 300]);
|
||||
|
||||
bench.iter(|| point_a.distance(&point_b));
|
||||
bench.iter(|| EuclidMetric::distance(&point_a, &point_b));
|
||||
}
|
||||
|
||||
fn build(bench: &mut Bencher) {
|
||||
@ -24,7 +24,7 @@ fn build(bench: &mut Bencher) {
|
||||
bench.iter(|| {
|
||||
Builder::default()
|
||||
.seed(SEED)
|
||||
.build_hnsw::<_, _, Vec<FloatArray>>(points.clone())
|
||||
.build_hnsw::<_, _, EuclidMetric, Vec<FloatArray>>(points.clone())
|
||||
});
|
||||
}
|
||||
|
||||
@ -35,7 +35,7 @@ fn query(bench: &mut Bencher) {
|
||||
.collect::<Vec<_>>();
|
||||
let (hnsw, _) = Builder::default()
|
||||
.seed(SEED)
|
||||
.build_hnsw::<_, _, Vec<FloatArray>>(points);
|
||||
.build_hnsw::<_, _, EuclidMetric, Vec<FloatArray>>(points);
|
||||
let point = FloatArray([rng.gen(); 300]);
|
||||
|
||||
bench.iter(|| {
|
||||
|
@ -6,7 +6,7 @@ use std::fs::File;
|
||||
use std::io::{BufReader, BufWriter};
|
||||
use std::iter::FromIterator;
|
||||
|
||||
use instant_distance::Point;
|
||||
use instant_distance::Metric;
|
||||
use pyo3::conversion::IntoPy;
|
||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||
use pyo3::types::{PyList, PyModule, PyString};
|
||||
@ -29,7 +29,7 @@ fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> {
|
||||
|
||||
#[pyclass]
|
||||
struct HnswMap {
|
||||
inner: instant_distance::HnswMap<FloatArray, MapValue, Vec<FloatArray>>,
|
||||
inner: instant_distance::HnswMap<FloatArray, EuclidMetric, MapValue, Vec<FloatArray>>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
@ -54,7 +54,7 @@ impl HnswMap {
|
||||
/// Load an index from the given file name
|
||||
#[staticmethod]
|
||||
fn load(fname: &str) -> PyResult<Self> {
|
||||
let hnsw_map = bincode::deserialize_from::<_, instant_distance::HnswMap<_, _, _>>(
|
||||
let hnsw_map = bincode::deserialize_from::<_, instant_distance::HnswMap<_, _, _, _>>(
|
||||
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
|
||||
)
|
||||
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
|
||||
@ -90,7 +90,7 @@ impl HnswMap {
|
||||
/// with a squared Euclidean distance metric.
|
||||
#[pyclass]
|
||||
struct Hnsw {
|
||||
inner: instant_distance::Hnsw<FloatArray, Vec<FloatArray>>,
|
||||
inner: instant_distance::Hnsw<FloatArray, EuclidMetric, Vec<FloatArray>>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
@ -111,7 +111,7 @@ impl Hnsw {
|
||||
/// Load an index from the given file name
|
||||
#[staticmethod]
|
||||
fn load(fname: &str) -> PyResult<Self> {
|
||||
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<_, _>>(
|
||||
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<_, _, _>>(
|
||||
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
|
||||
)
|
||||
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
|
||||
@ -144,7 +144,7 @@ impl Hnsw {
|
||||
/// Search buffer and result set
|
||||
#[pyclass]
|
||||
struct Search {
|
||||
inner: instant_distance::Search<FloatArray>,
|
||||
inner: instant_distance::Search<FloatArray, EuclidMetric>,
|
||||
cur: Option<(HnswType, usize)>,
|
||||
}
|
||||
|
||||
@ -364,8 +364,11 @@ impl TryFrom<&PyAny> for FloatArray {
|
||||
}
|
||||
}
|
||||
|
||||
impl Point for FloatArray {
|
||||
fn distance(&self, rhs: &Self) -> f32 {
|
||||
#[derive(Clone, Copy, Deserialize, Serialize)]
|
||||
pub struct EuclidMetric;
|
||||
|
||||
impl Metric<FloatArray> for EuclidMetric {
|
||||
fn distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 {
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
use std::arch::x86_64::{
|
||||
@ -373,11 +376,11 @@ impl Point for FloatArray {
|
||||
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32,
|
||||
_mm_fmadd_ps, _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps,
|
||||
};
|
||||
debug_assert_eq!(self.0.len() % 8, 4);
|
||||
debug_assert_eq!(lhs.0.len() % 8, 4);
|
||||
|
||||
unsafe {
|
||||
let mut acc_8x = _mm256_setzero_ps();
|
||||
for (lh_slice, rh_slice) in self.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) {
|
||||
for (lh_slice, rh_slice) in lhs.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) {
|
||||
let lh_8x = _mm256_load_ps(lh_slice.as_ptr());
|
||||
let rh_8x = _mm256_load_ps(rh_slice.as_ptr());
|
||||
let diff = _mm256_sub_ps(lh_8x, rh_8x);
|
||||
@ -388,7 +391,7 @@ impl Point for FloatArray {
|
||||
let right = _mm256_castps256_ps128(acc_8x); // lower half
|
||||
acc_4x = _mm_add_ps(acc_4x, right); // sum halves
|
||||
|
||||
let lh_4x = _mm_load_ps(self.0[DIMENSIONS - 4..].as_ptr());
|
||||
let lh_4x = _mm_load_ps(lhs.0[DIMENSIONS - 4..].as_ptr());
|
||||
let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr());
|
||||
let diff = _mm_sub_ps(lh_4x, rh_4x);
|
||||
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x);
|
||||
@ -401,7 +404,7 @@ impl Point for FloatArray {
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_arch = "x86_64"))]
|
||||
self.0
|
||||
lhs.0
|
||||
.iter()
|
||||
.zip(rhs.0.iter())
|
||||
.map(|(&a, &b)| (a - b).powi(2))
|
||||
|
@ -2,7 +2,7 @@ use bencher::{benchmark_group, benchmark_main, Bencher};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use instant_distance::Builder;
|
||||
use instant_distance::{Builder, Metric};
|
||||
|
||||
benchmark_main!(benches);
|
||||
benchmark_group!(benches, build_heuristic);
|
||||
@ -11,13 +11,13 @@ fn build_heuristic(bench: &mut Bencher) {
|
||||
let mut rng = StdRng::seed_from_u64(SEED);
|
||||
let points = (0..1024)
|
||||
.into_iter()
|
||||
.map(|_| Point(rng.gen(), rng.gen()))
|
||||
.map(|_| [rng.gen(), rng.gen()])
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
bench.iter(|| {
|
||||
Builder::default()
|
||||
.seed(SEED)
|
||||
.build_hnsw::<Point, Point, Vec<Point>>(points.clone())
|
||||
.build_hnsw::<[f32; 2], [f32; 2], EuclidMetric, Vec<[f32; 2]>>(points.clone())
|
||||
})
|
||||
}
|
||||
|
||||
@ -51,12 +51,15 @@ fn randomized(builder: Builder) -> (u64, usize) {
|
||||
}
|
||||
*/
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct Point(f32, f32);
|
||||
struct EuclidMetric;
|
||||
|
||||
impl instant_distance::Point for Point {
|
||||
fn distance(&self, other: &Self) -> f32 {
|
||||
impl Metric<[f32; 2]> for EuclidMetric {
|
||||
fn distance(a: &[f32; 2], b: &[f32; 2]) -> f32 {
|
||||
// Euclidean distance metric
|
||||
((self.0 - other.0).powi(2) + (self.1 - other.1).powi(2)).sqrt()
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(&a, &b)| (a - b).powi(2))
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,11 @@
|
||||
use instant_distance::{Builder, Search};
|
||||
use instant_distance::{Builder, Metric, Search};
|
||||
|
||||
fn main() {
|
||||
let points = vec![Point(255, 0, 0), Point(0, 255, 0), Point(0, 0, 255)];
|
||||
let values = vec!["red", "green", "blue"];
|
||||
|
||||
let map = Builder::default().build::<Point, Point, &str, Vec<Point>>(points, values);
|
||||
let map =
|
||||
Builder::default().build::<Point, Point, EuclidMetric, &str, Vec<Point>>(points, values);
|
||||
let mut search = Search::default();
|
||||
|
||||
let burnt_orange = Point(204, 85, 0);
|
||||
@ -17,10 +18,11 @@ fn main() {
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct Point(isize, isize, isize);
|
||||
|
||||
impl instant_distance::Point for Point {
|
||||
fn distance(&self, other: &Self) -> f32 {
|
||||
struct EuclidMetric;
|
||||
|
||||
impl Metric<Point> for EuclidMetric {
|
||||
fn distance(a: &Point, b: &Point) -> f32 {
|
||||
// Euclidean distance metric
|
||||
(((self.0 - other.0).pow(2) + (self.1 - other.1).pow(2) + (self.2 - other.2).pow(2)) as f32)
|
||||
.sqrt()
|
||||
(((a.0 - b.0).pow(2) + (a.1 - b.1).pow(2) + (a.2 - b.2).pow(2)) as f32).sqrt()
|
||||
}
|
||||
}
|
||||
|
@ -77,9 +77,10 @@ impl Builder {
|
||||
}
|
||||
|
||||
/// Build an `HnswMap` with the given sets of points and values
|
||||
pub fn build<T, P, V, S>(self, points: Vec<T>, values: Vec<V>) -> HnswMap<P, V, S>
|
||||
pub fn build<T, P, M, V, S>(self, points: Vec<T>, values: Vec<V>) -> HnswMap<P, M, V, S>
|
||||
where
|
||||
P: Point,
|
||||
P: ?Sized + Send + Sync,
|
||||
M: Metric<P>,
|
||||
V: Clone,
|
||||
S: Default + From<Vec<T>> + Len + Index<PointId, Output = P> + Sync,
|
||||
for<'a> &'a S: IntoIterator<Item = &'a P>,
|
||||
@ -88,9 +89,10 @@ impl Builder {
|
||||
}
|
||||
|
||||
/// Build the `Hnsw` with the given set of points
|
||||
pub fn build_hnsw<T, P, S>(self, points: Vec<T>) -> (Hnsw<P, S>, Vec<PointId>)
|
||||
pub fn build_hnsw<T, P, M: Metric<P>, S>(self, points: Vec<T>) -> (Hnsw<P, M, S>, Vec<PointId>)
|
||||
where
|
||||
P: Point,
|
||||
P: ?Sized + Send + Sync,
|
||||
M: Metric<P>,
|
||||
S: Default + From<Vec<T>> + Len + Index<PointId, Output = P> + Sync,
|
||||
for<'a> &'a S: IntoIterator<Item = &'a P>,
|
||||
{
|
||||
@ -141,18 +143,19 @@ impl Default for Heuristic {
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||
pub struct HnswMap<P, V, S> {
|
||||
pub struct HnswMap<P: ?Sized, M, V, S> {
|
||||
#[cfg_attr(
|
||||
feature = "serde",
|
||||
serde(bound(deserialize = "Hnsw<P, S>: Deserialize<'de>"))
|
||||
serde(bound(deserialize = "Hnsw<P, M, S>: Deserialize<'de>"))
|
||||
)]
|
||||
hnsw: Hnsw<P, S>,
|
||||
hnsw: Hnsw<P, M, S>,
|
||||
pub values: Vec<V>,
|
||||
}
|
||||
|
||||
impl<P, V, S> HnswMap<P, V, S>
|
||||
impl<P, M, V, S> HnswMap<P, M, V, S>
|
||||
where
|
||||
P: Point,
|
||||
P: ?Sized + Send + Sync,
|
||||
M: Metric<P>,
|
||||
V: Clone,
|
||||
S: Default + Len + Index<PointId, Output = P> + Sync,
|
||||
for<'a> &'a S: IntoIterator<Item = &'a P>,
|
||||
@ -176,7 +179,7 @@ where
|
||||
pub fn search<'a>(
|
||||
&'a self,
|
||||
point: &P,
|
||||
search: &'a mut Search<P>,
|
||||
search: &'a mut Search<P, M>,
|
||||
) -> impl Iterator<Item = MapItem<'a, P, V>> + ExactSizeIterator + 'a {
|
||||
self.hnsw
|
||||
.search(point, search)
|
||||
@ -189,20 +192,20 @@ where
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn get(&self, i: usize, search: &Search<P>) -> Option<MapItem<'_, P, V>> {
|
||||
pub fn get(&self, i: usize, search: &Search<P, M>) -> Option<MapItem<'_, P, V>> {
|
||||
Some(MapItem::from(self.hnsw.get(i, search)?, self))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MapItem<'a, P, V> {
|
||||
pub struct MapItem<'a, P: ?Sized, V> {
|
||||
pub distance: f32,
|
||||
pub pid: PointId,
|
||||
pub point: &'a P,
|
||||
pub value: &'a V,
|
||||
}
|
||||
|
||||
impl<'a, P, V> MapItem<'a, P, V> {
|
||||
fn from<S>(item: Item<'a, P>, map: &'a HnswMap<P, V, S>) -> Self {
|
||||
impl<'a, P: ?Sized, V> MapItem<'a, P, V> {
|
||||
fn from<M, S>(item: Item<'a, P>, map: &'a HnswMap<P, M, V, S>) -> Self {
|
||||
MapItem {
|
||||
distance: item.distance,
|
||||
pid: item.pid,
|
||||
@ -213,17 +216,18 @@ impl<'a, P, V> MapItem<'a, P, V> {
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||
pub struct Hnsw<P, S> {
|
||||
pub struct Hnsw<P: ?Sized, M, S> {
|
||||
ef_search: usize,
|
||||
points: S,
|
||||
zero: Vec<ZeroNode>,
|
||||
layers: Vec<Vec<UpperNode>>,
|
||||
phantom: PhantomData<P>,
|
||||
phantom: PhantomData<(Box<P>, M)>,
|
||||
}
|
||||
|
||||
impl<P, S> Hnsw<P, S>
|
||||
impl<P, M, S> Hnsw<P, M, S>
|
||||
where
|
||||
P: Point,
|
||||
P: ?Sized + Send + Sync,
|
||||
M: Metric<P>,
|
||||
S: Default + Len + Index<PointId, Output = P> + Sync,
|
||||
for<'a> &'a S: IntoIterator<Item = &'a P>,
|
||||
{
|
||||
@ -331,6 +335,7 @@ where
|
||||
points: &points,
|
||||
heuristic,
|
||||
ef_construction,
|
||||
phantom: PhantomData::<M>,
|
||||
#[cfg(feature = "indicatif")]
|
||||
progress,
|
||||
#[cfg(feature = "indicatif")]
|
||||
@ -389,7 +394,7 @@ where
|
||||
pub fn search<'a, 'b: 'a>(
|
||||
&'b self,
|
||||
point: &P,
|
||||
search: &'a mut Search<P>,
|
||||
search: &'a mut Search<P, M>,
|
||||
) -> impl Iterator<Item = Item<'b, P>> + ExactSizeIterator + 'a {
|
||||
search.reset();
|
||||
let map = move |candidate| Item::new(candidate, self);
|
||||
@ -428,19 +433,19 @@ where
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn get(&self, i: usize, search: &Search<P>) -> Option<Item<'_, P>> {
|
||||
pub fn get(&self, i: usize, search: &Search<P, M>) -> Option<Item<'_, P>> {
|
||||
Some(Item::new(search.nearest.get(i).copied()?, self))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Item<'a, P> {
|
||||
pub struct Item<'a, P: ?Sized> {
|
||||
pub distance: f32,
|
||||
pub pid: PointId,
|
||||
pub point: &'a P,
|
||||
}
|
||||
|
||||
impl<'a, P> Item<'a, P> {
|
||||
fn new<S>(candidate: Candidate, hnsw: &'a Hnsw<P, S>) -> Self
|
||||
impl<'a, P: ?Sized> Item<'a, P> {
|
||||
fn new<M, S>(candidate: Candidate, hnsw: &'a Hnsw<P, M, S>) -> Self
|
||||
where
|
||||
S: Index<PointId, Output = P>,
|
||||
{
|
||||
@ -452,23 +457,24 @@ impl<'a, P> Item<'a, P> {
|
||||
}
|
||||
}
|
||||
|
||||
struct Construction<'a, P: Point, S> {
|
||||
struct Construction<'a, P: ?Sized, M: Metric<P>, S> {
|
||||
zero: &'a [RwLock<ZeroNode>],
|
||||
pool: SearchPool<P>,
|
||||
pool: SearchPool<P, M>,
|
||||
top: LayerId,
|
||||
points: &'a S,
|
||||
heuristic: Option<Heuristic>,
|
||||
ef_construction: usize,
|
||||
phantom: PhantomData<M>,
|
||||
#[cfg(feature = "indicatif")]
|
||||
progress: Option<ProgressBar>,
|
||||
#[cfg(feature = "indicatif")]
|
||||
done: AtomicUsize,
|
||||
}
|
||||
|
||||
impl<'a, P, S> Construction<'a, P, S>
|
||||
impl<'a, P: ?Sized, M, S> Construction<'a, P, M, S>
|
||||
where
|
||||
P: Point,
|
||||
S: Index<PointId, Output = P>,
|
||||
M: Metric<P>,
|
||||
{
|
||||
/// Insert new node in the zero layer
|
||||
///
|
||||
@ -551,7 +557,7 @@ where
|
||||
_ => return Ordering::Greater,
|
||||
};
|
||||
|
||||
distance.cmp(&old.distance(&self.points[third]).into())
|
||||
distance.cmp(&M::distance(old, &self.points[third]).into())
|
||||
})
|
||||
.unwrap_or_else(|e| e);
|
||||
|
||||
@ -572,14 +578,14 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
type SearchPoolItem<P> = (Search<P>, Search<P>);
|
||||
type SearchPoolItem<P, M> = (Search<P, M>, Search<P, M>);
|
||||
|
||||
struct SearchPool<P: Point> {
|
||||
pool: Mutex<Vec<SearchPoolItem<P>>>,
|
||||
struct SearchPool<P: ?Sized, M: Metric<P>> {
|
||||
pool: Mutex<Vec<SearchPoolItem<P, M>>>,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
impl<P: Point> SearchPool<P> {
|
||||
impl<P: ?Sized, M: Metric<P>> SearchPool<P, M> {
|
||||
fn new(len: usize) -> Self {
|
||||
Self {
|
||||
pool: Mutex::new(Vec::new()),
|
||||
@ -587,14 +593,14 @@ impl<P: Point> SearchPool<P> {
|
||||
}
|
||||
}
|
||||
|
||||
fn pop(&self) -> SearchPoolItem<P> {
|
||||
fn pop(&self) -> SearchPoolItem<P, M> {
|
||||
match self.pool.lock().pop() {
|
||||
Some(res) => res,
|
||||
None => (Search::new(self.len), Search::new(self.len)),
|
||||
}
|
||||
}
|
||||
|
||||
fn push(&self, item: SearchPoolItem<P>) {
|
||||
fn push(&self, item: SearchPoolItem<P, M>) {
|
||||
self.pool.lock().push(item);
|
||||
}
|
||||
}
|
||||
@ -603,7 +609,7 @@ impl<P: Point> SearchPool<P> {
|
||||
///
|
||||
/// In particular, this contains most of the state used in algorithm 2. The structure is
|
||||
/// initialized by using `push()` to add the initial enter points.
|
||||
pub struct Search<P: Point> {
|
||||
pub struct Search<P: ?Sized, M: Metric<P>> {
|
||||
/// Nodes visited so far (`v` in the paper)
|
||||
visited: Visited,
|
||||
/// Candidates for further inspection (`C` in the paper)
|
||||
@ -618,10 +624,10 @@ pub struct Search<P: Point> {
|
||||
/// Maximum number of nearest neighbors to retain (`ef` in the paper)
|
||||
ef: usize,
|
||||
/// PhantomData to bind the Metric parameter
|
||||
phantom: PhantomData<P>,
|
||||
phantom: PhantomData<(Box<P>, M)>,
|
||||
}
|
||||
|
||||
impl<P: Point> Search<P> {
|
||||
impl<P: ?Sized, M: Metric<P>> Search<P, M> {
|
||||
fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
visited: Visited::with_capacity(capacity),
|
||||
@ -706,7 +712,7 @@ impl<P: Point> Search<P> {
|
||||
}
|
||||
|
||||
let other = &points[hop];
|
||||
let distance = OrderedFloat::from(point.distance(other));
|
||||
let distance = OrderedFloat::from(M::distance(point, other));
|
||||
let new = Candidate { distance, pid: hop };
|
||||
self.working.push(new);
|
||||
}
|
||||
@ -728,7 +734,8 @@ impl<P: Point> Search<P> {
|
||||
// are to the query point, to facilitate bridging between clustered points.
|
||||
let candidate_point = &points[candidate.pid];
|
||||
let nearest = !self.nearest.iter().any(|result| {
|
||||
let distance = OrderedFloat::from(candidate_point.distance(&points[result.pid]));
|
||||
let distance =
|
||||
OrderedFloat::from(M::distance(candidate_point, &points[result.pid]));
|
||||
distance < candidate.distance
|
||||
});
|
||||
|
||||
@ -761,7 +768,7 @@ impl<P: Point> Search<P> {
|
||||
}
|
||||
|
||||
let other = &points[pid];
|
||||
let distance = OrderedFloat::from(point.distance(other));
|
||||
let distance = OrderedFloat::from(M::distance(point, other));
|
||||
let new = Candidate { distance, pid };
|
||||
let idx = match self.nearest.binary_search(&new) {
|
||||
Err(idx) if idx < self.ef => idx,
|
||||
@ -819,7 +826,7 @@ impl<P: Point> Search<P> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: Point> Default for Search<P> {
|
||||
impl<P: ?Sized, M: Metric<P>> Default for Search<P, M> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
visited: Visited::with_capacity(0),
|
||||
@ -833,6 +840,10 @@ impl<P: Point> Default for Search<P> {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Metric<P: ?Sized>: Send + Sync {
|
||||
fn distance(a: &P, b: &P) -> f32;
|
||||
}
|
||||
|
||||
pub trait Len {
|
||||
fn len(&self) -> usize;
|
||||
|
||||
@ -855,10 +866,6 @@ impl<P> Index<PointId> for Vec<P> {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Point: Clone + Send + Sync {
|
||||
fn distance(&self, other: &Self) -> f32;
|
||||
}
|
||||
|
||||
/// The parameter `M` from the paper
|
||||
///
|
||||
/// This should become a generic argument to `Hnsw` when possible.
|
||||
|
@ -266,8 +266,9 @@ impl Default for PointId {
|
||||
}
|
||||
}
|
||||
|
||||
impl<P, S> Index<PointId> for Hnsw<P, S>
|
||||
impl<P, M, S> Index<PointId> for Hnsw<P, M, S>
|
||||
where
|
||||
P: ?Sized,
|
||||
S: Index<PointId, Output = P>,
|
||||
{
|
||||
type Output = P;
|
||||
|
@ -4,7 +4,7 @@ use ordered_float::OrderedFloat;
|
||||
use rand::rngs::{StdRng, ThreadRng};
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use instant_distance::{Builder, Point as _, Search};
|
||||
use instant_distance::{Builder, Metric, Search};
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::float_cmp, clippy::approx_constant)]
|
||||
@ -19,7 +19,7 @@ fn map() {
|
||||
println!("map (seed = {seed})");
|
||||
let map = Builder::default()
|
||||
.seed(seed)
|
||||
.build::<Point, Point, &str, Vec<Point>>(points, values);
|
||||
.build::<Point, Point, EuclidMetric, &str, Vec<Point>>(points, values);
|
||||
let mut search = Search::default();
|
||||
|
||||
for (i, item) in map.search(&Point(2.0, 2.0), &mut search).enumerate() {
|
||||
@ -66,14 +66,16 @@ fn randomized(builder: Builder) -> (u64, usize) {
|
||||
let query = Point(rng.gen(), rng.gen());
|
||||
let mut nearest = Vec::with_capacity(256);
|
||||
for (i, p) in points.iter().enumerate() {
|
||||
nearest.push((OrderedFloat::from(query.distance(p)), i));
|
||||
nearest.push((OrderedFloat::from(EuclidMetric::distance(&query, p)), i));
|
||||
if nearest.len() >= 200 {
|
||||
nearest.sort_unstable();
|
||||
nearest.truncate(100);
|
||||
}
|
||||
}
|
||||
|
||||
let (hnsw, pids) = builder.seed(seed).build_hnsw::<_, _, Vec<Point>>(points);
|
||||
let (hnsw, pids) = builder
|
||||
.seed(seed)
|
||||
.build_hnsw::<_, _, EuclidMetric, Vec<Point>>(points);
|
||||
let mut search = Search::default();
|
||||
let results = hnsw.search(&query, &mut search);
|
||||
assert!(results.len() >= 100);
|
||||
@ -94,9 +96,11 @@ fn randomized(builder: Builder) -> (u64, usize) {
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct Point(f32, f32);
|
||||
|
||||
impl instant_distance::Point for Point {
|
||||
fn distance(&self, other: &Self) -> f32 {
|
||||
struct EuclidMetric;
|
||||
|
||||
impl Metric<Point> for EuclidMetric {
|
||||
fn distance(a: &Point, b: &Point) -> f32 {
|
||||
// Euclidean distance metric
|
||||
((self.0 - other.0).powi(2) + (self.1 - other.1).powi(2)).sqrt()
|
||||
((a.0 - b.0).powi(2) + (a.1 - b.1).powi(2)).sqrt()
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user