From ca42794ac0b5aadd26f28601d500dc6b8aead553 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 20 Feb 2024 21:28:39 +0000 Subject: [PATCH] [Refactor] Faster and more generic multi-agent nets (#1921) --- test/test_helpers.py | 2 +- test/test_modules.py | 110 ++++++++-- torchrl/modules/models/multiagent.py | 311 ++++++++++++++++----------- 3 files changed, 279 insertions(+), 144 deletions(-) diff --git a/test/test_helpers.py b/test/test_helpers.py index eb9620001c7..e39e6cc6082 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -512,7 +512,7 @@ def test_initialize_stats_from_observation_norms(device, keys, composed, initial with pytest.raises( ValueError, match="Attempted to use an uninitialized parameter" ): - pre_init_state_dict = t_env.transform.state_dict() + t_env.transform.state_dict() return pre_init_state_dict = t_env.transform.state_dict() initialize_observation_norm_transforms( diff --git a/test/test_modules.py b/test/test_modules.py index 68917a10d16..3d01fd04768 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -858,24 +858,15 @@ def _get_mock_input_td( @pytest.mark.parametrize("n_agents", [1, 3]) @pytest.mark.parametrize("share_params", [True, False]) @pytest.mark.parametrize("centralised", [True, False]) - @pytest.mark.parametrize( - "batch", - [ - (10,), - ( - 10, - 3, - ), - (), - ], - ) - def test_mlp( + @pytest.mark.parametrize("n_agent_inputs", [6, None]) + @pytest.mark.parametrize("batch", [(10,), (10, 3), ()]) + def test_multiagent_mlp( self, n_agents, centralised, share_params, batch, - n_agent_inputs=6, + n_agent_inputs, n_agent_outputs=2, ): torch.manual_seed(0) @@ -887,6 +878,8 @@ def test_mlp( share_params=share_params, depth=2, ) + if n_agent_inputs is None: + n_agent_inputs = 6 td = self._get_mock_input_td(n_agents, n_agent_inputs, batch=batch) obs = td.get(("agents", "observation")) @@ -921,17 +914,63 @@ def test_mlp( # same input different output assert not torch.allclose(out[..., i, :], out[..., j, :]) + def test_multiagent_mlp_lazy(self): + mlp = MultiAgentMLP( + n_agent_inputs=None, + n_agent_outputs=6, + n_agents=3, + centralised=True, + share_params=False, + depth=2, + ) + optim = torch.optim.Adam(mlp.parameters()) + for p in mlp.parameters(): + if isinstance(p, torch.nn.parameter.UninitializedParameter): + break + else: + raise AssertionError("No UninitializedParameter found") + for p in optim.param_groups[0]["params"]: + if isinstance(p, torch.nn.parameter.UninitializedParameter): + break + else: + raise AssertionError("No UninitializedParameter found") + for _ in range(2): + td = self._get_mock_input_td(3, 4, batch=(10,)) + obs = td.get(("agents", "observation")) + out = mlp(obs) + out.mean().backward() + optim.step() + for p in mlp.parameters(): + if isinstance(p, torch.nn.parameter.UninitializedParameter): + raise AssertionError("UninitializedParameter found") + for p in optim.param_groups[0]["params"]: + if isinstance(p, torch.nn.parameter.UninitializedParameter): + raise AssertionError("UninitializedParameter found") + @pytest.mark.parametrize("n_agents", [1, 3]) @pytest.mark.parametrize("share_params", [True, False]) @pytest.mark.parametrize("centralised", [True, False]) + @pytest.mark.parametrize("channels", [3, None]) @pytest.mark.parametrize("batch", [(10,), (10, 3), ()]) - def test_cnn( - self, n_agents, centralised, share_params, batch, x=50, y=50, channels=3 + def test_multiagent_cnn( + self, + n_agents, + centralised, + share_params, + batch, + channels, + x=50, + y=50, ): torch.manual_seed(0) cnn = MultiAgentConvNet( - n_agents=n_agents, centralised=centralised, share_params=share_params + n_agents=n_agents, + centralised=centralised, + share_params=share_params, + in_features=channels, ) + if channels is None: + channels = 3 td = TensorDict( { "agents": TensorDict( @@ -973,6 +1012,45 @@ def test_cnn( # same input different output assert not torch.allclose(out[..., i, :], out[..., j, :]) + def test_multiagent_cnn_lazy(self): + cnn = MultiAgentConvNet( + n_agents=5, + centralised=False, + share_params=False, + in_features=None, + ) + optim = torch.optim.Adam(cnn.parameters()) + for p in cnn.parameters(): + if isinstance(p, torch.nn.parameter.UninitializedParameter): + break + else: + raise AssertionError("No UninitializedParameter found") + for p in optim.param_groups[0]["params"]: + if isinstance(p, torch.nn.parameter.UninitializedParameter): + break + else: + raise AssertionError("No UninitializedParameter found") + for _ in range(2): + td = TensorDict( + { + "agents": TensorDict( + {"observation": torch.randn(10, 5, 3, 50, 50)}, + [10, 5], + ) + }, + batch_size=[10], + ) + obs = td[("agents", "observation")] + out = cnn(obs) + out.mean().backward() + optim.step() + for p in cnn.parameters(): + if isinstance(p, torch.nn.parameter.UninitializedParameter): + raise AssertionError("UninitializedParameter found") + for p in optim.param_groups[0]["params"]: + if isinstance(p, torch.nn.parameter.UninitializedParameter): + raise AssertionError("UninitializedParameter found") + @pytest.mark.parametrize("n_agents", [1, 3]) @pytest.mark.parametrize( "batch", diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index f6b80ead12c..6229aa30fe3 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -2,20 +2,131 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations +import abc from typing import Optional, Sequence, Tuple, Type, Union import numpy as np import torch -from torch import nn +from tensordict import TensorDict +from torch import nn from torchrl.data.utils import DEVICE_TYPING from torchrl.modules.models import ConvNet, MLP -class MultiAgentMLP(nn.Module): +class MultiAgentNetBase(nn.Module): + """A base class for multi-agent networks.""" + + _empty_net: nn.Module + + def __init__( + self, + *, + n_agents: int, + centralised: bool, + share_params: bool, + agent_dim: int, + **kwargs, + ): + super().__init__() + + self.n_agents = n_agents + self.share_params = share_params + self.centralised = centralised + self.agent_dim = agent_dim + + agent_networks = [ + self._build_single_net(**kwargs) + for _ in range(self.n_agents if not self.share_params else 1) + ] + initialized = True + for p in agent_networks[0].parameters(): + if isinstance(p, torch.nn.UninitializedParameter): + initialized = False + break + self.initialized = initialized + self._make_params(agent_networks) + kwargs["device"] = "meta" + self.__dict__["_empty_net"] = self._build_single_net(**kwargs) + + @property + def _vmap_randomness(self): + if self.initialized: + return "error" + return "same" + + def _make_params(self, agent_networks): + if self.share_params: + self.params = TensorDict.from_module(agent_networks[0], as_module=True) + else: + self.params = TensorDict.from_modules(*agent_networks, as_module=True) + + @abc.abstractmethod + def _build_single_net(self, *, device, **kwargs): + ... + + @abc.abstractmethod + def _pre_forward_check(self, inputs): + ... + + @staticmethod + def vmap_func_module(module, *args, **kwargs): + def exec_module(params, *input): + with params.to_module(module): + return module(*input) + + return torch.vmap(exec_module, *args, **kwargs) + + def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: + if len(inputs) > 1: + inputs = torch.cat([*inputs], -1) + else: + inputs = inputs[0] + + inputs = self._pre_forward_check(inputs) + # If parameters are not shared, each agent has its own network + if not self.share_params: + if self.centralised: + output = self.vmap_func_module( + self._empty_net, (0, None), (-2,), randomness=self._vmap_randomness + )(self.params, inputs) + else: + output = self.vmap_func_module( + self._empty_net, + (0, self.agent_dim), + (-2,), + randomness=self._vmap_randomness, + )(self.params, inputs) + + # If parameters are shared, agents use the same network + else: + with self.params.to_module(self._empty_net): + output = self._empty_net(inputs) + + if self.centralised: + # If the parameters are shared, and it is centralised, all agents will have the same output + # We expand it to maintain the agent dimension, but values will be the same for all agents + n_agent_outputs = output.shape[-1] + output = output.view(*output.shape[:-1], n_agent_outputs) + output = output.unsqueeze(-2) + output = output.expand( + *output.shape[:-2], self.n_agents, n_agent_outputs + ) + + if output.shape[-2] != (self.n_agents): + raise ValueError( + f"Multi-agent network expected output with shape[-2]={self.n_agents}" + f" but got {output.shape}" + ) + + return output + + +class MultiAgentMLP(MultiAgentNetBase): """Mult-agent MLP. This is an MLP that can be used in multi-agent contexts. @@ -33,7 +144,8 @@ class MultiAgentMLP(nn.Module): Otherwise, each agent will only use its data as input. Args: - n_agent_inputs (int): number of inputs for each agent. + n_agent_inputs (int or None): number of inputs for each agent. If ``None``, + the number of inputs is lazily instantiated during the first call. n_agent_outputs (int): number of outputs for each agent. n_agents (int): number of agents. centralised (bool): If `centralised` is True, each agent will use the inputs of all agents to compute its output @@ -139,7 +251,7 @@ class MultiAgentMLP(nn.Module): def __init__( self, - n_agent_inputs: int, + n_agent_inputs: int | None, n_agent_outputs: int, n_agents: int, centralised: bool, @@ -150,87 +262,52 @@ def __init__( activation_class: Optional[Type[nn.Module]] = nn.Tanh, **kwargs, ): - super().__init__() self.n_agents = n_agents self.n_agent_inputs = n_agent_inputs self.n_agent_outputs = n_agent_outputs self.share_params = share_params self.centralised = centralised + self.num_cells = num_cells + self.activation_class = activation_class + self.depth = depth - self.agent_networks = nn.ModuleList( - [ - MLP( - in_features=n_agent_inputs - if not centralised - else n_agent_inputs * n_agents, - out_features=n_agent_outputs, - depth=depth, - num_cells=num_cells, - activation_class=activation_class, - device=device, - **kwargs, - ) - for _ in range(self.n_agents if not self.share_params else 1) - ] + super().__init__( + n_agents=n_agents, + centralised=centralised, + share_params=share_params, + device=device, + agent_dim=-2, + **kwargs, ) - def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: - if len(inputs) > 1: - inputs = torch.cat([*inputs], -1) - else: - inputs = inputs[0] - - if inputs.shape[-2:] != (self.n_agents, self.n_agent_inputs): + def _pre_forward_check(self, inputs): + if inputs.shape[-2] != self.n_agents: raise ValueError( - f"Multi-agent network expected input with last 2 dimensions {[self.n_agents, self.n_agent_inputs]}," + f"Multi-agent network expected input with shape[-2]={self.n_agents}," f" but got {inputs.shape}" ) - # If the model is centralized, agents have full observability if self.centralised: - inputs = inputs.reshape( - *inputs.shape[:-2], self.n_agents * self.n_agent_inputs - ) - - # If parameters are not shared, each agent has its own network - if not self.share_params: - if self.centralised: - output = torch.stack( - [net(inputs) for i, net in enumerate(self.agent_networks)], - dim=-2, - ) - else: - output = torch.stack( - [ - net(inputs[..., i, :]) - for i, net in enumerate(self.agent_networks) - ], - dim=-2, - ) - # If parameters are shared, agents use the same network - else: - output = self.agent_networks[0](inputs) - - if self.centralised: - # If the parameters are shared, and it is centralised, all agents will have the same output - # We expand it to maintain the agent dimension, but values will be the same for all agents - output = output.view(*output.shape[:-1], self.n_agent_outputs) - output = output.unsqueeze(-2) - output = output.expand( - *output.shape[:-2], self.n_agents, self.n_agent_outputs - ) - - if output.shape[-2:] != (self.n_agents, self.n_agent_outputs): - raise ValueError( - f"Multi-agent network expected output with last 2 dimensions {[self.n_agents, self.n_agent_outputs]}," - f" but got {output.shape}" - ) - - return output + inputs = inputs.flatten(-2, -1) + return inputs + + def _build_single_net(self, *, device, **kwargs): + n_agent_inputs = self.n_agent_inputs + if self.centralised and n_agent_inputs is not None: + n_agent_inputs = self.n_agent_inputs * self.n_agents + return MLP( + in_features=n_agent_inputs, + out_features=self.n_agent_outputs, + depth=self.depth, + num_cells=self.num_cells, + activation_class=self.activation_class, + device=device, + **kwargs, + ) -class MultiAgentConvNet(nn.Module): +class MultiAgentConvNet(MultiAgentNetBase): """Multi-agent CNN. In MARL settings, agents may or may not share the same policy for their actions: we say that the parameters can be shared or not. Similarly, a network may take the entire observation space (across agents) or on a per-agent basis to compute its output, which we refer to as "centralized" and "non-centralized", respectively. @@ -243,6 +320,10 @@ class MultiAgentConvNet(nn.Module): share_params (bool): If ``True``, the same :class:`~torchrl.modules.ConvNet` will be used to make the forward pass for all agents (homogeneous policies). Otherwise, each agent will use a different :class:`~torchrl.modules.ConvNet` to process its input (heterogeneous policies). + + Keyword Args: + in_features (int, optional): the input feature dimension. If left to ``None``, + a lazy module is used. device (str or torch.device, optional): device to create the module on. num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If an integer is provided, every layer will have the same number of cells. If an iterable is provided, @@ -374,36 +455,46 @@ def __init__( n_agents: int, centralised: bool, share_params: bool, - device: Optional[DEVICE_TYPING] = None, - num_cells: Optional[Sequence[int]] = None, + *, + in_features: int | None = None, + device: DEVICE_TYPING | None = None, + num_cells: Sequence[int] | None = None, kernel_sizes: Union[Sequence[Union[int, Sequence[int]]], int] = 5, strides: Union[Sequence, int] = 2, paddings: Union[Sequence, int] = 0, activation_class: Type[nn.Module] = nn.ELU, **kwargs, ): - super().__init__() - - self.n_agents = n_agents - self.centralised = centralised - self.share_params = share_params + self.in_features = in_features + self.num_cells = num_cells + self.strides = strides + self.kernel_sizes = kernel_sizes + self.paddings = paddings + self.activation_class = activation_class + super().__init__( + n_agents=n_agents, + centralised=centralised, + share_params=share_params, + device=device, + agent_dim=-4, + ) - self.agent_networks = nn.ModuleList( - [ - ConvNet( - num_cells=num_cells, - kernel_sizes=kernel_sizes, - strides=strides, - paddings=paddings, - activation_class=activation_class, - device=device, - **kwargs, - ) - for _ in range(self.n_agents if not self.share_params else 1) - ] + def _build_single_net(self, *, device, **kwargs): + in_features = self.in_features + if self.centralised and in_features is not None: + in_features = in_features * self.n_agents + return ConvNet( + in_features=in_features, + num_cells=self.num_cells, + kernel_sizes=self.kernel_sizes, + strides=self.strides, + paddings=self.paddings, + activation_class=self.activation_class, + device=device, + **kwargs, ) - def forward(self, inputs: torch.Tensor): + def _pre_forward_check(self, inputs): if len(inputs.shape) < 4: raise ValueError( """Multi-agent network expects (*batch_size, agent_index, x, y, channels)""" @@ -412,44 +503,10 @@ def forward(self, inputs: torch.Tensor): raise ValueError( f"""Multi-agent network expects {self.n_agents} but got {inputs.shape[-4]}""" ) - # If the model is centralized, agents have full observability if self.centralised: - shape = ( - *inputs.shape[:-4], - self.n_agents * inputs.shape[-3], - inputs.shape[-2], - inputs.shape[-1], - ) - inputs = torch.reshape(inputs, shape) - - # If the parameters are not shared, each agent has its own network - if not self.share_params: - if self.centralised: - output = torch.stack( - [net(inputs) for net in self.agent_networks], dim=-2 - ) - else: - output = torch.stack( - [ - net(inp) - for i, (net, inp) in enumerate( - zip(self.agent_networks, inputs.unbind(-4)) - ) - ], - dim=-2, - ) - else: - output = self.agent_networks[0](inputs) - if self.centralised: - # If the parameters are shared, and it is centralised all agents will have the same output. - # We expand it to maintain the agent dimension, but values will be the same for all agents - n_agent_outputs = output.shape[-1] - output = output.view(*output.shape[:-1], n_agent_outputs) - output = output.unsqueeze(-2) - output = output.expand( - *output.shape[:-2], self.n_agents, n_agent_outputs - ) - return output + # If the model is centralized, agents have full observability + inputs = torch.flatten(inputs, -4, -3) + return inputs class Mixer(nn.Module):