py: upgrade pyo3 to 0.21

This commit is contained in:
Dirkjan Ochtman 2024-04-02 09:29:43 +02:00
parent 778ecf98fe
commit 60209e689f
2 changed files with 34 additions and 20 deletions

View File

@ -17,6 +17,6 @@ crate-type = ["cdylib"]
[dependencies] [dependencies]
bincode = "1.3.1" bincode = "1.3.1"
instant-distance = { version = "0.6", path = "../instant-distance", features = ["with-serde"] } instant-distance = { version = "0.6", path = "../instant-distance", features = ["with-serde"] }
pyo3 = { version = "0.20", features = ["extension-module"] } pyo3 = { version = "0.21", features = ["extension-module"] }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde-big-array = "0.5.0" serde-big-array = "0.5.0"

View File

@ -10,15 +10,15 @@ use std::iter::FromIterator;
use instant_distance::Point; use instant_distance::Point;
use pyo3::conversion::IntoPy; use pyo3::conversion::IntoPy;
use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::types::{PyList, PyModule, PyString}; use pyo3::types::{PyAnyMethods, PyList, PyListMethods, PyModule, PyString};
use pyo3::{pyclass, pymethods, pymodule}; use pyo3::{pyclass, pymethods, pymodule, Bound};
use pyo3::{Py, PyAny, PyErr, PyObject, PyRef, PyRefMut, PyResult, Python}; use pyo3::{Py, PyAny, PyErr, PyObject, PyRef, PyRefMut, PyResult, Python};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_big_array::BigArray; use serde_big_array::BigArray;
#[pymodule] #[pymodule]
#[pyo3(name = "instant_distance")] #[pyo3(name = "instant_distance")]
fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> { fn instant_distance_py(_: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Neighbor>()?; m.add_class::<Neighbor>()?;
m.add_class::<Heuristic>()?; m.add_class::<Heuristic>()?;
m.add_class::<Config>()?; m.add_class::<Config>()?;
@ -37,10 +37,14 @@ struct HnswMap {
impl HnswMap { impl HnswMap {
/// Build the index /// Build the index
#[staticmethod] #[staticmethod]
fn build(points: &PyList, values: &PyList, config: &Config) -> PyResult<Self> { fn build(
points: &Bound<'_, PyList>,
values: &Bound<'_, PyList>,
config: &Config,
) -> PyResult<Self> {
let points = points let points = points
.into_iter() .iter()
.map(FloatArray::try_from) .map(|array| FloatArray::try_from(&array))
.collect::<Result<Vec<_>, PyErr>>()?; .collect::<Result<Vec<_>, PyErr>>()?;
let values = values let values = values
@ -78,7 +82,12 @@ impl HnswMap {
/// to the `ef_search` parameter set in the index's `config`. /// to the `ef_search` parameter set in the index's `config`.
/// ///
/// For best performance, reusing `Search` objects is recommended. /// For best performance, reusing `Search` objects is recommended.
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> { fn search(
slf: Py<Self>,
point: &Bound<'_, PyAny>,
search: &mut Search,
py: Python<'_>,
) -> PyResult<()> {
let point = FloatArray::try_from(point)?; let point = FloatArray::try_from(point)?;
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner); let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
search.cur = Some((HnswType::Map(slf.clone_ref(py)), 0)); search.cur = Some((HnswType::Map(slf.clone_ref(py)), 0));
@ -99,10 +108,10 @@ struct Hnsw {
impl Hnsw { impl Hnsw {
/// Build the index /// Build the index
#[staticmethod] #[staticmethod]
fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec<u32>)> { fn build(input: &Bound<'_, PyList>, config: &Config) -> PyResult<(Self, Vec<u32>)> {
let points = input let points = input
.into_iter() .iter()
.map(FloatArray::try_from) .map(|array| FloatArray::try_from(&array))
.collect::<Result<Vec<_>, PyErr>>()?; .collect::<Result<Vec<_>, PyErr>>()?;
let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points); let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points);
@ -135,7 +144,12 @@ impl Hnsw {
/// to the `ef_search` parameter set in the index's `config`. /// to the `ef_search` parameter set in the index's `config`.
/// ///
/// For best performance, reusing `Search` objects is recommended. /// For best performance, reusing `Search` objects is recommended.
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> { fn search(
slf: Py<Self>,
point: &Bound<'_, PyAny>,
search: &mut Search,
py: Python<'_>,
) -> PyResult<()> {
let point = FloatArray::try_from(point)?; let point = FloatArray::try_from(point)?;
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner); let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
search.cur = Some((HnswType::Hnsw(slf.clone_ref(py)), 0)); search.cur = Some((HnswType::Hnsw(slf.clone_ref(py)), 0));
@ -175,7 +189,7 @@ impl Search {
let py = slf.py(); let py = slf.py();
let neighbor = match &index { let neighbor = match &index {
HnswType::Hnsw(hnsw) => { HnswType::Hnsw(hnsw) => {
let hnsw = hnsw.as_ref(py).borrow(); let hnsw = hnsw.bind(py).borrow();
let item = hnsw.inner.get(idx, &slf.inner); let item = hnsw.inner.get(idx, &slf.inner);
item.map(|item| Neighbor { item.map(|item| Neighbor {
distance: item.distance, distance: item.distance,
@ -184,7 +198,7 @@ impl Search {
}) })
} }
HnswType::Map(map) => { HnswType::Map(map) => {
let map = map.as_ref(py).borrow(); let map = map.bind(py).borrow();
let item = map.inner.get(idx, &slf.inner); let item = map.inner.get(idx, &slf.inner);
item.map(|item| Neighbor { item.map(|item| Neighbor {
distance: item.distance, distance: item.distance,
@ -337,7 +351,7 @@ impl Neighbor {
"instant_distance.Neighbor(distance={}, pid={}, value={})", "instant_distance.Neighbor(distance={}, pid={}, value={})",
self.distance, self.distance,
self.pid, self.pid,
Python::with_gil(|py| self.value.as_ref(py).repr().map(|s| s.to_string()))?, Python::with_gil(|py| self.value.bind(py).repr().map(|s| s.to_string()))?,
)), )),
true => Ok(format!( true => Ok(format!(
"instant_distance.Item(distance={}, pid={})", "instant_distance.Item(distance={}, pid={})",
@ -351,10 +365,10 @@ impl Neighbor {
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize)]
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]); struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);
impl TryFrom<&PyAny> for FloatArray { impl TryFrom<&Bound<'_, PyAny>> for FloatArray {
type Error = PyErr; type Error = PyErr;
fn try_from(value: &PyAny) -> Result<Self, Self::Error> { fn try_from(value: &Bound<'_, PyAny>) -> Result<Self, Self::Error> {
let mut new = FloatArray([0.0; DIMENSIONS]); let mut new = FloatArray([0.0; DIMENSIONS]);
for (i, val) in value.iter()?.enumerate() { for (i, val) in value.iter()?.enumerate() {
match i >= DIMENSIONS { match i >= DIMENSIONS {
@ -416,10 +430,10 @@ enum MapValue {
String(String), String(String),
} }
impl TryFrom<&PyAny> for MapValue { impl TryFrom<Bound<'_, PyAny>> for MapValue {
type Error = PyErr; type Error = PyErr;
fn try_from(value: &PyAny) -> Result<Self, Self::Error> { fn try_from(value: Bound<'_, PyAny>) -> Result<Self, Self::Error> {
Ok(MapValue::String(value.extract::<String>()?)) Ok(MapValue::String(value.extract::<String>()?))
} }
} }
@ -427,7 +441,7 @@ impl TryFrom<&PyAny> for MapValue {
impl IntoPy<Py<PyAny>> for &'_ MapValue { impl IntoPy<Py<PyAny>> for &'_ MapValue {
fn into_py(self, py: Python<'_>) -> Py<PyAny> { fn into_py(self, py: Python<'_>) -> Py<PyAny> {
match self { match self {
MapValue::String(s) => PyString::new(py, s).into(), MapValue::String(s) => PyString::new_bound(py, s).into(),
} }
} }
} }