Skip to content

Commit

Permalink
[BugFix] Batched envs compatibility with custom keys (#1348)
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <matbet@meta.com>
  • Loading branch information
matteobettini authored Jul 3, 2023
1 parent 9813a0e commit 7889223
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 30 deletions.
14 changes: 10 additions & 4 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,10 +1127,15 @@ def __init__(
shape=self.batch_size,
)

def _reset(self, td):
if self.nested_done and td is not None and "_reset" in td.keys():
td["_reset"] = td["_reset"].sum(-2, dtype=torch.bool)
td = super()._reset(td)
def _reset(self, tensordict):
if (
self.nested_done
and tensordict is not None
and "_reset" in tensordict.keys()
):
tensordict = tensordict.clone()
tensordict["_reset"] = tensordict["_reset"].sum(-2, dtype=torch.bool)
td = super()._reset(tensordict)
if self.nested_done:
td[self.done_key] = (
td["done"].unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1)
Expand All @@ -1149,6 +1154,7 @@ def _reset(self, td):

def _step(self, td):
if self.nested_obs_action:
td = td.clone()
td["data"].batch_size = self.batch_size
td[self.action_key] = td[self.action_key].max(-2)[0]
td_root = super()._step(td)
Expand Down
60 changes: 60 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,66 @@ def test_parallel_env_reset_flag(self, batch_size, n_workers, max_steps=3):
assert (td_reset["done"][~_reset] == 1).all()
assert (td_reset["observation"][~_reset] == max_steps + 1).all()

@pytest.mark.parametrize("nested_obs_action", [True, False])
@pytest.mark.parametrize("nested_done", [True, False])
@pytest.mark.parametrize("nested_reward", [True, False])
@pytest.mark.parametrize("env_type", ["serial", "parallel"])
def test_parallel_env_nested(
self,
nested_obs_action,
nested_done,
nested_reward,
env_type,
n_envs=2,
batch_size=(32,),
nested_dim=5,
rollout_length=3,
seed=1,
):
env_fn = lambda: NestedCountingEnv(
nest_done=nested_done,
nest_reward=nested_reward,
nest_obs_action=nested_obs_action,
batch_size=batch_size,
nested_dim=nested_dim,
)
if env_type == "serial":
env = SerialEnv(n_envs, env_fn)
else:
env = ParallelEnv(n_envs, env_fn)
env.set_seed(seed)

batch_size = (n_envs, *batch_size)

td = env.reset()
assert td.batch_size == batch_size
if nested_done or nested_obs_action:
assert td["data"].batch_size == (*batch_size, nested_dim)
if not nested_done and not nested_reward and not nested_obs_action:
assert "data" not in td.keys()

policy = CountingEnvCountPolicy(env.action_spec, env.action_key)
td = env.rollout(rollout_length, policy)
assert td.batch_size == (*batch_size, rollout_length)
if nested_done or nested_obs_action:
assert td["data"].batch_size == (*batch_size, rollout_length, nested_dim)
if nested_reward or nested_done or nested_obs_action:
assert td["next", "data"].batch_size == (
*batch_size,
rollout_length,
nested_dim,
)
if not nested_done and not nested_reward and not nested_obs_action:
assert "data" not in td.keys()
assert "data" not in td["next"].keys()

if nested_obs_action:
assert "observation" not in td.keys()
assert (td[..., -1]["data", "states"] == 2).all()
else:
assert ("data", "states") not in td.keys(True, True)
assert (td[..., -1]["observation"] == 2).all()


@pytest.mark.parametrize("batch_size", [(), (2,), (32, 5)])
def test_env_base_reset_flag(batch_size, max_steps=3):
Expand Down
8 changes: 7 additions & 1 deletion torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,10 @@ def step_mdp(
return out


def _set_single_key(source, dest, key):
def _set_single_key(source, dest, key, clone=False):
# key should be already unraveled
if isinstance(key, str):
key = (key,)
for k in key:
val = source.get(k)
if is_tensor_collection(val):
Expand All @@ -234,6 +237,8 @@ def _set_single_key(source, dest, key):
source = val
dest = new_val
else:
if clone:
val = val.clone()
dest._set(k, val)


Expand Down Expand Up @@ -482,6 +487,7 @@ def __get__(self, owner_self, owner_cls):

def _sort_keys(element):
if isinstance(element, tuple):
element = unravel_keys(element)
return "_-|-_".join(element)
return element

Expand Down
51 changes: 26 additions & 25 deletions torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
from tensordict import TensorDict
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
from tensordict.utils import unravel_keys
from torch import multiprocessing as mp

from torchrl._utils import _check_for_faulty_process, VERBOSE
Expand All @@ -33,8 +34,7 @@
from torchrl.envs.common import _EnvWrapper, EnvBase
from torchrl.envs.env_creator import get_env_metadata

from torchrl.envs.utils import _sort_keys

from torchrl.envs.utils import _set_single_key, _sort_keys

_has_envpool = importlib.util.find_spec("envpool")

Expand Down Expand Up @@ -324,46 +324,47 @@ def _create_td(self) -> None:

if self._single_task:
self.env_input_keys = sorted(
list(self.input_spec["_action_spec"].keys(True))
+ list(self.state_spec.keys(True)),
list(self.input_spec["_action_spec"].keys(True, True))
+ list(self.state_spec.keys(True, True)),
key=_sort_keys,
)
self.env_output_keys = []
self.env_obs_keys = []
for key in self.output_spec["_observation_spec"].keys(True):
if isinstance(key, str):
key = (key,)
self.env_output_keys.append(("next", *key))
for key in self.output_spec["_observation_spec"].keys(True, True):
self.env_output_keys.append(unravel_keys(("next", key)))
self.env_obs_keys.append(key)
self.env_output_keys.append(("next", "reward"))
self.env_output_keys.append(("next", "done"))
self.env_output_keys.append(unravel_keys(("next", self.reward_key)))
self.env_output_keys.append(unravel_keys(("next", self.done_key)))
else:
env_input_keys = set()
for meta_data in self.meta_data:
if meta_data.specs["input_spec", "_state_spec"] is not None:
env_input_keys = env_input_keys.union(
meta_data.specs["input_spec", "_state_spec"].keys(True)
meta_data.specs["input_spec", "_state_spec"].keys(True, True)
)
env_input_keys = env_input_keys.union(
meta_data.specs["input_spec", "_action_spec"].keys(True)
meta_data.specs["input_spec", "_action_spec"].keys(True, True)
)
env_output_keys = set()
env_obs_keys = set()
for meta_data in self.meta_data:
env_obs_keys = env_obs_keys.union(
key
for key in meta_data.specs["output_spec"]["_observation_spec"].keys(
True
True, True
)
)
env_output_keys = env_output_keys.union(
("next", key) if isinstance(key, str) else ("next", *key)
unravel_keys(("next", key))
for key in meta_data.specs["output_spec"]["_observation_spec"].keys(
True
True, True
)
)
env_output_keys = env_output_keys.union(
{("next", "reward"), ("next", "done")}
{
unravel_keys(("next", self.reward_key)),
unravel_keys(("next", self.done_key)),
}
)
self.env_obs_keys = sorted(env_obs_keys, key=_sort_keys)
self.env_input_keys = sorted(env_input_keys, key=_sort_keys)
Expand All @@ -374,10 +375,10 @@ def _create_td(self) -> None:
.union(self.env_input_keys)
.union(self.env_obs_keys)
)
self._selected_keys.add("done")
self._selected_keys.add(self.done_key)
self._selected_keys.add("_reset")

self._selected_reset_keys = self.env_obs_keys + ["done"] + ["_reset"]
self._selected_reset_keys = self.env_obs_keys + [self.done_key] + ["_reset"]
self._selected_step_keys = self.env_output_keys

if self._single_task:
Expand Down Expand Up @@ -550,7 +551,7 @@ def _step(
if self._single_task:
out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape)
for key in self._selected_step_keys:
out._set(key, self.shared_tensordict_parent.get(key).clone())
_set_single_key(self.shared_tensordict_parent, out, key, clone=True)
else:
# strict=False ensures that non-homogeneous keys are still there
out = self.shared_tensordict_parent.select(
Expand Down Expand Up @@ -619,7 +620,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape)
for key in self._selected_reset_keys:
if key != "_reset":
out._set(key, self.shared_tensordict_parent.get(key).clone())
_set_single_key(self.shared_tensordict_parent, out, key, clone=True)
return out
else:
return self.shared_tensordict_parent.select(
Expand Down Expand Up @@ -790,7 +791,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
if self._single_task:
out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape)
for key in self._selected_step_keys:
out._set(key, self.shared_tensordict_parent.get(key).clone())
_set_single_key(self.shared_tensordict_parent, out, key, clone=True)
else:
# strict=False ensures that non-homogeneous keys are still there
out = self.shared_tensordict_parent.select(
Expand Down Expand Up @@ -853,7 +854,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape)
for key in self._selected_reset_keys:
if key != "_reset":
out._set(key, self.shared_tensordict_parent.get(key).clone())
_set_single_key(self.shared_tensordict_parent, out, key, clone=True)
return out
else:
return self.shared_tensordict_parent.select(
Expand Down Expand Up @@ -1187,7 +1188,7 @@ def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:

@torch.no_grad()
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
action = tensordict.get("action")
action = tensordict.get(self.action_key)
# Action needs to be moved to CPU and converted to numpy before being passed to envpool
action = action.to(torch.device("cpu"))
step_output = self._env.step(action.numpy())
Expand Down Expand Up @@ -1285,7 +1286,7 @@ def _transform_reset_output(
)

obs = self.obs.clone(False)
obs.update({"done": self.done_spec.zero()})
obs.update({self.done_key: self.done_spec.zero()})
return obs

def _transform_step_output(
Expand All @@ -1295,7 +1296,7 @@ def _transform_step_output(
obs, reward, done, *_ = envpool_output

obs = self._treevalue_or_numpy_to_tensor_or_dict(obs)
obs.update({"reward": torch.tensor(reward), "done": done})
obs.update({self.reward_key: torch.tensor(reward), self.done_key: done})
self.obs = tensordict_out = TensorDict(
obs,
batch_size=self.batch_size,
Expand Down

0 comments on commit 7889223

Please sign in to comment.