Skip to content

Commit

Permalink
[BugFix] Loading state_dict on uninitialized CatFrames (pytorch#855)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 22, 2023
1 parent aa971a7 commit 8efbb26
Showing 1 changed file with 24 additions and 25 deletions.
49 changes: 24 additions & 25 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,19 +1547,26 @@ def __init__(
if cat_dim > 0:
raise ValueError(self._CAT_DIM_ERR)
self.cat_dim = cat_dim
for in_key in self.in_keys:
buffer_name = f"_cat_buffers_{in_key}"
setattr(
self,
buffer_name,
torch.nn.parameter.UninitializedBuffer(
device=torch.device("cpu"), dtype=torch.get_default_dtype()
),
)

def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Resets _buffers."""
# Non-batched environments
if len(tensordict.batch_size) < 1 or tensordict.batch_size[0] == 1:
for in_key in self.in_keys:
buffer_name = f"_cat_buffers_{in_key}"
try:
buffer = getattr(self, buffer_name)
buffer.fill_(0.0)
except AttributeError:
# we'll instantiate later, when needed
pass
buffer = getattr(self, buffer_name)
if isinstance(buffer, torch.nn.parameter.UninitializedBuffer):
continue
buffer.fill_(0.0)

# Batched environments
else:
Expand All @@ -1573,12 +1580,10 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
)
for in_key in self.in_keys:
buffer_name = f"_cat_buffers_{in_key}"
try:
buffer = getattr(self, buffer_name)
buffer[_reset] = 0.0
except AttributeError:
# we'll instantiate later, when needed
pass
buffer = getattr(self, buffer_name)
if isinstance(buffer, torch.nn.parameter.UninitializedBuffer):
continue
buffer[_reset] = 0.0

return tensordict

Expand All @@ -1587,15 +1592,9 @@ def _make_missing_buffer(self, data, buffer_name):
d = shape[self.cat_dim]
shape[self.cat_dim] = d * self.N
shape = torch.Size(shape)
self.register_buffer(
buffer_name,
torch.zeros(
shape,
dtype=data.dtype,
device=data.device,
),
)
buffer = getattr(self, buffer_name)
getattr(self, buffer_name).materialize(shape)
buffer = getattr(self, buffer_name).to(data.dtype).to(data.device).zero_()
setattr(self, buffer_name, buffer)
return buffer

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
Expand All @@ -1605,12 +1604,12 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
buffer_name = f"_cat_buffers_{in_key}"
data = tensordict[in_key]
d = data.size(self.cat_dim)
try:
buffer = getattr(self, buffer_name)
buffer = getattr(self, buffer_name)
if isinstance(buffer, torch.nn.parameter.UninitializedBuffer):
buffer = self._make_missing_buffer(data, buffer_name)
else:
# shift obs 1 position to the right
buffer.copy_(torch.roll(buffer, shifts=-d, dims=self.cat_dim))
except AttributeError:
buffer = self._make_missing_buffer(data, buffer_name)
# add new obs
idx = self.cat_dim
if idx < 0:
Expand Down

0 comments on commit 8efbb26

Please sign in to comment.