Skip to content

Commit

Permalink
[Refactor] Faster and more generic multi-agent nets (pytorch#1921)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 20, 2024
1 parent 799f939 commit ca42794
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 144 deletions.
2 changes: 1 addition & 1 deletion test/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def test_initialize_stats_from_observation_norms(device, keys, composed, initial
with pytest.raises(
ValueError, match="Attempted to use an uninitialized parameter"
):
pre_init_state_dict = t_env.transform.state_dict()
t_env.transform.state_dict()
return
pre_init_state_dict = t_env.transform.state_dict()
initialize_observation_norm_transforms(
Expand Down
110 changes: 94 additions & 16 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,24 +858,15 @@ def _get_mock_input_td(
@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
@pytest.mark.parametrize(
"batch",
[
(10,),
(
10,
3,
),
(),
],
)
def test_mlp(
@pytest.mark.parametrize("n_agent_inputs", [6, None])
@pytest.mark.parametrize("batch", [(10,), (10, 3), ()])
def test_multiagent_mlp(
self,
n_agents,
centralised,
share_params,
batch,
n_agent_inputs=6,
n_agent_inputs,
n_agent_outputs=2,
):
torch.manual_seed(0)
Expand All @@ -887,6 +878,8 @@ def test_mlp(
share_params=share_params,
depth=2,
)
if n_agent_inputs is None:
n_agent_inputs = 6
td = self._get_mock_input_td(n_agents, n_agent_inputs, batch=batch)
obs = td.get(("agents", "observation"))

Expand Down Expand Up @@ -921,17 +914,63 @@ def test_mlp(
# same input different output
assert not torch.allclose(out[..., i, :], out[..., j, :])

def test_multiagent_mlp_lazy(self):
mlp = MultiAgentMLP(
n_agent_inputs=None,
n_agent_outputs=6,
n_agents=3,
centralised=True,
share_params=False,
depth=2,
)
optim = torch.optim.Adam(mlp.parameters())
for p in mlp.parameters():
if isinstance(p, torch.nn.parameter.UninitializedParameter):
break
else:
raise AssertionError("No UninitializedParameter found")
for p in optim.param_groups[0]["params"]:
if isinstance(p, torch.nn.parameter.UninitializedParameter):
break
else:
raise AssertionError("No UninitializedParameter found")
for _ in range(2):
td = self._get_mock_input_td(3, 4, batch=(10,))
obs = td.get(("agents", "observation"))
out = mlp(obs)
out.mean().backward()
optim.step()
for p in mlp.parameters():
if isinstance(p, torch.nn.parameter.UninitializedParameter):
raise AssertionError("UninitializedParameter found")
for p in optim.param_groups[0]["params"]:
if isinstance(p, torch.nn.parameter.UninitializedParameter):
raise AssertionError("UninitializedParameter found")

@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
@pytest.mark.parametrize("channels", [3, None])
@pytest.mark.parametrize("batch", [(10,), (10, 3), ()])
def test_cnn(
self, n_agents, centralised, share_params, batch, x=50, y=50, channels=3
def test_multiagent_cnn(
self,
n_agents,
centralised,
share_params,
batch,
channels,
x=50,
y=50,
):
torch.manual_seed(0)
cnn = MultiAgentConvNet(
n_agents=n_agents, centralised=centralised, share_params=share_params
n_agents=n_agents,
centralised=centralised,
share_params=share_params,
in_features=channels,
)
if channels is None:
channels = 3
td = TensorDict(
{
"agents": TensorDict(
Expand Down Expand Up @@ -973,6 +1012,45 @@ def test_cnn(
# same input different output
assert not torch.allclose(out[..., i, :], out[..., j, :])

def test_multiagent_cnn_lazy(self):
cnn = MultiAgentConvNet(
n_agents=5,
centralised=False,
share_params=False,
in_features=None,
)
optim = torch.optim.Adam(cnn.parameters())
for p in cnn.parameters():
if isinstance(p, torch.nn.parameter.UninitializedParameter):
break
else:
raise AssertionError("No UninitializedParameter found")
for p in optim.param_groups[0]["params"]:
if isinstance(p, torch.nn.parameter.UninitializedParameter):
break
else:
raise AssertionError("No UninitializedParameter found")
for _ in range(2):
td = TensorDict(
{
"agents": TensorDict(
{"observation": torch.randn(10, 5, 3, 50, 50)},
[10, 5],
)
},
batch_size=[10],
)
obs = td[("agents", "observation")]
out = cnn(obs)
out.mean().backward()
optim.step()
for p in cnn.parameters():
if isinstance(p, torch.nn.parameter.UninitializedParameter):
raise AssertionError("UninitializedParameter found")
for p in optim.param_groups[0]["params"]:
if isinstance(p, torch.nn.parameter.UninitializedParameter):
raise AssertionError("UninitializedParameter found")

@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize(
"batch",
Expand Down
Loading

0 comments on commit ca42794

Please sign in to comment.