Skip to content

Commit

Permalink
[BugFix] Fix robohive (pytorch#2080)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 16, 2024
1 parent bedd2b7 commit d2cfd28
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
1 change: 1 addition & 0 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3423,6 +3423,7 @@ def test_robohive(self, envname, from_pixels):
for val in env.rollout(4).values(True):
if is_tensor_collection(val):
assert not isinstance(val, LazyStackedTensorDict)
assert not val.is_empty()
check_env_specs(env)


Expand Down
9 changes: 7 additions & 2 deletions torchrl/envs/libs/robohive.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
*_, info = self.env.step(self.env.action_space.sample())
info = self.read_info(info, TensorDict({}, []))
info = info.get("info")
self.observation_spec["observation"] = make_composite_from_td(info)
self.observation_spec["info"] = make_composite_from_td(info)
return out

@classmethod
Expand Down Expand Up @@ -309,6 +309,11 @@ def get_obs():
self.observation_spec.update(spec)
self.empty_cache()

def _reset_output_transform(self, reset_data):
if not (isinstance(reset_data, tuple) and len(reset_data) == 2):
return reset_data, {}
return reset_data

def set_from_pixels(self, from_pixels: bool) -> None:
"""Sets the from_pixels attribute to an existing environment.
Expand Down Expand Up @@ -353,7 +358,6 @@ def read_obs(self, observation):
return super().read_obs(out)

def read_info(self, info, tensordict_out):
out = {}
if not info:
info_spec = self.observation_spec.get("info", None)
if info_spec is None:
Expand All @@ -364,6 +368,7 @@ def read_info(self, info, tensordict_out):
TensorDict(info, [])
.filter_non_tensor_data()
.exclude("obs_dict", "done", "reward", *self._env.obs_keys, "act")
.apply(lambda x: x, filter_empty=True)
)
if "info" in self.observation_spec.keys():
info_spec = self.observation_spec["info"]
Expand Down

0 comments on commit d2cfd28

Please sign in to comment.