From 2e82cab191132d6272d4b598a539bf88c5e961a0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 24 Nov 2024 09:13:37 +0100 Subject: [PATCH] [Feature] Composite.batch_size ghstack-source-id: 621884a559a71e80a4be36c7ba984fd08be47952 Pull Request resolved: https://github.com/pytorch/rl/pull/2597 --- torchrl/data/tensor_specs.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 1f31db01ec7..5404beb0ec0 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4257,6 +4257,14 @@ def __new__(cls, *args, **kwargs): cls._locked = False return super().__new__(cls) + @property + def batch_size(self): + return self._shape + + @batch_size.setter + def batch_size(self, value: torch.Size): + self._shape = value + @property def shape(self): return self._shape