Skip to content

Commit

Permalink
[BugFix] Fix sampling of values from NonTensorSpec (pytorch#2169)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 22, 2024
1 parent eaa3dd8 commit a93063b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
8 changes: 8 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3404,6 +3404,14 @@ def test_project(self, shape, device, spectype, rand_shape, n=5):
assert (sp != s).all()


class TestNonTensorSpec:
def test_sample(self):
nts = NonTensorSpec(shape=(3, 4))
assert nts.one((2,)).shape == (2, 3, 4)
assert nts.rand((2,)).shape == (2, 3, 4)
assert nts.zero((2,)).shape == (2, 3, 4)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
24 changes: 18 additions & 6 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1949,14 +1949,26 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec:
def clone(self) -> NonTensorSpec:
return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype)

def rand(self, shape):
return NonTensorData(data=None, batch_size=self.shape, device=self.device)
def rand(self, shape=None):
if shape is None:
shape = ()
return NonTensorData(
data=None, batch_size=(*shape, self.shape), device=self.device
)

def zero(self, batch_size):
return NonTensorData(data=None, batch_size=self.shape, device=self.device)
def zero(self, shape=None):
if shape is None:
shape = ()
return NonTensorData(
data=None, batch_size=(*shape, self.shape), device=self.device
)

def one(self, batch_size):
return NonTensorData(data=None, batch_size=self.shape, device=self.device)
def one(self, shape=None):
if shape is None:
shape = ()
return NonTensorData(
data=None, batch_size=(*shape, self.shape), device=self.device
)

def is_in(self, val: torch.Tensor) -> bool:
shape = torch.broadcast_shapes(self.shape, val.shape)
Expand Down

0 comments on commit a93063b

Please sign in to comment.