Skip to content

Commit

Permalink
[BugFix, Test] Fix flaky gym vecenvs tests (pytorch#1727)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Dec 4, 2023
1 parent d432a9c commit 7166f3c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 19 deletions.
15 changes: 10 additions & 5 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ class MyClass:


def rollout_consistency_assertion(
rollout, *, done_key="done", observation_key="observation"
rollout, *, done_key="done", observation_key="observation", done_strict=False
):
"""Tests that observations in "next" match observations in the next root tensordict when done is False, and don't match otherwise."""

Expand All @@ -335,19 +335,24 @@ def rollout_consistency_assertion(
# data resulting from step, when it's not done, after step_mdp
r_not_done_tp1 = rollout[:, 1:][~done]
torch.testing.assert_close(
r_not_done[observation_key], r_not_done_tp1[observation_key]
r_not_done[observation_key],
r_not_done_tp1[observation_key],
msg=f"Key {observation_key} did not match",
)

if not done.any():
return
if done_strict and not done.any():
raise RuntimeError("No done detected, test could not complete.")

# data resulting from step, when it's done
r_done = rollout[:, :-1]["next"][done]
# data resulting from step, when it's done, after step_mdp and reset
r_done_tp1 = rollout[:, 1:][done]
assert (
(r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) > 1e-1
).all(), (r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1)
).all(), (
f"Entries in next tensordict do not match entries in root "
f"tensordict after reset : {(r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) < 1e-1}"
)


def rand_reset(env):
Expand Down
27 changes: 19 additions & 8 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,21 +400,23 @@ def test_vecenvs_wrapper(self, envname):
["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"]
+ (["FetchReach-v2"] if _has_gym_robotics else []),
)
@pytest.mark.flaky(reruns=8, reruns_delay=1)
def test_vecenvs_env(self, envname):
from _utils_internal import rollout_consistency_assertion

with set_gym_backend("gymnasium"):
env = GymEnv(envname, num_envs=2, from_pixels=False)

env.set_seed(0)
assert env.get_library_name(env._env) == "gymnasium"
# rollouts can be executed without decorator
check_env_specs(env)
rollout = env.rollout(100, break_when_any_done=False)
for obs_key in env.observation_spec.keys(True, True):
rollout_consistency_assertion(
rollout, done_key="done", observation_key=obs_key
rollout,
done_key="done",
observation_key=obs_key,
done_strict="CartPole" in envname,
)
env.close()
del env

@implement_for("gym", "0.18", "0.27.0")
@pytest.mark.parametrize(
Expand All @@ -441,30 +443,39 @@ def test_vecenvs_wrapper(self, envname): # noqa: F811
)
assert env.batch_size == torch.Size([2])
check_env_specs(env)
env.close()
del env

@implement_for("gym", "0.18", "0.27.0")
@pytest.mark.parametrize(
"envname",
["CartPole-v1", "HalfCheetah-v4"],
)
@pytest.mark.flaky(reruns=3, reruns_delay=1)
def test_vecenvs_env(self, envname): # noqa: F811
with set_gym_backend("gym"):
env = GymEnv(envname, num_envs=2, from_pixels=False)

env.set_seed(0)
assert env.get_library_name(env._env) == "gym"
# rollouts can be executed without decorator
check_env_specs(env)
rollout = env.rollout(100, break_when_any_done=False)
for obs_key in env.observation_spec.keys(True, True):
rollout_consistency_assertion(
rollout, done_key="done", observation_key=obs_key
rollout,
done_key="done",
observation_key=obs_key,
done_strict="CartPole" in envname,
)
env.close()
del env
if envname != "CartPole-v1":
with set_gym_backend("gym"):
env = GymEnv(envname, num_envs=2, from_pixels=True)
env.set_seed(0)
# rollouts can be executed without decorator
check_env_specs(env)
env.close()
del env

@implement_for("gym", None, "0.18")
@pytest.mark.parametrize(
Expand Down
13 changes: 7 additions & 6 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,7 +1160,7 @@ def _read_obs(self, obs, key, tensor, index):
# Simplest case: there is one observation,
# presented as a np.ndarray. The key should be pixels or observation.
# We just write that value at its location in the tensor
tensor[index] = torch.as_tensor(obs, device=tensor.device)
tensor[index] = torch.tensor(obs, device=tensor.device)
elif isinstance(obs, dict):
if key not in obs:
raise KeyError(
Expand All @@ -1171,13 +1171,13 @@ def _read_obs(self, obs, key, tensor, index):
# if the obs is a dict, we expect that the key points also to
# a value in the obs. We retrieve this value and write it in the
# tensor
tensor[index] = torch.as_tensor(subobs, device=tensor.device)
tensor[index] = torch.tensor(subobs, device=tensor.device)

elif isinstance(obs, (list, tuple)):
# tuples are stacked along the first dimension when passing gym spaces
# to torchrl specs. As such, we can simply stack the tuple and set it
# at the relevant index (assuming stacking can be achieved)
tensor[index] = torch.as_tensor(obs, device=tensor.device)
tensor[index] = torch.tensor(obs, device=tensor.device)
else:
raise NotImplementedError(
f"Observations of type {type(obs)} are not supported yet."
Expand All @@ -1186,11 +1186,12 @@ def _read_obs(self, obs, key, tensor, index):
def __call__(self, info_dict, tensordict):
terminal_obs = info_dict.get(self.backend_key[self.backend], None)
for key, item in self.info_spec.items(True, True):
final_obs = item.zero()
final_obs_buffer = item.zero()
if terminal_obs is not None:
for i, obs in enumerate(terminal_obs):
self._read_obs(obs, key[-1], final_obs, index=i)
tensordict.set(key, final_obs)
# writes final_obs inplace with terminal_obs content
self._read_obs(obs, key[-1], final_obs_buffer, index=i)
tensordict.set(key, final_obs_buffer)
return tensordict


Expand Down

0 comments on commit 7166f3c

Please sign in to comment.