Introduce PointStorage and support variable length
This commit is contained in:
parent
d5f7b47b15
commit
ebf9e453b6
|
@ -18,11 +18,11 @@ crate-type = ["cdylib", "lib"]
|
|||
name = "instant_distance"
|
||||
|
||||
[dependencies]
|
||||
aligned-vec = { version = "0.5.0", features = ["serde"] }
|
||||
bincode = "1.3.1"
|
||||
instant-distance = { version = "0.6", path = "../instant-distance", features = ["with-serde"] }
|
||||
pyo3 = { version = "0.18.0", features = ["extension-module"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde-big-array = "0.5.0"
|
||||
|
||||
[dev-dependencies]
|
||||
bencher = "0.1.5"
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use aligned_vec::avec;
|
||||
use bencher::{benchmark_group, benchmark_main, Bencher};
|
||||
|
||||
use instant_distance::{Builder, Metric, Search};
|
||||
use instant_distance_py::{EuclidMetric, FloatArray};
|
||||
use instant_distance_py::{EuclidMetric, PointStorage};
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
benchmark_main!(benches);
|
||||
|
@ -9,8 +10,8 @@ benchmark_group!(benches, distance, build, query);
|
|||
|
||||
fn distance(bench: &mut Bencher) {
|
||||
let mut rng = StdRng::seed_from_u64(SEED);
|
||||
let point_a = FloatArray([rng.gen(); 300]);
|
||||
let point_b = FloatArray([rng.gen(); 300]);
|
||||
let point_a = avec![rng.gen(); 304];
|
||||
let point_b = avec![rng.gen(); 304];
|
||||
|
||||
bench.iter(|| EuclidMetric::distance(&point_a, &point_b));
|
||||
}
|
||||
|
@ -18,25 +19,25 @@ fn distance(bench: &mut Bencher) {
|
|||
fn build(bench: &mut Bencher) {
|
||||
let mut rng = StdRng::seed_from_u64(SEED);
|
||||
let points = (0..1024)
|
||||
.map(|_| FloatArray([rng.gen(); 300]))
|
||||
.map(|_| vec![rng.gen(); 304])
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
bench.iter(|| {
|
||||
Builder::default()
|
||||
.seed(SEED)
|
||||
.build_hnsw::<_, _, EuclidMetric, Vec<FloatArray>>(points.clone())
|
||||
.build_hnsw::<Vec<f32>, [f32], EuclidMetric, PointStorage>(points.clone())
|
||||
});
|
||||
}
|
||||
|
||||
fn query(bench: &mut Bencher) {
|
||||
let mut rng = StdRng::seed_from_u64(SEED);
|
||||
let points = (0..1024)
|
||||
.map(|_| FloatArray([rng.gen(); 300]))
|
||||
.map(|_| vec![rng.gen(); 304])
|
||||
.collect::<Vec<_>>();
|
||||
let (hnsw, _) = Builder::default()
|
||||
.seed(SEED)
|
||||
.build_hnsw::<_, _, EuclidMetric, Vec<FloatArray>>(points);
|
||||
let point = FloatArray([rng.gen(); 300]);
|
||||
.build_hnsw::<Vec<f32>, [f32], EuclidMetric, PointStorage>(points);
|
||||
let point = avec![rng.gen(); 304];
|
||||
|
||||
bench.iter(|| {
|
||||
let mut search = Search::default();
|
||||
|
|
|
@ -4,16 +4,17 @@
|
|||
use std::convert::TryFrom;
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, BufWriter};
|
||||
use std::iter::FromIterator;
|
||||
use std::iter::{repeat, FromIterator};
|
||||
use std::ops::Index;
|
||||
|
||||
use instant_distance::Metric;
|
||||
use aligned_vec::{AVec, ConstAlign};
|
||||
use instant_distance::{Len, Metric, PointId};
|
||||
use pyo3::conversion::IntoPy;
|
||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||
use pyo3::exceptions::PyValueError;
|
||||
use pyo3::types::{PyList, PyModule, PyString};
|
||||
use pyo3::{pyclass, pymethods, pymodule};
|
||||
use pyo3::{Py, PyAny, PyErr, PyObject, PyRef, PyRefMut, PyResult, Python};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_big_array::BigArray;
|
||||
|
||||
#[pymodule]
|
||||
#[pyo3(name = "instant_distance")]
|
||||
|
@ -29,7 +30,7 @@ fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> {
|
|||
|
||||
#[pyclass]
|
||||
struct HnswMap {
|
||||
inner: instant_distance::HnswMap<FloatArray, EuclidMetric, MapValue, Vec<FloatArray>>,
|
||||
inner: instant_distance::HnswMap<[f32], EuclidMetric, MapValue, PointStorage>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
|
@ -39,7 +40,12 @@ impl HnswMap {
|
|||
fn build(points: &PyList, values: &PyList, config: &Config) -> PyResult<Self> {
|
||||
let points = points
|
||||
.into_iter()
|
||||
.map(FloatArray::try_from)
|
||||
.map(|v| {
|
||||
v.iter()?
|
||||
.into_iter()
|
||||
.map(|x| x?.extract())
|
||||
.collect::<Result<Vec<_>, PyErr>>()
|
||||
})
|
||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||
|
||||
let values = values
|
||||
|
@ -47,16 +53,21 @@ impl HnswMap {
|
|||
.map(MapValue::try_from)
|
||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||
|
||||
let hsnw_map = instant_distance::Builder::from(config).build(points, values);
|
||||
let hsnw_map = instant_distance::Builder::from(config)
|
||||
.build::<Vec<_>, [f32], EuclidMetric, MapValue, PointStorage>(points, values);
|
||||
Ok(Self { inner: hsnw_map })
|
||||
}
|
||||
|
||||
/// Load an index from the given file name
|
||||
#[staticmethod]
|
||||
fn load(fname: &str) -> PyResult<Self> {
|
||||
let hnsw_map = bincode::deserialize_from::<_, instant_distance::HnswMap<_, _, _, _>>(
|
||||
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
|
||||
)
|
||||
let hnsw_map = bincode::deserialize_from::<
|
||||
_,
|
||||
instant_distance::HnswMap<[f32], EuclidMetric, MapValue, PointStorage>,
|
||||
>(BufReader::with_capacity(
|
||||
32 * 1024 * 1024,
|
||||
File::open(fname)?,
|
||||
))
|
||||
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
|
||||
Ok(Self { inner: hnsw_map })
|
||||
}
|
||||
|
@ -77,7 +88,7 @@ impl HnswMap {
|
|||
///
|
||||
/// For best performance, reusing `Search` objects is recommended.
|
||||
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> {
|
||||
let point = FloatArray::try_from(point)?;
|
||||
let point = try_avec_from_py(point)?;
|
||||
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
|
||||
search.cur = Some((HnswType::Map(slf.clone_ref(py)), 0));
|
||||
Ok(())
|
||||
|
@ -90,7 +101,7 @@ impl HnswMap {
|
|||
/// with a squared Euclidean distance metric.
|
||||
#[pyclass]
|
||||
struct Hnsw {
|
||||
inner: instant_distance::Hnsw<FloatArray, EuclidMetric, Vec<FloatArray>>,
|
||||
inner: instant_distance::Hnsw<[f32], EuclidMetric, PointStorage>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
|
@ -100,10 +111,16 @@ impl Hnsw {
|
|||
fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec<u32>)> {
|
||||
let points = input
|
||||
.into_iter()
|
||||
.map(FloatArray::try_from)
|
||||
.map(|v| {
|
||||
v.iter()?
|
||||
.into_iter()
|
||||
.map(|x| x?.extract())
|
||||
.collect::<Result<Vec<_>, PyErr>>()
|
||||
})
|
||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||
|
||||
let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points);
|
||||
let (inner, ids) = instant_distance::Builder::from(config)
|
||||
.build_hnsw::<Vec<f32>, [f32], EuclidMetric, PointStorage>(points);
|
||||
let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner()));
|
||||
Ok((Self { inner }, ids))
|
||||
}
|
||||
|
@ -111,9 +128,13 @@ impl Hnsw {
|
|||
/// Load an index from the given file name
|
||||
#[staticmethod]
|
||||
fn load(fname: &str) -> PyResult<Self> {
|
||||
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<_, _, _>>(
|
||||
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
|
||||
)
|
||||
let hnsw = bincode::deserialize_from::<
|
||||
_,
|
||||
instant_distance::Hnsw<[f32], EuclidMetric, PointStorage>,
|
||||
>(BufReader::with_capacity(
|
||||
32 * 1024 * 1024,
|
||||
File::open(fname)?,
|
||||
))
|
||||
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
|
||||
Ok(Self { inner: hnsw })
|
||||
}
|
||||
|
@ -134,7 +155,7 @@ impl Hnsw {
|
|||
///
|
||||
/// For best performance, reusing `Search` objects is recommended.
|
||||
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> {
|
||||
let point = FloatArray::try_from(point)?;
|
||||
let point = try_avec_from_py(point)?;
|
||||
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
|
||||
search.cur = Some((HnswType::Hnsw(slf.clone_ref(py)), 0));
|
||||
Ok(())
|
||||
|
@ -144,7 +165,7 @@ impl Hnsw {
|
|||
/// Search buffer and result set
|
||||
#[pyclass]
|
||||
struct Search {
|
||||
inner: instant_distance::Search<FloatArray, EuclidMetric>,
|
||||
inner: instant_distance::Search<[f32], EuclidMetric>,
|
||||
cur: Option<(HnswType, usize)>,
|
||||
}
|
||||
|
||||
|
@ -345,42 +366,36 @@ impl Neighbor {
|
|||
}
|
||||
}
|
||||
|
||||
#[repr(align(32))]
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
pub struct FloatArray(#[serde(with = "BigArray")] pub [f32; DIMENSIONS]);
|
||||
|
||||
impl TryFrom<&PyAny> for FloatArray {
|
||||
type Error = PyErr;
|
||||
|
||||
fn try_from(value: &PyAny) -> Result<Self, Self::Error> {
|
||||
let mut new = FloatArray([0.0; DIMENSIONS]);
|
||||
for (i, val) in value.iter()?.enumerate() {
|
||||
match i >= DIMENSIONS {
|
||||
true => return Err(PyTypeError::new_err("point array too long")),
|
||||
false => new.0[i] = val?.extract::<f32>()?,
|
||||
}
|
||||
}
|
||||
Ok(new)
|
||||
fn try_avec_from_py(value: &PyAny) -> Result<AVec<f32, ConstAlign<ALIGNMENT>>, PyErr> {
|
||||
let mut new = AVec::new(ALIGNMENT);
|
||||
for val in value.iter()? {
|
||||
new.push(val?.extract::<f32>()?);
|
||||
}
|
||||
for _ in 0..PointStorage::padding(new.len()) {
|
||||
new.push(0.0);
|
||||
}
|
||||
Ok(new)
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Deserialize, Serialize)]
|
||||
pub struct EuclidMetric;
|
||||
|
||||
impl Metric<FloatArray> for EuclidMetric {
|
||||
fn distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 {
|
||||
impl Metric<[f32]> for EuclidMetric {
|
||||
fn distance(lhs: &[f32], rhs: &[f32]) -> f32 {
|
||||
debug_assert_eq!(lhs.len(), rhs.len());
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
use std::arch::x86_64::{
|
||||
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps,
|
||||
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32,
|
||||
_mm_fmadd_ps, _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps,
|
||||
_mm_movehl_ps, _mm_shuffle_ps,
|
||||
};
|
||||
debug_assert_eq!(lhs.0.len() % 8, 4);
|
||||
debug_assert_eq!(lhs.len() % 8, 0);
|
||||
|
||||
unsafe {
|
||||
let mut acc_8x = _mm256_setzero_ps();
|
||||
for (lh_slice, rh_slice) in lhs.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) {
|
||||
for (lh_slice, rh_slice) in lhs.chunks_exact(8).zip(rhs.chunks_exact(8)) {
|
||||
let lh_8x = _mm256_load_ps(lh_slice.as_ptr());
|
||||
let rh_8x = _mm256_load_ps(rh_slice.as_ptr());
|
||||
let diff = _mm256_sub_ps(lh_8x, rh_8x);
|
||||
|
@ -391,11 +406,6 @@ impl Metric<FloatArray> for EuclidMetric {
|
|||
let right = _mm256_castps256_ps128(acc_8x); // lower half
|
||||
acc_4x = _mm_add_ps(acc_4x, right); // sum halves
|
||||
|
||||
let lh_4x = _mm_load_ps(lhs.0[DIMENSIONS - 4..].as_ptr());
|
||||
let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr());
|
||||
let diff = _mm_sub_ps(lh_4x, rh_4x);
|
||||
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x);
|
||||
|
||||
let lower = _mm_movehl_ps(acc_4x, acc_4x);
|
||||
acc_4x = _mm_add_ps(acc_4x, lower);
|
||||
let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1);
|
||||
|
@ -412,6 +422,114 @@ impl Metric<FloatArray> for EuclidMetric {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct PointStorage {
|
||||
point_len: usize,
|
||||
points_data: AVec<f32>,
|
||||
}
|
||||
|
||||
impl PointStorage {
|
||||
const fn padding(len: usize) -> usize {
|
||||
let floats_per_alignment = ALIGNMENT / std::mem::size_of::<f32>();
|
||||
match len % floats_per_alignment {
|
||||
0 => 0,
|
||||
floats_over_alignment => floats_per_alignment - floats_over_alignment,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = &[f32]> {
|
||||
self.points_data.chunks_exact(self.point_len)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PointStorage {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
point_len: 1,
|
||||
points_data: AVec::new(ALIGNMENT),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Index<usize> for PointStorage {
|
||||
type Output = [f32];
|
||||
|
||||
fn index(&self, index: usize) -> &Self::Output {
|
||||
let raw_idx = index * self.point_len;
|
||||
&self.points_data[raw_idx..(raw_idx + self.point_len)]
|
||||
}
|
||||
}
|
||||
|
||||
impl Index<PointId> for PointStorage {
|
||||
type Output = [f32];
|
||||
|
||||
fn index(&self, index: PointId) -> &Self::Output {
|
||||
self.index(index.into_inner() as usize)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<Vec<f32>>> for PointStorage {
|
||||
fn from(value: Vec<Vec<f32>>) -> Self {
|
||||
if let Some(point) = value.first() {
|
||||
let point_len = point.len();
|
||||
let padding = PointStorage::padding(point_len);
|
||||
let mut points_data =
|
||||
AVec::with_capacity(ALIGNMENT, value.len() * (point_len + padding));
|
||||
for point in value {
|
||||
// all points should have the same length
|
||||
debug_assert_eq!(point.len(), point_len);
|
||||
for v in point.into_iter().chain(repeat(0.0).take(padding)) {
|
||||
points_data.push(v);
|
||||
}
|
||||
}
|
||||
Self {
|
||||
point_len: point_len + padding,
|
||||
points_data,
|
||||
}
|
||||
} else {
|
||||
Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Len for PointStorage {
|
||||
fn len(&self) -> usize {
|
||||
self.points_data.len() / self.point_len
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> IntoIterator for &'a PointStorage {
|
||||
type Item = &'a [f32];
|
||||
|
||||
type IntoIter = PointStorageIterator<'a>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
PointStorageIterator {
|
||||
storage: self,
|
||||
next_idx: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PointStorageIterator<'a> {
|
||||
storage: &'a PointStorage,
|
||||
next_idx: usize,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for PointStorageIterator<'a> {
|
||||
type Item = &'a [f32];
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.next_idx < self.storage.len() {
|
||||
let result = &self.storage[self.next_idx];
|
||||
self.next_idx += 1;
|
||||
Some(result)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
enum MapValue {
|
||||
String(String),
|
||||
|
@ -433,4 +551,4 @@ impl IntoPy<Py<PyAny>> for &'_ MapValue {
|
|||
}
|
||||
}
|
||||
|
||||
const DIMENSIONS: usize = 300;
|
||||
const ALIGNMENT: usize = 32;
|
||||
|
|
Loading…
Reference in New Issue