diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index fc8e64e10a8..160c5cadfcc 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1953,21 +1953,21 @@ def rand(self, shape=None): if shape is None: shape = () return NonTensorData( - data=None, batch_size=(*shape, self.shape), device=self.device + data=None, batch_size=(*shape, *self.shape), device=self.device ) def zero(self, shape=None): if shape is None: shape = () return NonTensorData( - data=None, batch_size=(*shape, self.shape), device=self.device + data=None, batch_size=(*shape, *self.shape), device=self.device ) def one(self, shape=None): if shape is None: shape = () return NonTensorData( - data=None, batch_size=(*shape, self.shape), device=self.device + data=None, batch_size=(*shape, *self.shape), device=self.device ) def is_in(self, val: torch.Tensor) -> bool: