diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 1f31db01ec7..5404beb0ec0 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4257,6 +4257,14 @@ def __new__(cls, *args, **kwargs): cls._locked = False return super().__new__(cls) + @property + def batch_size(self): + return self._shape + + @batch_size.setter + def batch_size(self, value: torch.Size): + self._shape = value + @property def shape(self): return self._shape