Skip to content

Commit

Permalink
[BugFix] EnvBase._complete_done to complete "terminated" key proper…
Browse files Browse the repository at this point in the history
…ly (pytorch#2294)
  • Loading branch information
kurtamohler authored Jul 18, 2024
1 parent f764c02 commit c4b2eb0
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 6 deletions.
56 changes: 56 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,62 @@ def test_rollout_predictability(device):
).all()


# Check that the "terminated" key is filled in automatically if only the "done"
# key is provided in `_step`.
def test_done_key_completion_done():
class DoneEnv(CountingEnv):
def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
self.count += 1
tensordict = TensorDict(
source={
"observation": self.count.clone(),
"done": self.count > self.max_steps,
"reward": torch.zeros_like(self.count, dtype=torch.float),
},
batch_size=self.batch_size,
device=self.device,
)
return tensordict

env = DoneEnv(max_steps=torch.tensor([[0], [1]]), batch_size=(2,))
td = env.reset()
env.rand_action(td)
td = env.step(td)
assert torch.equal(td[("next", "done")], torch.tensor([[True], [False]]))
assert torch.equal(td[("next", "terminated")], torch.tensor([[True], [False]]))


# Check that the "done" key is filled in automatically if only the "terminated"
# key is provided in `_step`.
def test_done_key_completion_terminated():
class TerminatedEnv(CountingEnv):
def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
self.count += 1
tensordict = TensorDict(
source={
"observation": self.count.clone(),
"terminated": self.count > self.max_steps,
"reward": torch.zeros_like(self.count, dtype=torch.float),
},
batch_size=self.batch_size,
device=self.device,
)
return tensordict

env = TerminatedEnv(max_steps=torch.tensor([[0], [1]]), batch_size=(2,))
td = env.reset()
env.rand_action(td)
td = env.step(td)
assert torch.equal(td[("next", "done")], torch.tensor([[True], [False]]))
assert torch.equal(td[("next", "terminated")], torch.tensor([[True], [False]]))


@pytest.mark.skipif(not _has_gym, reason="no gym")
@pytest.mark.parametrize("env_name", [PENDULUM_VERSIONED])
@pytest.mark.parametrize("frame_skip", [1])
Expand Down
13 changes: 7 additions & 6 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,7 +1515,8 @@ def _complete_done(
shape = (*leading_dim, *item.shape)
if val is not None:
if val.shape != shape:
data.set(key, val.reshape(shape))
val = val.reshape(shape)
data.set(key, val)
vals[key] = val

if len(vals) < i + 1:
Expand All @@ -1535,18 +1536,18 @@ def _complete_done(
"Cannot infer the value of terminated when only done and truncated are present."
)
data.set("terminated", val)
data_keys.add("terminated")
elif (
key == "terminated"
and val is not None
and "done" in done_spec_keys
and "done" not in data_keys
):
if "truncated" in data_keys:
done = val | data.get("truncated")
data.set("done", done)
else:
data.set("done", val)
elif val is None:
val = val | data.get("truncated")
data.set("done", val)
data_keys.add("done")
elif val is None and key not in data_keys:
# we must keep this here: we only want to fill with 0s if we're sure
# done should not be copied to terminated or terminated to done
# in this case, just fill with 0s
Expand Down

0 comments on commit c4b2eb0

Please sign in to comment.