Skip to content

Commit

Permalink
[Refactor] Clearer separation between single_task and share_individua…
Browse files Browse the repository at this point in the history
…l_td (pytorch#2026)
  • Loading branch information
vmoens authored Mar 20, 2024
1 parent c77c711 commit effd868
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 48 deletions.
2 changes: 1 addition & 1 deletion test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,7 +1729,7 @@ def test_reset_heterogeneous_envs(
cls = ParallelEnv
else:
cls = SerialEnv
env = cls(2, [env1, env2], device=env_device)
env = cls(2, [env1, env2], device=env_device, share_individual_td=True)
collector = SyncDataCollector(
env,
RandomPolicy(env.action_spec),
Expand Down
120 changes: 105 additions & 15 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@
UnboundedContinuousTensorSpec,
)
from torchrl.envs import (
CatFrames,
CatTensors,
DoubleToFloat,
EnvBase,
EnvCreator,
ParallelEnv,
SerialEnv,
)
from torchrl.envs.batched_envs import _stackable
from torchrl.envs.gym_like import default_info_dict_reader
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper
Expand Down Expand Up @@ -498,19 +500,6 @@ def env_make():
lambda task=task: DMControlEnv("humanoid", task) for task in tasks
]

if not share_individual_td and not single_task:
with pytest.raises(
ValueError, match="share_individual_td must be set to None"
):
SerialEnv(3, env_make, share_individual_td=share_individual_td)
with pytest.raises(
ValueError, match="share_individual_td must be set to None"
):
maybe_fork_ParallelEnv(
3, env_make, share_individual_td=share_individual_td
)
return

env_serial = SerialEnv(3, env_make, share_individual_td=share_individual_td)
env_serial.start()
assert env_serial._single_task is single_task
Expand Down Expand Up @@ -2617,7 +2606,8 @@ def test_auto_cast_to_device(break_when_any_done):


@pytest.mark.parametrize("device", get_default_devices())
def test_backprop(device, maybe_fork_ParallelEnv):
@pytest.mark.parametrize("share_individual_td", [True, False])
def test_backprop(device, maybe_fork_ParallelEnv, share_individual_td):
# Tests that backprop through a series of single envs and through a serial env are identical
# Also tests that no backprop can be achieved with parallel env.
class DifferentiableEnv(EnvBase):
Expand Down Expand Up @@ -2677,8 +2667,14 @@ def make_env(seed, device=device):
2,
[functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)],
device=device,
share_individual_td=share_individual_td,
)
r_serial = serial_env.rollout(10, policy)
if share_individual_td:
r_serial = serial_env.rollout(10, policy)
else:
with pytest.raises(RuntimeError, match="Cannot update a view of a tensordict"):
r_serial = serial_env.rollout(10, policy)
return

g_serial = torch.autograd.grad(
r_serial["next", "reward"].sum(), policy.parameters()
Expand Down Expand Up @@ -2735,6 +2731,100 @@ def test_parallel_another_ctx():
pass


@pytest.mark.skipif(not _has_gym, reason="gym not found")
def test_single_task_share_individual_td():
cartpole = CARTPOLE_VERSIONED()
env = SerialEnv(2, lambda: GymEnv(cartpole))
assert not env.share_individual_td
assert env._single_task
env.rollout(2)
assert isinstance(env.shared_tensordict_parent, TensorDict)

env = SerialEnv(2, lambda: GymEnv(cartpole), share_individual_td=True)
assert env.share_individual_td
assert env._single_task
env.rollout(2)
assert isinstance(env.shared_tensordict_parent, LazyStackedTensorDict)

env = SerialEnv(2, [lambda: GymEnv(cartpole)] * 2)
assert not env.share_individual_td
assert env._single_task
env.rollout(2)
assert isinstance(env.shared_tensordict_parent, TensorDict)

env = SerialEnv(2, [lambda: GymEnv(cartpole)] * 2, share_individual_td=True)
assert env.share_individual_td
assert env._single_task
env.rollout(2)
assert isinstance(env.shared_tensordict_parent, LazyStackedTensorDict)

env = SerialEnv(2, [EnvCreator(lambda: GymEnv(cartpole)) for _ in range(2)])
assert not env.share_individual_td
assert not env._single_task
env.rollout(2)
assert isinstance(env.shared_tensordict_parent, TensorDict)

env = SerialEnv(
2,
[EnvCreator(lambda: GymEnv(cartpole)) for _ in range(2)],
share_individual_td=True,
)
assert env.share_individual_td
assert not env._single_task
env.rollout(2)
assert isinstance(env.shared_tensordict_parent, LazyStackedTensorDict)

# Change shape: makes results non-stackable
env = SerialEnv(
2,
[
EnvCreator(lambda: GymEnv(cartpole)),
EnvCreator(
lambda: TransformedEnv(
GymEnv(cartpole), CatFrames(N=4, dim=-1, in_keys=["observation"])
)
),
],
)
assert env.share_individual_td
assert not env._single_task
env.rollout(2)
assert isinstance(env.shared_tensordict_parent, LazyStackedTensorDict)

with pytest.raises(ValueError, match="share_individual_td=False"):
SerialEnv(
2,
[
EnvCreator(lambda: GymEnv(cartpole)),
EnvCreator(
lambda: TransformedEnv(
GymEnv(cartpole),
CatFrames(N=4, dim=-1, in_keys=["observation"]),
)
),
],
share_individual_td=False,
)


def test_stackable():
# Tests the _stackable util
stack = [TensorDict({"a": 0}), TensorDict({"b": 1})]
assert not _stackable(*stack), torch.stack(stack)
stack = [TensorDict({"a": [0]}), TensorDict({"a": 1})]
assert not _stackable(*stack)
stack = [TensorDict({"a": [0]}), TensorDict({"a": [1]})]
assert _stackable(*stack)
stack = [TensorDict({"a": [0]}), TensorDict({"a": [1], "b": {}})]
assert _stackable(*stack)
stack = [TensorDict({"a": {"b": [0]}}), TensorDict({"a": {"b": [1]}})]
assert _stackable(*stack)
stack = [TensorDict({"a": {"b": [0]}}), TensorDict({"a": {"b": 1}})]
assert not _stackable(*stack)
stack = [TensorDict({"a": "a string"}), TensorDict({"a": "another string"})]
assert _stackable(*stack)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
9 changes: 8 additions & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,13 @@ def __init__(
# we we did not receive an env device, we use the device of the env
self.env_device = self.env.device

# If the storing device is not the same as the policy device, we have
# no guarantee that the "next" entry from the policy will be on the
# same device as the collector metadata.
self._cast_to_env_device = self._cast_to_policy_device or (
self.env.device != self.storing_device
)

self.max_frames_per_traj = (
int(max_frames_per_traj) if max_frames_per_traj is not None else 0
)
Expand Down Expand Up @@ -923,7 +930,7 @@ def rollout(self) -> TensorDictBase:
policy_output, keys_to_update=self._policy_output_keys
)

if self._cast_to_policy_device:
if self._cast_to_env_device:
if self.env_device is not None:
env_input = self._shuttle.to(self.env_device, non_blocking=True)
elif self.env_device is None:
Expand Down
86 changes: 56 additions & 30 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,19 +295,11 @@ def __init__(
self._single_task = callable(create_env_fn) or (len(set(create_env_fn)) == 1)
if callable(create_env_fn):
create_env_fn = [create_env_fn for _ in range(num_workers)]
else:
if len(create_env_fn) != num_workers:
raise RuntimeError(
f"num_workers and len(create_env_fn) mismatch, "
f"got {len(create_env_fn)} and {num_workers}"
)
if (
share_individual_td is False and not self._single_task
): # then it has been explicitly set by the user
raise ValueError(
"share_individual_td must be set to None or True when using multi-task batched environments"
)
share_individual_td = True
elif len(create_env_fn) != num_workers:
raise RuntimeError(
f"num_workers and len(create_env_fn) mismatch, "
f"got {len(create_env_fn)} and {num_workers}"
)
create_env_kwargs = {} if create_env_kwargs is None else create_env_kwargs
if isinstance(create_env_kwargs, dict):
create_env_kwargs = [
Expand All @@ -322,7 +314,8 @@ def __init__(
if pin_memory:
raise ValueError("pin_memory for batched envs is deprecated")

self.share_individual_td = bool(share_individual_td)
# if share_individual_td is None, we will assess later if the output can be stacked
self.share_individual_td = share_individual_td
self._share_memory = shared_memory
self._memmap = memmap
self.allow_step_when_done = allow_step_when_done
Expand Down Expand Up @@ -365,13 +358,25 @@ def _get_metadata(
self.meta_data = meta_data.expand(
*(self.num_workers, *meta_data.batch_size)
)
if self.share_individual_td is None:
self.share_individual_td = False
else:
n_tasks = len(create_env_fn)
self.meta_data = []
for i in range(n_tasks):
self.meta_data.append(
get_env_metadata(create_env_fn[i], create_env_kwargs[i]).clone()
)
if self.share_individual_td is not True:
share_individual_td = not _stackable(
*[meta_data.tensordict for meta_data in self.meta_data]
)
if share_individual_td and self.share_individual_td is False:
raise ValueError(
"share_individual_td=False was provided but share_individual_td must "
"be True to accomodate non-stackable tensors."
)
self.share_individual_td = share_individual_td
self._set_properties()

def update_kwargs(self, kwargs: Union[dict, List[dict]]) -> None:
Expand Down Expand Up @@ -484,9 +489,14 @@ def map_device(key, value, device_map=device_map):
self.done_spec = output_spec["full_done_spec"]

self._dummy_env_str = str(meta_data[0])
self._env_tensordict = LazyStackedTensorDict.lazy_stack(
[meta_data.tensordict for meta_data in meta_data], 0
)
if self.share_individual_td:
self._env_tensordict = LazyStackedTensorDict.lazy_stack(
[meta_data.tensordict for meta_data in meta_data], 0
)
else:
self._env_tensordict = torch.stack(
[meta_data.tensordict for meta_data in meta_data], 0
)
self._batch_locked = meta_data[0].batch_locked
self.has_lazy_inputs = contains_lazy_spec(self.input_spec)

Expand All @@ -503,14 +513,11 @@ def load_state_dict(self, state_dict: OrderedDict) -> None:

def _create_td(self) -> None:
"""Creates self.shared_tensordict_parent, a TensorDict used to store the most recent observations."""
if self._single_task:
shared_tensordict_parent = self._env_tensordict.clone()
if not self._env_tensordict.shape[0] == self.num_workers:
raise RuntimeError(
"batched environment base tensordict has the wrong shape"
)
else:
shared_tensordict_parent = self._env_tensordict.clone()
shared_tensordict_parent = self._env_tensordict.clone()
if self._env_tensordict.shape[0] != self.num_workers:
raise RuntimeError(
"batched environment base tensordict has the wrong shape"
)

if self._single_task:
self._env_input_keys = sorted(
Expand All @@ -525,6 +532,7 @@ def _create_td(self) -> None:
self._env_obs_keys.append(key)
self._env_output_keys += self.reward_keys + self.done_keys
else:
# this is only possible if _single_task=False
env_input_keys = set()
for meta_data in self.meta_data:
if meta_data.specs["input_spec", "full_state_spec"] is not None:
Expand Down Expand Up @@ -577,7 +585,7 @@ def _create_td(self) -> None:
# output keys after step
self._selected_step_keys = {unravel_key(key) for key in self._env_output_keys}

if self._single_task:
if not self.share_individual_td:
shared_tensordict_parent = shared_tensordict_parent.select(
*self._selected_keys,
*(unravel_key(("next", key)) for key in self._env_output_keys),
Expand Down Expand Up @@ -807,10 +815,19 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
tensordict_ = None

_td = _env.reset(tensordict=tensordict_, **kwargs)
self.shared_tensordicts[i].update_(
_td,
keys_to_update=list(self._selected_reset_keys_filt),
)
try:
self.shared_tensordicts[i].update_(
_td,
keys_to_update=list(self._selected_reset_keys_filt),
)
except RuntimeError as err:
if "no_grad mode" in str(err):
raise RuntimeError(
"Cannot update a view of a tensordict when gradients are required. "
"To collect gradient across sub-environments, please set the "
"share_individual_td argument to True."
)
raise
selected_output_keys = self._selected_reset_keys_filt
device = self.device

Expand Down Expand Up @@ -1703,5 +1720,14 @@ def _filter_empty(tensordict):
return tensordict.select(*tensordict.keys(True, True))


def _stackable(*tensordicts):
try:
ls = LazyStackedTensorDict(*tensordicts, stack_dim=0)
ls.contiguous()
return not ls._has_exclusive_keys
except RuntimeError:
return False


# Create an alias for possible imports
_BatchedEnv = BatchedEnvBase
1 change: 0 additions & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2080,7 +2080,6 @@ def reset(
raise RuntimeError(
f"env._reset returned an object of type {type(tensordict_reset)} but a TensorDict was expected."
)

return self._reset_proc_data(tensordict, tensordict_reset)

def _reset_proc_data(self, tensordict, tensordict_reset):
Expand Down

0 comments on commit effd868

Please sign in to comment.