Skip to content

Commit

Permalink
[BugFix] Fix offline CatFrames for pixels (#1964)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 25, 2024
1 parent 931f70a commit 3df6d9f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
26 changes: 14 additions & 12 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,17 +967,21 @@ def test_transform_no_env(self, device, d, batch_size, dim, N):

@pytest.mark.skipif(not _has_gym, reason="gym required for this test")
@pytest.mark.parametrize("padding", ["zeros", "constant", "same"])
def test_tranform_offline_against_online(self, padding):
@pytest.mark.parametrize("envtype", ["gym", "conv"])
def test_tranform_offline_against_online(self, padding, envtype):
torch.manual_seed(0)
key = "observation" if envtype == "gym" else "pixels"
env = SerialEnv(
3,
lambda: TransformedEnv(
GymEnv("CartPole-v1"),
GymEnv("CartPole-v1")
if envtype == "gym"
else DiscreteActionConvMockEnv(),
CatFrames(
dim=-1,
dim=-3 if envtype == "conv" else -1,
N=5,
in_keys=["observation"],
out_keys=["observation_cat"],
in_keys=[key],
out_keys=[f"{key}_cat"],
padding=padding,
),
),
Expand All @@ -987,19 +991,17 @@ def test_tranform_offline_against_online(self, padding):
r = env.rollout(100, break_when_any_done=False)

c = CatFrames(
dim=-1,
dim=-3 if envtype == "conv" else -1,
N=5,
in_keys=["observation", ("next", "observation")],
out_keys=["observation_cat2", ("next", "observation_cat2")],
in_keys=[key, ("next", key)],
out_keys=[f"{key}_cat2", ("next", f"{key}_cat2")],
padding=padding,
)

r2 = c(r)

torch.testing.assert_close(r2["observation_cat2"], r2["observation_cat"])
assert (r2["observation_cat2"] == r2["observation_cat"]).all()

assert (r2["next", "observation_cat2"] == r2["next", "observation_cat"]).all()
torch.testing.assert_close(r2[f"{key}_cat2"], r2[f"{key}_cat"])
torch.testing.assert_close(r2["next", f"{key}_cat2"], r2["next", f"{key}_cat"])

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("batch_size", [(), (1,), (1, 2)])
Expand Down
21 changes: 15 additions & 6 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict.nn import dispatch, TensorDictModuleBase
from tensordict.utils import expand_as_right, NestedKey
from tensordict.utils import expand_as_right, expand_right, NestedKey
from torch import nn, Tensor
from torch.utils._pytree import tree_map
from torchrl._utils import _replace_last
Expand Down Expand Up @@ -2978,7 +2978,12 @@ def unfold_done(done, N):
data = data.unfold(tensordict.ndim - 1, self.N, 1)

# Place -1 dim at self.dim place before squashing
done_mask_expand = expand_as_right(done_mask, data)
done_mask_expand = done_mask.view(
*done_mask.shape[: tensordict.ndim],
*(1,) * (data.ndim - 1 - tensordict.ndim),
done_mask.shape[-1],
)
done_mask_expand = expand_as_right(done_mask_expand, data)
data = data.permute(
*range(0, data.ndim + self.dim - 1),
-1,
Expand All @@ -2994,11 +2999,13 @@ def unfold_done(done, N):
else:
# TODO: This is a pretty bad implementation, could be
# made more efficient but it works!
reset_vals = list(data_orig[reset.squeeze(-1)].unbind(0))
reset_any = reset.any(-1, False)
reset_vals = list(data_orig[reset_any].unbind(0))
j_ = float("inf")
reps = []
d = data.ndim + self.dim - 1
for j in done_mask_expand.sum(d).sum(d).view(-1) // n_feat:
n_feat = data.shape[data.ndim + self.dim :].numel()
for j in done_mask_expand.flatten(d, -1).sum(-1).view(-1) // n_feat:
if j > j_:
reset_vals = reset_vals[1:]
reps.extend([reset_vals[0]] * int(j))
Expand All @@ -3008,8 +3015,10 @@ def unfold_done(done, N):

if first_val is not None:
# Aggregate reset along last dim
reset = reset.any(-1, True)
rexp = reset.expand(*reset.shape[:-1], n_feat)
reset_any = reset.any(-1, False)
rexp = expand_right(
reset_any, (*reset_any.shape, *data.shape[data.ndim + self.dim :])
)
rexp = torch.cat(
[
torch.zeros_like(
Expand Down

0 comments on commit 3df6d9f

Please sign in to comment.