Skip to content

Commit

Permalink
[Refactoring] Replace direct gym version checks with decorated functi…
Browse files Browse the repository at this point in the history
…ons (pytorch#691)

* [Refactoring] Replace gym version checking with decorated functions (#)

Initial commit. Only tests.

* Refactoring in gym.py

* More refactoring in gym.py

* Completed refactoring

* amend

* amend
  • Loading branch information
ordinskiy authored Nov 21, 2022
1 parent a3bbba0 commit b583ac1
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 154 deletions.
34 changes: 32 additions & 2 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,43 @@
# this returns relative path from current file.
import pytest
import torch.cuda
from torchrl._utils import seed_generator
from torchrl._utils import seed_generator, implement_for
from torchrl.envs import EnvBase

from torchrl.envs.libs.gym import _has_gym

# Specified for test_utils.py
__version__ = "0.3"

# Default versions of the environments.
CARTPOLE_VERSIONED = "CartPole-v1"
HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
PENDULUM_VERSIONED = "Pendulum-v1"
PONG_VERSIONED = "ALE/Pong-v5"


@implement_for("gym", None, "0.21.0")
def _set_gym_environments(): # noqa: F811
global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED

CARTPOLE_VERSIONED = "CartPole-v0"
HALFCHEETAH_VERSIONED = "HalfCheetah-v2"
PENDULUM_VERSIONED = "Pendulum-v0"
PONG_VERSIONED = "Pong-v4"


@implement_for("gym", "0.21.0", None)
def _set_gym_environments(): # noqa: F811
global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED

CARTPOLE_VERSIONED = "CartPole-v1"
HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
PENDULUM_VERSIONED = "Pendulum-v1"
PONG_VERSIONED = "ALE/Pong-v5"


if _has_gym:
_set_gym_environments()


def get_relative_path(curr_file, *path_components):
return os.path.join(os.path.dirname(curr_file), *path_components)
Expand Down
13 changes: 1 addition & 12 deletions test/smoke_test_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,10 @@
import tempfile

import pytest
from _utils_internal import PONG_VERSIONED
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
from torchrl.envs.libs.gym import _has_gym, GymEnv

if _has_gym:
import gym
from packaging import version

gym_version = version.parse(gym.__version__)
PONG_VERSIONED = (
"ALE/Pong-v5" if gym_version > version.parse("0.20.0") else "Pong-v4"
)
else:
# placeholders
PONG_VERSIONED = "ALE/Pong-v5"

try:
from torch.utils.tensorboard import SummaryWriter

Expand Down
18 changes: 1 addition & 17 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import pytest
import torch
from _utils_internal import generate_seeds
from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED
from mocking_classes import (
ContinuousActionVecMockEnv,
DiscreteActionConvMockEnv,
Expand Down Expand Up @@ -42,22 +42,6 @@
TensorDictModule,
)

if _has_gym:
import gym
from packaging import version

gym_version = version.parse(gym.__version__)
PENDULUM_VERSIONED = (
"Pendulum-v1" if gym_version > version.parse("0.20.0") else "Pendulum-v0"
)
PONG_VERSIONED = (
"ALE/Pong-v5" if gym_version > version.parse("0.20.0") else "Pong-v4"
)
else:
# placeholders
PENDULUM_VERSIONED = "Pendulum-v1"
PONG_VERSIONED = "ALE/Pong-v5"

# torch.set_default_dtype(torch.double)


Expand Down
31 changes: 9 additions & 22 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
import pytest
import torch
import yaml
from _utils_internal import get_available_devices
from _utils_internal import (
get_available_devices,
CARTPOLE_VERSIONED,
PENDULUM_VERSIONED,
PONG_VERSIONED,
HALFCHEETAH_VERSIONED,
)
from mocking_classes import (
ActionObsMergeLinear,
DiscreteActionConvMockEnv,
Expand Down Expand Up @@ -49,30 +55,11 @@
)
from torchrl.modules.tensordict_module import WorldModelWrapper

gym_version = None
if _has_gym:
import gym

gym_version = version.parse(gym.__version__)
PENDULUM_VERSIONED = (
"Pendulum-v1" if gym_version > version.parse("0.20.0") else "Pendulum-v0"
)
CARTPOLE_VERSIONED = (
"CartPole-v1" if gym_version > version.parse("0.20.0") else "CartPole-v0"
)
PONG_VERSIONED = (
"ALE/Pong-v5" if gym_version > version.parse("0.20.0") else "Pong-v4"
)
HALFCHEETAH_VERSIONED = (
"HalfCheetah-v4" if gym_version > version.parse("0.20.0") else "HalfCheetah-v2"
)
else:
# placeholder
gym_version = version.parse("0.0.1")

# placeholders
PENDULUM_VERSIONED = "Pendulum-v1"
CARTPOLE_VERSIONED = "CartPole-v1"
PONG_VERSIONED = "ALE/Pong-v5"

try:
this_dir = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -1048,7 +1035,7 @@ def test_batch_unlocked_with_batch_size(device):

@pytest.mark.skipif(not _has_gym, reason="no gym")
@pytest.mark.skipif(
gym_version < version.parse("0.20.0"),
gym_version is None or gym_version < version.parse("0.20.0"),
reason="older versions of half-cheetah do not have 'x_position' info key.",
)
def test_info_dict_reader(seed=0):
Expand Down
69 changes: 28 additions & 41 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,27 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from sys import platform

import numpy as np
import pytest
import torch
from _utils_internal import _test_fake_tensordict
from _utils_internal import get_available_devices
from _utils_internal import (
_test_fake_tensordict,
get_available_devices,
HALFCHEETAH_VERSIONED,
PONG_VERSIONED,
PENDULUM_VERSIONED,
)
from packaging import version
from tensordict.tensordict import assert_allclose_td
from torchrl._utils import implement_for
from torchrl.collectors import MultiaSyncDataCollector
from torchrl.collectors.collectors import RandomPolicy
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.dm_control import DMControlEnv, DMControlWrapper
from torchrl.envs.libs.dm_control import _has_dmc
from torchrl.envs.libs.gym import GymEnv, GymWrapper
from torchrl.envs.libs.gym import _has_gym, _is_from_pixels
from torchrl.envs.libs.habitat import HabitatEnv, _has_habitat
from torchrl.envs.libs.jumanji import JumanjiEnv, _has_jumanji
Expand All @@ -32,39 +43,8 @@
from dm_control import suite
from dm_control.suite.wrappers import pixels

from sys import platform

from tensordict.tensordict import assert_allclose_td
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.dm_control import DMControlEnv, DMControlWrapper
from torchrl.envs.libs.gym import GymEnv, GymWrapper

IS_OSX = platform == "darwin"

if _has_gym:
from packaging import version

gym_version = version.parse(gym.__version__)
PENDULUM_VERSIONED = (
"Pendulum-v1" if gym_version > version.parse("0.20.0") else "Pendulum-v0"
)
HC_VERSIONED = (
"HalfCheetah-v4" if gym_version > version.parse("0.20.0") else "HalfCheetah-v2"
)
PONG_VERSIONED = (
"ALE/Pong-v5" if gym_version > version.parse("0.20.0") else "Pong-v4"
)

# if gym_version < version.parse("0.24.0") and torch.cuda.device_count() > 0:
# from opengl_rendering import create_opengl_context
#
# create_opengl_context()
else:
# placeholders
PENDULUM_VERSIONED = "Pendulum-v1"
HC_VERSIONED = "HalfCheetah-v4"
PONG_VERSIONED = "ALE/Pong-v5"


@pytest.mark.skipif(not _has_gym, reason="no gym library found")
@pytest.mark.parametrize(
Expand Down Expand Up @@ -123,10 +103,7 @@ def test_gym(self, env_name, frame_skip, from_pixels, pixels_only):
base_env = gym.make(env_name, frameskip=frame_skip)
frame_skip = 1
else:
if gym_version < version.parse("0.26.0"):
base_env = gym.make(env_name)
else:
base_env = gym.make(env_name, render_mode="rgb_array")
base_env = _make_gym_environment(env_name)

if from_pixels and not _is_from_pixels(base_env):
base_env = PixelObservationWrapper(base_env, pixels_only=pixels_only)
Expand Down Expand Up @@ -164,6 +141,16 @@ def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only):
_test_fake_tensordict(env)


@implement_for("gym", None, "0.26")
def _make_gym_environment(env_name): # noqa: F811
return gym.make(env_name)


@implement_for("gym", "0.26", None)
def _make_gym_environment(env_name): # noqa: F811
return gym.make(env_name, render_mode="rgb_array")


@pytest.mark.skipif(not _has_dmc, reason="no dm_control library found")
@pytest.mark.parametrize("env_name,task", [["cheetah", "run"]])
@pytest.mark.parametrize("frame_skip", [1, 3])
Expand Down Expand Up @@ -270,9 +257,9 @@ def test_faketd(self, env_name, task, frame_skip, from_pixels, pixels_only):
"env_lib,env_args,env_kwargs",
[
[DMControlEnv, ("cheetah", "run"), {"from_pixels": True}],
[GymEnv, (HC_VERSIONED,), {"from_pixels": True}],
[GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": True}],
[DMControlEnv, ("cheetah", "run"), {"from_pixels": False}],
[GymEnv, (HC_VERSIONED,), {"from_pixels": False}],
[GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": False}],
[GymEnv, (PONG_VERSIONED,), {}],
],
)
Expand Down Expand Up @@ -307,9 +294,9 @@ def test_td_creation_from_spec(env_lib, env_args, env_kwargs):
"env_lib,env_args,env_kwargs",
[
[DMControlEnv, ("cheetah", "run"), {"from_pixels": True}],
[GymEnv, (HC_VERSIONED,), {"from_pixels": True}],
[GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": True}],
[DMControlEnv, ("cheetah", "run"), {"from_pixels": False}],
[GymEnv, (HC_VERSIONED,), {"from_pixels": False}],
[GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": False}],
[GymEnv, (PONG_VERSIONED,), {}],
],
)
Expand Down
19 changes: 6 additions & 13 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
import numpy as np
import pytest
import torch
from _utils_internal import get_available_devices, retry, dtype_fixture # noqa
from _utils_internal import ( # noqa
get_available_devices,
retry,
dtype_fixture,
PENDULUM_VERSIONED,
)
from mocking_classes import (
ContinuousActionVecMockEnv,
DiscreteActionConvMockEnvNumpy,
Expand Down Expand Up @@ -59,18 +64,6 @@
)
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform

if _has_gym:
import gym
from packaging import version

gym_version = version.parse(gym.__version__)
PENDULUM_VERSIONED = (
"Pendulum-v1" if gym_version > version.parse("0.20.0") else "Pendulum-v0"
)
else:
# placeholders
PENDULUM_VERSIONED = "Pendulum-v1"

TIMEOUT = 10.0


Expand Down
Loading

0 comments on commit b583ac1

Please sign in to comment.