Skip to content

Commit

Permalink
Sync latest RLMeta SegmentTree implementation to TorchRL
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaomengy committed May 7, 2022
1 parent f0d7e1c commit 7ed9cf1
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 113 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
**/.cache/*
**/__pycache__/*
**/build/*
**/outputs/*
*.egg-info/*
*.eggs/*
*.so
38 changes: 38 additions & 0 deletions torchrl/csrc/numpy_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <pybind11/numpy.h>

#include <algorithm>
#include <cstdint>
#include <vector>

namespace py = pybind11;

namespace torchrl {
namespace utils {

template <typename T>
std::vector<int64_t> NumpyArrayShape(const py::array_t<T>& arr) {
const int64_t ndim = arr.ndim();
std::vector<int64_t> shape(ndim);
for (int64_t i = 0; i < ndim; ++i) {
shape[i] = static_cast<int64_t>(arr.shape(i));
}
return shape;
}

template <typename T_SRC, typename T_DST = T_SRC>
py::array_t<T_DST> NumpyEmptyLike(const py::array_t<T_SRC>& src) {
py::array_t<T_DST> dst(src.size());
const std::vector<int64_t> shape = NumpyArrayShape(src);
dst.resize(shape);
return dst;
}

} // namespace utils
} // namespace torchrl
132 changes: 77 additions & 55 deletions torchrl/csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,122 +19,144 @@ PYBIND11_MODULE(_torchrl, m) {
std::shared_ptr<torchrl::SumSegmentTree<float>>>(m,
"SumSegmentTree")
.def(py::init<int64_t>())
.def_property_readonly("size", &torchrl::SumSegmentTree<float>::size)
.def_property_readonly("capacity",
&torchrl::SumSegmentTree<float>::capacity)
.def_property_readonly("identity_element",
&torchrl::SumSegmentTree<float>::identity_element)
.def("__len__", &torchrl::SumSegmentTree<float>::size)
.def("size", &torchrl::SumSegmentTree<float>::size)
.def("capacity", &torchrl::SumSegmentTree<float>::capacity)
.def("identity_element",
&torchrl::SumSegmentTree<float>::identity_element)
.def("__getitem__", py::overload_cast<int64_t>(
&torchrl::SumSegmentTree<float>::At, py::const_))
.def("__getitem__", py::overload_cast<const py::array_t<int64_t> &>(
.def("__getitem__", py::overload_cast<const py::array_t<int64_t>&>(
&torchrl::SumSegmentTree<float>::At, py::const_))
.def("__getitem__", py::overload_cast<const torch::Tensor &>(
.def("__getitem__", py::overload_cast<const torch::Tensor&>(
&torchrl::SumSegmentTree<float>::At, py::const_))
.def("at", py::overload_cast<int64_t>(&torchrl::SumSegmentTree<float>::At,
py::const_))
.def("at", py::overload_cast<const py::array_t<int64_t> &>(
.def("at", py::overload_cast<const py::array_t<int64_t>&>(
&torchrl::SumSegmentTree<float>::At, py::const_))
.def("at", py::overload_cast<const torch::Tensor &>(
.def("at", py::overload_cast<const torch::Tensor&>(
&torchrl::SumSegmentTree<float>::At, py::const_))
.def("__setitem__", py::overload_cast<int64_t, const float &>(
.def("__setitem__", py::overload_cast<int64_t, const float&>(
&torchrl::SumSegmentTree<float>::Update))
.def("__setitem__",
py::overload_cast<const py::array_t<int64_t> &, const float &>(
py::overload_cast<const py::array_t<int64_t>&, const float&>(
&torchrl::SumSegmentTree<float>::Update))
.def("__setitem__", py::overload_cast<const py::array_t<int64_t> &,
const py::array_t<float> &>(
.def("__setitem__", py::overload_cast<const py::array_t<int64_t>&,
const py::array_t<float>&>(
&torchrl::SumSegmentTree<float>::Update))
.def("__setitem__", py::overload_cast<const torch::Tensor&, const float&>(
&torchrl::SumSegmentTree<float>::Update))
.def("__setitem__",
py::overload_cast<const torch::Tensor &, const float &>(
&torchrl::SumSegmentTree<float>::Update))
.def("__setitem__",
py::overload_cast<const torch::Tensor &, const torch::Tensor &>(
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&torchrl::SumSegmentTree<float>::Update))
.def("update", py::overload_cast<int64_t, const float &>(
.def("update", py::overload_cast<int64_t, const float&>(
&torchrl::SumSegmentTree<float>::Update))
.def("update",
py::overload_cast<const py::array_t<int64_t> &, const float &>(
py::overload_cast<const py::array_t<int64_t>&, const float&>(
&torchrl::SumSegmentTree<float>::Update))
.def("update", py::overload_cast<const py::array_t<int64_t> &,
const py::array_t<float> &>(
.def("update", py::overload_cast<const py::array_t<int64_t>&,
const py::array_t<float>&>(
&torchrl::SumSegmentTree<float>::Update))
.def("update", py::overload_cast<const torch::Tensor &, const float &>(
.def("update", py::overload_cast<const torch::Tensor&, const float&>(
&torchrl::SumSegmentTree<float>::Update))
.def("update",
py::overload_cast<const torch::Tensor &, const torch::Tensor &>(
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&torchrl::SumSegmentTree<float>::Update))
.def("query", py::overload_cast<int64_t, int64_t>(
&torchrl::SumSegmentTree<float>::Query, py::const_))
.def("query", py::overload_cast<const py::array_t<int64_t> &,
const py::array_t<int64_t> &>(
.def("query", py::overload_cast<const py::array_t<int64_t>&,
const py::array_t<int64_t>&>(
&torchrl::SumSegmentTree<float>::Query, py::const_))
.def("query",
py::overload_cast<const torch::Tensor &, const torch::Tensor &>(
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&torchrl::SumSegmentTree<float>::Query, py::const_))
.def("scan_lower_bound",
py::overload_cast<const float &>(
py::overload_cast<const float&>(
&torchrl::SumSegmentTree<float>::ScanLowerBound, py::const_))
.def("scan_lower_bound",
py::overload_cast<const py::array_t<float> &>(
py::overload_cast<const py::array_t<float>&>(
&torchrl::SumSegmentTree<float>::ScanLowerBound, py::const_))
.def("scan_lower_bound",
py::overload_cast<const torch::Tensor &>(
&torchrl::SumSegmentTree<float>::ScanLowerBound, py::const_));
py::overload_cast<const torch::Tensor&>(
&torchrl::SumSegmentTree<float>::ScanLowerBound, py::const_))
.def(py::pickle(
[](const SumSegmentTree<float>& s) {
return py::make_tuple(s.DumpValues());
},
[](const py::tuple& t) {
assert(t.size() == 1);
const py::array_t<T>& arr = t[0].cast<py::array_t<T>>();
SumSegmentTree<T> s(arr.size());
s.LoadValues(arr);
return s;
}));

py::class_<torchrl::MinSegmentTree<float>,
std::shared_ptr<torchrl::MinSegmentTree<float>>>(m,
"MinSegmentTree")
.def(py::init<int64_t>())
.def_property_readonly("size", &torchrl::MinSegmentTree<float>::size)
.def_property_readonly("capacity",
&torchrl::MinSegmentTree<float>::capacity)
.def_property_readonly("identity_element",
&torchrl::MinSegmentTree<float>::identity_element)
.def("__len__", &torchrl::MinSegmentTree<float>::size)
.def("size", &torchrl::MinSegmentTree<float>::size)
.def("capacity", &torchrl::MinSegmentTree<float>::capacity)
.def("identity_element",
&torchrl::MinSegmentTree<float>::identity_element)
.def("__getitem__", py::overload_cast<int64_t>(
&torchrl::MinSegmentTree<float>::At, py::const_))
.def("__getitem__", py::overload_cast<const py::array_t<int64_t> &>(
.def("__getitem__", py::overload_cast<const py::array_t<int64_t>&>(
&torchrl::MinSegmentTree<float>::At, py::const_))
.def("__getitem__", py::overload_cast<const torch::Tensor &>(
.def("__getitem__", py::overload_cast<const torch::Tensor&>(
&torchrl::MinSegmentTree<float>::At, py::const_))
.def("at", py::overload_cast<int64_t>(&torchrl::MinSegmentTree<float>::At,
py::const_))
.def("at", py::overload_cast<const py::array_t<int64_t> &>(
.def("at", py::overload_cast<const py::array_t<int64_t>&>(
&torchrl::MinSegmentTree<float>::At, py::const_))
.def("at", py::overload_cast<const torch::Tensor &>(
.def("at", py::overload_cast<const torch::Tensor&>(
&torchrl::MinSegmentTree<float>::At, py::const_))
.def("__setitem__", py::overload_cast<int64_t, const float &>(
.def("__setitem__", py::overload_cast<int64_t, const float&>(
&torchrl::MinSegmentTree<float>::Update))
.def("__setitem__",
py::overload_cast<const py::array_t<int64_t> &, const float &>(
py::overload_cast<const py::array_t<int64_t>&, const float&>(
&torchrl::MinSegmentTree<float>::Update))
.def("__setitem__", py::overload_cast<const py::array_t<int64_t> &,
const py::array_t<float> &>(
.def("__setitem__", py::overload_cast<const py::array_t<int64_t>&,
const py::array_t<float>&>(
&torchrl::MinSegmentTree<float>::Update))
.def("__setitem__", py::overload_cast<const torch::Tensor&, const float&>(
&torchrl::MinSegmentTree<float>::Update))
.def("__setitem__",
py::overload_cast<const torch::Tensor &, const float &>(
&torchrl::MinSegmentTree<float>::Update))
.def("__setitem__",
py::overload_cast<const torch::Tensor &, const torch::Tensor &>(
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&torchrl::MinSegmentTree<float>::Update))
.def("update", py::overload_cast<int64_t, const float &>(
.def("update", py::overload_cast<int64_t, const float&>(
&torchrl::MinSegmentTree<float>::Update))
.def("update",
py::overload_cast<const py::array_t<int64_t> &, const float &>(
py::overload_cast<const py::array_t<int64_t>&, const float&>(
&torchrl::MinSegmentTree<float>::Update))
.def("update", py::overload_cast<const py::array_t<int64_t> &,
const py::array_t<float> &>(
.def("update", py::overload_cast<const py::array_t<int64_t>&,
const py::array_t<float>&>(
&torchrl::MinSegmentTree<float>::Update))
.def("update", py::overload_cast<const torch::Tensor &, const float &>(
.def("update", py::overload_cast<const torch::Tensor&, const float&>(
&torchrl::MinSegmentTree<float>::Update))
.def("update",
py::overload_cast<const torch::Tensor &, const torch::Tensor &>(
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&torchrl::MinSegmentTree<float>::Update))
.def("query", py::overload_cast<int64_t, int64_t>(
&torchrl::MinSegmentTree<float>::Query, py::const_))
.def("query", py::overload_cast<const py::array_t<int64_t> &,
const py::array_t<int64_t> &>(
.def("query", py::overload_cast<const py::array_t<int64_t>&,
const py::array_t<int64_t>&>(
&torchrl::MinSegmentTree<float>::Query, py::const_))
.def("query",
py::overload_cast<const torch::Tensor &, const torch::Tensor &>(
&torchrl::MinSegmentTree<float>::Query, py::const_));
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&torchrl::MinSegmentTree<float>::Query, py::const_))
.def(py::pickle(
[](const MinSegmentTree<float>& s) {
return py::make_tuple(s.DumpValues());
},
[](const py::tuple& t) {
assert(t.size() == 1);
const py::array_t<T>& arr = t[0].cast<py::array_t<T>>();
MinSegmentTree<T> s(arr.size());
s.LoadValues(arr);
return s;
}));
}
Loading

0 comments on commit 7ed9cf1

Please sign in to comment.