Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Fine control over devices in collectors #1835

Merged
merged 69 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
35d8847
init
vmoens Jan 25, 2024
0f9fad2
amend
vmoens Jan 25, 2024
213c477
amend
vmoens Jan 25, 2024
a1fd23f
amend
vmoens Jan 25, 2024
17aef47
clear_device_
vmoens Jan 25, 2024
601d6d6
stream
vmoens Jan 25, 2024
439022b
amend
vmoens Jan 25, 2024
d26a8d2
amend
vmoens Jan 25, 2024
0e20e03
amend
vmoens Jan 26, 2024
c56fda4
amend
vmoens Jan 26, 2024
bd9ee28
amend
vmoens Jan 26, 2024
61b1f87
amend
vmoens Jan 26, 2024
03a3747
amend
vmoens Jan 26, 2024
515de7d
amend
vmoens Jan 26, 2024
b994996
amend
vmoens Jan 26, 2024
dc264c6
amend
vmoens Jan 26, 2024
12a8ae1
amend
vmoens Jan 28, 2024
2150f0b
amend
vmoens Jan 28, 2024
0759eb1
lint
vmoens Jan 28, 2024
5d3d9c4
amend
vmoens Jan 28, 2024
c1870f9
amend
vmoens Jan 29, 2024
b033426
amend
vmoens Jan 29, 2024
0dcbc0e
Merge remote-tracking branch 'origin/main' into fix-device-collectors
vmoens Jan 29, 2024
7759b3f
amend
vmoens Jan 29, 2024
bd06a5e
amend
vmoens Jan 29, 2024
f0406eb
amend
vmoens Jan 29, 2024
b63aca3
amend
vmoens Jan 29, 2024
0f08d26
amend
vmoens Jan 29, 2024
2e4445c
amend
vmoens Jan 29, 2024
14ab95c
Merge remote-tracking branch 'origin/main' into fix-device-collectors
vmoens Jan 29, 2024
a4d1df6
amend
vmoens Jan 29, 2024
5976de0
Merge remote-tracking branch 'origin/main' into fix-device-collectors
vmoens Jan 29, 2024
b215157
amend
vmoens Jan 29, 2024
5560b70
amend
vmoens Jan 30, 2024
b98a806
amend
vmoens Jan 30, 2024
1f1dd92
amend
vmoens Jan 30, 2024
fc228d9
amend
vmoens Jan 30, 2024
1ef6ce1
amend
vmoens Jan 30, 2024
6258a92
amend
vmoens Jan 30, 2024
1e2fc62
amend
vmoens Jan 30, 2024
4feac80
amend
vmoens Jan 30, 2024
b12e564
amend
vmoens Jan 30, 2024
90d2343
amend
vmoens Jan 30, 2024
78d7c43
Merge remote-tracking branch 'origin/fix-device-collectors' into fix-…
vmoens Jan 30, 2024
8fd0ba3
amend
vmoens Jan 30, 2024
cec3ab1
amend
vmoens Jan 30, 2024
377944c
lint
vmoens Jan 30, 2024
4688f98
amend
vmoens Jan 30, 2024
9ec6786
amend
vmoens Jan 30, 2024
3b0653a
amend
vmoens Jan 30, 2024
9dcc6c3
amend
vmoens Jan 30, 2024
0d921c5
amend
vmoens Jan 30, 2024
8412a0c
amend
vmoens Jan 30, 2024
27c8605
amend
vmoens Jan 30, 2024
6326c6d
amend
vmoens Jan 30, 2024
5350bcf
amend
vmoens Jan 30, 2024
ecaa8b0
amend
vmoens Jan 30, 2024
5312610
amend
vmoens Jan 30, 2024
1221b4e
amend
vmoens Jan 30, 2024
6df335a
amend
vmoens Jan 30, 2024
7cd328e
amend
vmoens Jan 30, 2024
6f4c637
amend
vmoens Jan 30, 2024
52a1197
amend
vmoens Jan 30, 2024
f66c73b
amend
vmoens Jan 30, 2024
abe7436
lint
vmoens Jan 30, 2024
5f86a24
amend
vmoens Jan 30, 2024
205d224
lint
vmoens Jan 30, 2024
76bb4cf
amend
vmoens Jan 30, 2024
a7110b0
remove asserts
vmoens Jan 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Jan 29, 2024
commit b63aca3460873436b5c497145aebd741a457c70e
37 changes: 20 additions & 17 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,15 @@ def __init__(self, out_features: int):
self.linear = nn.LazyLinear(out_features)

def forward(self, tensordict):
return TensorDict(
{self.out_keys[0]: self.linear(tensordict.get(self.in_keys[0]))},
[],
return tensordict.set(
self.out_keys[0], self.linear(tensordict.get(self.in_keys[0]))
)


class UnwrappablePolicy(nn.Module):
def __init__(self, out_features: int):
super().__init__()
self.linear = nn.LazyLinear(out_features)
self.linear = nn.Linear(2, out_features)

def forward(self, observation, other_stuff):
return self.linear(observation), other_stuff.sum()
Expand Down Expand Up @@ -1495,28 +1494,33 @@ def test_no_wrap_compatible_module(self, collector_class, env_maker):
)

if collector_class is not SyncDataCollector:
assert all(
isinstance(p, TensorDictCompatiblePolicy)
for p in collector._policy_dict.values()
)
assert all(
p.out_keys == ["action"] for p in collector._policy_dict.values()
)
assert all(p is policy for p in collector._policy_dict.values())
# We now do the casting only on the remote workers
pass
else:
assert isinstance(collector.policy, TensorDictCompatiblePolicy)
assert collector.policy.out_keys == ["action"]
assert collector.policy is policy

for i, data in enumerate(collector):
if i == 0:
assert (data["action"] != 0).any()
for p in policy.parameters():
p.data.zero_()
assert p.device == torch.device("cpu")
collector.update_policy_weights_()
elif i == 4:
assert (data["action"] == 0).all()
break

collector.shutdown()
del collector

def test_auto_wrap_error(self, collector_class, env_maker):
policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1])

with pytest.raises(
TypeError,
match=(r"Arguments to policy.forward are incompatible with entries in"),
):
) if collector_class is SyncDataCollector else pytest.raises(EOFError):
collector_class(
**self._create_collector_kwargs(env_maker, collector_class, policy)
)
Expand Down Expand Up @@ -1631,8 +1635,7 @@ def env_fn(seed):
frames_per_batch=frames_per_batch,
init_random_frames=-1,
reset_at_each_iter=False,
devices=get_default_devices()[0],
storing_devices=get_default_devices()[0],
device=get_default_devices()[0],
split_trajs=False,
preemptive_threshold=0.0, # stop after one iteration
)
Expand Down Expand Up @@ -1818,7 +1821,7 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, frames_per_batch=20):
)


class TestHetEnvsCollector:
class TestHeterogeneousEnvsCollector:
@pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)])
@pytest.mark.parametrize("frames_per_batch", [4, 8, 16])
def test_collector_heterogeneous_env(
Expand Down
Loading