Initial version of Python bindings

This commit is contained in:
Dirkjan Ochtman 2021-03-16 17:07:18 +01:00
parent 1e15c683ef
commit 3d2ebb9e15
4 changed files with 290 additions and 1 deletions

5
.cargo/config.toml Normal file
View File

@ -0,0 +1,5 @@
[target.x86_64-apple-darwin]
rustflags = [
"-C", "link-arg=-undefined",
"-C", "link-arg=dynamic_lookup",
]

View File

@ -1,5 +1,5 @@
[workspace]
members = ["instant-distance"]
members = ["instant-distance", "instant-distance-py"]
[profile.bench]
debug = true

View File

@ -0,0 +1,16 @@
[package]
name = "instant-distance-py"
version = "0.1.0"
authors = ["Dirkjan Ochtman <dirkjan@ochtman.nl>"]
edition = "2018"
license = "MIT OR Apache-2.0"
[lib]
name = "instant_distance"
crate-type = ["cdylib"]
[dependencies]
instant-distance = { version = "0.2", path = "../instant-distance" }
pyo3 = { version = "0.13.2", features = ["extension-module"] }
serde = { version = "1", features = ["derive"] }
serde-big-array = "0.3.2"

View File

@ -0,0 +1,268 @@
#![allow(clippy::from_iter_instead_of_collect)]
use std::convert::TryFrom;
use std::iter::FromIterator;
use instant_distance::Point;
use pyo3::exceptions::PyTypeError;
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
use pyo3::types::{PyList, PyModule};
use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python};
use serde::{Deserialize, Serialize};
use serde_big_array::big_array;
#[pymodule]
fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PointId>()?;
m.add_class::<Heuristic>()?;
m.add_class::<Config>()?;
m.add_class::<Search>()?;
m.add_class::<Hnsw>()?;
Ok(())
}
#[pyclass]
struct Hnsw {
inner: instant_distance::Hnsw<FloatArray>,
}
#[pymethods]
impl Hnsw {
#[staticmethod]
fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec<PointId>)> {
let points = input
.into_iter()
.map(FloatArray::try_from)
.collect::<Result<Vec<_>, PyErr>>()?;
let (inner, ids) = instant_distance::Builder::from(config).build(&points);
let ids = Vec::from_iter(ids.into_iter().map(PointId::from));
Ok((Self { inner }, ids))
}
fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> {
let point = FloatArray::try_from(point)?;
let _ = self.inner.search(&point, &mut search.inner);
search.cur = Some(0);
Ok(())
}
}
#[pyclass]
struct Search {
inner: instant_distance::Search,
cur: Option<usize>,
}
#[pymethods]
impl Search {
#[new]
fn new() -> Self {
Self {
inner: instant_distance::Search::default(),
cur: None,
}
}
}
#[pyproto]
impl PyIterProtocol for Search {
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
slf
}
fn __next__(mut slf: PyRefMut<Self>) -> Option<PointId> {
let idx = match &slf.cur {
Some(idx) => *idx,
None => return None,
};
let pid = match slf.inner.get(idx) {
Some(pid) => pid,
None => {
slf.cur = None;
return None;
}
};
slf.cur = Some(idx + 1);
Some(PointId::from(pid))
}
}
#[pyclass]
#[derive(Copy, Clone, Default)]
struct Config {
#[pyo3(get, set)]
ef_search: usize,
#[pyo3(get, set)]
ef_construction: usize,
#[pyo3(get, set)]
ml: f32,
#[pyo3(get, set)]
seed: u64,
#[pyo3(get, set)]
heuristic: Option<Heuristic>,
}
#[pymethods]
impl Config {
#[new]
fn new() -> Self {
let builder = instant_distance::Builder::default();
let (ef_search, ef_construction, ml, seed) = builder.into_parts();
let heuristic = Some(Heuristic::default());
Self {
ef_search,
ef_construction,
ml,
seed,
heuristic,
}
}
}
impl From<&Config> for instant_distance::Builder {
fn from(py: &Config) -> Self {
let Config {
ef_search,
ef_construction,
ml,
seed,
heuristic,
} = *py;
Self::default()
.ef_search(ef_search)
.ef_construction(ef_construction)
.ml(ml)
.seed(seed)
.select_heuristic(heuristic.map(|h| h.into()))
}
}
#[pyclass]
#[derive(Copy, Clone)]
struct Heuristic {
#[pyo3(get, set)]
extend_candidates: bool,
#[pyo3(get, set)]
keep_pruned: bool,
}
#[pymethods]
impl Heuristic {
#[new]
fn new() -> Self {
let default = instant_distance::Heuristic::default();
let instant_distance::Heuristic {
extend_candidates,
keep_pruned,
} = default;
Self {
extend_candidates,
keep_pruned,
}
}
}
impl Default for Heuristic {
fn default() -> Self {
Self {
extend_candidates: false,
keep_pruned: true,
}
}
}
impl From<Heuristic> for instant_distance::Heuristic {
fn from(py: Heuristic) -> Self {
let Heuristic {
extend_candidates,
keep_pruned,
} = py;
Self {
extend_candidates,
keep_pruned,
}
}
}
#[pyclass]
struct PointId {
inner: instant_distance::PointId,
}
impl From<instant_distance::PointId> for PointId {
fn from(inner: instant_distance::PointId) -> Self {
Self { inner }
}
}
#[pyproto]
impl<'p> PyObjectProtocol<'p> for PointId {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("{:?}", self.inner))
}
fn __hash__(&'p self) -> PyResult<u32> {
Ok(self.inner.into_inner())
}
}
#[repr(align(32))]
#[derive(Clone, Deserialize, Serialize)]
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);
impl TryFrom<&PyAny> for FloatArray {
type Error = PyErr;
fn try_from(value: &PyAny) -> Result<Self, Self::Error> {
let mut new = FloatArray([0.0; DIMENSIONS]);
for (i, val) in value.iter()?.enumerate() {
match i >= DIMENSIONS {
true => return Err(PyTypeError::new_err("point array too long")),
false => new.0[i] = val?.extract::<f32>()?,
}
}
Ok(new)
}
}
big_array! { BigArray; DIMENSIONS }
impl Point for FloatArray {
fn distance(&self, rhs: &Self) -> f32 {
use std::arch::x86_64::{
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps,
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_fmadd_ps,
_mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps,
};
debug_assert_eq!(self.0.len() % 8, 4);
unsafe {
let mut acc_8x = _mm256_setzero_ps();
for (lh_slice, rh_slice) in self.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) {
let lh_8x = _mm256_load_ps(lh_slice.as_ptr());
let rh_8x = _mm256_load_ps(rh_slice.as_ptr());
let diff = _mm256_sub_ps(lh_8x, rh_8x);
acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x);
}
let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half
let right = _mm256_castps256_ps128(acc_8x); // lower half
acc_4x = _mm_add_ps(acc_4x, right); // sum halves
let lh_4x = _mm_load_ps(self.0[DIMENSIONS - 4..].as_ptr());
let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr());
let diff = _mm_sub_ps(lh_4x, rh_4x);
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x);
let lower = _mm_movehl_ps(acc_4x, acc_4x);
acc_4x = _mm_add_ps(acc_4x, lower);
let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1);
acc_4x = _mm_add_ss(acc_4x, upper);
_mm_cvtss_f32(acc_4x)
}
}
}
const DIMENSIONS: usize = 300;