Skip to content

Commit

Permalink
[BugFix] Fix _reset data passing in parallel env (pytorch#1880)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 6, 2024
1 parent 62d977b commit e53eb73
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 2 deletions.
23 changes: 23 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,29 @@ def test_parallel_env_with_policy(
# env_serial.close()
env0.close()

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
@pytest.mark.parametrize("heterogeneous", [False, True])
def test_transform_env_transform_no_device(self, heterogeneous):
# Tests non-regression on 1865
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(), StepCounter(max_steps=3)
)

if heterogeneous:
make_envs = [EnvCreator(make_env), EnvCreator(make_env)]
else:
make_envs = make_env
penv = ParallelEnv(2, make_envs)
r = penv.rollout(6, break_when_any_done=False)
assert r.shape == (2, 6)
try:
env = TransformedEnv(penv)
r = env.rollout(6, break_when_any_done=False)
assert r.shape == (2, 6)
finally:
penv.close()

@pytest.mark.skipif(not _has_gym, reason="no gym")
@pytest.mark.parametrize(
"env_name",
Expand Down
25 changes: 23 additions & 2 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,7 +1284,22 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
tensordict_, keys_to_update=list(self._selected_reset_keys)
)
continue
out = ("reset", tensordict_)
if tensordict_ is not None:
tdkeys = list(tensordict_.keys(True, True))

# This way we can avoid calling select over all the keys in the shared tensordict
def tentative_update(val, other):
if other is not None:
val.copy_(other)
return val

self.shared_tensordicts[i].apply_(
tentative_update, tensordict_, default=None
)
out = ("reset", tdkeys)
else:
out = ("reset", False)

channel.send(out)
workers.append(i)

Expand Down Expand Up @@ -1509,7 +1524,13 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
torchrl_logger.info(f"resetting worker {pid}")
if not initialized:
raise RuntimeError("call 'init' before resetting")
cur_td = env.reset(tensordict=data)
# we use 'data' to pass the keys that we need to pass to reset,
# because passing the entire buffer may have unwanted consequences
cur_td = env.reset(
tensordict=root_shared_tensordict.select(*data, strict=False)
if data
else None
)
shared_tensordict.update_(
cur_td,
keys_to_update=list(_selected_reset_keys),
Expand Down

0 comments on commit e53eb73

Please sign in to comment.