Skip to content

Commit

Permalink
[BugFix] Make sure ParallelEnv does not overflow mem when policy requ…
Browse files Browse the repository at this point in the history
…ires grad (pytorch#1909)
  • Loading branch information
vmoens authored Feb 15, 2024
1 parent bd7e268 commit 0314e05
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 1 deletion.
91 changes: 90 additions & 1 deletion test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import argparse
import functools
import gc
import os.path
import re
Expand Down Expand Up @@ -65,7 +66,14 @@
DiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import CatTensors, DoubleToFloat, EnvCreator, ParallelEnv, SerialEnv
from torchrl.envs import (
CatTensors,
DoubleToFloat,
EnvBase,
EnvCreator,
ParallelEnv,
SerialEnv,
)
from torchrl.envs.gym_like import default_info_dict_reader
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper
Expand Down Expand Up @@ -2473,6 +2481,87 @@ def test_auto_cast_to_device(break_when_any_done):
assert_allclose_td(rollout0, rollout1)


@pytest.mark.parametrize("device", get_default_devices())
def test_backprop(device):
# Tests that backprop through a series of single envs and through a serial env are identical
# Also tests that no backprop can be achieved with parallel env.
class DifferentiableEnv(EnvBase):
def __init__(self, device):
super().__init__(device=device)
self.observation_spec = CompositeSpec(
observation=UnboundedContinuousTensorSpec(3, device=device),
device=device,
)
self.action_spec = CompositeSpec(
action=UnboundedContinuousTensorSpec(3, device=device), device=device
)
self.reward_spec = CompositeSpec(
reward=UnboundedContinuousTensorSpec(1, device=device), device=device
)
self.seed = 0

def _set_seed(self, seed):
self.seed = seed
return seed

def _reset(self, tensordict):
td = self.observation_spec.zero().update(self.done_spec.zero())
td["observation"] = (
td["observation"].clone() + self.seed % 10
).requires_grad_()
return td

def _step(self, tensordict):
action = tensordict.get("action")
obs = (tensordict.get("observation") + action) / action.norm()
return TensorDict(
{
"reward": action.sum().unsqueeze(0),
**self.full_done_spec.zero(),
"observation": obs,
}
)

torch.manual_seed(0)
policy = Actor(torch.nn.Linear(3, 3, device=device))
env0 = DifferentiableEnv(device=device)
seed = env0.set_seed(0)
env1 = DifferentiableEnv(device=device)
env1.set_seed(seed)
r0 = env0.rollout(10, policy)
r1 = env1.rollout(10, policy)
r = torch.stack([r0, r1])
g = torch.autograd.grad(r["next", "reward"].sum(), policy.parameters())

def make_env(seed, device=device):
env = DifferentiableEnv(device=device)
env.set_seed(seed)
return env

serial_env = SerialEnv(
2,
[functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)],
device=device,
)
r_serial = serial_env.rollout(10, policy)

g_serial = torch.autograd.grad(
r_serial["next", "reward"].sum(), policy.parameters()
)
torch.testing.assert_close(g, g_serial)

p_env = ParallelEnv(
2,
[functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)],
device=device,
)
try:
r_parallel = p_env.rollout(10, policy)
assert not r_parallel.exclude("action").requires_grad
finally:
p_env.close()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
9 changes: 9 additions & 0 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,12 @@ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta):
>>> # If no cuda device is available
>>> env = ParallelEnv(N, MyEnv(..., device="cpu"))
.. warning::
ParallelEnv disable gradients in all operations (:meth:`~.step`,
:meth:`~.reset` and :meth:`~.step_and_maybe_reset`) because gradients
cannot be passed through :class:`multiprocessing.Pipe` objects.
Only :class:`~torchrl.envs.SerialEnv` will support backpropagation.
"""

def _start_workers(self) -> None:
Expand Down Expand Up @@ -1143,6 +1149,7 @@ def load_state_dict(self, state_dict: OrderedDict) -> None:
event.wait()
event.clear()

@torch.no_grad()
@_check_start
def step_and_maybe_reset(
self, tensordict: TensorDictBase
Expand Down Expand Up @@ -1205,6 +1212,7 @@ def step_and_maybe_reset(
tensordict.set("next", next_td)
return tensordict, tensordict_

@torch.no_grad()
@_check_start
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# We must use the in_keys and nothing else for the following reasons:
Expand Down Expand Up @@ -1261,6 +1269,7 @@ def select_and_clone(name, tensor):
out = out.to(device, non_blocking=self.non_blocking)
return out

@torch.no_grad()
@_check_start
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
if tensordict is not None:
Expand Down

0 comments on commit 0314e05

Please sign in to comment.