py: upgrade to pyo3 0.21

This commit is contained in:
Dirkjan Ochtman 2024-04-02 09:21:30 +02:00
parent 29cafd0db1
commit ab12c4e0c1
No known key found for this signature in database
2 changed files with 22 additions and 14 deletions

View File

@ -18,5 +18,5 @@ crate-type = ["cdylib"]
[dependencies] [dependencies]
bincode = "1.3.2" bincode = "1.3.2"
instant-segment = { version = "0.11", path = "../instant-segment", features = ["with-serde"] } instant-segment = { version = "0.11", path = "../instant-segment", features = ["with-serde"] }
pyo3 = { version = "0.20", features = ["extension-module"] } pyo3 = { version = "0.21", features = ["extension-module"] }
smartstring = "1" smartstring = "1"

View File

@ -5,14 +5,15 @@ use std::fs::File;
use std::io::{BufReader, BufWriter}; use std::io::{BufReader, BufWriter};
use pyo3::exceptions::PyValueError; use pyo3::exceptions::PyValueError;
use pyo3::types::{PyIterator, PyModule}; use pyo3::pybacked::PyBackedStr;
use pyo3::{pyclass, pymethods, pymodule}; use pyo3::types::{PyAnyMethods, PyIterator, PyModule};
use pyo3::{pyclass, pymethods, pymodule, Bound};
use pyo3::{PyErr, PyRef, PyRefMut, PyResult, Python}; use pyo3::{PyErr, PyRef, PyRefMut, PyResult, Python};
use smartstring::alias::String as SmartString; use smartstring::alias::String as SmartString;
#[pymodule] #[pymodule]
#[pyo3(name = "instant_segment")] #[pyo3(name = "instant_segment")]
fn instant_segment_py(_: Python, m: &PyModule) -> PyResult<()> { fn instant_segment_py(_: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Search>()?; m.add_class::<Search>()?;
m.add_class::<Segmenter>()?; m.add_class::<Segmenter>()?;
Ok(()) Ok(())
@ -31,23 +32,29 @@ impl Segmenter {
/// The `unigrams` iterator should yield `(str, float)` items, while the `bigrams` /// The `unigrams` iterator should yield `(str, float)` items, while the `bigrams`
/// iterator should yield `((str, str), float)` items. /// iterator should yield `((str, str), float)` items.
#[new] #[new]
fn new(unigrams: &PyIterator, bigrams: &PyIterator) -> PyResult<Self> { fn new(unigrams: &Bound<'_, PyIterator>, bigrams: &Bound<'_, PyIterator>) -> PyResult<Self> {
let unigrams = unigrams let unigrams = unigrams
.map(|item| { .iter()?
let item = item?; .map(|result| {
let key = item.get_item(0)?.extract::<&str>()?; let item = result?;
let val = item.get_item(1)?.extract::<f64>()?; let key = item.get_item(0)?;
let key = key.extract::<&str>()?;
let val = item.get_item(1)?;
let val = val.extract::<f64>()?;
Ok((SmartString::from(key), val)) Ok((SmartString::from(key), val))
}) })
.collect::<Result<Vec<_>, PyErr>>()?; .collect::<Result<Vec<_>, PyErr>>()?;
let bigrams = bigrams let bigrams = bigrams
.iter()?
.map(|item| { .map(|item| {
let item = item?; let item = item?;
let key = item.get_item(0)?; let key = item.get_item(0)?;
let first = key.get_item(0)?.extract::<&str>()?; let first = key.get_item(0)?;
let second = key.get_item(1)?.extract::<&str>()?; let first = first.extract::<&str>()?;
let second = key.get_item(1)?;
let second = second.extract::<&str>()?;
let val = item.get_item(1)?.extract::<f64>()?; let val = item.get_item(1)?.extract::<f64>()?;
Ok(((SmartString::from(first), SmartString::from(second)), val)) Ok(((SmartString::from(first), SmartString::from(second)), val))
@ -99,11 +106,12 @@ impl Segmenter {
/// ///
/// Returns the relative probability for the given sentence in the the corpus represented by /// Returns the relative probability for the given sentence in the the corpus represented by
/// this `Segmenter`. Will return `None` iff given an empty iterator argument. /// this `Segmenter`. Will return `None` iff given an empty iterator argument.
fn score_sentence(&self, words: &PyIterator) -> PyResult<Option<f64>> { fn score_sentence(&self, words: &Bound<'_, PyIterator>) -> PyResult<Option<f64>> {
let words = words let words = words
.map(|s| s?.extract::<&str>()) .iter()?
.map(|result| result?.extract::<PyBackedStr>())
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(self.inner.score_sentence(words.into_iter())) Ok(self.inner.score_sentence(words.iter().map(|s| &**s)))
} }
} }