Skip to content

Commit

Permalink
[Refactor] Fix imports (pytorch#1551)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Sep 20, 2023
1 parent d517fc3 commit 4162e6d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
13 changes: 11 additions & 2 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv
from torchrl.envs.libs.openml import OpenMLEnv
from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv
from torchrl.envs.libs.robohive import RoboHiveEnv
from torchrl.envs.libs.robohive import _has_robohive, RoboHiveEnv
from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env
from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper
from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType
Expand Down Expand Up @@ -1977,9 +1977,18 @@ def test_collector(self, task, parallel):
break


@pytest.mark.skipif(not _has_robohive, reason="SMACv2 not found")
class TestRoboHive:
@pytest.mark.parametrize("envname", RoboHiveEnv.env_list)
# unfortunately we must import robohive to get the available envs
# and this import will occur whenever pytest is run on this file.
# The other option would be not to use parametrize but that also
# means less informative error trace stacks.
# In the CI, robohive should not coexist with other libs so that's fine.
# Locally these imports can be annoying, especially given the amount of
# stuff printed by robohive.
@pytest.mark.parametrize("envname", RoboHiveEnv.available_envs)
@pytest.mark.parametrize("from_pixels", [True, False])
@set_gym_backend("gym")
def test_robohive(self, envname, from_pixels):
if any(substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s")):
print("not testing envs with prebuilt rendering")
Expand Down
2 changes: 0 additions & 2 deletions torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
if torch.cuda.device_count() > 1:
n = torch.cuda.device_count() - 1
os.environ["MUJOCO_EGL_DEVICE_ID"] = str(1 + (os.getpid() % n))
# if VERBOSE:
print("MUJOCO_EGL_DEVICE_ID: ", os.environ["MUJOCO_EGL_DEVICE_ID"])

from ._extension import _init_extension

Expand Down
12 changes: 7 additions & 5 deletions torchrl/envs/libs/robohive.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,15 @@ def CURR_DIR(cls):
else:
return None

@_classproperty
def available_envs(cls):
if not _has_robohive:
return
RoboHiveEnv.register_envs()
yield from cls.env_list

@classmethod
def register_envs(cls):

if not _has_robohive:
raise ImportError(
"Cannot load robohive from the current virtual environment."
Expand Down Expand Up @@ -333,7 +339,3 @@ def get_available_cams(cls, env_name):
env = gym.make(env_name)
cams = [env.sim.model.id2name(ic, 7) for ic in range(env.sim.model.ncam)]
return cams


if _has_robohive:
RoboHiveEnv.register_envs()

0 comments on commit 4162e6d

Please sign in to comment.