Skip to content

Commit

Permalink
Update to pyo3 0.21
Browse files Browse the repository at this point in the history
  • Loading branch information
Dr-Emann committed Apr 7, 2024
1 parent 4ae3eae commit 0cd73f1
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 61 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ name = "rbloom"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.20", features = [
pyo3 = { version = "0.21", features = [
"extension-module",
"abi3-py37",
] } # stable ABI with minimum Python version 3.7
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["maturin>=0.14,<0.15"]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"

[project]
Expand Down
134 changes: 75 additions & 59 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use bitline::BitLine;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::sync::GILOnceCell;
use pyo3::types::PyType;
use pyo3::{basic::CompareOp, prelude::*, types::PyBytes, types::PyTuple};
use std::fs::File;
Expand All @@ -12,7 +13,7 @@ struct Bloom {
filter: BitLine,
k: u64, // Number of hash functions (implemented via a LCG that uses
// the original hash as a seed)
hash_func: Option<PyObject>,
hash_func: Option<Py<PyAny>>,
}

#[pymethods]
Expand All @@ -21,14 +22,9 @@ impl Bloom {
fn new(
expected_items: u64,
false_positive_rate: f64,
hash_func: Option<&PyAny>,
hash_func: Option<Bound<'_, PyAny>>,
) -> PyResult<Self> {
// Check the inputs
if let Some(hash_func) = hash_func {
if !hash_func.is_callable() {
return Err(PyTypeError::new_err("hash_func must be callable"));
}
}
if false_positive_rate <= 0.0 || false_positive_rate >= 1.0 {
return Err(PyValueError::new_err(
"false_positive_rate must be between 0 and 1",
Expand All @@ -39,19 +35,21 @@ impl Bloom {
"expected_items must be greater than 0",
));
}
let hash_func = match hash_func {
Some(hash_func) if !hash_func.is(builtin_hash_func(hash_func.py())?) => {
if !hash_func.is_callable() {
return Err(PyTypeError::new_err("hash_func must be callable"));
}
Some(hash_func.unbind())
}
_ => None,
};

// Calculate the parameters for the filter
let size_in_bits =
-1.0 * (expected_items as f64) * false_positive_rate.ln() / 2.0f64.ln().powi(2);
let k = (size_in_bits / expected_items as f64) * 2.0f64.ln();

let hash_func = match hash_func {
// if __builtins__.hash was passed, use None instead
Some(hash_func) if !hash_func.is(get_builtin_hash_func(hash_func.py())?) => {
Some(hash_func.to_object(hash_func.py()))
}
_ => None,
};
// Create the filter
Ok(Bloom {
filter: BitLine::new(size_in_bits as u64)?,
Expand All @@ -68,10 +66,10 @@ impl Bloom {

/// Retrieve the hash_func given to __init__
#[getter]
fn hash_func<'a>(&'a self, py: Python<'a>) -> PyResult<&'a PyAny> {
fn hash_func<'py>(&self, py: Python<'py>) -> PyResult<&Bound<'py, PyAny>> {
match self.hash_func.as_ref() {
Some(hash_func) => Ok(hash_func.as_ref(py)),
None => get_builtin_hash_func(py),
Some(hash_func) => Ok(hash_func.bind(py)),
None => builtin_hash_func(py),
}
}

Expand All @@ -84,7 +82,7 @@ impl Bloom {
}

#[pyo3(signature = (o, /))]
fn add(&mut self, o: &PyAny) -> PyResult<()> {
fn add(&mut self, o: &Bound<'_, PyAny>) -> PyResult<()> {
let hash = hash(o, &self.hash_func)?;
for index in lcg::generate_indexes(hash, self.k, self.filter.len()) {
self.filter.set(index);
Expand All @@ -98,7 +96,7 @@ impl Bloom {
/// contain all items in this set), but it will not return a false negative:
/// If this returns false, this set contains an element which is not in other
#[pyo3(signature = (other, /))]
fn issubset(&self, other: &PyAny) -> PyResult<bool> {
fn issubset(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
self.with_other_as_bloom(other, |other_bloom| {
Ok(self.filter.is_subset(&other_bloom.filter))
})
Expand All @@ -110,13 +108,13 @@ impl Bloom {
/// contain all items in other), but it will not return a false negative:
/// If this returns false, other contains an element which is not in self
#[pyo3(signature = (other, /))]
fn issuperset(&self, other: &PyAny) -> PyResult<bool> {
fn issuperset(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
self.with_other_as_bloom(other, |other_bloom| {
Ok(other_bloom.filter.is_subset(&self.filter))
})
}

fn __contains__(&self, o: &PyAny) -> PyResult<bool> {
fn __contains__(&self, o: &Bound<'_, PyAny>) -> PyResult<bool> {
let hash = hash(o, &self.hash_func)?;
for index in lcg::generate_indexes(hash, self.k, self.filter.len()) {
if !self.filter.get(index) {
Expand All @@ -128,15 +126,15 @@ impl Bloom {

/// Return a new set with elements from the set and all others.
#[pyo3(signature = (*others))]
fn union(&self, others: &PyTuple) -> PyResult<Self> {
fn union(&self, others: &Bound<'_, PyTuple>) -> PyResult<Self> {
let mut result = self.clone();
result.update(others)?;
Ok(result)
}

/// Return a new set with elements common to the set and all others.
#[pyo3(signature = (*others))]
fn intersection(&self, others: &PyTuple) -> PyResult<Self> {
fn intersection(&self, others: &Bound<'_, PyTuple>) -> PyResult<Self> {
let mut result = self.clone();
result.intersection_update(others)?;
Ok(result)
Expand Down Expand Up @@ -173,7 +171,7 @@ impl Bloom {
}

#[pyo3(signature = (*others))]
fn update(&mut self, others: &PyTuple) -> PyResult<()> {
fn update(&mut self, others: &Bound<'_, PyTuple>) -> PyResult<()> {
for other in others.iter() {
// If the other object is a Bloom, use the bitwise union
if let Ok(other) = other.extract::<PyRef<Bloom>>() {
Expand All @@ -182,15 +180,15 @@ impl Bloom {
// Otherwise, iterate over the other object and add each item
else {
for obj in other.iter()? {
self.add(obj?)?;
self.add(&obj?)?;
}
}
}
Ok(())
}

#[pyo3(signature = (*others))]
fn intersection_update(&mut self, others: &PyTuple) -> PyResult<()> {
fn intersection_update(&mut self, others: &Bound<'_, PyTuple>) -> PyResult<()> {
// Lazily allocated temp bitset
let mut temp: Option<Self> = None;
for other in others.iter() {
Expand All @@ -203,7 +201,7 @@ impl Bloom {
let temp = temp.get_or_insert_with(|| self.clone());
temp.clear();
for obj in other.iter()? {
temp.add(obj?)?;
temp.add(&obj?)?;
}
self.__iand__(temp)?;
}
Expand Down Expand Up @@ -246,17 +244,21 @@ impl Bloom {
}

#[classattr]
const __hash__: Option<PyObject> = None;
const __hash__: Option<Py<PyAny>> = None;

/// Load from a file, see "Persistence" section in the README
#[classmethod]
fn load(_cls: &PyType, filepath: &str, hash_func: &PyAny) -> PyResult<Bloom> {
fn load(
_cls: &Bound<'_, PyType>,
filepath: &str,
hash_func: &Bound<'_, PyAny>,
) -> PyResult<Bloom> {
// check that the hash_func is callable
if !hash_func.is_callable() {
return Err(PyTypeError::new_err("hash_func must be callable"));
}
// check that the hash_func isn't the built-in hash function
if hash_func.is(get_builtin_hash_func(hash_func.py())?) {
if hash_func.is(builtin_hash_func(hash_func.py())?) {
return Err(PyValueError::new_err(
"Cannot load a bloom filter that uses the built-in hash function",
));
Expand All @@ -280,13 +282,17 @@ impl Bloom {

/// Load from a bytes(), see "Persistence" section in the README
#[classmethod]
fn load_bytes(_cls: &PyType, bytes: &[u8], hash_func: &PyAny) -> PyResult<Bloom> {
fn load_bytes(
_cls: &Bound<'_, PyType>,
bytes: &[u8],
hash_func: &Bound<'_, PyAny>,
) -> PyResult<Bloom> {
// check that the hash_func is callable
if !hash_func.is_callable() {
return Err(PyTypeError::new_err("hash_func must be callable"));
}
// check that the hash_func isn't the built-in hash function
if hash_func.is(get_builtin_hash_func(hash_func.py())?) {
if hash_func.is(builtin_hash_func(hash_func.py())?) {
return Err(PyValueError::new_err(
"Cannot load a bloom filter that uses the built-in hash function",
));
Expand Down Expand Up @@ -321,22 +327,27 @@ impl Bloom {
}

/// Save to a byte(), see "Persistence" section in the README
fn save_bytes(&self, py: Python<'_>) -> PyResult<PyObject> {
fn save_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
const K_SIZE: usize = mem::size_of::<u64>();
if self.hash_func.is_none() {
return Err(PyValueError::new_err(
"Cannot save a bloom filter that uses the built-in hash function",
));
}

let serialized: Vec<u8> = [&self.k.to_le_bytes(), self.filter.bits() as &[u8]].concat();

Ok(PyBytes::new(py, &serialized).into())
debug_assert_eq!(K_SIZE, self.k.to_le_bytes().len());
let len = K_SIZE + self.filter.bits().len();
PyBytes::new_bound_with(py, len, |data| {
data[..K_SIZE].copy_from_slice(&self.k.to_le_bytes());
data[K_SIZE..].copy_from_slice(self.filter.bits());
Ok(())
})
}
}

// Non-python methods
impl Bloom {
fn hash_fn_clone(&self, py: Python<'_>) -> Option<PyObject> {
fn hash_fn_clone(&self, py: Python<'_>) -> Option<Py<PyAny>> {
self.hash_func.as_ref().map(|f| f.clone_ref(py))
}

Expand All @@ -351,7 +362,7 @@ impl Bloom {
/// Extract other as a bloom, or iterate other, and add all items to a temporary bloom
fn with_other_as_bloom<O>(
&self,
other: &PyAny,
other: &Bound<'_, PyAny>,
f: impl FnOnce(&Bloom) -> PyResult<O>,
) -> PyResult<O> {
match other.extract::<PyRef<Bloom>>() {
Expand All @@ -362,15 +373,15 @@ impl Bloom {
Err(_) => {
let mut other_bloom = self.zeroed_clone(other.py());
for obj in other.iter()? {
other_bloom.add(obj?)?;
other_bloom.add(&obj?)?;
}
f(&other_bloom)
}
}
}
}

/// This is a primitive BitVec-like structure that uses a Box<[u8]> as
/// This is a primitive BitVec-like structure that uses a `Box<[u8]>` as
/// the backing store; it exists here to avoid the need for a dependency
/// on bitvec and to act as a container around all the bit manipulation.
/// Indexing is done using u64 to avoid address space issues on 32-bit
Expand Down Expand Up @@ -585,11 +596,12 @@ mod lcg {
}
}

fn hash(o: &PyAny, hash_func: &Option<PyObject>) -> PyResult<i128> {
fn hash(o: &Bound<'_, PyAny>, hash_func: &Option<Py<PyAny>>) -> PyResult<i128> {
match hash_func {
Some(hash_func) => {
let hash = hash_func.call1(o.py(), (o,))?;
Ok(hash.extract(o.py())?)
let hash_func = hash_func.bind(o.py());
let hash = hash_func.call1((o,))?;
Ok(hash.extract()?)
}
None => Ok(o.hash()? as i128),
}
Expand All @@ -603,28 +615,32 @@ fn check_compatible(a: &Bloom, b: &Bloom) -> PyResult<()> {
}

// now only the hash function can be different
let same_hash_fn = match (&a.hash_func, &b.hash_func) {
(Some(lhs), Some(rhs)) => lhs.is(rhs),
(&None, &None) => true,
_ => false,
};

if same_hash_fn {
Ok(())
} else {
Err(PyValueError::new_err(
"Bloom filters must have the same hash function",
))
match (&a.hash_func, &b.hash_func) {
(Some(lhs), Some(rhs)) if lhs.is(rhs) => {}
(&None, &None) => {}
_ => {
return Err(PyValueError::new_err(
"Bloom filters must have the same hash function",
))
}
}

Ok(())
}

fn get_builtin_hash_func(py: Python<'_>) -> PyResult<&'_ PyAny> {
let builtins = PyModule::import(py, "builtins")?;
builtins.getattr("hash")
fn builtin_hash_func(py: Python<'_>) -> PyResult<&Bound<'_, PyAny>> {
static HASH_FUNC: GILOnceCell<Py<PyAny>> = GILOnceCell::new();

let res = HASH_FUNC.get_or_try_init(py, || -> PyResult<_> {
let builtins = PyModule::import_bound(py, "builtins")?;
Ok(builtins.getattr("hash")?.unbind())
})?;

Ok(res.bind(py))
}

#[pymodule]
fn rbloom(_py: Python, m: &PyModule) -> PyResult<()> {
fn rbloom(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Bloom>()?;
Ok(())
}

0 comments on commit 0cd73f1

Please sign in to comment.