diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 8182fbb3aa4..c5e7bb6ea45 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -11,13 +11,43 @@ # this returns relative path from current file. import pytest import torch.cuda -from torchrl._utils import seed_generator +from torchrl._utils import seed_generator, implement_for from torchrl.envs import EnvBase - +from torchrl.envs.libs.gym import _has_gym # Specified for test_utils.py __version__ = "0.3" +# Default versions of the environments. +CARTPOLE_VERSIONED = "CartPole-v1" +HALFCHEETAH_VERSIONED = "HalfCheetah-v4" +PENDULUM_VERSIONED = "Pendulum-v1" +PONG_VERSIONED = "ALE/Pong-v5" + + +@implement_for("gym", None, "0.21.0") +def _set_gym_environments(): # noqa: F811 + global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED + + CARTPOLE_VERSIONED = "CartPole-v0" + HALFCHEETAH_VERSIONED = "HalfCheetah-v2" + PENDULUM_VERSIONED = "Pendulum-v0" + PONG_VERSIONED = "Pong-v4" + + +@implement_for("gym", "0.21.0", None) +def _set_gym_environments(): # noqa: F811 + global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED + + CARTPOLE_VERSIONED = "CartPole-v1" + HALFCHEETAH_VERSIONED = "HalfCheetah-v4" + PENDULUM_VERSIONED = "Pendulum-v1" + PONG_VERSIONED = "ALE/Pong-v5" + + +if _has_gym: + _set_gym_environments() + def get_relative_path(curr_file, *path_components): return os.path.join(os.path.dirname(curr_file), *path_components) diff --git a/test/smoke_test_deps.py b/test/smoke_test_deps.py index 03caa3e8d39..56463039bf4 100644 --- a/test/smoke_test_deps.py +++ b/test/smoke_test_deps.py @@ -2,21 +2,10 @@ import tempfile import pytest +from _utils_internal import PONG_VERSIONED from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv from torchrl.envs.libs.gym import _has_gym, GymEnv -if _has_gym: - import gym - from packaging import version - - gym_version = version.parse(gym.__version__) - PONG_VERSIONED = ( - "ALE/Pong-v5" if gym_version > version.parse("0.20.0") else "Pong-v4" - ) -else: - # placeholders - PONG_VERSIONED = "ALE/Pong-v5" - try: from torch.utils.tensorboard import SummaryWriter diff --git a/test/test_collector.py b/test/test_collector.py index 0176338950e..f7f94035e0e 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -8,7 +8,7 @@ import numpy as np import pytest import torch -from _utils_internal import generate_seeds +from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED from mocking_classes import ( ContinuousActionVecMockEnv, DiscreteActionConvMockEnv, @@ -42,22 +42,6 @@ TensorDictModule, ) -if _has_gym: - import gym - from packaging import version - - gym_version = version.parse(gym.__version__) - PENDULUM_VERSIONED = ( - "Pendulum-v1" if gym_version > version.parse("0.20.0") else "Pendulum-v0" - ) - PONG_VERSIONED = ( - "ALE/Pong-v5" if gym_version > version.parse("0.20.0") else "Pong-v4" - ) -else: - # placeholders - PENDULUM_VERSIONED = "Pendulum-v1" - PONG_VERSIONED = "ALE/Pong-v5" - # torch.set_default_dtype(torch.double) diff --git a/test/test_env.py b/test/test_env.py index aefeb4f36de..fa1607041ae 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -11,7 +11,13 @@ import pytest import torch import yaml -from _utils_internal import get_available_devices +from _utils_internal import ( + get_available_devices, + CARTPOLE_VERSIONED, + PENDULUM_VERSIONED, + PONG_VERSIONED, + HALFCHEETAH_VERSIONED, +) from mocking_classes import ( ActionObsMergeLinear, DiscreteActionConvMockEnv, @@ -49,30 +55,11 @@ ) from torchrl.modules.tensordict_module import WorldModelWrapper +gym_version = None if _has_gym: import gym gym_version = version.parse(gym.__version__) - PENDULUM_VERSIONED = ( - "Pendulum-v1" if gym_version > version.parse("0.20.0") else "Pendulum-v0" - ) - CARTPOLE_VERSIONED = ( - "CartPole-v1" if gym_version > version.parse("0.20.0") else "CartPole-v0" - ) - PONG_VERSIONED = ( - "ALE/Pong-v5" if gym_version > version.parse("0.20.0") else "Pong-v4" - ) - HALFCHEETAH_VERSIONED = ( - "HalfCheetah-v4" if gym_version > version.parse("0.20.0") else "HalfCheetah-v2" - ) -else: - # placeholder - gym_version = version.parse("0.0.1") - - # placeholders - PENDULUM_VERSIONED = "Pendulum-v1" - CARTPOLE_VERSIONED = "CartPole-v1" - PONG_VERSIONED = "ALE/Pong-v5" try: this_dir = os.path.dirname(os.path.realpath(__file__)) @@ -1048,7 +1035,7 @@ def test_batch_unlocked_with_batch_size(device): @pytest.mark.skipif(not _has_gym, reason="no gym") @pytest.mark.skipif( - gym_version < version.parse("0.20.0"), + gym_version is None or gym_version < version.parse("0.20.0"), reason="older versions of half-cheetah do not have 'x_position' info key.", ) def test_info_dict_reader(seed=0): diff --git a/test/test_libs.py b/test/test_libs.py index ae99b26dbf4..bb853f9642d 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3,16 +3,27 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse +from sys import platform import numpy as np import pytest import torch -from _utils_internal import _test_fake_tensordict -from _utils_internal import get_available_devices +from _utils_internal import ( + _test_fake_tensordict, + get_available_devices, + HALFCHEETAH_VERSIONED, + PONG_VERSIONED, + PENDULUM_VERSIONED, +) from packaging import version +from tensordict.tensordict import assert_allclose_td +from torchrl._utils import implement_for from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import RandomPolicy +from torchrl.envs import EnvCreator, ParallelEnv +from torchrl.envs.libs.dm_control import DMControlEnv, DMControlWrapper from torchrl.envs.libs.dm_control import _has_dmc +from torchrl.envs.libs.gym import GymEnv, GymWrapper from torchrl.envs.libs.gym import _has_gym, _is_from_pixels from torchrl.envs.libs.habitat import HabitatEnv, _has_habitat from torchrl.envs.libs.jumanji import JumanjiEnv, _has_jumanji @@ -32,39 +43,8 @@ from dm_control import suite from dm_control.suite.wrappers import pixels -from sys import platform - -from tensordict.tensordict import assert_allclose_td -from torchrl.envs import EnvCreator, ParallelEnv -from torchrl.envs.libs.dm_control import DMControlEnv, DMControlWrapper -from torchrl.envs.libs.gym import GymEnv, GymWrapper - IS_OSX = platform == "darwin" -if _has_gym: - from packaging import version - - gym_version = version.parse(gym.__version__) - PENDULUM_VERSIONED = ( - "Pendulum-v1" if gym_version > version.parse("0.20.0") else "Pendulum-v0" - ) - HC_VERSIONED = ( - "HalfCheetah-v4" if gym_version > version.parse("0.20.0") else "HalfCheetah-v2" - ) - PONG_VERSIONED = ( - "ALE/Pong-v5" if gym_version > version.parse("0.20.0") else "Pong-v4" - ) - - # if gym_version < version.parse("0.24.0") and torch.cuda.device_count() > 0: - # from opengl_rendering import create_opengl_context - # - # create_opengl_context() -else: - # placeholders - PENDULUM_VERSIONED = "Pendulum-v1" - HC_VERSIONED = "HalfCheetah-v4" - PONG_VERSIONED = "ALE/Pong-v5" - @pytest.mark.skipif(not _has_gym, reason="no gym library found") @pytest.mark.parametrize( @@ -123,10 +103,7 @@ def test_gym(self, env_name, frame_skip, from_pixels, pixels_only): base_env = gym.make(env_name, frameskip=frame_skip) frame_skip = 1 else: - if gym_version < version.parse("0.26.0"): - base_env = gym.make(env_name) - else: - base_env = gym.make(env_name, render_mode="rgb_array") + base_env = _make_gym_environment(env_name) if from_pixels and not _is_from_pixels(base_env): base_env = PixelObservationWrapper(base_env, pixels_only=pixels_only) @@ -164,6 +141,16 @@ def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only): _test_fake_tensordict(env) +@implement_for("gym", None, "0.26") +def _make_gym_environment(env_name): # noqa: F811 + return gym.make(env_name) + + +@implement_for("gym", "0.26", None) +def _make_gym_environment(env_name): # noqa: F811 + return gym.make(env_name, render_mode="rgb_array") + + @pytest.mark.skipif(not _has_dmc, reason="no dm_control library found") @pytest.mark.parametrize("env_name,task", [["cheetah", "run"]]) @pytest.mark.parametrize("frame_skip", [1, 3]) @@ -270,9 +257,9 @@ def test_faketd(self, env_name, task, frame_skip, from_pixels, pixels_only): "env_lib,env_args,env_kwargs", [ [DMControlEnv, ("cheetah", "run"), {"from_pixels": True}], - [GymEnv, (HC_VERSIONED,), {"from_pixels": True}], + [GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": True}], [DMControlEnv, ("cheetah", "run"), {"from_pixels": False}], - [GymEnv, (HC_VERSIONED,), {"from_pixels": False}], + [GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": False}], [GymEnv, (PONG_VERSIONED,), {}], ], ) @@ -307,9 +294,9 @@ def test_td_creation_from_spec(env_lib, env_args, env_kwargs): "env_lib,env_args,env_kwargs", [ [DMControlEnv, ("cheetah", "run"), {"from_pixels": True}], - [GymEnv, (HC_VERSIONED,), {"from_pixels": True}], + [GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": True}], [DMControlEnv, ("cheetah", "run"), {"from_pixels": False}], - [GymEnv, (HC_VERSIONED,), {"from_pixels": False}], + [GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": False}], [GymEnv, (PONG_VERSIONED,), {}], ], ) diff --git a/test/test_transforms.py b/test/test_transforms.py index 4f268031d2f..d06ac84b528 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -8,7 +8,12 @@ import numpy as np import pytest import torch -from _utils_internal import get_available_devices, retry, dtype_fixture # noqa +from _utils_internal import ( # noqa + get_available_devices, + retry, + dtype_fixture, + PENDULUM_VERSIONED, +) from mocking_classes import ( ContinuousActionVecMockEnv, DiscreteActionConvMockEnvNumpy, @@ -59,18 +64,6 @@ ) from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform -if _has_gym: - import gym - from packaging import version - - gym_version = version.parse(gym.__version__) - PENDULUM_VERSIONED = ( - "Pendulum-v1" if gym_version > version.parse("0.20.0") else "Pendulum-v0" - ) -else: - # placeholders - PENDULUM_VERSIONED = "Pendulum-v1" - TIMEOUT = 10.0 diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 27db9edf1b7..cf284f4e7db 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -19,13 +19,13 @@ TensorSpec, UnboundedContinuousTensorSpec, ) +from ..._utils import implement_for from ...data.utils import numpy_to_torch_dtype_dict from ..gym_like import GymLikeEnv, default_info_dict_reader from ..utils import _classproperty try: import gym - from packaging import version _has_gym = True except ImportError: @@ -48,9 +48,6 @@ from torchrl.envs.libs.utils import ( GymPixelObservationWrapper as PixelObservationWrapper, ) - gym_version = version.parse(gym.__version__) - if gym_version >= version.parse("0.26.0"): - from gym.wrappers.compatibility import EnvCompatibility __all__ = ["GymWrapper", "GymEnv"] @@ -103,16 +100,22 @@ def _gym_to_torchrl_spec_transform( def _get_envs(to_dict=False) -> List: - if gym_version < version.parse("0.26.0"): - envs = gym.envs.registration.registry.env_specs.keys() - else: - envs = gym.envs.registration.registry.keys() - + envs = _get_gym_envs() envs = list(envs) envs = sorted(envs) return envs +@implement_for("gym", None, "0.26.0") +def _get_gym_envs(): # noqa: F811 + return gym.envs.registration.registry.env_specs.keys() + + +@implement_for("gym", "0.26.0", None) +def _get_gym_envs(): # noqa: F811 + return gym.envs.registration.registry.keys() + + def _get_gym(): if _has_gym: return gym @@ -186,20 +189,30 @@ def _build_env( "PixelObservationWrapper cannot be used to wrap an environment" "that is already a PixelObservationWrapper instance." ) - if gym_version >= version.parse("0.26.0") and not env.render_mode: - warnings.warn( - "Environments provided to GymWrapper that need to be wrapped in PixelObservationWrapper " - "should be created with `gym.make(env_name, render_mode=mode)` where possible," - 'where mode is either "rgb_array" or any other supported mode.' - ) - # resetting as 0.26 comes with a very 'nice' OrderEnforcing wrapper - env = EnvCompatibility(env) - env.reset() - env = LegacyPixelObservationWrapper(env, pixels_only=pixels_only) - else: - env = PixelObservationWrapper(env, pixels_only=pixels_only) + env = self._build_gym_env(env, pixels_only) return env + @implement_for("gym", None, "0.26.0") + def _build_gym_env(self, env, pixels_only): # noqa: F811 + return PixelObservationWrapper(env, pixels_only=pixels_only) + + @implement_for("gym", "0.26.0", None) + def _build_gym_env(self, env, pixels_only): # noqa: F811 + from gym.wrappers.compatibility import EnvCompatibility + + if env.render_mode: + return PixelObservationWrapper(env, pixels_only=pixels_only) + + warnings.warn( + "Environments provided to GymWrapper that need to be wrapped in PixelObservationWrapper " + "should be created with `gym.make(env_name, render_mode=mode)` where possible," + 'where mode is either "rgb_array" or any other supported mode.' + ) + # resetting as 0.26 comes with a very 'nice' OrderEnforcing wrapper + env = EnvCompatibility(env) + env.reset() + return LegacyPixelObservationWrapper(env, pixels_only=pixels_only) + @_classproperty def available_envs(cls) -> List[str]: return _get_envs() @@ -208,29 +221,35 @@ def available_envs(cls) -> List[str]: def lib(self) -> ModuleType: return gym - def _set_seed(self, seed: int) -> int: - skip = False + def _set_seed(self, seed: int) -> int: # noqa: F811 if self._seed_calls_reset is None: - if gym_version < version.parse("0.19.0"): - self._seed_calls_reset = False - self._env.seed(seed=seed) - else: - try: - self.reset(seed=seed) - skip = True - self._seed_calls_reset = True - except TypeError as err: - warnings.warn( - f"reset with seed kwarg returned an exception: {err}.\n" - f"Calling env.seed from now on." - ) - self._seed_calls_reset = False - if self._seed_calls_reset and not skip: + # Determine basing on gym version whether `reset` is called when setting seed. + self._set_seed_initial(seed) + elif self._seed_calls_reset: self.reset(seed=seed) - elif not self._seed_calls_reset: + else: self._env.seed(seed=seed) + return seed + @implement_for("gym", None, "0.19.0") + def _set_seed_initial(self, seed: int) -> None: # noqa: F811 + self._seed_calls_reset = False + self._env.seed(seed=seed) + + @implement_for("gym", "0.19.0", None) + def _set_seed_initial(self, seed: int) -> None: # noqa: F811 + try: + self.reset(seed=seed) + self._seed_calls_reset = True + except TypeError as err: + warnings.warn( + f"reset with seed kwarg returned an exception: {err}.\n" + f"Calling env.seed from now on." + ) + self._seed_calls_reset = False + self._env.seed(seed=seed) + def _make_specs(self, env: "gym.Env") -> None: self.action_spec = _gym_to_torchrl_spec_transform( env.action_space, @@ -294,15 +313,25 @@ class GymEnv(GymWrapper): def __init__(self, env_name, disable_env_checker=None, **kwargs): kwargs["env_name"] = env_name - if gym_version >= version.parse("0.24.0"): - kwargs["disable_env_checker"] = ( - disable_env_checker if disable_env_checker is not None else True - ) - elif disable_env_checker is not None: + self._set_gym_args(kwargs, disable_env_checker) + super().__init__(**kwargs) + + @implement_for("gym", None, "0.24.0") + def _set_gym_args( # noqa: F811 + self, kwargs, disable_env_checker: bool = None + ) -> None: + if disable_env_checker is not None: raise RuntimeError( "disable_env_checker should only be set if gym version is > 0.24" ) - super().__init__(**kwargs) + + @implement_for("gym", "0.24.0", None) + def _set_gym_args( # noqa: F811 + self, kwargs, disable_env_checker: bool = None + ) -> None: + kwargs["disable_env_checker"] = ( + disable_env_checker if disable_env_checker is not None else True + ) def _build_env( self, @@ -316,8 +345,7 @@ def _build_env( f" {self.git_url}" ) from_pixels = kwargs.get("from_pixels", False) - if from_pixels and gym_version > version.parse("0.25.0"): - kwargs.setdefault("render_mode", "rgb_array") + self._set_gym_default(kwargs, from_pixels) if "from_pixels" in kwargs: del kwargs["from_pixels"] pixels_only = kwargs.get("pixels_only", True) @@ -350,6 +378,16 @@ def _build_env( raise err return super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels) + @implement_for("gym", None, "0.25.1") + def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811 + # Do nothing for older gym versions. + pass + + @implement_for("gym", "0.25.1", None) + def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811 + if from_pixels: + kwargs.setdefault("render_mode", "rgb_array") + @property def env_name(self): return self._constructor_kwargs["env_name"]