Skip to content

Commit

Permalink
[BugFix] Make DMControlEnv aware of truncated signals (pytorch#2196)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 3, 2024
1 parent 8d99026 commit 2370d6e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
23 changes: 18 additions & 5 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,12 +1307,12 @@ def _make_gym_environment(env_name): # noqa: F811


@pytest.mark.skipif(not _has_dmc, reason="no dm_control library found")
@pytest.mark.parametrize("env_name,task", [["cheetah", "run"]])
@pytest.mark.parametrize("frame_skip", [1, 3])
@pytest.mark.parametrize(
"from_pixels,pixels_only", [[True, True], [True, False], [False, False]]
)
class TestDMControl:
@pytest.mark.parametrize("env_name,task", [["cheetah", "run"]])
@pytest.mark.parametrize("frame_skip", [1, 3])
@pytest.mark.parametrize(
"from_pixels,pixels_only", [[True, True], [True, False], [False, False]]
)
def test_dmcontrol(self, env_name, task, frame_skip, from_pixels, pixels_only):
if from_pixels and (not torch.has_cuda or not torch.cuda.device_count()):
raise pytest.skip("no cuda device")
Expand Down Expand Up @@ -1384,6 +1384,11 @@ def test_dmcontrol(self, env_name, task, frame_skip, from_pixels, pixels_only):
assert final_seed0 == final_seed2
assert_allclose_td(rollout0, rollout2)

@pytest.mark.parametrize("env_name,task", [["cheetah", "run"]])
@pytest.mark.parametrize("frame_skip", [1, 3])
@pytest.mark.parametrize(
"from_pixels,pixels_only", [[True, True], [True, False], [False, False]]
)
def test_faketd(self, env_name, task, frame_skip, from_pixels, pixels_only):
if from_pixels and not torch.cuda.device_count():
raise pytest.skip("no cuda device")
Expand All @@ -1397,6 +1402,14 @@ def test_faketd(self, env_name, task, frame_skip, from_pixels, pixels_only):
)
check_env_specs(env)

def test_truncated(self):
env = DMControlEnv("walker", "walk")
r = env.rollout(1001)
assert r.shape == (1000,)
assert r[-1]["next", "truncated"]
assert r[-1]["next", "done"]
assert not r[-1]["next", "terminated"]


params = []
if _has_dmc:
Expand Down
10 changes: 9 additions & 1 deletion torchrl/envs/libs/dm_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpy as np
import torch
from dm_env import StepType

from torchrl._utils import logger as torchrl_logger, VERBOSE

Expand Down Expand Up @@ -321,7 +322,14 @@ def _output_transform(
timestep_tuple = (timestep_tuple,)
reward = timestep_tuple[0].reward

done = truncated = terminated = False # dm_control envs are non-terminating
truncated = terminated = False
if timestep_tuple[0].step_type == StepType.LAST:
if np.isclose(timestep_tuple[0].discount, 1):
truncated = True
else:
terminated = True
done = truncated or terminated

observation = timestep_tuple[0].observation
info = {}

Expand Down

0 comments on commit 2370d6e

Please sign in to comment.