diff --git a/test/test_specs.py b/test/test_specs.py index 058144c1a94..c8b780e49a1 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -3404,6 +3404,14 @@ def test_project(self, shape, device, spectype, rand_shape, n=5): assert (sp != s).all() +class TestNonTensorSpec: + def test_sample(self): + nts = NonTensorSpec(shape=(3, 4)) + assert nts.one((2,)).shape == (2, 3, 4) + assert nts.rand((2,)).shape == (2, 3, 4) + assert nts.zero((2,)).shape == (2, 3, 4) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 50999eb5f36..fc8e64e10a8 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1949,14 +1949,26 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec: def clone(self) -> NonTensorSpec: return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype) - def rand(self, shape): - return NonTensorData(data=None, batch_size=self.shape, device=self.device) + def rand(self, shape=None): + if shape is None: + shape = () + return NonTensorData( + data=None, batch_size=(*shape, self.shape), device=self.device + ) - def zero(self, batch_size): - return NonTensorData(data=None, batch_size=self.shape, device=self.device) + def zero(self, shape=None): + if shape is None: + shape = () + return NonTensorData( + data=None, batch_size=(*shape, self.shape), device=self.device + ) - def one(self, batch_size): - return NonTensorData(data=None, batch_size=self.shape, device=self.device) + def one(self, shape=None): + if shape is None: + shape = () + return NonTensorData( + data=None, batch_size=(*shape, self.shape), device=self.device + ) def is_in(self, val: torch.Tensor) -> bool: shape = torch.broadcast_shapes(self.shape, val.shape)