First cut of python bindings update
This commit is contained in:
parent
99a473c298
commit
1d27340883
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"python.formatting.provider": "black"
|
||||||
|
}
|
|
@ -19,6 +19,7 @@ fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
|
||||||
m.add_class::<Config>()?;
|
m.add_class::<Config>()?;
|
||||||
m.add_class::<Search>()?;
|
m.add_class::<Search>()?;
|
||||||
m.add_class::<Hnsw>()?;
|
m.add_class::<Hnsw>()?;
|
||||||
|
m.add_class::<HnswMap>()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,6 +80,63 @@ impl Hnsw {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyclass]
|
||||||
|
struct HnswMap {
|
||||||
|
inner: instant_distance::HnswMap<FloatArray, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl HnswMap {
|
||||||
|
#[getter]
|
||||||
|
fn values(&self) -> PyResult<Vec<String>> {
|
||||||
|
Ok(self.inner.values.clone())
|
||||||
|
}
|
||||||
|
/// Build the index
|
||||||
|
#[staticmethod]
|
||||||
|
fn build(points: &PyList, values: Vec<String>, config: &Config) -> PyResult<Self> {
|
||||||
|
let points = points
|
||||||
|
.into_iter()
|
||||||
|
.map(FloatArray::try_from)
|
||||||
|
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||||
|
|
||||||
|
let hsnw_map = instant_distance::Builder::from(config).build(points, values);
|
||||||
|
Ok(Self { inner: hsnw_map })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load an index from the given file name
|
||||||
|
#[staticmethod]
|
||||||
|
fn load(fname: &str) -> PyResult<Self> {
|
||||||
|
let hnsw_map =
|
||||||
|
bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, String>>(
|
||||||
|
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
|
||||||
|
)
|
||||||
|
.map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?;
|
||||||
|
Ok(Self { inner: hnsw_map })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Dump the index to the given file name
|
||||||
|
fn dump(&self, fname: &str) -> PyResult<()> {
|
||||||
|
let f = BufWriter::with_capacity(32 * 1024 * 1024, File::create(fname)?);
|
||||||
|
bincode::serialize_into(f, &self.inner)
|
||||||
|
.map_err(|e| PyValueError::new_err(format!("serialization error: {:?}", e)))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Search the index for points neighboring the given point
|
||||||
|
///
|
||||||
|
/// The `search` object contains buffers used for searching. When the search completes,
|
||||||
|
/// iterate over the `Search` to get the results. The number of results should be equal
|
||||||
|
/// to the `ef_search` parameter set in the index's `config`.
|
||||||
|
///
|
||||||
|
/// For best performance, reusing `Search` objects is recommended.
|
||||||
|
fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> {
|
||||||
|
let point = FloatArray::try_from(point)?;
|
||||||
|
let _ = self.inner.search(&point, &mut search.inner);
|
||||||
|
search.cur = Some(0);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Search buffer and result set
|
/// Search buffer and result set
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
struct Search {
|
struct Search {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import instant_distance, random
|
import instant_distance, random
|
||||||
|
|
||||||
def main():
|
|
||||||
|
def test_hsnw():
|
||||||
points = [[random.random() for _ in range(300)] for _ in range(1024)]
|
points = [[random.random() for _ in range(300)] for _ in range(1024)]
|
||||||
config = instant_distance.Config()
|
config = instant_distance.Config()
|
||||||
(hnsw, ids) = instant_distance.Hnsw.build(points, config)
|
(hnsw, ids) = instant_distance.Hnsw.build(points, config)
|
||||||
|
@ -10,5 +11,31 @@ def main():
|
||||||
for candidate in search:
|
for candidate in search:
|
||||||
print(candidate)
|
print(candidate)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
def test_hsnw_map():
|
||||||
|
the_chosen_one = 123
|
||||||
|
|
||||||
|
embeddings = [[random.random() for _ in range(300)] for _ in range(1024)]
|
||||||
|
with open("/usr/share/dict/words", "r") as f: # *nix only
|
||||||
|
values = f.read().splitlines()[1024:]
|
||||||
|
|
||||||
|
config = instant_distance.Config()
|
||||||
|
hnsw_map = instant_distance.HnswMap.build(embeddings, values, config)
|
||||||
|
|
||||||
|
search = instant_distance.Search()
|
||||||
|
hnsw_map.search(embeddings[the_chosen_one], search)
|
||||||
|
|
||||||
|
closest_pid = list(search)[0].pid
|
||||||
|
approx_nearest = hnsw_map.values[closest_pid]
|
||||||
|
actual_word = values[the_chosen_one]
|
||||||
|
|
||||||
|
print("pid:\t\t", closest_pid)
|
||||||
|
print("approx word:\t", approx_nearest)
|
||||||
|
print("actual word:\t", actual_word)
|
||||||
|
|
||||||
|
assert approx_nearest == actual_word
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_hsnw()
|
||||||
|
test_hsnw_map()
|
||||||
|
|
|
@ -130,7 +130,7 @@ impl Default for Heuristic {
|
||||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||||
pub struct HnswMap<P, V> {
|
pub struct HnswMap<P, V> {
|
||||||
hnsw: Hnsw<P>,
|
hnsw: Hnsw<P>,
|
||||||
values: Vec<V>,
|
pub values: Vec<V>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<P, V> HnswMap<P, V>
|
impl<P, V> HnswMap<P, V>
|
||||||
|
|
Loading…
Reference in New Issue