Skip to content

Commit

Permalink
[Feature] Return depth from RoboHiveEnv (pytorch#2058)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vmoens@meta.com>
  • Loading branch information
sriramsk1999 and vmoens authored Apr 22, 2024
1 parent ee0bfe5 commit 68ef60b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
7 changes: 4 additions & 3 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3396,10 +3396,11 @@ class TestRoboHive:
# In the CI, robohive should not coexist with other libs so that's fine.
# Robohive logging behaviour can be controlled via ROBOHIVE_VERBOSITY=ALL/INFO/(WARN)/ERROR/ONCE/ALWAYS/SILENT
@pytest.mark.parametrize("from_pixels", [False, True])
@pytest.mark.parametrize("from_depths", [False, True])
@pytest.mark.parametrize("envname", RoboHiveEnv.available_envs)
def test_robohive(self, envname, from_pixels):
def test_robohive(self, envname, from_pixels, from_depths):
with set_gym_backend("gymnasium"):
torchrl_logger.info(f"{envname}-{from_pixels}")
torchrl_logger.info(f"{envname}-{from_pixels}-{from_depths}")
if any(
substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s")
):
Expand All @@ -3415,7 +3416,7 @@ def test_robohive(self, envname, from_pixels):
torchrl_logger.info("no camera")
return
try:
env = RoboHiveEnv(envname, from_pixels=from_pixels)
env = RoboHiveEnv(envname, from_pixels=from_pixels, from_depths=from_depths)
except AttributeError as err:
if "'MjData' object has no attribute 'get_body_xipos'" in str(err):
torchrl_logger.info("tcdm are broken")
Expand Down
36 changes: 26 additions & 10 deletions torchrl/envs/libs/robohive.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ class RoboHiveEnv(GymEnv, metaclass=_RoboHiveBuild):
be returned (by default under the ``"pixels"`` entry in the output tensordict).
If ``False``, observations (eg, states) and pixels will be returned
whenever ``from_pixels=True``. Defaults to ``True``.
from_depths (bool, optional): if ``True``, an attempt to return the depth
observations from the env will be performed. By default, these observations
will be written under the ``"depths"`` entry. Requires ``from_pixels`` to be ``True``.
Defaults to ``False``.
frame_skip (int, optional): if provided, indicates for how many steps the
same action is to be repeated. The observation returned will be the
last observation of the sequence, whereas the reward will be the sum
Expand Down Expand Up @@ -155,6 +159,7 @@ def _build_env( # noqa: F811
env_name: str,
from_pixels: bool = False,
pixels_only: bool = False,
from_depths: bool = False,
**kwargs,
) -> "gym.core.Env": # noqa: F821
if from_pixels:
Expand All @@ -168,7 +173,9 @@ def _build_env( # noqa: F811
)
kwargs["cameras"] = self.get_available_cams(env_name)
cams = list(kwargs.pop("cameras"))
env_name = self.register_visual_env(cams=cams, env_name=env_name)
env_name = self.register_visual_env(
cams=cams, env_name=env_name, from_depths=from_depths
)

elif "cameras" in kwargs and kwargs["cameras"]:
raise RuntimeError("Got a list of cameras but from_pixels is set to False.")
Expand All @@ -194,10 +201,6 @@ def _build_env( # noqa: F811
**kwargs,
)
self.wrapper_frame_skip = 1
if env.visual_keys:
from_pixels = bool(len(env.visual_keys))
else:
from_pixels = False
except TypeError as err:
if "unexpected keyword argument 'frameskip" not in str(err):
raise err
Expand All @@ -209,6 +212,7 @@ def _build_env( # noqa: F811
# except Exception as err:
# raise RuntimeError(f"Failed to build env {env_name}.") from err
self.from_pixels = from_pixels
self.from_depths = from_depths
self.render_device = render_device
if kwargs.get("read_info", True):
self.set_info_dict_reader(self.read_info)
Expand All @@ -224,7 +228,7 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
return out

@classmethod
def register_visual_env(cls, env_name, cams):
def register_visual_env(cls, env_name, cams, from_depths):
with set_directory(cls.CURR_DIR):
from robohive.envs.env_variants import register_env_variant

Expand All @@ -233,9 +237,9 @@ def register_visual_env(cls, env_name, cams):
cams = sorted(cams)
cams_rep = [i.replace("A:", "A_") for i in cams]
new_env_name = "-".join([cam[:-3] for cam in cams_rep] + [env_name])
if new_env_name in cls.env_list:
return new_env_name
visual_keys = [f"rgb:{c}:224x224:2d" for c in cams]
if from_depths:
visual_keys.extend([f"d:{c}:224x224:2d" for c in cams])
register_env_variant(
env_name,
variants={
Expand All @@ -262,20 +266,26 @@ def get_obs():
if self.from_pixels:
visual = self.env.get_exteroception()
obs_dict.update(visual)
pixel_list = []
pixel_list, depth_list = [], []
for obs_key in obs_dict:
if obs_key.startswith("rgb"):
pix = obs_dict[obs_key]
if not pix.shape[0] == 1:
pix = pix[None]
pixel_list.append(pix)
elif obs_key.startswith("d:"):
dep = obs_dict[obs_key]
dep = dep[None]
depth_list.append(dep)
elif obs_key in env.obs_keys:
value = env.obs_dict[obs_key]
if not value.shape:
value = value[None]
_dict[obs_key] = value
if pixel_list:
_dict["pixels"] = np.concatenate(pixel_list, 0)
if depth_list:
_dict["depths"] = np.concatenate(depth_list, 0)
return _dict

for i in range(3):
Expand Down Expand Up @@ -335,7 +345,7 @@ def read_obs(self, observation):
pass
# recover vec
obsdict = {}
pixel_list = []
pixel_list, depth_list = [], []
if self.from_pixels:
visual = self.env.get_exteroception()
observations.update(visual)
Expand All @@ -345,6 +355,10 @@ def read_obs(self, observation):
if not pix.shape[0] == 1:
pix = pix[None]
pixel_list.append(pix)
elif key.startswith("d:"):
dep = observations[key]
dep = dep[None]
depth_list.append(dep)
elif key in self._env.obs_keys:
value = observations[key]
if not value.shape:
Expand All @@ -354,6 +368,8 @@ def read_obs(self, observation):
# obsvec = np.concatenate(obsvec, 0)
if self.from_pixels:
obsdict.update({"pixels": np.concatenate(pixel_list, 0)})
if self.from_pixels and self.from_depths:
obsdict.update({"depths": np.concatenate(depth_list, 0)})
out = obsdict
return super().read_obs(out)

Expand Down

0 comments on commit 68ef60b

Please sign in to comment.