Skip to content

Commit

Permalink
bf
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed May 9, 2022
1 parent b447a61 commit 7082c14
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 146 deletions.
1 change: 0 additions & 1 deletion .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,3 @@ StatementMacros:
TabWidth: 8
UseTab: Never
...

148 changes: 5 additions & 143 deletions torchrl/csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,148 +15,10 @@
namespace py = pybind11;

PYBIND11_MODULE(_torchrl, m) {
py::class_<torchrl::SumSegmentTree<float>,
std::shared_ptr<torchrl::SumSegmentTree<float>>>(m,
"SumSegmentTree")
.def(py::init<int64_t>())
.def_property_readonly("size", &torchrl::SumSegmentTree<float>::size)
.def_property_readonly("capacity",
&torchrl::SumSegmentTree<float>::capacity)
.def_property_readonly("identity_element",
&torchrl::SumSegmentTree<float>::identity_element)
.def("__len__", &torchrl::SumSegmentTree<float>::size)
.def("__getitem__", py::overload_cast<int64_t>(
&torchrl::SumSegmentTree<float>::At, py::const_))
.def("__getitem__", py::overload_cast<const py::array_t<int64_t>&>(
&torchrl::SumSegmentTree<float>::At, py::const_))
.def("__getitem__", py::overload_cast<const torch::Tensor&>(
&torchrl::SumSegmentTree<float>::At, py::const_))
.def("at", py::overload_cast<int64_t>(&torchrl::SumSegmentTree<float>::At,
py::const_))
.def("at", py::overload_cast<const py::array_t<int64_t>&>(
&torchrl::SumSegmentTree<float>::At, py::const_))
.def("at", py::overload_cast<const torch::Tensor&>(
&torchrl::SumSegmentTree<float>::At, py::const_))
.def("__setitem__", py::overload_cast<int64_t, const float&>(
&torchrl::SumSegmentTree<float>::Update))
.def("__setitem__",
py::overload_cast<const py::array_t<int64_t>&, const float&>(
&torchrl::SumSegmentTree<float>::Update))
.def("__setitem__", py::overload_cast<const py::array_t<int64_t>&,
const py::array_t<float>&>(
&torchrl::SumSegmentTree<float>::Update))
.def("__setitem__", py::overload_cast<const torch::Tensor&, const float&>(
&torchrl::SumSegmentTree<float>::Update))
.def("__setitem__",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&torchrl::SumSegmentTree<float>::Update))
.def("update", py::overload_cast<int64_t, const float&>(
&torchrl::SumSegmentTree<float>::Update))
.def("update",
py::overload_cast<const py::array_t<int64_t>&, const float&>(
&torchrl::SumSegmentTree<float>::Update))
.def("update", py::overload_cast<const py::array_t<int64_t>&,
const py::array_t<float>&>(
&torchrl::SumSegmentTree<float>::Update))
.def("update", py::overload_cast<const torch::Tensor&, const float&>(
&torchrl::SumSegmentTree<float>::Update))
.def("update",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&torchrl::SumSegmentTree<float>::Update))
.def("query", py::overload_cast<int64_t, int64_t>(
&torchrl::SumSegmentTree<float>::Query, py::const_))
.def("query", py::overload_cast<const py::array_t<int64_t>&,
const py::array_t<int64_t>&>(
&torchrl::SumSegmentTree<float>::Query, py::const_))
.def("query",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&torchrl::SumSegmentTree<float>::Query, py::const_))
.def("scan_lower_bound",
py::overload_cast<const float&>(
&torchrl::SumSegmentTree<float>::ScanLowerBound, py::const_))
.def("scan_lower_bound",
py::overload_cast<const py::array_t<float>&>(
&torchrl::SumSegmentTree<float>::ScanLowerBound, py::const_))
.def("scan_lower_bound",
py::overload_cast<const torch::Tensor&>(
&torchrl::SumSegmentTree<float>::ScanLowerBound, py::const_))
.def(py::pickle(
[](const SumSegmentTree<float>& s) {
return py::make_tuple(s.DumpValues());
},
[](const py::tuple& t) {
assert(t.size() == 1);
const py::array_t<T>& arr = t[0].cast<py::array_t<T>>();
SumSegmentTree<T> s(arr.size());
s.LoadValues(arr);
return s;
}));
torchrl::DefineSumSegmentTree<float>("Fp32", m);
torchrl::DefineSumSegmentTree<double>("Fp64", m);

torchrl::DefineMinSegmentTree<float>("Fp32", m);
torchrl::DefineMinSegmentTree<double>("Fp64", m);

py::class_<torchrl::MinSegmentTree<float>,
std::shared_ptr<torchrl::MinSegmentTree<float>>>(m,
"MinSegmentTree")
.def(py::init<int64_t>())
.def_property_readonly("size", &torchrl::MinSegmentTree<float>::size)
.def_property_readonly("capacity",
&torchrl::MinSegmentTree<float>::capacity)
.def_property_readonly("identity_element",
&torchrl::MinSegmentTree<float>::identity_element)
.def("__len__", &torchrl::MinSegmentTree<float>::size)
.def("__getitem__", py::overload_cast<int64_t>(
&torchrl::MinSegmentTree<float>::At, py::const_))
.def("__getitem__", py::overload_cast<const py::array_t<int64_t>&>(
&torchrl::MinSegmentTree<float>::At, py::const_))
.def("__getitem__", py::overload_cast<const torch::Tensor&>(
&torchrl::MinSegmentTree<float>::At, py::const_))
.def("at", py::overload_cast<int64_t>(&torchrl::MinSegmentTree<float>::At,
py::const_))
.def("at", py::overload_cast<const py::array_t<int64_t>&>(
&torchrl::MinSegmentTree<float>::At, py::const_))
.def("at", py::overload_cast<const torch::Tensor&>(
&torchrl::MinSegmentTree<float>::At, py::const_))
.def("__setitem__", py::overload_cast<int64_t, const float&>(
&torchrl::MinSegmentTree<float>::Update))
.def("__setitem__",
py::overload_cast<const py::array_t<int64_t>&, const float&>(
&torchrl::MinSegmentTree<float>::Update))
.def("__setitem__", py::overload_cast<const py::array_t<int64_t>&,
const py::array_t<float>&>(
&torchrl::MinSegmentTree<float>::Update))
.def("__setitem__", py::overload_cast<const torch::Tensor&, const float&>(
&torchrl::MinSegmentTree<float>::Update))
.def("__setitem__",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&torchrl::MinSegmentTree<float>::Update))
.def("update", py::overload_cast<int64_t, const float&>(
&torchrl::MinSegmentTree<float>::Update))
.def("update",
py::overload_cast<const py::array_t<int64_t>&, const float&>(
&torchrl::MinSegmentTree<float>::Update))
.def("update", py::overload_cast<const py::array_t<int64_t>&,
const py::array_t<float>&>(
&torchrl::MinSegmentTree<float>::Update))
.def("update", py::overload_cast<const torch::Tensor&, const float&>(
&torchrl::MinSegmentTree<float>::Update))
.def("update",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&torchrl::MinSegmentTree<float>::Update))
.def("query", py::overload_cast<int64_t, int64_t>(
&torchrl::MinSegmentTree<float>::Query, py::const_))
.def("query", py::overload_cast<const py::array_t<int64_t>&,
const py::array_t<int64_t>&>(
&torchrl::MinSegmentTree<float>::Query, py::const_))
.def("query",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&torchrl::MinSegmentTree<float>::Query, py::const_))
.def(py::pickle(
[](const MinSegmentTree<float>& s) {
return py::make_tuple(s.DumpValues());
},
[](const py::tuple& t) {
assert(t.size() == 1);
const py::array_t<T>& arr = t[0].cast<py::array_t<T>>();
MinSegmentTree<T> s(arr.size());
s.LoadValues(arr);
return s;
}));
}
153 changes: 151 additions & 2 deletions torchrl/csrc/segment_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
#include <limits>
#include <vector>

#include "numpy_utils.h"
#include "torch_utils.h"
#include "torchrl/csrc/numpy_utils.h"
#include "torchrl/csrc/torch_utils.h"

namespace py = pybind11;

Expand Down Expand Up @@ -307,4 +307,153 @@ class MinSegmentTree final : public SegmentTree<T, MinOp<T>> {
: SegmentTree<T, MinOp<T>>(size, std::numeric_limits<T>::max()) {}
};

template <typename T>
void DefineSumSegmentTree(const std::string& type, py::module& m) {
const std::string pyclass = "SumSegmentTree" + type;
py::class_<SumSegmentTree<T>, std::shared_ptr<SumSegmentTree<T>>>(
m, pyclass.c_str())
.def(py::init<int64_t>())
.def_property_readonly("size", &SumSegmentTree<T>::size)
.def_property_readonly("capacity", &SumSegmentTree<T>::capacity)
.def_property_readonly("identity_element",
&SumSegmentTree<T>::identity_element)
.def("__len__", &SumSegmentTree<T>::size)
.def("__getitem__",
py::overload_cast<int64_t>(&SumSegmentTree<T>::At, py::const_))
.def("__getitem__", py::overload_cast<const py::array_t<int64_t>&>(
&SumSegmentTree<T>::At, py::const_))
.def("__getitem__", py::overload_cast<const torch::Tensor&>(
&SumSegmentTree<T>::At, py::const_))
.def("at", py::overload_cast<int64_t>(&SumSegmentTree<T>::At, py::const_))
.def("at", py::overload_cast<const py::array_t<int64_t>&>(
&SumSegmentTree<T>::At, py::const_))
.def("at", py::overload_cast<const torch::Tensor&>(&SumSegmentTree<T>::At,
py::const_))
.def("__setitem__",
py::overload_cast<int64_t, const T&>(&SumSegmentTree<T>::Update))
.def("__setitem__",
py::overload_cast<const py::array_t<int64_t>&, const T&>(
&SumSegmentTree<T>::Update))
.def(
"__setitem__",
py::overload_cast<const py::array_t<int64_t>&, const py::array_t<T>&>(
&SumSegmentTree<T>::Update))
.def("__setitem__", py::overload_cast<const torch::Tensor&, const T&>(
&SumSegmentTree<T>::Update))
.def("__setitem__",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&SumSegmentTree<T>::Update))
.def("update",
py::overload_cast<int64_t, const T&>(&SumSegmentTree<T>::Update))
.def("update", py::overload_cast<const py::array_t<int64_t>&, const T&>(
&SumSegmentTree<T>::Update))
.def(
"update",
py::overload_cast<const py::array_t<int64_t>&, const py::array_t<T>&>(
&SumSegmentTree<T>::Update))
.def("update", py::overload_cast<const torch::Tensor&, const T&>(
&SumSegmentTree<T>::Update))
.def("update",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&SumSegmentTree<T>::Update))
.def("query", py::overload_cast<int64_t, int64_t>(
&SumSegmentTree<T>::Query, py::const_))
.def("query", py::overload_cast<const py::array_t<int64_t>&,
const py::array_t<int64_t>&>(
&SumSegmentTree<T>::Query, py::const_))
.def("query",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&SumSegmentTree<T>::Query, py::const_))
.def("scan_lower_bound",
py::overload_cast<const T&>(&SumSegmentTree<T>::ScanLowerBound,
py::const_))
.def("scan_lower_bound",
py::overload_cast<const py::array_t<T>&>(
&SumSegmentTree<T>::ScanLowerBound, py::const_))
.def("scan_lower_bound",
py::overload_cast<const torch::Tensor&>(
&SumSegmentTree<T>::ScanLowerBound, py::const_))
.def(py::pickle(
[](const SumSegmentTree<T>& s) {
return py::make_tuple(s.DumpValues());
},
[](const py::tuple& t) {
assert(t.size() == 1);
const py::array_t<T>& arr = t[0].cast<py::array_t<T>>();
SumSegmentTree<T> s(arr.size());
s.LoadValues(arr);
return s;
}));
}

template <typename T>
void DefineMinSegmentTree(const std::string& type, py::module& m) {
const std::string pyclass = "MinSegmentTree" + type;
py::class_<MinSegmentTree<T>, std::shared_ptr<MinSegmentTree<T>>>(
m, pyclass.c_str())
.def(py::init<int64_t>())
.def_property_readonly("size", &MinSegmentTree<T>::size)
.def_property_readonly("capacity", &MinSegmentTree<T>::capacity)
.def_property_readonly("identity_element",
&MinSegmentTree<T>::identity_element)
.def("__len__", &MinSegmentTree<T>::size)
.def("__getitem__",
py::overload_cast<int64_t>(&MinSegmentTree<T>::At, py::const_))
.def("__getitem__", py::overload_cast<const py::array_t<int64_t>&>(
&MinSegmentTree<T>::At, py::const_))
.def("__getitem__", py::overload_cast<const torch::Tensor&>(
&MinSegmentTree<T>::At, py::const_))
.def("at", py::overload_cast<int64_t>(&MinSegmentTree<T>::At, py::const_))
.def("at", py::overload_cast<const py::array_t<int64_t>&>(
&MinSegmentTree<T>::At, py::const_))
.def("at", py::overload_cast<const torch::Tensor&>(&MinSegmentTree<T>::At,
py::const_))
.def("__setitem__",
py::overload_cast<int64_t, const T&>(&MinSegmentTree<T>::Update))
.def("__setitem__",
py::overload_cast<const py::array_t<int64_t>&, const T&>(
&MinSegmentTree<T>::Update))
.def(
"__setitem__",
py::overload_cast<const py::array_t<int64_t>&, const py::array_t<T>&>(
&MinSegmentTree<T>::Update))
.def("__setitem__", py::overload_cast<const torch::Tensor&, const T&>(
&MinSegmentTree<T>::Update))
.def("__setitem__",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&MinSegmentTree<T>::Update))
.def("update",
py::overload_cast<int64_t, const T&>(&MinSegmentTree<T>::Update))
.def("update", py::overload_cast<const py::array_t<int64_t>&, const T&>(
&MinSegmentTree<T>::Update))
.def(
"update",
py::overload_cast<const py::array_t<int64_t>&, const py::array_t<T>&>(
&MinSegmentTree<T>::Update))
.def("update", py::overload_cast<const torch::Tensor&, const T&>(
&MinSegmentTree<T>::Update))
.def("update",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&MinSegmentTree<T>::Update))
.def("query", py::overload_cast<int64_t, int64_t>(
&MinSegmentTree<T>::Query, py::const_))
.def("query", py::overload_cast<const py::array_t<int64_t>&,
const py::array_t<int64_t>&>(
&MinSegmentTree<T>::Query, py::const_))
.def("query",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&MinSegmentTree<T>::Query, py::const_))
.def(py::pickle(
[](const MinSegmentTree<T>& s) {
return py::make_tuple(s.DumpValues());
},
[](const py::tuple& t) {
assert(t.size() == 1);
const py::array_t<T>& arr = t[0].cast<py::array_t<T>>();
MinSegmentTree<T> s(arr.size());
s.LoadValues(arr);
return s;
}));
}

} // namespace torchrl

0 comments on commit 7082c14

Please sign in to comment.