Skip to content

Commit

Permalink
[BugFix] Fix support for MiniGrid envs (pytorch#2416)
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtamohler authored Sep 4, 2024
1 parent 5a81930 commit 60cd104
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 13 deletions.
1 change: 1 addition & 0 deletions .github/unittest/linux_libs/scripts_gym/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies:
- pip:
# Initial version is required to install Atari ROMS in setup_env.sh
- gym[atari]==0.13
- minigrid
- hypothesis
- future
- cloudpickle
Expand Down
28 changes: 28 additions & 0 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,16 @@

_has_meltingpot = importlib.util.find_spec("meltingpot") is not None

_has_minigrid = importlib.util.find_spec("minigrid") is not None


@pytest.fixture(scope="session", autouse=True)
def maybe_init_minigrid():
if _has_minigrid and _has_gymnasium:
import minigrid

minigrid.register_minigrid_envs()


def get_gym_pixel_wrapper():
try:
Expand Down Expand Up @@ -1279,6 +1289,24 @@ def test_resetting_strategies(self, heterogeneous):
gc.collect()


@pytest.mark.skipif(
not _has_minigrid or not _has_gymnasium, reason="MiniGrid not found"
)
class TestMiniGrid:
@pytest.mark.parametrize(
"id",
[
"BabyAI-KeyCorridorS6R3-v0",
"MiniGrid-Empty-16x16-v0",
"MiniGrid-BlockedUnlockPickup-v0",
],
)
def test_minigrid(self, id):
env_base = gymnasium.make(id)
env = GymWrapper(env_base)
check_env_specs(env)


@implement_for("gym", None, "0.26")
def _make_gym_environment(env_name): # noqa: F811
gym = gym_backend()
Expand Down
13 changes: 8 additions & 5 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

import numpy as np
import torch
from tensordict import TensorDict, TensorDictBase
from tensordict import NonTensorData, TensorDict, TensorDictBase
from torchrl._utils import logger as torchrl_logger

from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded
from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded
from torchrl.envs.common import _EnvWrapper, EnvBase


Expand Down Expand Up @@ -283,9 +283,12 @@ def read_obs(
observations = observations_dict
else:
for key, val in observations.items():
observations[key] = self.observation_spec[key].encode(
val, ignore_device=True
)
if isinstance(self.observation_spec[key], NonTensor):
observations[key] = NonTensorData(val)
else:
observations[key] = self.observation_spec[key].encode(
val, ignore_device=True
)
return observations

def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
Expand Down
33 changes: 25 additions & 8 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Composite,
MultiCategorical,
MultiOneHot,
NonTensor,
OneHot,
TensorSpec,
Unbounded,
Expand All @@ -55,6 +56,14 @@

_has_mo = importlib.util.find_spec("mo_gymnasium") is not None
_has_sb3 = importlib.util.find_spec("stable_baselines3") is not None
_has_minigrid = importlib.util.find_spec("minigrid") is not None


def _minigrid_lib():
assert _has_minigrid, "minigrid not found"
import minigrid

return minigrid


class set_gym_backend(_DecoratorContextManager):
Expand Down Expand Up @@ -369,6 +378,8 @@ def _gym_to_torchrl_spec_transform(
categorical_action_encoding=categorical_action_encoding,
remap_state_to_observation=remap_state_to_observation,
)
elif _has_minigrid and isinstance(spec, _minigrid_lib().core.mission.MissionSpace):
return NonTensor((), device=device)
else:
raise NotImplementedError(
f"spec of type {type(spec).__name__} is currently unaccounted for"
Expand Down Expand Up @@ -766,14 +777,20 @@ def __init__(self, env=None, categorical_action_encoding=False, **kwargs):
self._seed_calls_reset = None
self._categorical_action_encoding = categorical_action_encoding
if env is not None:
if "EnvCompatibility" in str(
env
): # a hacky way of knowing if EnvCompatibility is part of the wrappers of env
raise ValueError(
"GymWrapper does not support the gym.wrapper.compatibility.EnvCompatibility wrapper. "
"If this feature is needed, detail your use case in an issue of "
"https://github.com/pytorch/rl/issues."
)
try:
env_str = str(env)
except TypeError:
# MiniGrid has a bug where the __str__ method fails
pass
else:
if (
"EnvCompatibility" in env_str
): # a hacky way of knowing if EnvCompatibility is part of the wrappers of env
raise ValueError(
"GymWrapper does not support the gym.wrapper.compatibility.EnvCompatibility wrapper. "
"If this feature is needed, detail your use case in an issue of "
"https://github.com/pytorch/rl/issues."
)
libname = self.get_library_name(env)
with set_gym_backend(libname):
kwargs["env"] = env
Expand Down

0 comments on commit 60cd104

Please sign in to comment.