Skip to content

Commit

Permalink
Fixing imports
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed May 9, 2022
1 parent 7082c14 commit c4acfcf
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
1 change: 0 additions & 1 deletion torchrl/csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,4 @@ PYBIND11_MODULE(_torchrl, m) {

torchrl::DefineMinSegmentTree<float>("Fp32", m);
torchrl::DefineMinSegmentTree<double>("Fp64", m);

}
20 changes: 17 additions & 3 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
import torch
from torch import Tensor

from torchrl._torchrl import MinSegmentTree, SumSegmentTree
from torchrl._torchrl import (
MinSegmentTreeFp32,
MinSegmentTreeFp64,
SumSegmentTreeFp32,
SumSegmentTreeFp64,
)
from torchrl.data.replay_buffers.utils import (
cat_fields_to_device,
to_numpy,
Expand Down Expand Up @@ -302,6 +307,7 @@ def __init__(
alpha: float,
beta: float,
eps: float = 1e-8,
dtype: torch.dtype = torch.float,
collate_fn=None,
pin_memory: bool = False,
prefetch: Optional[int] = None,
Expand All @@ -319,8 +325,16 @@ def __init__(
self._alpha = alpha
self._beta = beta
self._eps = eps
self._sum_tree = SumSegmentTree(size)
self._min_tree = MinSegmentTree(size)
if dtype in (torch.float, torch.FloatType, torch.float32):
self._sum_tree = SumSegmentTreeFp32(size)
self._min_tree = MinSegmentTreeFp32(size)
elif dtype in (torch.double, torch.DoubleTensor, torch.float64):
self._sum_tree = SumSegmentTreeFp64(size)
self._min_tree = MinSegmentTreeFp64(size)
else:
raise NotImplementedError(
f"dtype {dtype} not supported by PrioritizedReplayBuffer"
)
self._max_priority = 1.0

@pin_memory_output
Expand Down

0 comments on commit c4acfcf

Please sign in to comment.