Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Gym 'vectorized' envs compatibility #1519

Merged
merged 20 commits into from
Sep 17, 2023
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Sep 15, 2023
commit a9a0ebeef3b37e224a97861682a97cb571673779
7 changes: 6 additions & 1 deletion test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@
from torchrl.envs.libs.openml import OpenMLEnv
from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv
from torchrl.envs.libs.robohive import RoboHiveEnv
from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env
from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper
from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType
from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator
from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env

_has_d4rl = importlib.util.find_spec("d4rl") is not None

Expand Down Expand Up @@ -395,6 +395,11 @@ def test_vecenvs(self): # noqa: F811
env = GymEnv(envname, num_envs=2, from_pixels=True)
check_env_specs(env)

@implement_for("gym", None, "0.18")
def test_vecenvs(self): # noqa: F811
# skipping tests for older versions of gym
return


@implement_for("gym", None, "0.26")
def _make_gym_environment(env_name): # noqa: F811
Expand Down
6 changes: 2 additions & 4 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
import torch

from tensordict import TensorDictBase
from torchrl.envs.batched_envs import CloudpickleWrapper

from torchrl.envs.utils import _classproperty

from torchrl._utils import implement_for
from torchrl.data.tensor_specs import (
Expand All @@ -31,14 +28,15 @@
UnboundedContinuousTensorSpec,
)
from torchrl.data.utils import numpy_to_torch_dtype_dict
from torchrl.envs.batched_envs import CloudpickleWrapper

from torchrl.envs.gym_like import (
BaseInfoDictReader,
default_info_dict_reader,
GymLikeEnv,
)

from torchrl.envs.utils import _classproperty
from torchrl.envs.gym_like import default_info_dict_reader, GymLikeEnv

try:
from torch.utils._contextlib import _DecoratorContextManager
Expand Down