Skip to content

Commit

Permalink
[Feature] Store MARL parameters in module (pytorch#2351)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Aug 3, 2024
1 parent 812f936 commit 3267533
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 2 deletions.
45 changes: 45 additions & 0 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,51 @@ def one_outofplace(mod):
mlp.from_stateful_net(snet)
assert (mlp.params == 1).all()

@retry(AssertionError, 5)
@pytest.mark.parametrize("n_agents", [3])
@pytest.mark.parametrize("share_params", [True])
@pytest.mark.parametrize("centralized", [True])
@pytest.mark.parametrize("n_agent_inputs", [6])
@pytest.mark.parametrize("batch", [(4,)])
@pytest.mark.parametrize("tdparams", [True, False])
def test_multiagent_mlp_tdparams(
self,
n_agents,
centralized,
share_params,
batch,
n_agent_inputs,
tdparams,
n_agent_outputs=2,
):
torch.manual_seed(1)
mlp = MultiAgentMLP(
n_agent_inputs=n_agent_inputs,
n_agent_outputs=n_agent_outputs,
n_agents=n_agents,
centralized=centralized,
share_params=share_params,
depth=2,
use_td_params=tdparams,
)
if tdparams:
assert list(mlp._empty_net.parameters()) == []
assert list(mlp.params.parameters()) == list(mlp.parameters())
else:
assert list(mlp._empty_net.parameters()) == list(mlp.parameters())
assert not hasattr(mlp.params, "parameters")
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
return
mlp = nn.Sequential(mlp)
mlp_device = mlp.to(device)
param_set = set(mlp.parameters())
for p in mlp[0].params.values(True, True):
assert p in param_set

def test_multiagent_mlp_lazy(self):
mlp = MultiAgentMLP(
n_agent_inputs=None,
Expand Down
30 changes: 28 additions & 2 deletions torchrl/modules/models/multiagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
share_params: bool | None = None,
agent_dim: int | None = None,
vmap_randomness: str = "different",
use_td_params: bool = True,
**kwargs,
):
super().__init__()
Expand All @@ -53,6 +54,7 @@ def __init__(
if agent_dim is None:
raise TypeError("agent_dim arg must be passed.")

self.use_td_params = use_td_params
self.n_agents = n_agents
self.share_params = share_params
self.centralized = centralized
Expand All @@ -70,6 +72,7 @@ def __init__(
break
self.initialized = initialized
self._make_params(agent_networks)

# We make sure all params and buffers are on 'meta' device
# To do this, we set the device keyword arg to 'meta', we also temporarily change
# the default device. Finally, we convert all params to 'meta' tensors that are not params.
Expand All @@ -87,6 +90,8 @@ def __init__(
TensorDict.from_module(self._empty_net).data.to("meta").to_module(
self._empty_net
)
if not self.use_td_params:
self.params.to_module(self._empty_net)

@property
def vmap_randomness(self):
Expand All @@ -100,9 +105,13 @@ def vmap_randomness(self):

def _make_params(self, agent_networks):
if self.share_params:
self.params = TensorDict.from_module(agent_networks[0], as_module=True)
self.params = TensorDict.from_module(
agent_networks[0], as_module=self.use_td_params
)
else:
self.params = TensorDict.from_modules(*agent_networks, as_module=True)
self.params = TensorDict.from_modules(
*agent_networks, as_module=self.use_td_params
)

@abc.abstractmethod
def _build_single_net(self, *, device, **kwargs):
Expand Down Expand Up @@ -289,6 +298,8 @@ class MultiAgentMLP(MultiAgentNetBase):
the number of inputs is lazily instantiated during the first call.
n_agent_outputs (int): number of outputs for each agent.
n_agents (int): number of agents.
Keyword Args:
centralized (bool): If `centralized` is True, each agent will use the inputs of all agents to compute its output
(n_agent_inputs * n_agents will be the number of inputs for one agent).
Otherwise, each agent will only use its data as input.
Expand All @@ -307,6 +318,11 @@ class MultiAgentMLP(MultiAgentNetBase):
default: 32.
activation_class (Type[nn.Module]): activation class to be used.
default: nn.Tanh.
use_td_params (bool, optional): if ``True``, the parameters can be found in `self.params` which is a
:class:`~tensordict.nn.TensorDictParams` object (which inherits both from `TensorDict` and `nn.Module`).
If ``False``, parameters are contained in `self._empty_net`. All things considered, these two approaches
should be roughly identical but not interchangeable: for instance, a ``state_dict`` created with
``use_td_params=True`` cannot be used when ``use_td_params=False``.
**kwargs: for :class:`torchrl.modules.models.MLP` can be passed to customize the MLPs.
.. note:: to initialize the MARL module parameters with the `torch.nn.init`
Expand Down Expand Up @@ -399,12 +415,14 @@ def __init__(
n_agent_inputs: int | None,
n_agent_outputs: int,
n_agents: int,
*,
centralized: bool | None = None,
share_params: bool | None = None,
device: Optional[DEVICE_TYPING] = None,
depth: Optional[int] = None,
num_cells: Optional[Union[Sequence, int]] = None,
activation_class: Optional[Type[nn.Module]] = nn.Tanh,
use_td_params: bool = True,
**kwargs,
):
self.n_agents = n_agents
Expand All @@ -422,6 +440,7 @@ def __init__(
share_params=share_params,
device=device,
agent_dim=-2,
use_td_params=use_td_params,
**kwargs,
)

Expand Down Expand Up @@ -483,6 +502,11 @@ class MultiAgentConvNet(MultiAgentNetBase):
Defaults to ``2``.
activation_class (Type[nn.Module]): activation class to be used.
Default to :class:`torch.nn.ELU`.
use_td_params (bool, optional): if ``True``, the parameters can be found in `self.params` which is a
:class:`~tensordict.nn.TensorDictParams` object (which inherits both from `TensorDict` and `nn.Module`).
If ``False``, parameters are contained in `self._empty_net`. All things considered, these two approaches
should be roughly identical but not interchangeable: for instance, a ``state_dict`` created with
``use_td_params=True`` cannot be used when ``use_td_params=False``.
**kwargs: for :class:`~torchrl.modules.models.ConvNet` can be passed to customize the ConvNet.
Expand Down Expand Up @@ -611,6 +635,7 @@ def __init__(
strides: Union[Sequence, int] = 2,
paddings: Union[Sequence, int] = 0,
activation_class: Type[nn.Module] = nn.ELU,
use_td_params: bool = True,
**kwargs,
):
self.in_features = in_features
Expand All @@ -625,6 +650,7 @@ def __init__(
share_params=share_params,
device=device,
agent_dim=-4,
use_td_params=use_td_params,
**kwargs,
)

Expand Down

0 comments on commit 3267533

Please sign in to comment.