Introduce Metric trait

This commit is contained in:
Kuba Jaroszewski 2023-03-27 22:25:42 +02:00
parent 11b5626775
commit d5f7b47b15
7 changed files with 104 additions and 84 deletions

View File

@ -1,7 +1,7 @@
use bencher::{benchmark_group, benchmark_main, Bencher};
use instant_distance::{Builder, Point, Search};
use instant_distance_py::FloatArray;
use instant_distance::{Builder, Metric, Search};
use instant_distance_py::{EuclidMetric, FloatArray};
use rand::{rngs::StdRng, Rng, SeedableRng};
benchmark_main!(benches);
@ -12,7 +12,7 @@ fn distance(bench: &mut Bencher) {
let point_a = FloatArray([rng.gen(); 300]);
let point_b = FloatArray([rng.gen(); 300]);
bench.iter(|| point_a.distance(&point_b));
bench.iter(|| EuclidMetric::distance(&point_a, &point_b));
}
fn build(bench: &mut Bencher) {
@ -24,7 +24,7 @@ fn build(bench: &mut Bencher) {
bench.iter(|| {
Builder::default()
.seed(SEED)
.build_hnsw::<_, _, Vec<FloatArray>>(points.clone())
.build_hnsw::<_, _, EuclidMetric, Vec<FloatArray>>(points.clone())
});
}
@ -35,7 +35,7 @@ fn query(bench: &mut Bencher) {
.collect::<Vec<_>>();
let (hnsw, _) = Builder::default()
.seed(SEED)
.build_hnsw::<_, _, Vec<FloatArray>>(points);
.build_hnsw::<_, _, EuclidMetric, Vec<FloatArray>>(points);
let point = FloatArray([rng.gen(); 300]);
bench.iter(|| {

View File

@ -6,7 +6,7 @@ use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::iter::FromIterator;
use instant_distance::Point;
use instant_distance::Metric;
use pyo3::conversion::IntoPy;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::types::{PyList, PyModule, PyString};
@ -29,7 +29,7 @@ fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> {
#[pyclass]
struct HnswMap {
inner: instant_distance::HnswMap<FloatArray, MapValue, Vec<FloatArray>>,
inner: instant_distance::HnswMap<FloatArray, EuclidMetric, MapValue, Vec<FloatArray>>,
}
#[pymethods]
@ -54,7 +54,7 @@ impl HnswMap {
/// Load an index from the given file name
#[staticmethod]
fn load(fname: &str) -> PyResult<Self> {
let hnsw_map = bincode::deserialize_from::<_, instant_distance::HnswMap<_, _, _>>(
let hnsw_map = bincode::deserialize_from::<_, instant_distance::HnswMap<_, _, _, _>>(
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
)
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
@ -90,7 +90,7 @@ impl HnswMap {
/// with a squared Euclidean distance metric.
#[pyclass]
struct Hnsw {
inner: instant_distance::Hnsw<FloatArray, Vec<FloatArray>>,
inner: instant_distance::Hnsw<FloatArray, EuclidMetric, Vec<FloatArray>>,
}
#[pymethods]
@ -111,7 +111,7 @@ impl Hnsw {
/// Load an index from the given file name
#[staticmethod]
fn load(fname: &str) -> PyResult<Self> {
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<_, _>>(
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<_, _, _>>(
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
)
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
@ -144,7 +144,7 @@ impl Hnsw {
/// Search buffer and result set
#[pyclass]
struct Search {
inner: instant_distance::Search<FloatArray>,
inner: instant_distance::Search<FloatArray, EuclidMetric>,
cur: Option<(HnswType, usize)>,
}
@ -364,8 +364,11 @@ impl TryFrom<&PyAny> for FloatArray {
}
}
impl Point for FloatArray {
fn distance(&self, rhs: &Self) -> f32 {
#[derive(Clone, Copy, Deserialize, Serialize)]
pub struct EuclidMetric;
impl Metric<FloatArray> for EuclidMetric {
fn distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 {
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::{
@ -373,11 +376,11 @@ impl Point for FloatArray {
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32,
_mm_fmadd_ps, _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps,
};
debug_assert_eq!(self.0.len() % 8, 4);
debug_assert_eq!(lhs.0.len() % 8, 4);
unsafe {
let mut acc_8x = _mm256_setzero_ps();
for (lh_slice, rh_slice) in self.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) {
for (lh_slice, rh_slice) in lhs.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) {
let lh_8x = _mm256_load_ps(lh_slice.as_ptr());
let rh_8x = _mm256_load_ps(rh_slice.as_ptr());
let diff = _mm256_sub_ps(lh_8x, rh_8x);
@ -388,7 +391,7 @@ impl Point for FloatArray {
let right = _mm256_castps256_ps128(acc_8x); // lower half
acc_4x = _mm_add_ps(acc_4x, right); // sum halves
let lh_4x = _mm_load_ps(self.0[DIMENSIONS - 4..].as_ptr());
let lh_4x = _mm_load_ps(lhs.0[DIMENSIONS - 4..].as_ptr());
let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr());
let diff = _mm_sub_ps(lh_4x, rh_4x);
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x);
@ -401,7 +404,7 @@ impl Point for FloatArray {
}
}
#[cfg(not(target_arch = "x86_64"))]
self.0
lhs.0
.iter()
.zip(rhs.0.iter())
.map(|(&a, &b)| (a - b).powi(2))

View File

@ -2,7 +2,7 @@ use bencher::{benchmark_group, benchmark_main, Bencher};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use instant_distance::Builder;
use instant_distance::{Builder, Metric};
benchmark_main!(benches);
benchmark_group!(benches, build_heuristic);
@ -11,13 +11,13 @@ fn build_heuristic(bench: &mut Bencher) {
let mut rng = StdRng::seed_from_u64(SEED);
let points = (0..1024)
.into_iter()
.map(|_| Point(rng.gen(), rng.gen()))
.map(|_| [rng.gen(), rng.gen()])
.collect::<Vec<_>>();
bench.iter(|| {
Builder::default()
.seed(SEED)
.build_hnsw::<Point, Point, Vec<Point>>(points.clone())
.build_hnsw::<[f32; 2], [f32; 2], EuclidMetric, Vec<[f32; 2]>>(points.clone())
})
}
@ -51,12 +51,15 @@ fn randomized(builder: Builder) -> (u64, usize) {
}
*/
#[derive(Clone, Copy, Debug)]
struct Point(f32, f32);
struct EuclidMetric;
impl instant_distance::Point for Point {
fn distance(&self, other: &Self) -> f32 {
impl Metric<[f32; 2]> for EuclidMetric {
fn distance(a: &[f32; 2], b: &[f32; 2]) -> f32 {
// Euclidean distance metric
((self.0 - other.0).powi(2) + (self.1 - other.1).powi(2)).sqrt()
a.iter()
.zip(b.iter())
.map(|(&a, &b)| (a - b).powi(2))
.sum::<f32>()
.sqrt()
}
}

View File

@ -1,10 +1,11 @@
use instant_distance::{Builder, Search};
use instant_distance::{Builder, Metric, Search};
fn main() {
let points = vec![Point(255, 0, 0), Point(0, 255, 0), Point(0, 0, 255)];
let values = vec!["red", "green", "blue"];
let map = Builder::default().build::<Point, Point, &str, Vec<Point>>(points, values);
let map =
Builder::default().build::<Point, Point, EuclidMetric, &str, Vec<Point>>(points, values);
let mut search = Search::default();
let burnt_orange = Point(204, 85, 0);
@ -17,10 +18,11 @@ fn main() {
#[derive(Clone, Copy, Debug)]
struct Point(isize, isize, isize);
impl instant_distance::Point for Point {
fn distance(&self, other: &Self) -> f32 {
struct EuclidMetric;
impl Metric<Point> for EuclidMetric {
fn distance(a: &Point, b: &Point) -> f32 {
// Euclidean distance metric
(((self.0 - other.0).pow(2) + (self.1 - other.1).pow(2) + (self.2 - other.2).pow(2)) as f32)
.sqrt()
(((a.0 - b.0).pow(2) + (a.1 - b.1).pow(2) + (a.2 - b.2).pow(2)) as f32).sqrt()
}
}

View File

@ -77,9 +77,10 @@ impl Builder {
}
/// Build an `HnswMap` with the given sets of points and values
pub fn build<T, P, V, S>(self, points: Vec<T>, values: Vec<V>) -> HnswMap<P, V, S>
pub fn build<T, P, M, V, S>(self, points: Vec<T>, values: Vec<V>) -> HnswMap<P, M, V, S>
where
P: Point,
P: ?Sized + Send + Sync,
M: Metric<P>,
V: Clone,
S: Default + From<Vec<T>> + Len + Index<PointId, Output = P> + Sync,
for<'a> &'a S: IntoIterator<Item = &'a P>,
@ -88,9 +89,10 @@ impl Builder {
}
/// Build the `Hnsw` with the given set of points
pub fn build_hnsw<T, P, S>(self, points: Vec<T>) -> (Hnsw<P, S>, Vec<PointId>)
pub fn build_hnsw<T, P, M: Metric<P>, S>(self, points: Vec<T>) -> (Hnsw<P, M, S>, Vec<PointId>)
where
P: Point,
P: ?Sized + Send + Sync,
M: Metric<P>,
S: Default + From<Vec<T>> + Len + Index<PointId, Output = P> + Sync,
for<'a> &'a S: IntoIterator<Item = &'a P>,
{
@ -141,18 +143,19 @@ impl Default for Heuristic {
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct HnswMap<P, V, S> {
pub struct HnswMap<P: ?Sized, M, V, S> {
#[cfg_attr(
feature = "serde",
serde(bound(deserialize = "Hnsw<P, S>: Deserialize<'de>"))
serde(bound(deserialize = "Hnsw<P, M, S>: Deserialize<'de>"))
)]
hnsw: Hnsw<P, S>,
hnsw: Hnsw<P, M, S>,
pub values: Vec<V>,
}
impl<P, V, S> HnswMap<P, V, S>
impl<P, M, V, S> HnswMap<P, M, V, S>
where
P: Point,
P: ?Sized + Send + Sync,
M: Metric<P>,
V: Clone,
S: Default + Len + Index<PointId, Output = P> + Sync,
for<'a> &'a S: IntoIterator<Item = &'a P>,
@ -176,7 +179,7 @@ where
pub fn search<'a>(
&'a self,
point: &P,
search: &'a mut Search<P>,
search: &'a mut Search<P, M>,
) -> impl Iterator<Item = MapItem<'a, P, V>> + ExactSizeIterator + 'a {
self.hnsw
.search(point, search)
@ -189,20 +192,20 @@ where
}
#[doc(hidden)]
pub fn get(&self, i: usize, search: &Search<P>) -> Option<MapItem<'_, P, V>> {
pub fn get(&self, i: usize, search: &Search<P, M>) -> Option<MapItem<'_, P, V>> {
Some(MapItem::from(self.hnsw.get(i, search)?, self))
}
}
pub struct MapItem<'a, P, V> {
pub struct MapItem<'a, P: ?Sized, V> {
pub distance: f32,
pub pid: PointId,
pub point: &'a P,
pub value: &'a V,
}
impl<'a, P, V> MapItem<'a, P, V> {
fn from<S>(item: Item<'a, P>, map: &'a HnswMap<P, V, S>) -> Self {
impl<'a, P: ?Sized, V> MapItem<'a, P, V> {
fn from<M, S>(item: Item<'a, P>, map: &'a HnswMap<P, M, V, S>) -> Self {
MapItem {
distance: item.distance,
pid: item.pid,
@ -213,17 +216,18 @@ impl<'a, P, V> MapItem<'a, P, V> {
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct Hnsw<P, S> {
pub struct Hnsw<P: ?Sized, M, S> {
ef_search: usize,
points: S,
zero: Vec<ZeroNode>,
layers: Vec<Vec<UpperNode>>,
phantom: PhantomData<P>,
phantom: PhantomData<(Box<P>, M)>,
}
impl<P, S> Hnsw<P, S>
impl<P, M, S> Hnsw<P, M, S>
where
P: Point,
P: ?Sized + Send + Sync,
M: Metric<P>,
S: Default + Len + Index<PointId, Output = P> + Sync,
for<'a> &'a S: IntoIterator<Item = &'a P>,
{
@ -331,6 +335,7 @@ where
points: &points,
heuristic,
ef_construction,
phantom: PhantomData::<M>,
#[cfg(feature = "indicatif")]
progress,
#[cfg(feature = "indicatif")]
@ -389,7 +394,7 @@ where
pub fn search<'a, 'b: 'a>(
&'b self,
point: &P,
search: &'a mut Search<P>,
search: &'a mut Search<P, M>,
) -> impl Iterator<Item = Item<'b, P>> + ExactSizeIterator + 'a {
search.reset();
let map = move |candidate| Item::new(candidate, self);
@ -428,19 +433,19 @@ where
}
#[doc(hidden)]
pub fn get(&self, i: usize, search: &Search<P>) -> Option<Item<'_, P>> {
pub fn get(&self, i: usize, search: &Search<P, M>) -> Option<Item<'_, P>> {
Some(Item::new(search.nearest.get(i).copied()?, self))
}
}
pub struct Item<'a, P> {
pub struct Item<'a, P: ?Sized> {
pub distance: f32,
pub pid: PointId,
pub point: &'a P,
}
impl<'a, P> Item<'a, P> {
fn new<S>(candidate: Candidate, hnsw: &'a Hnsw<P, S>) -> Self
impl<'a, P: ?Sized> Item<'a, P> {
fn new<M, S>(candidate: Candidate, hnsw: &'a Hnsw<P, M, S>) -> Self
where
S: Index<PointId, Output = P>,
{
@ -452,23 +457,24 @@ impl<'a, P> Item<'a, P> {
}
}
struct Construction<'a, P: Point, S> {
struct Construction<'a, P: ?Sized, M: Metric<P>, S> {
zero: &'a [RwLock<ZeroNode>],
pool: SearchPool<P>,
pool: SearchPool<P, M>,
top: LayerId,
points: &'a S,
heuristic: Option<Heuristic>,
ef_construction: usize,
phantom: PhantomData<M>,
#[cfg(feature = "indicatif")]
progress: Option<ProgressBar>,
#[cfg(feature = "indicatif")]
done: AtomicUsize,
}
impl<'a, P, S> Construction<'a, P, S>
impl<'a, P: ?Sized, M, S> Construction<'a, P, M, S>
where
P: Point,
S: Index<PointId, Output = P>,
M: Metric<P>,
{
/// Insert new node in the zero layer
///
@ -551,7 +557,7 @@ where
_ => return Ordering::Greater,
};
distance.cmp(&old.distance(&self.points[third]).into())
distance.cmp(&M::distance(old, &self.points[third]).into())
})
.unwrap_or_else(|e| e);
@ -572,14 +578,14 @@ where
}
}
type SearchPoolItem<P> = (Search<P>, Search<P>);
type SearchPoolItem<P, M> = (Search<P, M>, Search<P, M>);
struct SearchPool<P: Point> {
pool: Mutex<Vec<SearchPoolItem<P>>>,
struct SearchPool<P: ?Sized, M: Metric<P>> {
pool: Mutex<Vec<SearchPoolItem<P, M>>>,
len: usize,
}
impl<P: Point> SearchPool<P> {
impl<P: ?Sized, M: Metric<P>> SearchPool<P, M> {
fn new(len: usize) -> Self {
Self {
pool: Mutex::new(Vec::new()),
@ -587,14 +593,14 @@ impl<P: Point> SearchPool<P> {
}
}
fn pop(&self) -> SearchPoolItem<P> {
fn pop(&self) -> SearchPoolItem<P, M> {
match self.pool.lock().pop() {
Some(res) => res,
None => (Search::new(self.len), Search::new(self.len)),
}
}
fn push(&self, item: SearchPoolItem<P>) {
fn push(&self, item: SearchPoolItem<P, M>) {
self.pool.lock().push(item);
}
}
@ -603,7 +609,7 @@ impl<P: Point> SearchPool<P> {
///
/// In particular, this contains most of the state used in algorithm 2. The structure is
/// initialized by using `push()` to add the initial enter points.
pub struct Search<P: Point> {
pub struct Search<P: ?Sized, M: Metric<P>> {
/// Nodes visited so far (`v` in the paper)
visited: Visited,
/// Candidates for further inspection (`C` in the paper)
@ -618,10 +624,10 @@ pub struct Search<P: Point> {
/// Maximum number of nearest neighbors to retain (`ef` in the paper)
ef: usize,
/// PhantomData to bind the Metric parameter
phantom: PhantomData<P>,
phantom: PhantomData<(Box<P>, M)>,
}
impl<P: Point> Search<P> {
impl<P: ?Sized, M: Metric<P>> Search<P, M> {
fn new(capacity: usize) -> Self {
Self {
visited: Visited::with_capacity(capacity),
@ -706,7 +712,7 @@ impl<P: Point> Search<P> {
}
let other = &points[hop];
let distance = OrderedFloat::from(point.distance(other));
let distance = OrderedFloat::from(M::distance(point, other));
let new = Candidate { distance, pid: hop };
self.working.push(new);
}
@ -728,7 +734,8 @@ impl<P: Point> Search<P> {
// are to the query point, to facilitate bridging between clustered points.
let candidate_point = &points[candidate.pid];
let nearest = !self.nearest.iter().any(|result| {
let distance = OrderedFloat::from(candidate_point.distance(&points[result.pid]));
let distance =
OrderedFloat::from(M::distance(candidate_point, &points[result.pid]));
distance < candidate.distance
});
@ -761,7 +768,7 @@ impl<P: Point> Search<P> {
}
let other = &points[pid];
let distance = OrderedFloat::from(point.distance(other));
let distance = OrderedFloat::from(M::distance(point, other));
let new = Candidate { distance, pid };
let idx = match self.nearest.binary_search(&new) {
Err(idx) if idx < self.ef => idx,
@ -819,7 +826,7 @@ impl<P: Point> Search<P> {
}
}
impl<P: Point> Default for Search<P> {
impl<P: ?Sized, M: Metric<P>> Default for Search<P, M> {
fn default() -> Self {
Self {
visited: Visited::with_capacity(0),
@ -833,6 +840,10 @@ impl<P: Point> Default for Search<P> {
}
}
pub trait Metric<P: ?Sized>: Send + Sync {
fn distance(a: &P, b: &P) -> f32;
}
pub trait Len {
fn len(&self) -> usize;
@ -855,10 +866,6 @@ impl<P> Index<PointId> for Vec<P> {
}
}
pub trait Point: Clone + Send + Sync {
fn distance(&self, other: &Self) -> f32;
}
/// The parameter `M` from the paper
///
/// This should become a generic argument to `Hnsw` when possible.

View File

@ -266,8 +266,9 @@ impl Default for PointId {
}
}
impl<P, S> Index<PointId> for Hnsw<P, S>
impl<P, M, S> Index<PointId> for Hnsw<P, M, S>
where
P: ?Sized,
S: Index<PointId, Output = P>,
{
type Output = P;

View File

@ -4,7 +4,7 @@ use ordered_float::OrderedFloat;
use rand::rngs::{StdRng, ThreadRng};
use rand::{Rng, SeedableRng};
use instant_distance::{Builder, Point as _, Search};
use instant_distance::{Builder, Metric, Search};
#[test]
#[allow(clippy::float_cmp, clippy::approx_constant)]
@ -19,7 +19,7 @@ fn map() {
println!("map (seed = {seed})");
let map = Builder::default()
.seed(seed)
.build::<Point, Point, &str, Vec<Point>>(points, values);
.build::<Point, Point, EuclidMetric, &str, Vec<Point>>(points, values);
let mut search = Search::default();
for (i, item) in map.search(&Point(2.0, 2.0), &mut search).enumerate() {
@ -66,14 +66,16 @@ fn randomized(builder: Builder) -> (u64, usize) {
let query = Point(rng.gen(), rng.gen());
let mut nearest = Vec::with_capacity(256);
for (i, p) in points.iter().enumerate() {
nearest.push((OrderedFloat::from(query.distance(p)), i));
nearest.push((OrderedFloat::from(EuclidMetric::distance(&query, p)), i));
if nearest.len() >= 200 {
nearest.sort_unstable();
nearest.truncate(100);
}
}
let (hnsw, pids) = builder.seed(seed).build_hnsw::<_, _, Vec<Point>>(points);
let (hnsw, pids) = builder
.seed(seed)
.build_hnsw::<_, _, EuclidMetric, Vec<Point>>(points);
let mut search = Search::default();
let results = hnsw.search(&query, &mut search);
assert!(results.len() >= 100);
@ -94,9 +96,11 @@ fn randomized(builder: Builder) -> (u64, usize) {
#[derive(Clone, Copy, Debug)]
struct Point(f32, f32);
impl instant_distance::Point for Point {
fn distance(&self, other: &Self) -> f32 {
struct EuclidMetric;
impl Metric<Point> for EuclidMetric {
fn distance(a: &Point, b: &Point) -> f32 {
// Euclidean distance metric
((self.0 - other.0).powi(2) + (self.1 - other.1).powi(2)).sqrt()
((a.0 - b.0).powi(2) + (a.1 - b.1).powi(2)).sqrt()
}
}