Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix info reading with async gym #2150

Merged
merged 6 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed May 2, 2024
commit 4a92125da0cb0861dbb3fca987f48cc16721b16f
14 changes: 11 additions & 3 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,10 +1104,10 @@ def __init__(self, dim=3, use_termination=True, max_steps=4):
self.dim = dim
self.use_termination = use_termination
self.observation_space = gym_backend("spaces").Box(
low=-np.inf, high=np.inf, shape=(self.dim,)
low=-np.inf, high=np.inf, shape=(self.dim,), dtype=np.float32
)
self.action_space = gym_backend("spaces").Box(
low=-np.inf, high=np.inf, shape=(1,)
low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32
)
self.max_steps = max_steps

Expand All @@ -1118,7 +1118,9 @@ def _get_obs(self):
return self.state.copy()

def reset(self, seed=0, options=None):
self.state = np.zeros(self.observation_space.shape)
self.state = np.zeros(
self.observation_space.shape, dtype=np.float32
)
observation = self._get_obs()
info = self._get_info()
assert (observation < self.max_steps).all()
Expand Down Expand Up @@ -1197,6 +1199,12 @@ def test_resetting_strategies(self, heterogeneous):
r2 = env.rollout(10, break_when_any_done=False)
assert_allclose_td(r0, r1)
assert_allclose_td(r1, r2)
for r in (r0, r1, r2):
torch.testing.assert_close(r["field1"], r["observation"].pow(2))
torch.testing.assert_close(
r["next", "field1"], r["next", "observation"].pow(2)
)

finally:
if not env.is_closed:
env.close()
Expand Down
28 changes: 23 additions & 5 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class default_info_dict_reader(BaseInfoDictReader):
correspondent key to form a :class:`torchrl.data.CompositeSpec`.
If not provided, a composite spec with :class:`~torchrl.data.UnboundedContinuousTensorSpec`
specs will lazyly be created.
ignore_private (bool, optional): If ``True``, private infos (starting with
an underscore) will be ignored. Defaults to ``True``.

In cases where keys can be directly written to a tensordict (mostly if they abide to the
tensordict shape), one simply needs to indicate the keys to be registered during
Expand All @@ -74,7 +76,9 @@ def __init__(
| Dict[str, TensorSpec]
| CompositeSpec
| None = None,
ignore_private: bool = True,
):
self.ignore_private = ignore_private
self._lazy = False
if keys is None:
self._lazy = True
Expand Down Expand Up @@ -113,11 +117,17 @@ def __call__(
keys = self.keys
if keys is None:
keys = info_dict.keys()
if self.ignore_private:
keys = [key for key in keys if not key.startswith("_")]
self.keys = keys
info_spec = None if self.info_spec is not None else CompositeSpec()
for key in keys:
if key in info_dict:
tensordict.set(key, info_dict[key])
if info_dict[key].dtype == np.dtype("O"):
val = np.stack(info_dict[key])
else:
val = info_dict[key]
tensordict.set(key, val)
if info_spec is not None:
val = tensordict.get(key)
info_spec[key] = UnboundedContinuousTensorSpec(
Expand Down Expand Up @@ -422,7 +432,9 @@ def _reset_output_transform(self, reset_outputs_tuple: Tuple) -> Tuple:
...

def set_info_dict_reader(
self, info_dict_reader: BaseInfoDictReader | None = None
self,
info_dict_reader: BaseInfoDictReader | None = None,
ignore_private: bool = True,
) -> GymLikeEnv:
"""Sets an info_dict_reader function.

Expand All @@ -436,6 +448,8 @@ def set_info_dict_reader(
This function should modify the tensordict in-place. If none is
provided, :class:`~torchrl.envs.gym_like.default_info_dict_reader`
will be used.
ignore_private (bool, optional): If ``True``, private infos (starting with
an underscore) will be ignored. Defaults to ``True``.

Returns: the same environment with the dict_reader registered.

Expand All @@ -456,7 +470,7 @@ def set_info_dict_reader(

"""
if info_dict_reader is None:
info_dict_reader = default_info_dict_reader()
info_dict_reader = default_info_dict_reader(ignore_private=ignore_private)
self.info_dict_reader.append(info_dict_reader)
if isinstance(info_dict_reader, BaseInfoDictReader):
# if we have a BaseInfoDictReader, we know what the specs will be
Expand All @@ -481,7 +495,7 @@ def set_info_dict_reader(

return self

def auto_register_info_dict(self):
def auto_register_info_dict(self, ignore_private: bool = True):
"""Automatically registers the info dict.

It is assumed that all the information contained in the info dict can be registered as numerical values
Expand All @@ -494,6 +508,10 @@ def auto_register_info_dict(self):
This method requires running a few iterations in the environment to
manually check that the behaviour matches expectations.

Args:
ignore_private (bool, optional): If ``True``, private infos (starting with
an underscore) will be ignored. Defaults to ``True``.

Examples:
>>> from torchrl.envs import GymEnv
>>> env = GymEnv("HalfCheetah-v4")
Expand All @@ -504,7 +522,7 @@ def auto_register_info_dict(self):

if self.info_dict_reader:
raise RuntimeError("The environment already has an info-dict reader.")
self.set_info_dict_reader()
self.set_info_dict_reader(ignore_private=ignore_private)
try:
check_env_specs(self)
return self
Expand Down
33 changes: 26 additions & 7 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,9 @@ def __call__(cls, *args, **kwargs):
)
add_info_dict = False
if add_info_dict:
# First register the basic info dict reader
instance.auto_register_info_dict()
# Make it so that infos are properly cast where they should at done time
instance.set_info_dict_reader(
terminal_obs_reader(instance.observation_spec, backend=backend)
)
Expand Down Expand Up @@ -1494,12 +1497,15 @@ class terminal_obs_reader(BaseInfoDictReader):
"sb3": "terminal_observation",
"gym": "final_observation",
}
backend_info_key = {
"sb3": "terminal_info",
"gym": "final_info",
}

def __init__(self, observation_spec: CompositeSpec, backend, name="final"):
self.name = name
self._info_spec = CompositeSpec(
{(self.name, key): item.clone() for key, item in observation_spec.items()},
shape=observation_spec.shape,
{name: observation_spec.clone()}, shape=observation_spec.shape
)
self.backend = backend

Expand Down Expand Up @@ -1539,14 +1545,27 @@ def _read_obs(self, obs, key, tensor, index):

def __call__(self, info_dict, tensordict):
terminal_obs = info_dict.get(self.backend_key[self.backend], None)
for key, item in self.info_spec.items(True, True):
key = (key,) if isinstance(key, str) else key
final_obs_buffer = item.zero()
terminal_info = info_dict.get(self.backend_info_key[self.backend], None)
if terminal_obs is not None:
terminal_info = {
key: np.stack([info[key] for info in terminal_info])
for key in terminal_info[0].keys()
}
else:
terminal_info = {}
obs_dict = terminal_info.copy()
if terminal_obs is not None:
obs_dict["observation"] = terminal_obs
for key, terminal_obs in obs_dict.items():
spec = self.info_spec[self.name, key]
# for key, item in self.info_spec.items(True, True):
# key = (key,) if isinstance(key, str) else key
final_obs_buffer = spec.zero()
if terminal_obs is not None:
for i, obs in enumerate(terminal_obs):
# writes final_obs inplace with terminal_obs content
self._read_obs(obs, key[-1], final_obs_buffer, index=i)
tensordict.set(key, final_obs_buffer)
self._read_obs(obs, key, final_obs_buffer, index=i)
tensordict.set((self.name, key), final_obs_buffer)
return tensordict


Expand Down
Loading