Skip to content

Commit

Permalink
[Quality] Fix repr of MARL modules (#2192)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 31, 2024
1 parent 1405600 commit 8d99026
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 51 deletions.
50 changes: 32 additions & 18 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import re

from numbers import Number

Expand Down Expand Up @@ -759,13 +760,13 @@ def _get_mock_input_td(
@retry(AssertionError, 5)
@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
@pytest.mark.parametrize("centralized", [True, False])
@pytest.mark.parametrize("n_agent_inputs", [6, None])
@pytest.mark.parametrize("batch", [(4,), (4, 3), ()])
def test_multiagent_mlp(
self,
n_agents,
centralised,
centralized,
share_params,
batch,
n_agent_inputs,
Expand All @@ -776,7 +777,7 @@ def test_multiagent_mlp(
n_agent_inputs=n_agent_inputs,
n_agent_outputs=n_agent_outputs,
n_agents=n_agents,
centralised=centralised,
centralized=centralized,
share_params=share_params,
depth=2,
)
Expand All @@ -788,7 +789,7 @@ def test_multiagent_mlp(
out = mlp(obs)
assert out.shape == (*batch, n_agents, n_agent_outputs)
for i in range(n_agents):
if centralised and share_params:
if centralized and share_params:
assert torch.allclose(out[..., i, :], out[..., 0, :])
else:
for j in range(i + 1, n_agents):
Expand All @@ -797,7 +798,7 @@ def test_multiagent_mlp(
obs[..., 0, 0] += 1
out2 = mlp(obs)
for i in range(n_agents):
if centralised:
if centralized:
# a modification to the input of agent 0 will impact all agents
assert not torch.allclose(out[..., i, :], out2[..., i, :])
elif i > 0:
Expand All @@ -817,13 +818,26 @@ def test_multiagent_mlp(
for j in range(i + 1, n_agents):
# same input different output
assert not torch.allclose(out[..., i, :], out[..., j, :])
pattern = rf"""MultiAgentMLP\(
MLP\(
\(0\): Linear\(in_features=\d+, out_features=32, bias=True\)
\(1\): Tanh\(\)
\(2\): Linear\(in_features=32, out_features=32, bias=True\)
\(3\): Tanh\(\)
\(4\): Linear\(in_features=32, out_features=2, bias=True\)
\),
n_agents={n_agents},
share_params={share_params},
centralized={centralized},
agent_dim={-2}\)"""
assert re.match(pattern, str(mlp), re.DOTALL)

def test_multiagent_mlp_lazy(self):
mlp = MultiAgentMLP(
n_agent_inputs=None,
n_agent_outputs=6,
n_agents=3,
centralised=True,
centralized=True,
share_params=False,
depth=2,
)
Expand Down Expand Up @@ -858,19 +872,19 @@ def test_multiagent_mlp_lazy(self):

@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
@pytest.mark.parametrize("centralized", [True, False])
def test_multiagent_reset_mlp(
self,
n_agents,
centralised,
centralized,
share_params,
):
actor_net = MultiAgentMLP(
n_agent_inputs=4,
n_agent_outputs=6,
num_cells=(4, 4),
n_agents=n_agents,
centralised=centralised,
centralized=centralized,
share_params=share_params,
)
params_before = actor_net.params.clone()
Expand All @@ -888,13 +902,13 @@ def test_multiagent_reset_mlp(

@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
@pytest.mark.parametrize("centralized", [True, False])
@pytest.mark.parametrize("channels", [3, None])
@pytest.mark.parametrize("batch", [(4,), (4, 3), ()])
def test_multiagent_cnn(
self,
n_agents,
centralised,
centralized,
share_params,
batch,
channels,
Expand All @@ -904,7 +918,7 @@ def test_multiagent_cnn(
torch.manual_seed(0)
cnn = MultiAgentConvNet(
n_agents=n_agents,
centralised=centralised,
centralized=centralized,
share_params=share_params,
in_features=channels,
kernel_sizes=3,
Expand All @@ -923,15 +937,15 @@ def test_multiagent_cnn(
obs = td[("agents", "observation")]
out = cnn(obs)
assert out.shape[:-1] == (*batch, n_agents)
if centralised and share_params:
if centralized and share_params:
torch.testing.assert_close(out, out[..., :1, :].expand_as(out))
else:
for i in range(n_agents):
for j in range(i + 1, n_agents):
assert not torch.allclose(out[..., i, :], out[..., j, :])
obs[..., 0, 0, 0, 0] += 1
out2 = cnn(obs)
if centralised:
if centralized:
# a modification to the input of agent 0 will impact all agents
assert not torch.isclose(out, out2).all()
elif n_agents > 1:
Expand All @@ -956,7 +970,7 @@ def test_multiagent_cnn_lazy(self):
n_channels = 3
cnn = MultiAgentConvNet(
n_agents=n_agents,
centralised=False,
centralized=False,
share_params=False,
in_features=None,
kernel_sizes=3,
Expand Down Expand Up @@ -1000,18 +1014,18 @@ def test_multiagent_cnn_lazy(self):

@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
@pytest.mark.parametrize("centralized", [True, False])
def test_multiagent_reset_cnn(
self,
n_agents,
centralised,
centralized,
share_params,
):
actor_net = MultiAgentConvNet(
in_features=4,
num_cells=[5, 5],
n_agents=n_agents,
centralised=centralised,
centralized=centralized,
share_params=share_params,
)
params_before = actor_net.params.clone()
Expand Down
Loading

0 comments on commit 8d99026

Please sign in to comment.