Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Gym 'vectorized' envs compatibility #1519

Merged
merged 20 commits into from
Sep 17, 2023
1 change: 1 addition & 0 deletions .github/unittest/linux_libs/scripts_gym/batch_scripts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ do
pip install gymnasium[atari]
fi
pip install mo-gymnasium
pip install gymnasium-robotics

$DIR/run_test.sh

Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ to be able to create this other composition:
TimeMaxPool
ToTensorImage
UnsqueezeTransform
VecGymEnvTransform
VecNorm
VC1Transform
VIPRewardTransform
Expand Down
26 changes: 26 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,29 @@ class MyClass:
for key in td.keys():
MyClass.__annotations__[key] = torch.Tensor
return tensorclass(MyClass)


def rollout_consistency_assertion(
rollout, *, done_key="done", observation_key="observation"
):
"""Tests that observations in "next" match observations in the next root tensordict when done is False, and don't match otherwise."""

done = rollout[:, :-1]["next", done_key].squeeze(-1)
# data resulting from step, when it's not done
r_not_done = rollout[:, :-1]["next"][~done]
# data resulting from step, when it's not done, after step_mdp
r_not_done_tp1 = rollout[:, 1:][~done]
torch.testing.assert_close(
r_not_done[observation_key], r_not_done_tp1[observation_key]
)

if not done.any():
return

# data resulting from step, when it's done
r_done = rollout[:, :-1]["next"][done]
# data resulting from step, when it's done, after step_mdp and reset
r_done_tp1 = rollout[:, 1:][done]
assert (
(r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) > 1e-1
).all(), (r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1)
112 changes: 111 additions & 1 deletion test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rollout_consistency_assertion,
)
from packaging import version
from tensordict import LazyStackedTensorDict
Expand Down Expand Up @@ -67,12 +68,14 @@
GymWrapper,
MOGymEnv,
MOGymWrapper,
set_gym_backend,
)
from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv
from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv
from torchrl.envs.libs.openml import OpenMLEnv
from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv
from torchrl.envs.libs.robohive import RoboHiveEnv
from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env
from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper
from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType
from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator
Expand All @@ -83,7 +86,7 @@

_has_sklearn = importlib.util.find_spec("sklearn") is not None

from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env
_has_gym_robotics = importlib.util.find_spec("gymnasium_robotics") is not None

if _has_gym:
try:
Expand Down Expand Up @@ -323,6 +326,113 @@ def test_one_hot_and_categorical(self): # noqa: F811
# versions.
return

@implement_for("gymnasium", "0.27.0", None)
# this env has Dict-based observation which is a nice thing to test
@pytest.mark.parametrize(
"envname",
["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"]
+ (["FetchReach-v2"] if _has_gym_robotics else []),
)
def test_vecenvs_wrapper(self, envname):
import gymnasium

# we can't use parametrize with implement_for
env = GymWrapper(
gymnasium.vector.SyncVectorEnv(
2 * [lambda envname=envname: gymnasium.make(envname)]
)
)
assert env.batch_size == torch.Size([2])
check_env_specs(env)
env = GymWrapper(
gymnasium.vector.AsyncVectorEnv(
2 * [lambda envname=envname: gymnasium.make(envname)]
)
)
assert env.batch_size == torch.Size([2])
check_env_specs(env)

@implement_for("gymnasium", "0.27.0", None)
# this env has Dict-based observation which is a nice thing to test
@pytest.mark.parametrize(
"envname",
["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"]
+ (["FetchReach-v2"] if _has_gym_robotics else []),
)
def test_vecenvs_env(self, envname):
from _utils_internal import rollout_consistency_assertion

with set_gym_backend("gymnasium"):
env = GymEnv(envname, num_envs=2, from_pixels=False)
check_env_specs(env)
rollout = env.rollout(100, break_when_any_done=False)
for obs_key in env.observation_spec.keys(True, True):
rollout_consistency_assertion(
rollout, done_key="done", observation_key=obs_key
)

@implement_for("gym", "0.18", "0.27.0")
@pytest.mark.parametrize(
"envname",
["CartPole-v1", "HalfCheetah-v4"],
)
def test_vecenvs_wrapper(self, envname): # noqa: F811
import gym

# we can't use parametrize with implement_for
for envname in ["CartPole-v1", "HalfCheetah-v4"]:
env = GymWrapper(
gym.vector.SyncVectorEnv(
2 * [lambda envname=envname: gym.make(envname)]
)
)
assert env.batch_size == torch.Size([2])
check_env_specs(env)
env = GymWrapper(
gym.vector.AsyncVectorEnv(
2 * [lambda envname=envname: gym.make(envname)]
)
)
assert env.batch_size == torch.Size([2])
check_env_specs(env)

@implement_for("gym", "0.18", "0.27.0")
@pytest.mark.parametrize(
"envname",
["CartPole-v1", "HalfCheetah-v4"],
)
def test_vecenvs_env(self, envname): # noqa: F811
with set_gym_backend("gym"):
env = GymEnv(envname, num_envs=2, from_pixels=False)
check_env_specs(env)
rollout = env.rollout(100, break_when_any_done=False)
for obs_key in env.observation_spec.keys(True, True):
rollout_consistency_assertion(
rollout, done_key="done", observation_key=obs_key
)
if envname != "CartPole-v1":
with set_gym_backend("gym"):
env = GymEnv(envname, num_envs=2, from_pixels=True)
check_env_specs(env)

@implement_for("gym", None, "0.18")
@pytest.mark.parametrize(
"envname",
["CartPole-v1", "HalfCheetah-v4"],
)
def test_vecenvs_wrapper(self, envname): # noqa: F811
# skipping tests for older versions of gym
...

@implement_for("gym", None, "0.18")
@pytest.mark.parametrize(
"envname",
["CartPole-v1", "HalfCheetah-v4"],
)
def test_vecenvs_env(self, envname): # noqa: F811
# skipping tests for older versions of gym
...


@implement_for("gym", None, "0.26")
def _make_gym_environment(env_name): # noqa: F811
Expand Down
4 changes: 2 additions & 2 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ def module_set(self):
cls = inspect.getmodule(self.fn)
setattr(cls, self.fn.__name__, self.fn)

@staticmethod
def import_module(module_name: Union[Callable, str]) -> str:
@classmethod
def import_module(cls, module_name: Union[Callable, str]) -> str:
"""Imports module and returns its version."""
if not callable(module_name):
module = import_module(module_name)
Expand Down
15 changes: 12 additions & 3 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def action_spec(self, value: TensorSpec) -> None:
)
if value.shape[: len(self.batch_size)] != self.batch_size:
raise ValueError(
"The value of spec.shape must match the env batch size."
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
)

if isinstance(value, CompositeSpec):
Expand Down Expand Up @@ -791,7 +791,7 @@ def reward_spec(self, value: TensorSpec) -> None:
)
if value.shape[: len(self.batch_size)] != self.batch_size:
raise ValueError(
"The value of spec.shape must match the env batch size."
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
)
if isinstance(value, CompositeSpec):
for _ in value.values(True, True): # noqa: B007
Expand Down Expand Up @@ -820,6 +820,15 @@ def reward_spec(self, value: TensorSpec) -> None:

# done spec
def _get_done_keys(self):
if "full_done_spec" not in self.output_spec.keys():
# populate the "done" entry
# this will be raised if there is not full_done_spec (unlikely) or no done_key
# Since output_spec is lazily populated with an empty composite spec for
# done_spec, the second case is much more likely to occur.
self.done_spec = DiscreteTensorSpec(
n=2, shape=(*self.batch_size, 1), dtype=torch.bool, device=self.device
)

keys = self.output_spec["full_done_spec"].keys(True, True)
if not len(keys):
raise AttributeError("Could not find done spec")
Expand Down Expand Up @@ -967,7 +976,7 @@ def done_spec(self, value: TensorSpec) -> None:
)
if value.shape[: len(self.batch_size)] != self.batch_size:
raise ValueError(
"The value of spec.shape must match the env batch size."
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
)
if isinstance(value, CompositeSpec):
for _ in value.values(True, True): # noqa: B007
Expand Down
39 changes: 27 additions & 12 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ class GymLikeEnv(_EnvWrapper):
It is also expected that env.reset() returns an observation similar to the one observed after a step is completed.
"""

_info_dict_reader: BaseInfoDictReader
_info_dict_reader: List[BaseInfoDictReader]

@classmethod
def __new__(cls, *args, **kwargs):
cls._info_dict_reader = None
cls._info_dict_reader = []
return super().__new__(cls, *args, _batch_locked=True, **kwargs)

def read_action(self, action):
Expand All @@ -144,7 +144,7 @@ def read_done(self, done):
done (np.ndarray, boolean or other format): done state obtained from the environment

"""
return done, done
return done, done.any() if not isinstance(done, bool) else done
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't particularily like it
One aspect to consider is when we're doing our skip_frame but one of the envs is done and not the others. Then we stop the loop. This means that a bunch of envs silently don't complete the loop. Not sure what to do bc we can't really keep on applying the same action to the other env if it's in auto-reset.

One solution is that if a non-complete done is encountered we raise a warning telling users to use the skip-frame transform instead.

Copy link
Contributor

@albertbou92 albertbou92 Sep 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the situation you described is too common and silently not completing the loop is dangerous.

A bit hacky but cant we create some kind on NO-OP step wrapper (e.g. if action is -1, just returns the last obs, reward etc, otherwise normal step)? and overwrite the actions of the envs that have reached the end of the episode before skip_frame actions are taken.

You do the loop skip_frame times for all envs but for early finishing envs a useless step will be called at the end.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about disabling frame_skip for num_envs > 0, and just say that one needs to use the transform instead?

Copy link
Contributor

@albertbou92 albertbou92 Sep 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I always thought frame_skip as a transform was more consistent with the rest of the library. I thought there was some problem that made it less efficient or sth and that was why the option was also given in the env. But yes, if that is not the case, I think using the transform is the best way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vmoens quick heads up here. I think the frame_skip transform did not handle truncation/termination correctly the last time I checked. Also, some envs natively handle frame_skip and this runs much faster as they don't have to render the frames, e.g. ALE-v5 which skips the frames internally with the frameskip keyword in the constructor.


def read_reward(self, reward):
"""Reads the reward and maps it to the reward space.
Expand Down Expand Up @@ -231,8 +231,11 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:

tensordict_out = TensorDict(obs_dict, batch_size=tensordict.batch_size)

if self.info_dict_reader is not None and info is not None:
self.info_dict_reader(info, tensordict_out)
if self.info_dict_reader and info is not None:
for info_dict_reader in self.info_dict_reader:
out = info_dict_reader(info, tensordict_out)
if out is not None:
tensordict_out = out
tensordict_out = tensordict_out.to(self.device, non_blocking=True)
return tensordict_out

Expand All @@ -255,9 +258,12 @@ def _reset(
source=source,
batch_size=self.batch_size,
)
if self.info_dict_reader is not None and info is not None:
self.info_dict_reader(info, tensordict_out)
elif info is None and self.info_dict_reader is not None:
if self.info_dict_reader and info is not None:
for info_dict_reader in self.info_dict_reader:
out = info_dict_reader(info, tensordict_out)
if out is not None:
tensordict_out = out
elif info is None and self.info_dict_reader:
# populate the reset with the items we have not seen from info
for key, item in self.observation_spec.items(True, True):
if key not in tensordict_out.keys(True, True):
Expand Down Expand Up @@ -298,9 +304,12 @@ def set_info_dict_reader(self, info_dict_reader: BaseInfoDictReader) -> GymLikeE
>>> assert "my_info_key" in tensordict.keys()

"""
self.info_dict_reader = info_dict_reader
for info_key, spec in info_dict_reader.info_spec.items():
self.observation_spec[info_key] = spec.to(self.device)
self.info_dict_reader.append(info_dict_reader)
if isinstance(info_dict_reader, BaseInfoDictReader):
# if we have a BaseInfoDictReader, we know what the specs will be
# In other cases (eg, RoboHive) we will need to figure it out empirically.
for info_key, spec in info_dict_reader.info_spec.items():
self.observation_spec[info_key] = spec.to(self.device)
return self

def __repr__(self) -> str:
Expand All @@ -314,4 +323,10 @@ def info_dict_reader(self):

@info_dict_reader.setter
def info_dict_reader(self, value: callable):
self._info_dict_reader = value
warnings.warn(
f"Please use {type(self)}.set_info_dict_reader method to set a new info reader. "
f"This method will append a reader to the list of existing readers (if any). "
f"Setting info_dict_reader directly will be soon deprecated.",
category=DeprecationWarning,
)
self._info_dict_reader.append(value)
Loading