Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Store MARL parameters in module #2351

Merged
merged 2 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,51 @@ def one_outofplace(mod):
mlp.from_stateful_net(snet)
assert (mlp.params == 1).all()

@retry(AssertionError, 5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the effect of this?

@pytest.mark.parametrize("n_agents", [3])
@pytest.mark.parametrize("share_params", [True])
@pytest.mark.parametrize("centralized", [True])
@pytest.mark.parametrize("n_agent_inputs", [6])
@pytest.mark.parametrize("batch", [(4,)])
@pytest.mark.parametrize("tdparams", [True, False])
def test_multiagent_mlp_tdparams(
self,
n_agents,
centralized,
share_params,
batch,
n_agent_inputs,
tdparams,
n_agent_outputs=2,
):
torch.manual_seed(1)
mlp = MultiAgentMLP(
n_agent_inputs=n_agent_inputs,
n_agent_outputs=n_agent_outputs,
n_agents=n_agents,
centralized=centralized,
share_params=share_params,
depth=2,
use_td_params=tdparams,
)
if tdparams:
assert list(mlp._empty_net.parameters()) == []
assert list(mlp.params.parameters()) == list(mlp.parameters())
else:
assert list(mlp._empty_net.parameters()) == list(mlp.parameters())
assert not hasattr(mlp.params, "parameters")
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
return
mlp = nn.Sequential(mlp)
mlp_device = mlp.to(device)
param_set = set(mlp.parameters())
for p in mlp[0].params.values(True, True):
assert p in param_set

def test_multiagent_mlp_lazy(self):
mlp = MultiAgentMLP(
n_agent_inputs=None,
Expand Down
30 changes: 28 additions & 2 deletions torchrl/modules/models/multiagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
share_params: bool | None = None,
agent_dim: int | None = None,
vmap_randomness: str = "different",
use_td_params: bool = True,
**kwargs,
):
super().__init__()
Expand All @@ -53,6 +54,7 @@ def __init__(
if agent_dim is None:
raise TypeError("agent_dim arg must be passed.")

self.use_td_params = use_td_params
self.n_agents = n_agents
self.share_params = share_params
self.centralized = centralized
Expand All @@ -70,6 +72,7 @@ def __init__(
break
self.initialized = initialized
self._make_params(agent_networks)

# We make sure all params and buffers are on 'meta' device
# To do this, we set the device keyword arg to 'meta', we also temporarily change
# the default device. Finally, we convert all params to 'meta' tensors that are not params.
Expand All @@ -87,6 +90,8 @@ def __init__(
TensorDict.from_module(self._empty_net).data.to("meta").to_module(
self._empty_net
)
if not self.use_td_params:
self.params.to_module(self._empty_net)

@property
def vmap_randomness(self):
Expand All @@ -100,9 +105,13 @@ def vmap_randomness(self):

def _make_params(self, agent_networks):
if self.share_params:
self.params = TensorDict.from_module(agent_networks[0], as_module=True)
self.params = TensorDict.from_module(
agent_networks[0], as_module=self.use_td_params
)
else:
self.params = TensorDict.from_modules(*agent_networks, as_module=True)
self.params = TensorDict.from_modules(
*agent_networks, as_module=self.use_td_params
)

@abc.abstractmethod
def _build_single_net(self, *, device, **kwargs):
Expand Down Expand Up @@ -289,6 +298,8 @@ class MultiAgentMLP(MultiAgentNetBase):
the number of inputs is lazily instantiated during the first call.
n_agent_outputs (int): number of outputs for each agent.
n_agents (int): number of agents.

Keyword Args:
centralized (bool): If `centralized` is True, each agent will use the inputs of all agents to compute its output
(n_agent_inputs * n_agents will be the number of inputs for one agent).
Otherwise, each agent will only use its data as input.
Expand All @@ -307,6 +318,11 @@ class MultiAgentMLP(MultiAgentNetBase):
default: 32.
activation_class (Type[nn.Module]): activation class to be used.
default: nn.Tanh.
use_td_params (bool, optional): if ``True``, the parameters can be found in `self.params` which is a
:class:`~tensordict.nn.TensorDictParams` object (which inherits both from `TensorDict` and `nn.Module`).
If ``False``, parameters are contained in `self._empty_net`. All things considered, these two approaches
should be roughly identical but not interchangeable: for instance, a ``state_dict`` created with
``use_td_params=True`` cannot be used when ``use_td_params=False``.
**kwargs: for :class:`torchrl.modules.models.MLP` can be passed to customize the MLPs.

.. note:: to initialize the MARL module parameters with the `torch.nn.init`
Expand Down Expand Up @@ -399,12 +415,14 @@ def __init__(
n_agent_inputs: int | None,
n_agent_outputs: int,
n_agents: int,
*,
centralized: bool | None = None,
share_params: bool | None = None,
device: Optional[DEVICE_TYPING] = None,
depth: Optional[int] = None,
num_cells: Optional[Union[Sequence, int]] = None,
activation_class: Optional[Type[nn.Module]] = nn.Tanh,
use_td_params: bool = True,
**kwargs,
):
self.n_agents = n_agents
Expand All @@ -422,6 +440,7 @@ def __init__(
share_params=share_params,
device=device,
agent_dim=-2,
use_td_params=use_td_params,
**kwargs,
)

Expand Down Expand Up @@ -483,6 +502,11 @@ class MultiAgentConvNet(MultiAgentNetBase):
Defaults to ``2``.
activation_class (Type[nn.Module]): activation class to be used.
Default to :class:`torch.nn.ELU`.
use_td_params (bool, optional): if ``True``, the parameters can be found in `self.params` which is a
:class:`~tensordict.nn.TensorDictParams` object (which inherits both from `TensorDict` and `nn.Module`).
If ``False``, parameters are contained in `self._empty_net`. All things considered, these two approaches
should be roughly identical but not interchangeable: for instance, a ``state_dict`` created with
``use_td_params=True`` cannot be used when ``use_td_params=False``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defaults to True missing from docs

**kwargs: for :class:`~torchrl.modules.models.ConvNet` can be passed to customize the ConvNet.


Expand Down Expand Up @@ -611,6 +635,7 @@ def __init__(
strides: Union[Sequence, int] = 2,
paddings: Union[Sequence, int] = 0,
activation_class: Type[nn.Module] = nn.ELU,
use_td_params: bool = True,
**kwargs,
):
self.in_features = in_features
Expand All @@ -625,6 +650,7 @@ def __init__(
share_params=share_params,
device=device,
agent_dim=-4,
use_td_params=use_td_params,
**kwargs,
)

Expand Down
Loading