Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Allow usage of a different device on main and sub-envs in ParallelEnv and SerialEnv #1626

Merged
merged 71 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
1f539dd
init
vmoens Oct 6, 2023
d4c16e1
amend
vmoens Oct 6, 2023
d2321aa
amend
vmoens Oct 6, 2023
565115a
amend
vmoens Oct 6, 2023
3c46136
amend
vmoens Oct 6, 2023
78cfa41
amend
vmoens Oct 6, 2023
a6bd8eb
amend
vmoens Oct 6, 2023
04d4ae7
amend
vmoens Oct 6, 2023
3e31963
Merge remote-tracking branch 'origin/main' into step_maybe_reset
vmoens Oct 10, 2023
f1b0ea4
tensordict_
vmoens Oct 10, 2023
16b3538
amend rollout logic
vmoens Oct 10, 2023
7ad0864
amend
vmoens Oct 10, 2023
bcac398
amend
vmoens Oct 10, 2023
428f8ee
inference
vmoens Oct 10, 2023
6fbd0bd
cpu -> cuda
vmoens Oct 10, 2023
02db623
checks
vmoens Oct 10, 2023
08a8f47
using pipe instead of event
vmoens Oct 10, 2023
45e64f7
amend
vmoens Oct 10, 2023
7dd4821
amend
vmoens Oct 10, 2023
e1a2206
rm cuda event
vmoens Oct 10, 2023
dc2caab
amend
vmoens Oct 10, 2023
01ffbf9
amend
vmoens Oct 10, 2023
ac76ec3
amend
vmoens Oct 10, 2023
ceab010
amend
vmoens Oct 10, 2023
f0327c9
amend
vmoens Oct 10, 2023
518b3d1
amend
vmoens Oct 10, 2023
47dd93b
amend
vmoens Oct 10, 2023
354fb6f
amend
vmoens Oct 10, 2023
53d5f9a
amend
vmoens Oct 10, 2023
78c00e8
amend
vmoens Oct 10, 2023
f63480e
amend
vmoens Oct 10, 2023
9a3631f
amend
vmoens Oct 10, 2023
2ceb438
amend
vmoens Oct 10, 2023
6ecebda
amend
vmoens Oct 10, 2023
5c613c3
amend
vmoens Oct 10, 2023
9f97e58
amend
vmoens Oct 10, 2023
9cbcbb0
amend
vmoens Oct 10, 2023
72c4163
amend
vmoens Oct 10, 2023
6f4c374
amend
vmoens Oct 10, 2023
3657b41
empty
vmoens Oct 11, 2023
9206b93
amend
vmoens Oct 11, 2023
897123f
Merge remote-tracking branch 'origin/main' into parallel_cuda_refactor
vmoens Nov 10, 2023
2bac78c
amend
vmoens Nov 10, 2023
51a8856
amend
vmoens Nov 10, 2023
7a300f5
amend
vmoens Nov 10, 2023
b25921d
amend
vmoens Nov 10, 2023
d42348d
amend
vmoens Nov 10, 2023
f3421aa
amend
vmoens Nov 10, 2023
939ece4
amend
vmoens Nov 10, 2023
082ba9a
amend
vmoens Nov 10, 2023
4fd670f
amend
vmoens Nov 10, 2023
2a773f3
amend
vmoens Nov 10, 2023
e7cb5dd
Merge remote-tracking branch 'origin/main' into parallel_cuda_refactor
vmoens Nov 27, 2023
ea88bb4
Merge remote-tracking branch 'origin/main' into parallel_cuda_refactor
vmoens Nov 27, 2023
7c04f62
amend
vmoens Nov 27, 2023
b422775
amend
vmoens Nov 27, 2023
ed0287d
amend
vmoens Nov 27, 2023
05b0c08
amend
vmoens Nov 27, 2023
33de0fc
amend
vmoens Nov 28, 2023
492a884
amend
vmoens Nov 28, 2023
ff4799d
amend
vmoens Nov 28, 2023
f25b957
amend
vmoens Nov 28, 2023
8899dbd
amend
vmoens Nov 28, 2023
65c9deb
amend
vmoens Nov 28, 2023
77c2d6b
amend
vmoens Nov 28, 2023
7928744
amend
vmoens Nov 28, 2023
e7fda36
amend
vmoens Nov 29, 2023
fb9a03a
amend
vmoens Nov 29, 2023
d73ca22
amend
vmoens Nov 29, 2023
c4d4c6b
amend
vmoens Nov 29, 2023
b2840b0
doc
vmoens Nov 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Oct 6, 2023
commit 1f539dd4f7a10c7c8ee67d5dc1d6eb60ad6c49f0
9 changes: 1 addition & 8 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,16 +794,9 @@ def _step_and_maybe_reset(self) -> None:
traj_ids = traj_ids.clone()
# collectors do not support passing other tensors than `"_reset"`
# to `reset()`.
traj_sop = _aggregate_resets(td_reset, reset_keys=self.env.reset_keys)
td_reset = self.env.reset(td_reset)

if td_reset.batch_dims:
# better cloning here than when passing the td for stacking
# cloning is necessary to avoid modifying entries in-place
self._tensordict = torch.where(traj_sop, td_reset, self._tensordict)
else:
self._tensordict.update(td_reset)

traj_sop = _aggregate_resets(td_reset, reset_keys=self.env.reset_keys)
traj_ids[traj_sop] = traj_ids.max() + torch.arange(
1, traj_sop.sum() + 1, device=traj_ids.device
)
Expand Down
39 changes: 31 additions & 8 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import abc
from copy import deepcopy
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple

import numpy as np
import torch
Expand All @@ -26,6 +26,7 @@
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.utils import (
_replace_last,
_update_during_reset,
get_available_libraries,
step_mdp,
terminated_or_truncated,
Expand Down Expand Up @@ -1535,13 +1536,11 @@ def reset(
raise RuntimeError(
f"Env done entry '{done_key}' was (partially) True after a call to reset(). This is not allowed."
)

if tensordict is not None:
tensordict.update(tensordict_reset)
else:
tensordict = tensordict_reset
tensordict.exclude(*self.reset_keys, inplace=True)
return tensordict
return (
_update_during_reset(tensordict_reset, tensordict, self.reset_keys)
if tensordict is not None
else tensordict_reset
)

def numel(self) -> int:
return prod(self.batch_size)
Expand Down Expand Up @@ -1836,6 +1835,30 @@ def policy(td):
out_td.refine_names(..., "time")
return out_td

def step_and_maybe_reset(
self, tensordict: TensorDictBase
) -> Tuple[TensorDictBase, TensorDictBase]:
tensordict = self.step(tensordict)
tensordict_ = step_mdp(
tensordict,
keep_other=True,
exclude_action=False,
exclude_reward=True,
reward_keys=self.reward_keys,
action_keys=self.action_keys,
done_keys=self.done_keys,
)
# done and truncated are in done_keys
# We read if any key is done.
any_done = terminated_or_truncated(
tensordict,
full_done_spec=self.output_spec["full_done_spec"],
key="_reset",
)
if any_done:
tensordict_ = self.reset(tensordict_)
return tensordict, tensordict_

@property
def reset_keys(self) -> List[NestedKey]:
"""Returns a list of reset keys.
Expand Down
36 changes: 36 additions & 0 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,3 +899,39 @@ def skim_through(td, reset=reset):

reset = skim_through(data)
return reset


def _update_during_reset(
tensordict_reset: TensorDictBase,
tensordict: TensorDictBase,
reset_keys: List[NestedKey],
):
for reset_key in reset_keys:
# get the node of the reset key
if isinstance(reset_key, tuple):
# the reset key *must* have gone through unravel_key
# we don't test it to avoid induced overhead
node_key = reset_key[:-1]
node_reset = tensordict_reset.get(node_key)
node = tensordict.get(node_key)
else:
node_reset = tensordict_reset
node = tensordict
# get the reset signal
reset = tensordict.pop(reset_key, None)
if reset is None or reset.all():
# perform simple update, at a single level.
# by contract, a reset signal at one level cannot
# be followed by other resets at nested levels, so it's safe to
# simply update
node.update(node_reset)
else:
# there can be two cases: (1) the key is present in both tds,
# in which case we use the reset mask to update
# (2) the key is not present in the input tensordict, in which
# case we just return the data

# empty tensordicts won't be returned
reset = reset.reshape(node)
node.where(reset, node_reset, out=node, pad=0)
return tensordict
Loading