diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 1322f627ea5..99741a8add8 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1547,6 +1547,15 @@ def __init__( if cat_dim > 0: raise ValueError(self._CAT_DIM_ERR) self.cat_dim = cat_dim + for in_key in self.in_keys: + buffer_name = f"_cat_buffers_{in_key}" + setattr( + self, + buffer_name, + torch.nn.parameter.UninitializedBuffer( + device=torch.device("cpu"), dtype=torch.get_default_dtype() + ), + ) def reset(self, tensordict: TensorDictBase) -> TensorDictBase: """Resets _buffers.""" @@ -1554,12 +1563,10 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: if len(tensordict.batch_size) < 1 or tensordict.batch_size[0] == 1: for in_key in self.in_keys: buffer_name = f"_cat_buffers_{in_key}" - try: - buffer = getattr(self, buffer_name) - buffer.fill_(0.0) - except AttributeError: - # we'll instantiate later, when needed - pass + buffer = getattr(self, buffer_name) + if isinstance(buffer, torch.nn.parameter.UninitializedBuffer): + continue + buffer.fill_(0.0) # Batched environments else: @@ -1573,12 +1580,10 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: ) for in_key in self.in_keys: buffer_name = f"_cat_buffers_{in_key}" - try: - buffer = getattr(self, buffer_name) - buffer[_reset] = 0.0 - except AttributeError: - # we'll instantiate later, when needed - pass + buffer = getattr(self, buffer_name) + if isinstance(buffer, torch.nn.parameter.UninitializedBuffer): + continue + buffer[_reset] = 0.0 return tensordict @@ -1587,15 +1592,9 @@ def _make_missing_buffer(self, data, buffer_name): d = shape[self.cat_dim] shape[self.cat_dim] = d * self.N shape = torch.Size(shape) - self.register_buffer( - buffer_name, - torch.zeros( - shape, - dtype=data.dtype, - device=data.device, - ), - ) - buffer = getattr(self, buffer_name) + getattr(self, buffer_name).materialize(shape) + buffer = getattr(self, buffer_name).to(data.dtype).to(data.device).zero_() + setattr(self, buffer_name, buffer) return buffer def _call(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -1605,12 +1604,12 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: buffer_name = f"_cat_buffers_{in_key}" data = tensordict[in_key] d = data.size(self.cat_dim) - try: - buffer = getattr(self, buffer_name) + buffer = getattr(self, buffer_name) + if isinstance(buffer, torch.nn.parameter.UninitializedBuffer): + buffer = self._make_missing_buffer(data, buffer_name) + else: # shift obs 1 position to the right buffer.copy_(torch.roll(buffer, shifts=-d, dims=self.cat_dim)) - except AttributeError: - buffer = self._make_missing_buffer(data, buffer_name) # add new obs idx = self.cat_dim if idx < 0: