Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Fine control over devices in collectors #1835

Merged
merged 69 commits into from
Jan 30, 2024
Merged
Changes from 1 commit
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
35d8847
init
vmoens Jan 25, 2024
0f9fad2
amend
vmoens Jan 25, 2024
213c477
amend
vmoens Jan 25, 2024
a1fd23f
amend
vmoens Jan 25, 2024
17aef47
clear_device_
vmoens Jan 25, 2024
601d6d6
stream
vmoens Jan 25, 2024
439022b
amend
vmoens Jan 25, 2024
d26a8d2
amend
vmoens Jan 25, 2024
0e20e03
amend
vmoens Jan 26, 2024
c56fda4
amend
vmoens Jan 26, 2024
bd9ee28
amend
vmoens Jan 26, 2024
61b1f87
amend
vmoens Jan 26, 2024
03a3747
amend
vmoens Jan 26, 2024
515de7d
amend
vmoens Jan 26, 2024
b994996
amend
vmoens Jan 26, 2024
dc264c6
amend
vmoens Jan 26, 2024
12a8ae1
amend
vmoens Jan 28, 2024
2150f0b
amend
vmoens Jan 28, 2024
0759eb1
lint
vmoens Jan 28, 2024
5d3d9c4
amend
vmoens Jan 28, 2024
c1870f9
amend
vmoens Jan 29, 2024
b033426
amend
vmoens Jan 29, 2024
0dcbc0e
Merge remote-tracking branch 'origin/main' into fix-device-collectors
vmoens Jan 29, 2024
7759b3f
amend
vmoens Jan 29, 2024
bd06a5e
amend
vmoens Jan 29, 2024
f0406eb
amend
vmoens Jan 29, 2024
b63aca3
amend
vmoens Jan 29, 2024
0f08d26
amend
vmoens Jan 29, 2024
2e4445c
amend
vmoens Jan 29, 2024
14ab95c
Merge remote-tracking branch 'origin/main' into fix-device-collectors
vmoens Jan 29, 2024
a4d1df6
amend
vmoens Jan 29, 2024
5976de0
Merge remote-tracking branch 'origin/main' into fix-device-collectors
vmoens Jan 29, 2024
b215157
amend
vmoens Jan 29, 2024
5560b70
amend
vmoens Jan 30, 2024
b98a806
amend
vmoens Jan 30, 2024
1f1dd92
amend
vmoens Jan 30, 2024
fc228d9
amend
vmoens Jan 30, 2024
1ef6ce1
amend
vmoens Jan 30, 2024
6258a92
amend
vmoens Jan 30, 2024
1e2fc62
amend
vmoens Jan 30, 2024
4feac80
amend
vmoens Jan 30, 2024
b12e564
amend
vmoens Jan 30, 2024
90d2343
amend
vmoens Jan 30, 2024
78d7c43
Merge remote-tracking branch 'origin/fix-device-collectors' into fix-…
vmoens Jan 30, 2024
8fd0ba3
amend
vmoens Jan 30, 2024
cec3ab1
amend
vmoens Jan 30, 2024
377944c
lint
vmoens Jan 30, 2024
4688f98
amend
vmoens Jan 30, 2024
9ec6786
amend
vmoens Jan 30, 2024
3b0653a
amend
vmoens Jan 30, 2024
9dcc6c3
amend
vmoens Jan 30, 2024
0d921c5
amend
vmoens Jan 30, 2024
8412a0c
amend
vmoens Jan 30, 2024
27c8605
amend
vmoens Jan 30, 2024
6326c6d
amend
vmoens Jan 30, 2024
5350bcf
amend
vmoens Jan 30, 2024
ecaa8b0
amend
vmoens Jan 30, 2024
5312610
amend
vmoens Jan 30, 2024
1221b4e
amend
vmoens Jan 30, 2024
6df335a
amend
vmoens Jan 30, 2024
7cd328e
amend
vmoens Jan 30, 2024
6f4c637
amend
vmoens Jan 30, 2024
52a1197
amend
vmoens Jan 30, 2024
f66c73b
amend
vmoens Jan 30, 2024
abe7436
lint
vmoens Jan 30, 2024
5f86a24
amend
vmoens Jan 30, 2024
205d224
lint
vmoens Jan 30, 2024
76bb4cf
amend
vmoens Jan 30, 2024
a7110b0
remove asserts
vmoens Jan 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
lint
  • Loading branch information
vmoens committed Jan 30, 2024
commit 377944c9141c635678d070d42ff537140a669c1f
14 changes: 10 additions & 4 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from copy import deepcopy
from multiprocessing import connection, queues
from multiprocessing.managers import SyncManager
from torch.utils._pytree import tree_map
from textwrap import indent
from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, Union

Expand All @@ -34,6 +33,7 @@
from tensordict.nn import TensorDictModule, TensorDictModuleBase
from tensordict.utils import NestedKey
from torch import multiprocessing as mp
from torch.utils._pytree import tree_map
from torch.utils.data import IterableDataset

from torchrl._utils import (
Expand Down Expand Up @@ -244,7 +244,9 @@ def _make_compatible_policy(self, policy, observation_spec=None):
else:
policy_device = None
if policy_device:
next_observation = tree_map(lambda x: x.to(policy_device), next_observation)
next_observation = tree_map(
lambda x: x.to(policy_device), next_observation
)
output = policy(**next_observation)

if isinstance(output, tuple):
Expand Down Expand Up @@ -967,7 +969,9 @@ def rollout(self) -> TensorDictBase:
policy_output = self.policy(policy_input)
if self._shuttle is not policy_output:
# ad-hoc update shuttle
self._shuttle = self._shuttle._fast_apply(self._update_device_wise, policy_output)
self._shuttle = self._shuttle._fast_apply(
self._update_device_wise, policy_output
)
# # update is a no-op if identities match, so this is safe and efficient in all cases
# # We could remove the inplace and update just the keys that have been updated by
# # policy, but that would require some tricks to check if the policy
Expand All @@ -988,7 +992,9 @@ def rollout(self) -> TensorDictBase:
env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
if self._shuttle is not env_output:
# ad-hoc update shuttle
self._shuttle = self._shuttle._fast_apply(self._update_device_wise, env_output)
self._shuttle = self._shuttle._fast_apply(
self._update_device_wise, env_output
)
# # Here we could update only the leaves that are part of the env output
# # since we have access to them and update() supports lists of keys
# # to update.
Expand Down
Loading