Skip to content

Commit

Permalink
[BugFix] Gym async with _reset full of True (pytorch#2145)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 1, 2024
1 parent 69a6cb1 commit 7109a3f
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 9 deletions.
109 changes: 109 additions & 0 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import gc
import importlib
import os
from contextlib import nullcontext
Expand Down Expand Up @@ -1094,6 +1096,113 @@ def test_vecenvs_nan(self): # noqa: F811
del c
return

def _get_dummy_gym_env(self, backend, **kwargs):
with set_gym_backend(backend):

class CustomEnv(gym_backend().Env):
def __init__(self, dim=3, use_termination=True, max_steps=4):
self.dim = dim
self.use_termination = use_termination
self.observation_space = gym_backend("spaces").Box(
low=-np.inf, high=np.inf, shape=(self.dim,)
)
self.action_space = gym_backend("spaces").Box(
low=-np.inf, high=np.inf, shape=(1,)
)
self.max_steps = max_steps

def _get_info(self):
return {"field1": self.state**2}

def _get_obs(self):
return self.state.copy()

def reset(self, seed=0, options=None):
self.state = np.zeros(self.observation_space.shape)
observation = self._get_obs()
info = self._get_info()
assert (observation < self.max_steps).all()
return observation, info

def step(self, action):
# self.state += action.item()
self.state += 1
truncated, terminated = False, False
if self.use_termination:
terminated = self.state[0] == 4
reward = 1 if terminated else 0 # Binary sparse rewards
observation = self._get_obs()
info = self._get_info()
return observation, reward, terminated, truncated, info

return CustomEnv(**kwargs)

@pytest.mark.parametrize("heterogeneous", [False, True])
def test_resetting_strategies(self, heterogeneous):
if _has_gymnasium:
backend = "gymnasium"
else:
backend = "gym"
with set_gym_backend(backend):
if version.parse(gym_backend().__version__) < version.parse("0.26"):
torchrl_logger.info(
"Running into unrelated errors with older versions of gym."
)
return
steps = 5
if not heterogeneous:
env = GymWrapper(
gym_backend().vector.AsyncVectorEnv(
[functools.partial(self._get_dummy_gym_env, backend=backend)]
* 4
)
)
else:
env = GymWrapper(
gym_backend().vector.AsyncVectorEnv(
[
functools.partial(
self._get_dummy_gym_env,
max_steps=i + 4,
backend=backend,
)
for i in range(4)
]
)
)
try:
check_env_specs(env)
td = env.rollout(steps, break_when_any_done=False)
if not heterogeneous:
assert not (td["observation"] == 4).any()
assert (td["next", "observation"] == 4).sum() == 3 * 4

# check with manual reset
torch.manual_seed(0)
env.set_seed(0)
reset = env.reset(
TensorDict({"_reset": torch.ones(4, 1, dtype=torch.bool)}, [4])
)
r0 = env.rollout(
10, break_when_any_done=False, auto_reset=False, tensordict=reset
)
torch.manual_seed(0)
env.set_seed(0)
reset = env.reset()
r1 = env.rollout(
10, break_when_any_done=False, auto_reset=False, tensordict=reset
)
torch.manual_seed(0)
env.set_seed(0)
r2 = env.rollout(10, break_when_any_done=False)
assert_allclose_td(r0, r1)
assert_allclose_td(r1, r2)
finally:
if not env.is_closed:
env.close()
del env
gc.collect()


@implement_for("gym", None, "0.26")
def _make_gym_environment(env_name): # noqa: F811
Expand Down
22 changes: 15 additions & 7 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,10 @@ def _set_seed_initial(self, seed: int) -> None: # noqa: F811
f"Calling env.seed from now on."
)
self._seed_calls_reset = False
self._env.seed(seed=seed)
try:
self._env.seed(seed=seed)
except AttributeError as err2:
raise err from err2

@implement_for("gymnasium")
def _set_seed_initial(self, seed: int) -> None: # noqa: F811
Expand Down Expand Up @@ -1154,14 +1157,19 @@ def _reset(
if self._is_batched:
# batched (aka 'vectorized') env reset is a bit special: envs are
# automatically reset. What we do here is just to check if _reset
# is present. If it is not, we just reset. Otherwise we just skip.
# is present. If it is not, we just reset. Otherwise, we just skip.
if tensordict is None:
return super()._reset(tensordict)
return super()._reset(tensordict, **kwargs)
reset = tensordict.get("_reset", None)
if reset is None:
return super()._reset(tensordict)
elif reset is not None:
return tensordict.exclude("_reset")
if reset is not None:
# we must copy the tensordict because the transform
# expects a tuple (tensordict, tensordict_reset) where the
# first still carries a _reset
tensordict = tensordict.exclude("_reset")
if reset is None or reset.all():
return super()._reset(tensordict, **kwargs)
else:
return tensordict
return super()._reset(tensordict, **kwargs)


Expand Down
17 changes: 15 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6988,16 +6988,29 @@ def _reset(
if (
reset is not done
and (reset != done).any()
and (not reset.all() or not reset.any())
# it can happen that all are reset, in which case
# it's fine (doesn't need to match done)
and not reset.all()
):
raise RuntimeError(
"Cannot partially reset a gym(nasium) async env with a reset mask that does not match the done mask. "
"Cannot partially reset a gym(nasium) async env with a "
"reset mask that does not match the done mask. "
f"Got reset={reset}\nand done={done}"
)
# if not reset.any(), we don't need to do anything.
# if reset.all(), we don't either (bc GymWrapper will call a plain reset).
if reset is not None and reset.any():
saved_next = self._memo["saved_next"]
if saved_next is None:
if reset.all():
# We're fine: this means that a full reset was passed and the
# env was manually reset
tensordict_reset.pop(self.final_name, None)
return tensordict_reset
raise RuntimeError(
"Did not find a saved tensordict while the reset mask was "
f"not empty: reset={reset}. Done was {done}."
)
# reset = reset.view(tensordict.shape)
# we have a data container from the previous call to step
# that contains part of the observation we need.
Expand Down

0 comments on commit 7109a3f

Please sign in to comment.