Skip to content

Commit

Permalink
[Feature]: faster meta-tensor API for TensorDict (pytorch#272)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 14, 2022
1 parent 4c81c6f commit e5bea04
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 119 deletions.
2 changes: 1 addition & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_tensordict_set(device):
torch.randn(4, 5, 1, 2, dtype=torch.double, device=device),
inplace=False,
)
assert td._tensordict_meta["key1"].shape == td._tensordict["key1"].shape
assert td._dict_meta["key1"].shape == td._tensordict["key1"].shape


def test_pad():
Expand Down
1 change: 1 addition & 0 deletions torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(self, fun=lambda x: x):

def __missing__(self, key):
value = self.fun(key)
self[key] = value
return value


Expand Down
2 changes: 2 additions & 0 deletions torchrl/data/tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class MemmapTensor(object):
"""

requires_grad = False

def __init__(
self,
elem: Union[torch.Tensor, MemmapTensor],
Expand Down
14 changes: 8 additions & 6 deletions torchrl/data/tensordict/metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
self,
*shape: Union[int, torch.Tensor, "MemmapTensor"],
device: Optional[DEVICE_TYPING] = "cpu",
dtype: torch.dtype = torch.get_default_dtype(),
dtype: torch.dtype = None,
requires_grad: bool = False,
_is_shared: Optional[bool] = None,
_is_memmap: Optional[bool] = None,
Expand All @@ -81,8 +81,8 @@ def __init__(
_is_memmap = isinstance(tensor, MemmapTensor)
device = tensor.device if not tensor.is_meta else device
if _is_tensordict is None:
_is_tensordict = not isinstance(tensor, (MemmapTensor, torch.Tensor))
if isinstance(tensor, (MemmapTensor, torch.Tensor)):
_is_tensordict = not _is_memmap and not isinstance(tensor, torch.Tensor)
if not _is_tensordict:
dtype = tensor.dtype
else:
dtype = None
Expand All @@ -97,11 +97,11 @@ def __init__(
if not isinstance(shape, torch.Size):
shape = torch.Size(shape)
self.shape = shape
self.device = torch.device(device)
self.dtype = dtype
self.device = device
self.dtype = dtype if dtype is not None else torch.get_default_dtype()
self.requires_grad = requires_grad
self._ndim = len(shape)
self._numel = np.prod(shape)
self._numel = None
self._is_shared = bool(_is_shared)
self._is_memmap = bool(_is_memmap)
self._is_tensordict = bool(_is_tensordict)
Expand Down Expand Up @@ -155,6 +155,8 @@ def is_tensordict(self) -> bool:
return self._is_tensordict

def numel(self) -> int:
if self._numel is None:
self._numel = np.prod(self.shape)
return self._numel

def ndimension(self) -> int:
Expand Down
Loading

0 comments on commit e5bea04

Please sign in to comment.