Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix info reading with async gym #2150

Merged
merged 6 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,10 +1104,10 @@ def __init__(self, dim=3, use_termination=True, max_steps=4):
self.dim = dim
self.use_termination = use_termination
self.observation_space = gym_backend("spaces").Box(
low=-np.inf, high=np.inf, shape=(self.dim,)
low=-np.inf, high=np.inf, shape=(self.dim,), dtype=np.float32
)
self.action_space = gym_backend("spaces").Box(
low=-np.inf, high=np.inf, shape=(1,)
low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32
)
self.max_steps = max_steps

Expand All @@ -1118,7 +1118,9 @@ def _get_obs(self):
return self.state.copy()

def reset(self, seed=0, options=None):
self.state = np.zeros(self.observation_space.shape)
self.state = np.zeros(
self.observation_space.shape, dtype=np.float32
)
observation = self._get_obs()
info = self._get_info()
assert (observation < self.max_steps).all()
Expand Down Expand Up @@ -1197,6 +1199,12 @@ def test_resetting_strategies(self, heterogeneous):
r2 = env.rollout(10, break_when_any_done=False)
assert_allclose_td(r0, r1)
assert_allclose_td(r1, r2)
for r in (r0, r1, r2):
torch.testing.assert_close(r["field1"], r["observation"].pow(2))
torch.testing.assert_close(
r["next", "field1"], r["next", "observation"].pow(2)
)

finally:
if not env.is_closed:
env.close()
Expand Down
35 changes: 30 additions & 5 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class default_info_dict_reader(BaseInfoDictReader):
correspondent key to form a :class:`torchrl.data.CompositeSpec`.
If not provided, a composite spec with :class:`~torchrl.data.UnboundedContinuousTensorSpec`
specs will lazyly be created.
ignore_private (bool, optional): If ``True``, private infos (starting with
an underscore) will be ignored. Defaults to ``True``.

In cases where keys can be directly written to a tensordict (mostly if they abide to the
tensordict shape), one simply needs to indicate the keys to be registered during
Expand All @@ -74,7 +76,9 @@ def __init__(
| Dict[str, TensorSpec]
| CompositeSpec
| None = None,
ignore_private: bool = True,
):
self.ignore_private = ignore_private
self._lazy = False
if keys is None:
self._lazy = True
Expand Down Expand Up @@ -113,16 +117,29 @@ def __call__(
keys = self.keys
if keys is None:
keys = info_dict.keys()
if self.ignore_private:
keys = [key for key in keys if not key.startswith("_")]
self.keys = keys
# create an info_spec only if there is none
info_spec = None if self.info_spec is not None else CompositeSpec()
for key in keys:
if key in info_dict:
tensordict.set(key, info_dict[key])
if info_dict[key].dtype == np.dtype("O"):
val = np.stack(info_dict[key])
else:
val = info_dict[key]
tensordict.set(key, val)
if info_spec is not None:
val = tensordict.get(key)
info_spec[key] = UnboundedContinuousTensorSpec(
val.shape, device=val.device, dtype=val.dtype
)
elif self.info_spec is not None:
# Fill missing with 0s
tensordict.set(key, self.info_spec[key].zero())
else:
raise KeyError(f"The key {key} could not be found or inferred.")
# set the info spec if there wasn't any - this should occur only once in this class
if info_spec is not None:
if tensordict.device is not None:
info_spec = info_spec.to(tensordict.device)
Expand Down Expand Up @@ -422,7 +439,9 @@ def _reset_output_transform(self, reset_outputs_tuple: Tuple) -> Tuple:
...

def set_info_dict_reader(
self, info_dict_reader: BaseInfoDictReader | None = None
self,
info_dict_reader: BaseInfoDictReader | None = None,
ignore_private: bool = True,
) -> GymLikeEnv:
"""Sets an info_dict_reader function.

Expand All @@ -436,6 +455,8 @@ def set_info_dict_reader(
This function should modify the tensordict in-place. If none is
provided, :class:`~torchrl.envs.gym_like.default_info_dict_reader`
will be used.
ignore_private (bool, optional): If ``True``, private infos (starting with
an underscore) will be ignored. Defaults to ``True``.

Returns: the same environment with the dict_reader registered.

Expand All @@ -456,7 +477,7 @@ def set_info_dict_reader(

"""
if info_dict_reader is None:
info_dict_reader = default_info_dict_reader()
info_dict_reader = default_info_dict_reader(ignore_private=ignore_private)
self.info_dict_reader.append(info_dict_reader)
if isinstance(info_dict_reader, BaseInfoDictReader):
# if we have a BaseInfoDictReader, we know what the specs will be
Expand All @@ -481,7 +502,7 @@ def set_info_dict_reader(

return self

def auto_register_info_dict(self):
def auto_register_info_dict(self, ignore_private: bool = True):
"""Automatically registers the info dict.

It is assumed that all the information contained in the info dict can be registered as numerical values
Expand All @@ -494,6 +515,10 @@ def auto_register_info_dict(self):
This method requires running a few iterations in the environment to
manually check that the behaviour matches expectations.

Args:
ignore_private (bool, optional): If ``True``, private infos (starting with
an underscore) will be ignored. Defaults to ``True``.

Examples:
>>> from torchrl.envs import GymEnv
>>> env = GymEnv("HalfCheetah-v4")
Expand All @@ -504,7 +529,7 @@ def auto_register_info_dict(self):

if self.info_dict_reader:
raise RuntimeError("The environment already has an info-dict reader.")
self.set_info_dict_reader()
self.set_info_dict_reader(ignore_private=ignore_private)
try:
check_env_specs(self)
return self
Expand Down
40 changes: 33 additions & 7 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,9 @@ def __call__(cls, *args, **kwargs):
)
add_info_dict = False
if add_info_dict:
# First register the basic info dict reader
instance.auto_register_info_dict()
# Make it so that infos are properly cast where they should at done time
instance.set_info_dict_reader(
terminal_obs_reader(instance.observation_spec, backend=backend)
)
Expand Down Expand Up @@ -1494,12 +1497,15 @@ class terminal_obs_reader(BaseInfoDictReader):
"sb3": "terminal_observation",
"gym": "final_observation",
}
backend_info_key = {
"sb3": "terminal_info",
"gym": "final_info",
}

def __init__(self, observation_spec: CompositeSpec, backend, name="final"):
self.name = name
self._info_spec = CompositeSpec(
{(self.name, key): item.clone() for key, item in observation_spec.items()},
shape=observation_spec.shape,
{name: observation_spec.clone()}, shape=observation_spec.shape
)
self.backend = backend

Expand Down Expand Up @@ -1539,14 +1545,34 @@ def _read_obs(self, obs, key, tensor, index):

def __call__(self, info_dict, tensordict):
terminal_obs = info_dict.get(self.backend_key[self.backend], None)
for key, item in self.info_spec.items(True, True):
key = (key,) if isinstance(key, str) else key
final_obs_buffer = item.zero()
terminal_info = info_dict.get(self.backend_info_key[self.backend], None)
if terminal_info is not None:
# terminal_info is a list of items that can be None or not
# If they're not None, they are a dict of values that we want to put in a root dict
keys = set()
for info in terminal_info:
if info is None:
continue
keys = keys.union(info.keys())
terminal_info = {
key: [info[key] if info is not None else info for info in terminal_info]
for key in keys
}
else:
terminal_info = {}
obs_dict = terminal_info.copy()
if terminal_obs is not None:
obs_dict["observation"] = terminal_obs
for key, terminal_obs in obs_dict.items():
spec = self.info_spec[self.name, key]
# for key, item in self.info_spec.items(True, True):
# key = (key,) if isinstance(key, str) else key
final_obs_buffer = spec.zero()
if terminal_obs is not None:
for i, obs in enumerate(terminal_obs):
# writes final_obs inplace with terminal_obs content
self._read_obs(obs, key[-1], final_obs_buffer, index=i)
tensordict.set(key, final_obs_buffer)
self._read_obs(obs, key, final_obs_buffer, index=i)
tensordict.set((self.name, key), final_obs_buffer)
return tensordict


Expand Down
10 changes: 5 additions & 5 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7000,13 +7000,13 @@ def _reset(
# if not reset.any(), we don't need to do anything.
# if reset.all(), we don't either (bc GymWrapper will call a plain reset).
if reset is not None and reset.any():
if reset.all():
# We're fine: this means that a full reset was passed and the
# env was manually reset
tensordict_reset.pop(self.final_name, None)
return tensordict_reset
saved_next = self._memo["saved_next"]
if saved_next is None:
if reset.all():
# We're fine: this means that a full reset was passed and the
# env was manually reset
tensordict_reset.pop(self.final_name, None)
return tensordict_reset
raise RuntimeError(
"Did not find a saved tensordict while the reset mask was "
f"not empty: reset={reset}. Done was {done}."
Expand Down
Loading