From 78b4717831361654678ea56847445d68143635af Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 26 May 2024 01:13:35 +0100 Subject: [PATCH] [BugFix] Fix sampling in NonTensorSpec (#2172) --- torchrl/data/tensor_specs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index fc8e64e10a8..160c5cadfcc 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1953,21 +1953,21 @@ def rand(self, shape=None): if shape is None: shape = () return NonTensorData( - data=None, batch_size=(*shape, self.shape), device=self.device + data=None, batch_size=(*shape, *self.shape), device=self.device ) def zero(self, shape=None): if shape is None: shape = () return NonTensorData( - data=None, batch_size=(*shape, self.shape), device=self.device + data=None, batch_size=(*shape, *self.shape), device=self.device ) def one(self, shape=None): if shape is None: shape = () return NonTensorData( - data=None, batch_size=(*shape, self.shape), device=self.device + data=None, batch_size=(*shape, *self.shape), device=self.device ) def is_in(self, val: torch.Tensor) -> bool: