diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 56b1bfc7fea..0d9e23929a5 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -335,6 +335,20 @@ algorithms, such as DQN, DDPG or Dreamer. RSSMPrior RSSMPosterior +Multi-agent-specific modules +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +These networks implement models that can be used in +multi-agent contexts. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + MultiAgentMLP + QMixer + VDNMixer + Exploration ----------- diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index ed2d5c3cff7..3325cd05fd6 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -185,6 +185,21 @@ Dreamer DreamerModelLoss DreamerValueLoss +Multi-agent objectives +---------------------- +.. currentmodule:: torchrl.objectives.multiagent + +These objectives are specific to multi-agent algorithms. + +QMixer +~~~~~~ + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + QMixerLoss + Returns ------- diff --git a/setup.py b/setup.py index d162ee6164e..3723c1b1981 100644 --- a/setup.py +++ b/setup.py @@ -235,6 +235,7 @@ def _main(argv): "checkpointing": [ "torchsnapshot", ], + "marl": ["vmas"], }, zip_safe=False, classifiers=[ @@ -254,5 +255,4 @@ def _main(argv): if __name__ == "__main__": - _main(sys.argv[1:]) diff --git a/test/test_cost.py b/test/test_cost.py index b1f7caf5c4c..9849fb09801 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -23,6 +23,8 @@ TensorDictSequential as Seq, ) +from torchrl.modules.models import QMixer + _has_functorch = True try: import functorch as ft # noqa @@ -46,6 +48,7 @@ # from torchrl.data.postprocs.utils import expand_as_right from tensordict.tensordict import assert_allclose_td, TensorDict +from tensordict.utils import unravel_key from torch import autograd, nn from torchrl.data import ( BoundedTensorSpec, @@ -80,6 +83,7 @@ ActorCriticOperator, ActorValueOperator, ProbabilisticActor, + QValueModule, ValueOperator, ) from torchrl.modules.utils import Buffer @@ -97,6 +101,7 @@ IQLLoss, KLPENPPOLoss, PPOLoss, + QMixerLoss, SACLoss, TD3Loss, ) @@ -719,6 +724,418 @@ def test_distributional_dqn_tensordict_run(self, action_spec_type, td_est): assert loss_fn.tensor_keys.priority in td.keys() +class TestQMixer(LossModuleTestBase): + seed = 0 + + def _create_mock_actor( + self, + action_spec_type, + obs_dim=3, + action_dim=4, + device="cpu", + observation_key=("agents", "observation"), + action_key=("agents", "action"), + action_value_key=("agents", "action_value"), + chosen_action_value_key=("agents", "chosen_action_value"), + ): + # Actor + if action_spec_type == "one_hot": + action_spec = OneHotDiscreteTensorSpec(action_dim) + elif action_spec_type == "categorical": + action_spec = DiscreteTensorSpec(action_dim) + else: + raise ValueError(f"Wrong {action_spec_type}") + + module = nn.Linear(obs_dim, action_dim).to(device) + + module = TensorDictModule( + module, + in_keys=[observation_key], + out_keys=[action_value_key], + ).to(device) + value_module = QValueModule( + action_value_key=action_value_key, + out_keys=[ + action_key, + action_value_key, + chosen_action_value_key, + ], + spec=action_spec, + action_space=None, + ).to(device) + actor = SafeSequential(module, value_module) + + return actor + + def _create_mock_mixer( + self, + state_shape=(64, 64, 3), + n_agents=4, + device="cpu", + chosen_action_value_key=("agents", "chosen_action_value"), + state_key="state", + global_chosen_action_value_key="chosen_action_value", + ): + qmixer = TensorDictModule( + module=QMixer( + state_shape=state_shape, + mixing_embed_dim=32, + n_agents=n_agents, + device=device, + ), + in_keys=[chosen_action_value_key, state_key], + out_keys=[global_chosen_action_value_key], + ).to(device) + + return qmixer + + def _create_mock_data_dqn( + self, + action_spec_type, + batch=(2,), + T=None, + n_agents=4, + obs_dim=3, + state_shape=(64, 64, 3), + action_dim=4, + device="cpu", + action_key=("agents", "action"), + action_value_key=("agents", "action_value"), + ): + if T is not None: + batch = batch + (T,) + # create a tensordict + obs = torch.randn(*batch, n_agents, obs_dim, device=device) + state = torch.randn(*batch, *state_shape, device=device) + next_obs = torch.randn(*batch, n_agents, obs_dim, device=device) + next_state = torch.randn(*batch, *state_shape, device=device) + + action_value = torch.randn(*batch, n_agents, action_dim, device=device) + if action_spec_type == "one_hot": + action = (action_value == action_value.max(dim=-1, keepdim=True)[0]).to( + torch.long + ) + elif action_spec_type == "categorical": + action = torch.argmax(action_value, dim=-1).to(torch.long) + + reward = torch.randn(*batch, 1, device=device) + done = torch.zeros(*batch, 1, dtype=torch.bool, device=device) + td = TensorDict( + { + "agents": TensorDict( + {"observation": obs}, + [*batch, n_agents], + device=device, + ), + "state": state, + "collector": { + "mask": torch.zeros(*batch, dtype=torch.bool, device=device) + }, + "next": TensorDict( + { + "agents": TensorDict( + {"observation": next_obs}, + [*batch, n_agents], + device=device, + ), + "state": next_state, + "reward": reward, + "done": done, + }, + batch_size=batch, + device=device, + ), + }, + batch_size=batch, + device=device, + ) + td.set(action_key, action) + td.set(action_value_key, action_value) + if T is not None: + td.refine_names(None, "time") + return td + + @pytest.mark.parametrize("delay_value", (False, True)) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) + @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + def test_qmixer(self, delay_value, device, action_spec_type, td_est): + torch.manual_seed(self.seed) + actor = self._create_mock_actor( + action_spec_type=action_spec_type, device=device + ) + mixer = self._create_mock_mixer(device=device) + td = self._create_mock_data_dqn( + action_spec_type=action_spec_type, device=device + ) + loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value) + if td_est is ValueEstimators.GAE: + with pytest.raises(NotImplementedError): + loss_fn.make_value_estimator(td_est) + return + if td_est is not None: + loss_fn.make_value_estimator(td_est) + with ( + pytest.warns(UserWarning, match="No target network updater has been") + if delay_value + else contextlib.nullcontext() + ), _check_td_steady(td): + loss = loss_fn(td) + assert loss_fn.tensor_keys.priority in td.keys() + + sum([item for _, item in loss.items()]).backward() + assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 + + # Check param update effect on targets + target_value = loss_fn.target_local_value_network_params.clone() + for p in loss_fn.parameters(): + p.data += 3 + target_value2 = loss_fn.target_local_value_network_params.clone() + if loss_fn.delay_value: + assert_allclose_td(target_value, target_value2) + else: + assert not (target_value == target_value2).any() + + # Check param update effect on targets + target_value = loss_fn.target_mixer_network_params.clone() + for p in loss_fn.parameters(): + p.data += 3 + target_value2 = loss_fn.target_mixer_network_params.clone() + if loss_fn.delay_value: + assert_allclose_td(target_value, target_value2) + else: + assert not (target_value == target_value2).any() + + # check that policy is updated after parameter update + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + @pytest.mark.parametrize("n", range(4)) + @pytest.mark.parametrize("delay_value", (False, True)) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) + def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9): + torch.manual_seed(self.seed) + actor = self._create_mock_actor( + action_spec_type=action_spec_type, device=device + ) + mixer = self._create_mock_mixer(device=device) + td = self._create_mock_data_dqn( + action_spec_type=action_spec_type, T=4, device=device + ) + loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value) + + ms = MultiStep(gamma=gamma, n_steps=n).to(device) + ms_td = ms(td.clone()) + + with ( + pytest.warns(UserWarning, match="No target network updater has been") + if delay_value + else contextlib.nullcontext() + ), _check_td_steady(ms_td): + loss_ms = loss_fn(ms_td) + assert loss_fn.tensor_keys.priority in ms_td.keys() + + with torch.no_grad(): + loss = loss_fn(td) + if n == 0: + assert_allclose_td(td, ms_td.select(*td.keys(True, True))) + _loss = sum([item for _, item in loss.items()]) + _loss_ms = sum([item for _, item in loss_ms.items()]) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + sum([item for _, item in loss_ms.items()]).backward() + assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 + + # Check param update effect on targets + target_value = loss_fn.target_local_value_network_params.clone() + for p in loss_fn.parameters(): + p.data += 3 + target_value2 = loss_fn.target_local_value_network_params.clone() + if loss_fn.delay_value: + assert_allclose_td(target_value, target_value2) + else: + assert not (target_value == target_value2).any() + + # Check param update effect on targets + target_value = loss_fn.target_mixer_network_params.clone() + for p in loss_fn.parameters(): + p.data += 3 + target_value2 = loss_fn.target_mixer_network_params.clone() + if loss_fn.delay_value: + assert_allclose_td(target_value, target_value2) + else: + assert not (target_value == target_value2).any() + + # check that policy is updated after parameter update + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + @pytest.mark.parametrize( + "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] + ) + def test_qmix_tensordict_keys(self, td_est): + torch.manual_seed(self.seed) + action_spec_type = "one_hot" + actor = self._create_mock_actor(action_spec_type=action_spec_type) + mixer = self._create_mock_mixer() + loss_fn = QMixerLoss(actor, mixer) + + default_keys = { + "advantage": "advantage", + "value_target": "value_target", + "local_value": ("agents", "chosen_action_value"), + "global_value": "chosen_action_value", + "priority": "td_error", + "action_value": ("agents", "action_value"), + "action": ("agents", "action"), + "reward": "reward", + "done": "done", + } + + self.tensordict_keys_test(loss_fn, default_keys=default_keys) + + loss_fn = QMixerLoss(actor, mixer) + key_mapping = { + "advantage": ("advantage", "advantage_2"), + "value_target": ("value_target", ("value_target", "nested")), + "reward": ("reward", "reward_test"), + "done": ("done", ("done", "test")), + } + self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + + actor = self._create_mock_actor( + action_spec_type=action_spec_type, + ) + mixer = self._create_mock_mixer( + global_chosen_action_value_key=("some", "nested") + ) + loss_fn = QMixerLoss(actor, mixer) + key_mapping = { + "global_value": ("value", ("some", "nested")), + } + self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + + @pytest.mark.parametrize("action_spec_type", ("categorical", "one_hot")) + @pytest.mark.parametrize( + "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] + ) + def test_qmix_tensordict_run(self, action_spec_type, td_est): + torch.manual_seed(self.seed) + tensor_keys = { + "action_value": ("other", "action_value_test"), + "action": ("other", "action"), + "local_value": ("some", "local_v"), + "global_value": "global_v", + "priority": "priority_test", + } + actor = self._create_mock_actor( + action_spec_type=action_spec_type, + action_value_key=tensor_keys["action_value"], + action_key=tensor_keys["action"], + chosen_action_value_key=tensor_keys["local_value"], + ) + mixer = self._create_mock_mixer( + chosen_action_value_key=tensor_keys["local_value"], + global_chosen_action_value_key=tensor_keys["global_value"], + ) + td = self._create_mock_data_dqn( + action_spec_type=action_spec_type, + action_key=tensor_keys["action"], + action_value_key=tensor_keys["action_value"], + ) + + loss_fn = QMixerLoss(actor, mixer, loss_function="l2") + loss_fn.set_keys(**tensor_keys) + + if td_est is not None: + loss_fn.make_value_estimator(td_est) + with _check_td_steady(td): + _ = loss_fn(td) + assert loss_fn.tensor_keys.priority in td.keys() + + @pytest.mark.parametrize( + "mixer_local_chosen_action_value_key", + [("agents", "chosen_action_value"), ("other")], + ) + @pytest.mark.parametrize( + "mixer_global_chosen_action_value_key", + ["chosen_action_value", ("nested", "other")], + ) + def test_mixer_keys( + self, + mixer_local_chosen_action_value_key, + mixer_global_chosen_action_value_key, + n_agents=4, + obs_dim=3, + ): + torch.manual_seed(0) + actor = self._create_mock_actor( + action_spec_type="categorical", + ) + mixer = self._create_mock_mixer( + chosen_action_value_key=mixer_local_chosen_action_value_key, + global_chosen_action_value_key=mixer_global_chosen_action_value_key, + n_agents=n_agents, + ) + + td = TensorDict( + { + "agents": TensorDict( + {"observation": torch.zeros(32, n_agents, obs_dim)}, [32, n_agents] + ), + "state": torch.zeros(32, 64, 64, 3), + "next": TensorDict( + { + "agents": TensorDict( + {"observation": torch.zeros(32, n_agents, obs_dim)}, + [32, n_agents], + ), + "state": torch.zeros(32, 64, 64, 3), + "reward": torch.zeros(32, 1), + "done": torch.zeros(32, 1, dtype=torch.bool), + }, + [32], + ), + }, + [32], + ) + td = actor(td) + + loss = QMixerLoss(actor, mixer) + + # Wthout etting the keys + if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"): + with pytest.raises(RuntimeError): + loss(td) + elif unravel_key(mixer_global_chosen_action_value_key) != "chosen_action_value": + with pytest.raises( + KeyError, match='key "chosen_action_value" not found in TensorDict' + ): + loss(td) + else: + loss(td) + + loss = QMixerLoss(actor, mixer) + # When setting the key + loss.set_keys(global_value=mixer_global_chosen_action_value_key) + if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"): + with pytest.raises( + RuntimeError + ): # The mixer in key still does not match the actor out_key + loss(td) + else: + loss(td) + + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" ) @@ -2699,7 +3116,6 @@ def test_discrete_sac( target_entropy, td_est, ): - torch.manual_seed(self.seed) td = self._create_mock_data_sac(device=device) @@ -3247,7 +3663,6 @@ def _create_seq_mock_data_redq( @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_redq(self, delay_qvalue, num_qvalue, device, td_est): - torch.manual_seed(self.seed) td = self._create_mock_data_redq(device=device) @@ -3342,7 +3757,6 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est): @pytest.mark.parametrize("separate_losses", [False, True]) def test_redq_separate_losses(self, separate_losses): - torch.manual_seed(self.seed) actor, qvalue, common, td = self._create_mock_common_layer_setup() @@ -3431,7 +3845,6 @@ def test_redq_separate_losses(self, separate_losses): @pytest.mark.parametrize("separate_losses", [False, True]) def test_redq_deprecated_separate_losses(self, separate_losses): - torch.manual_seed(self.seed) actor, qvalue, common, td = self._create_mock_common_layer_setup() @@ -3520,7 +3933,6 @@ def test_redq_deprecated_separate_losses(self, separate_losses): @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_default_devices()) def test_redq_shared(self, delay_qvalue, num_qvalue, device): - torch.manual_seed(self.seed) td = self._create_mock_data_redq(device=device) @@ -3585,7 +3997,6 @@ def test_redq_shared(self, delay_qvalue, num_qvalue, device): @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): - torch.manual_seed(self.seed) td = self._create_mock_data_redq(device=device) @@ -4111,7 +4522,6 @@ def test_cql_batcher( with_lagrange, device, ): - torch.manual_seed(self.seed) td = self._create_seq_mock_data_cql(device=device) @@ -6454,7 +6864,6 @@ def test_iql( expectile, td_est, ): - torch.manual_seed(self.seed) td = self._create_mock_data_iql(device=device) @@ -7114,7 +7523,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: # total dist d0 = 0.0 - for (key, source_val) in upd._sources.items(True, True): + for key, source_val in upd._sources.items(True, True): if not isinstance(key, tuple): key = (key,) key = ("target_" + key[0], *key[1:]) @@ -7130,7 +7539,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: for i in range(value_network_update_interval + 1): # test that no update is occuring until value_network_update_interval d1 = 0.0 - for (key, source_val) in upd._sources.items(True, True): + for key, source_val in upd._sources.items(True, True): if not isinstance(key, tuple): key = (key,) key = ("target_" + key[0], *key[1:]) @@ -7145,7 +7554,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: assert upd.counter == 0 # test that a new update has occured d1 = 0.0 - for (key, source_val) in upd._sources.items(True, True): + for key, source_val in upd._sources.items(True, True): if not isinstance(key, tuple): key = (key,) key = ("target_" + key[0], *key[1:]) @@ -7158,7 +7567,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: elif mode == "soft": upd.step() d1 = 0.0 - for (key, source_val) in upd._sources.items(True, True): + for key, source_val in upd._sources.items(True, True): if not isinstance(key, tuple): key = (key,) key = ("target_" + key[0], *key[1:]) @@ -7171,7 +7580,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: upd.init_() upd.step() d2 = 0.0 - for (key, source_val) in upd._sources.items(True, True): + for key, source_val in upd._sources.items(True, True): if not isinstance(key, tuple): key = (key,) key = ("target_" + key[0], *key[1:]) diff --git a/test/test_modules.py b/test/test_modules.py index 2481ec09f69..caa4cca1c9b 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -14,7 +14,16 @@ from tensordict import TensorDict from torch import nn from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec -from torchrl.modules import CEMPlanner, LSTMNet, SafeModule, TanhModule, ValueOperator +from torchrl.modules import ( + CEMPlanner, + LSTMNet, + MultiAgentMLP, + QMixer, + SafeModule, + TanhModule, + ValueOperator, + VDNMixer, +) from torchrl.modules.distributions.utils import safeatanh, safetanh from torchrl.modules.models import ConvNet, MLP, NoisyLazyLinear, NoisyLinear from torchrl.modules.models.model_based import ( @@ -200,7 +209,6 @@ def test_lstm_net( has_precond_hidden, double_prec_fixture, ): - torch.manual_seed(0) batch = 5 time_steps = 6 @@ -708,6 +716,213 @@ def test_multi_inputs(self, out_keys, has_spec): assert (data[out_key] >= min - eps).all() +class TestMultiAgent: + def _get_mock_input_td( + self, n_agents, n_agents_inputs, state_shape=(64, 64, 3), T=None, batch=(2,) + ): + if T is not None: + batch = batch + (T,) + obs = torch.randn(*batch, n_agents, n_agents_inputs) + state = torch.randn(*batch, *state_shape) + + td = TensorDict( + { + "agents": TensorDict( + {"observation": obs}, + [*batch, n_agents], + ), + "state": state, + }, + batch_size=batch, + ) + return td + + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize("share_params", [True, False]) + @pytest.mark.parametrize("centralised", [True, False]) + @pytest.mark.parametrize( + "batch", + [ + (10,), + ( + 10, + 3, + ), + (), + ], + ) + def test_mlp( + self, + n_agents, + centralised, + share_params, + batch, + n_agent_inputs=6, + n_agent_outputs=2, + ): + torch.manual_seed(0) + mlp = MultiAgentMLP( + n_agent_inputs=n_agent_inputs, + n_agent_outputs=n_agent_outputs, + n_agents=n_agents, + centralised=centralised, + share_params=share_params, + depth=2, + ) + td = self._get_mock_input_td(n_agents, n_agent_inputs, batch=batch) + obs = td.get(("agents", "observation")) + + out = mlp(obs) + assert out.shape == (*batch, n_agents, n_agent_outputs) + for i in range(n_agents): + if centralised and share_params: + assert torch.allclose(out[..., i, :], out[..., 0, :]) + else: + for j in range(i + 1, n_agents): + assert not torch.allclose(out[..., i, :], out[..., j, :]) + + obs[..., 0, 0] += 1 + out2 = mlp(obs) + for i in range(n_agents): + if centralised: + # a modification to the input of agent 0 will impact all agents + assert not torch.allclose(out[..., i, :], out2[..., i, :]) + elif i > 0: + assert torch.allclose(out[..., i, :], out2[..., i, :]) + + obs = torch.randn(*batch, 1, n_agent_inputs).expand( + *batch, n_agents, n_agent_inputs + ) + out = mlp(obs) + for i in range(n_agents): + if share_params: + # same input same output + assert torch.allclose(out[..., i, :], out[..., 0, :]) + else: + for j in range(i + 1, n_agents): + # same input different output + assert not torch.allclose(out[..., i, :], out[..., j, :]) + + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize( + "batch", + [ + (10,), + ( + 10, + 3, + ), + (), + ], + ) + def test_vdn(self, n_agents, batch): + torch.manual_seed(0) + mixer = VDNMixer(n_agents=n_agents, device="cpu") + + td = self._get_mock_input_td(n_agents, batch=batch, n_agents_inputs=1) + obs = td.get(("agents", "observation")) + assert obs.shape == (*batch, n_agents, 1) + out = mixer(obs) + assert out.shape == (*batch, 1) + assert torch.equal(obs.sum(-2), out) + + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize( + "batch", + [ + (10,), + ( + 10, + 3, + ), + (), + ], + ) + @pytest.mark.parametrize("state_shape", [(64, 64, 3), (10,)]) + def test_qmix(self, n_agents, batch, state_shape): + torch.manual_seed(0) + mixer = QMixer( + n_agents=n_agents, + state_shape=state_shape, + mixing_embed_dim=32, + device="cpu", + ) + + td = self._get_mock_input_td( + n_agents, batch=batch, n_agents_inputs=1, state_shape=state_shape + ) + obs = td.get(("agents", "observation")) + state = td.get("state") + assert obs.shape == (*batch, n_agents, 1) + assert state.shape == (*batch, *state_shape) + out = mixer(obs, state) + assert out.shape == (*batch, 1) + + @pytest.mark.parametrize("mixer", ["qmix", "vdn"]) + def test_mixer_malformed_input( + self, mixer, n_agents=3, batch=(32,), state_shape=(64, 64, 3) + ): + td = self._get_mock_input_td( + n_agents, batch=batch, n_agents_inputs=3, state_shape=state_shape + ) + if mixer == "qmix": + mixer = QMixer( + n_agents=n_agents, + state_shape=state_shape, + mixing_embed_dim=32, + device="cpu", + ) + else: + mixer = VDNMixer(n_agents=n_agents, device="cpu") + obs = td.get(("agents", "observation")) + state = td.get("state") + + if mixer.needs_state: + with pytest.raises( + ValueError, + match="Mixer that needs state was passed more than 2 inputs", + ): + mixer(obs) + else: + with pytest.raises( + ValueError, + match="Mixer that doesn't need state was passed more than 1 input", + ): + mixer(obs, state) + + in_put = [obs, state] if mixer.needs_state else [obs] + with pytest.raises( + ValueError, + match="Mixer network expected chosen_action_value with last 2 dimensions", + ): + mixer(*in_put) + if mixer.needs_state: + state_diff = state.unsqueeze(-1) + with pytest.raises( + ValueError, + match="Mixer network expected state with ending shape", + ): + mixer(obs, state_diff) + + td = self._get_mock_input_td( + n_agents, batch=batch, n_agents_inputs=1, state_shape=state_shape + ) + obs = td.get(("agents", "observation")) + state = td.get("state") + obs = obs.sum(-2) + in_put = [obs, state] if mixer.needs_state else [obs] + with pytest.raises( + ValueError, + match="Mixer network expected chosen_action_value with last 2 dimensions", + ): + mixer(*in_put) + + obs = td.get(("agents", "observation")) + state = td.get("state") + in_put = [obs, state] if mixer.needs_state else [obs] + mixer(*in_put) + + @pytest.mark.skipif(torch.__version__ < "2.0", reason="torch 2.0 is required") @pytest.mark.parametrize("use_vmap", [False, True]) @pytest.mark.parametrize("scale", range(10)) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ebc913a1214..0f17c37b5f4 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3460,6 +3460,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec """Transforms the observation spec, adding the new keys generated by RewardSum.""" # Retrieve parent reward spec reward_spec = self.parent.reward_spec + reward_key = self.parent.reward_key if self.parent else "reward" episode_specs = {} if isinstance(reward_spec, CompositeSpec): @@ -3478,7 +3479,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec else: # If reward_spec is not a CompositeSpec, the only in_key should be ´reward´ - if set(self.in_keys) != {"reward"}: + if set(unravel_key_list(self.in_keys)) != {unravel_key(reward_key)}: raise KeyError( "reward_spec is not a CompositeSpec class, in_keys should only include ´reward´" ) diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 0e6fe606fa3..ebb73bcedf6 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -25,15 +25,18 @@ DuelingCnnDQNet, LSTMNet, MLP, + MultiAgentMLP, NoisyLazyLinear, NoisyLinear, ObsDecoder, ObsEncoder, + QMixer, reset_noise, RSSMPosterior, RSSMPrior, Squeeze2dLayer, SqueezeLayer, + VDNMixer, ) from .tensordict_module import ( Actor, diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 8654d338c18..8e5d0c2f9c9 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -17,4 +17,5 @@ LSTMNet, MLP, ) +from .multiagent import MultiAgentMLP, QMixer, VDNMixer from .utils import Squeeze2dLayer, SqueezeLayer diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py new file mode 100644 index 00000000000..de565b336d2 --- /dev/null +++ b/torchrl/modules/models/multiagent.py @@ -0,0 +1,600 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Sequence, Tuple, Type, Union + +import numpy as np + +import torch +from torch import nn + +from ...data import DEVICE_TYPING + +from .models import MLP + + +class MultiAgentMLP(nn.Module): + """Mult-agent MLP. + + This is an MLP that can be used in multi-agent contexts. + For example, as a policy or as a value function. + See `examples/multiagent` for examples. + + It expects inputs with shape (*B, n_agents, n_agent_inputs) + It returns outputs with shape (*B, n_agents, n_agent_outputs) + + If `share_params` is True, the same MLP will be used to make the forward pass for all agents (homogeneous policies). + Otherwise, each agent will use a different MLP to process its input (heterogeneous policies). + + If `centralised` is True, each agent will use the inputs of all agents to compute its output + (n_agent_inputs * n_agents will be the number of inputs for one agent). + Otherwise, each agent will only use its data as input. + + Args: + n_agent_inputs (int): number of inputs for each agent. + n_agent_outputs (int): number of outputs for each agent. + n_agents (int): number of agents. + centralised (bool): If `centralised` is True, each agent will use the inputs of all agents to compute its output + (n_agent_inputs * n_agents will be the number of inputs for one agent). + Otherwise, each agent will only use its data as input. + share_params (bool): If `share_params` is True, the same MLP will be used to make the forward pass + for all agents (homogeneous policies). Otherwise, each agent will use a different MLP to process + its input (heterogeneous policies). + device (str or toech.device, optional): device to create the module on. + depth (int, optional): depth of the network. A depth of 0 will produce a single linear layer network with the + desired input and output size. A length of 1 will create 2 linear layers etc. If no depth is indicated, + the depth information should be contained in the num_cells argument (see below). If num_cells is an + iterable and depth is indicated, both should match: len(num_cells) must be equal to depth. + default: 3. + num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If + an integer is provided, every layer will have the same number of cells. If an iterable is provided, + the linear layers out_features will match the content of num_cells. + default: 32. + activation_class (Type[nn.Module]): activation class to be used. + default: nn.Tanh. + **kwargs: for :class:`torchrl.modules.models.MLP` can be passed to customize the MLPs. + + Examples: + >>> from torchrl.modules import MultiAgentMLP + >>> import torch + >>> n_agents = 6 + >>> n_agent_inputs=3 + >>> n_agent_outputs=2 + >>> batch = 64 + >>> obs = torch.zeros(batch, n_agents, n_agent_inputs + First let's instantiate a local network shared by all agents (e.g. a parameter-shared policy) + >>> mlp = MultiAgentMLP( + ... n_agent_inputs=n_agent_inputs, + ... n_agent_outputs=n_agent_outputs, + ... n_agents=n_agents, + ... centralised=False, + ... share_params=True, + ... depth=2, + ... ) + >>> print(mlp) + MultiAgentMLP( + (agent_networks): ModuleList( + (0): MLP( + (0): Linear(in_features=3, out_features=32, bias=True) + (1): Tanh() + (2): Linear(in_features=32, out_features=32, bias=True) + (3): Tanh() + (4): Linear(in_features=32, out_features=2, bias=True) + ) + ) + ) + >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs) + Now let's instantiate a centralised network shared by all agents (e.g. a centalised value function) + >>> mlp = MultiAgentMLP( + ... n_agent_inputs=n_agent_inputs, + ... n_agent_outputs=n_agent_outputs, + ... n_agents=n_agents, + ... centralised=True, + ... share_params=True, + ... depth=2, + ... ) + >>> print(mlp) + MultiAgentMLP( + (agent_networks): ModuleList( + (0): MLP( + (0): Linear(in_features=18, out_features=32, bias=True) + (1): Tanh() + (2): Linear(in_features=32, out_features=32, bias=True) + (3): Tanh() + (4): Linear(in_features=32, out_features=2, bias=True) + ) + ) + ) + We can see that the input to the first layer is n_agents * n_agent_inputs, + this is because in the case the net acts as a centralised mlp (like a single huge agent) + >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs) + Outputs will be identical for all agents. + Now we can do both examples just shown but with an independent set of parameters for each agent + Let's show the centralised=False case. + >>> mlp = MultiAgentMLP( + ... n_agent_inputs=n_agent_inputs, + ... n_agent_outputs=n_agent_outputs, + ... n_agents=n_agents, + ... centralised=False, + ... share_params=False, + ... depth=2, + ... ) + >>> print(mlp) + MultiAgentMLP( + (agent_networks): ModuleList( + (0-5): 6 x MLP( + (0): Linear(in_features=3, out_features=32, bias=True) + (1): Tanh() + (2): Linear(in_features=32, out_features=32, bias=True) + (3): Tanh() + (4): Linear(in_features=32, out_features=2, bias=True) + ) + ) + ) + We can see that this is the same as in the first example, but now we have 6 MLPs, one per agent! + >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs) + """ + + def __init__( + self, + n_agent_inputs: int, + n_agent_outputs: int, + n_agents: int, + centralised: bool, + share_params: bool, + device: Optional[DEVICE_TYPING] = None, + depth: Optional[int] = None, + num_cells: Optional[Union[Sequence, int]] = None, + activation_class: Optional[Type[nn.Module]] = nn.Tanh, + **kwargs, + ): + super().__init__() + + self.n_agents = n_agents + self.n_agent_inputs = n_agent_inputs + self.n_agent_outputs = n_agent_outputs + self.share_params = share_params + self.centralised = centralised + + self.agent_networks = nn.ModuleList( + [ + MLP( + in_features=n_agent_inputs + if not centralised + else n_agent_inputs * n_agents, + out_features=n_agent_outputs, + depth=depth, + num_cells=num_cells, + activation_class=activation_class, + device=device, + **kwargs, + ) + for _ in range(self.n_agents if not self.share_params else 1) + ] + ) + + def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: + if len(inputs) > 1: + inputs = torch.cat([*inputs], -1) + else: + inputs = inputs[0] + + if inputs.shape[-2:] != (self.n_agents, self.n_agent_inputs): + raise ValueError( + f"Multi-agent network expected input with last 2 dimensions {[self.n_agents, self.n_agent_inputs]}," + f" but got {inputs.shape}" + ) + + # If the model is centralized, agents have full observability + if self.centralised: + inputs = inputs.reshape( + *inputs.shape[:-2], self.n_agents * self.n_agent_inputs + ) + + # If parameters are not shared, each agent has its own network + if not self.share_params: + if self.centralised: + output = torch.stack( + [net(inputs) for i, net in enumerate(self.agent_networks)], + dim=-2, + ) + else: + output = torch.stack( + [ + net(inputs[..., i, :]) + for i, net in enumerate(self.agent_networks) + ], + dim=-2, + ) + # If parameters are shared, agents use the same network + else: + output = self.agent_networks[0](inputs) + + if self.centralised: + # If the parameters are shared, and it is centralised, all agents will have the same output + # We expand it to maintain the agent dimension, but values will be the same for all agents + output = ( + output.view(*output.shape[:-1], self.n_agent_outputs) + .unsqueeze(-2) + .expand(*output.shape[:-1], self.n_agents, self.n_agent_outputs) + ) + + if output.shape[-2:] != (self.n_agents, self.n_agent_outputs): + raise ValueError( + f"Multi-agent network expected output with last 2 dimensions {[self.n_agents, self.n_agent_outputs]}," + f" but got {output.shape}" + ) + + return output + + +class Mixer(nn.Module): + """A multi-agent value mixer. + + It transforms the local value of each agent's chosen action of shape (*B, self.n_agents, 1), + into a global value with shape (*B, 1). + Used with the :class:`torchrl.objectives.QMixerLoss`. + See `examples/multiagent/qmix_vdn.py` for examples. + + Args: + n_agents (int): number of agents. + needs_state (bool): whether the mixer takes a global state as input. + state_shape (tuple or torch.Size): the shape of the state (excluding eventual leading batch dimensions). + device (str or torch.Device): torch device for the network. + + Examples: + Creating a VDN mixer + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.modules.models.multiagent import VDNMixer + >>> n_agents = 4 + >>> vdn = TensorDictModule( + ... module=VDNMixer( + ... n_agents=n_agents, + ... device="cpu", + ... ), + ... in_keys=[("agents","chosen_action_value")], + ... out_keys=["chosen_action_value"], + ... ) + >>> td = TensorDict({"agents": TensorDict({"chosen_action_value": torch.zeros(32, n_agents, 1)}, [32, n_agents])}, [32]) + >>> td + TensorDict( + fields={ + agents: TensorDict( + fields={ + chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 4]), + device=None, + is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False) + >>> vdn(td) + TensorDict( + fields={ + agents: TensorDict( + fields={ + chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 4]), + device=None, + is_shared=False), + chosen_action_value: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False) + Creating a QMix mixer + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.modules.models.multiagent import QMixer + >>> n_agents = 4 + >>> qmix = TensorDictModule( + ... module=QMixer( + ... state_shape=(64, 64, 3), + ... mixing_embed_dim=32, + ... n_agents=n_agents, + ... device="cpu", + ... ), + ... in_keys=[("agents", "chosen_action_value"), "state"], + ... out_keys=["chosen_action_value"], + ... ) + >>> td = TensorDict({"agents": TensorDict({"chosen_action_value": torch.zeros(32, n_agents, 1)}, [32, n_agents]), "state": torch.zeros(32, 64, 64, 3)}, [32]) + >>> td + TensorDict( + fields={ + agents: TensorDict( + fields={ + chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 4]), + device=None, + is_shared=False), + state: Tensor(shape=torch.Size([32, 64, 64, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False) + >>> vdn(td) + TensorDict( + fields={ + agents: TensorDict( + fields={ + chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 4]), + device=None, + is_shared=False), + chosen_action_value: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False), + state: Tensor(shape=torch.Size([32, 64, 64, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False) + """ + + def __init__( + self, + n_agents: int, + needs_state: bool, + state_shape: Union[Tuple[int, ...], torch.Size], + device: DEVICE_TYPING, + ): + super().__init__() + + self.n_agents = n_agents + self.device = device + self.needs_state = needs_state + self.state_shape = state_shape + + def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: + """Forward pass of the mixer. + + Args: + *inputs: The first input should be the value of the chosen action of shape (*B, self.n_agents, 1), + representing the local q value of each agent. + The second input (optional, used only in some mixers) + is the shared state of all agents of shape (*B, *self.state_shape). + + Returns: + The global value of the chosen actions obtained after mixing, with shape (*B, 1) + + """ + if not self.needs_state: + if len(inputs) > 1: + raise ValueError( + "Mixer that doesn't need state was passed more than 1 input" + ) + chosen_action_value = inputs[0] + else: + if len(inputs) != 2: + raise ValueError("Mixer that needs state was passed more than 2 inputs") + + chosen_action_value, state = inputs + + if state.shape[-len(self.state_shape) :] != self.state_shape: + raise ValueError( + f"Mixer network expected state with ending shape {self.state_shape}," + f" but got state shape {state.shape}" + ) + + if chosen_action_value.shape[-2:] != (self.n_agents, 1): + raise ValueError( + f"Mixer network expected chosen_action_value with last 2 dimensions {(self.n_agents,1)}," + f" but got {chosen_action_value.shape}" + ) + batch_dims = chosen_action_value.shape[:-2] + + if not self.needs_state: + output = self.mix(chosen_action_value, None) + else: + output = self.mix(chosen_action_value, state) + + if output.shape != (*batch_dims, 1): + raise ValueError( + f"Mixer network expected output with same shape as input minus the multi-agent dimension," + f" but got {output.shape}" + ) + + return output + + def mix(self, chosen_action_value: torch.Tensor, state: torch.Tensor): + """Forward pass for the mixer. + + Args: + chosen_action_value: Tensor of shape [*B, n_agents] + + Returns: + chosen_action_value: Tensor of shape [*B] + """ + raise NotImplementedError + + +class VDNMixer(Mixer): + """Value-Decomposition Network mixer. + + Mixes the local Q values of the agents into a global Q value by summing them together. + From the paper https://arxiv.org/abs/1706.05296 . + + It transforms the local value of each agent's chosen action of shape (*B, self.n_agents, 1), + into a global value with shape (*B, 1). + Used with the :class:`torchrl.objectives.QMixerLoss`. + See `examples/multiagent/qmix_vdn.py` for examples. + + Args: + n_agents (int): number of agents. + device (str or torch.Device): torch device for the network. + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.modules.models.multiagent import VDNMixer + >>> n_agents = 4 + >>> vdn = TensorDictModule( + ... module=VDNMixer( + ... n_agents=n_agents, + ... device="cpu", + ... ), + ... in_keys=[("agents","chosen_action_value")], + ... out_keys=["chosen_action_value"], + ... ) + >>> td = TensorDict({"agents": TensorDict({"chosen_action_value": torch.zeros(32, n_agents, 1)}, [32, n_agents])}, [32]) + >>> td + TensorDict( + fields={ + agents: TensorDict( + fields={ + chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 4]), + device=None, + is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False) + >>> vdn(td) + TensorDict( + fields={ + agents: TensorDict( + fields={ + chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 4]), + device=None, + is_shared=False), + chosen_action_value: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False) + """ + + def __init__( + self, + n_agents: int, + device: DEVICE_TYPING, + ): + super().__init__( + needs_state=False, + state_shape=torch.Size([]), + n_agents=n_agents, + device=device, + ) + + def mix(self, chosen_action_value: torch.Tensor, state: torch.Tensor): + return chosen_action_value.sum(dim=-2) + + +class QMixer(Mixer): + """QMix mixer. + + Mixes the local Q values of the agents into a global Q value through a monotonic + hyper-network whose parameters are obtained from a global state. + From the paper https://arxiv.org/abs/1803.11485 . + + It transforms the local value of each agent's chosen action of shape (*B, self.n_agents, 1), + into a global value with shape (*B, 1). + Used with the :class:`torchrl.objectives.QMixerLoss`. + See `examples/multiagent/qmix_vdn.py` for examples. + + Args: + state_shape (tuple or torch.Size): the shape of the state (excluding eventual leading batch dimensions). + mixing_embed_dim (int): the size of the mixing embedded dimension. + n_agents (int): number of agents. + device (str or torch.Device): torch device for the network. + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.modules.models.multiagent import QMixer + >>> n_agents = 4 + >>> qmix = TensorDictModule( + ... module=QMixer( + ... state_shape=(64, 64, 3), + ... mixing_embed_dim=32, + ... n_agents=n_agents, + ... device="cpu", + ... ), + ... in_keys=[("agents", "chosen_action_value"), "state"], + ... out_keys=["chosen_action_value"], + ... ) + >>> td = TensorDict({"agents": TensorDict({"chosen_action_value": torch.zeros(32, n_agents, 1)}, [32, n_agents]), "state": torch.zeros(32, 64, 64, 3)}, [32]) + >>> td + TensorDict( + fields={ + agents: TensorDict( + fields={ + chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 4]), + device=None, + is_shared=False), + state: Tensor(shape=torch.Size([32, 64, 64, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False) + >>> vdn(td) + TensorDict( + fields={ + agents: TensorDict( + fields={ + chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 4]), + device=None, + is_shared=False), + chosen_action_value: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False), + state: Tensor(shape=torch.Size([32, 64, 64, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False) + """ + + def __init__( + self, + state_shape: Union[Tuple[int, ...], torch.Size], + mixing_embed_dim: int, + n_agents: int, + device: DEVICE_TYPING, + ): + super().__init__( + needs_state=True, state_shape=state_shape, n_agents=n_agents, device=device + ) + + self.embed_dim = mixing_embed_dim + self.state_dim = int(np.prod(state_shape)) + + self.hyper_w_1 = nn.Linear( + self.state_dim, self.embed_dim * self.n_agents, device=self.device + ) + self.hyper_w_final = nn.Linear( + self.state_dim, self.embed_dim, device=self.device + ) + + # State dependent bias for hidden layer + self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim, device=self.device) + + # V(s) instead of a bias for the last layers + self.V = nn.Sequential( + nn.Linear(self.state_dim, self.embed_dim, device=self.device), + nn.ReLU(), + nn.Linear(self.embed_dim, 1, device=self.device), + ) + + def mix(self, chosen_action_value: torch.Tensor, state: torch.Tensor): + bs = chosen_action_value.shape[:-2] + state = state.view(-1, self.state_dim) + chosen_action_value = chosen_action_value.view(-1, 1, self.n_agents) + # First layer + w1 = torch.abs(self.hyper_w_1(state)) + b1 = self.hyper_b_1(state) + w1 = w1.view(-1, self.n_agents, self.embed_dim) + b1 = b1.view(-1, 1, self.embed_dim) + hidden = nn.functional.elu( + torch.bmm(chosen_action_value, w1) + b1 + ) # [-1, 1, self.embed_dim] + # Second layer + w_final = torch.abs(self.hyper_w_final(state)) + w_final = w_final.view(-1, self.embed_dim, 1) + # State-dependent bias + v = self.V(state).view(-1, 1, 1) + # Compute final output + y = torch.bmm(hidden, w_final) + v # [-1, 1, 1] + # Reshape and return + q_tot = y.view(*bs, 1) + return q_tot diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 56c185272b9..c5f34a7774d 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -9,7 +9,7 @@ import inspect import re import warnings -from typing import Iterable, Optional, Type, Union +from typing import Iterable, List, Optional, Type, Union import torch @@ -17,6 +17,7 @@ from tensordict.nn import TensorDictModule, TensorDictModuleBase from tensordict.tensordict import TensorDictBase +from tensordict.utils import NestedKey from torch import nn @@ -364,12 +365,16 @@ def ensure_tensordict_compatible( module: Union[ FunctionalModule, FunctionalModuleWithBuffers, TensorDictModule, nn.Module ], - in_keys: Optional[Iterable[str]] = None, - out_keys: Optional[Iterable[str]] = None, + in_keys: Optional[List[NestedKey]] = None, + out_keys: Optional[List[NestedKey]] = None, safe: bool = False, wrapper_type: Optional[Type] = TensorDictModule, **kwargs, ): + """Ensures module is compatible with TensorDictModule and, if not, it wraps it.""" + in_keys = unravel_key_list(in_keys) if in_keys else in_keys + out_keys = unravel_key_list(out_keys) if out_keys else out_keys + """Checks and ensures an object with forward method is TensorDict compatible.""" if is_tensordict_compatible(module): if in_keys is not None and set(in_keys) != set(module.in_keys): diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 5755fc2a27c..163365bdc75 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -10,6 +10,7 @@ from .dqn import DistributionalDQNLoss, DQNLoss from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss from .iql import IQLLoss +from .multiagent import QMixerLoss from .ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss from .redq import REDQLoss from .reinforce import ReinforceLoss diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index c8a1ccbb390..d740d45507e 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import warnings from dataclasses import dataclass -from typing import Union +from typing import Optional, Union import torch from tensordict import TensorDict, TensorDictBase @@ -40,7 +40,8 @@ class DQNLoss(LossModule): value_network (QValueActor or nn.Module): a Q value operator. Keyword Args: - loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". + loss_function (str, optional): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". + Defaults to "l2". delay_value (bool, optional): whether to duplicate the value network into a new target value network to create a double DQN. Default is ``False``. @@ -51,7 +52,7 @@ class DQNLoss(LossModule): :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). If not provided, an attempt to retrieve it from the value network will be made. - priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] + priority_key (NestedKey, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] The key at which priority is assumed to be stored within TensorDicts added to this ReplayBuffer. This is to be used when the sampler is of type :class:`~torchrl.data.PrioritizedSampler`. Defaults to ``"td_error"``. @@ -123,10 +124,10 @@ class _AcceptedKeys: Will be used for the underlying value estimator. Defaults to ``"advantage"``. value_target (NestedKey): The input tensordict key where the target state value is expected. Will be used for the underlying value estimator Defaults to ``"value_target"``. - value (NestedKey): The input tensordict key where the state value is expected. - Will be used for the underlying value estimator. Defaults to ``"state_value"``. - state_action_value (NestedKey): The input tensordict key where the state action value is expected. - Defaults to ``"state_action_value"``. + value (NestedKey): The input tensordict key where the chosen action value is expected. + Will be used for the underlying value estimator. Defaults to ``"chosen_action_value"``. + action_value (NestedKey): The input tensordict key where the action value is expected. + Defaults to ``"action_value"``. action (NestedKey): The input tensordict key where the action is expected. Defaults to ``"action"``. priority (NestedKey): The input tensordict key where the target priority is written to. @@ -155,13 +156,12 @@ def __init__( self, value_network: Union[QValueActor, nn.Module], *, - loss_function: str = "l2", + loss_function: Optional[str] = "l2", delay_value: bool = False, gamma: float = None, action_space: Union[str, TensorSpec] = None, priority_key: str = None, ) -> None: - super().__init__() self._in_keys = None self._set_deprecated_ctor_keys(priority=priority_key) @@ -231,10 +231,6 @@ def in_keys(self): self._set_in_keys() return self._in_keys - @in_keys.setter - def in_keys(self, values): - self._in_keys = values - def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): if value_type is None: value_type = self.default_value_estimator @@ -282,7 +278,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: a tensor containing the DQN loss. """ - device = self.device if self.device is not None else tensordict.device + if self.device is not None: + warnings.warn( + "The use of a device for the objective function will soon be deprecated", + category=DeprecationWarning, + ) + device = self.device + else: + device = tensordict.device tddevice = tensordict.to(device) td_copy = tddevice.clone(False) @@ -304,13 +307,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: pred_val_index = (pred_val * action).sum(-1) target_value = self.value_estimator.value_estimate( - tddevice.clone(False), target_params=self.target_value_network_params + td_copy, target_params=self.target_value_network_params ).squeeze(-1) - priority_tensor = (pred_val_index - target_value).pow(2) - priority_tensor = priority_tensor.detach().unsqueeze(-1) - if tddevice.device is not None: - priority_tensor = priority_tensor.to(tddevice.device) + with torch.no_grad(): + priority_tensor = (pred_val_index - target_value).pow(2) + priority_tensor = priority_tensor.unsqueeze(-1) + if tensordict.device is not None: + priority_tensor = priority_tensor.to(tensordict.device) tensordict.set( self.tensor_keys.priority, diff --git a/torchrl/objectives/multiagent/__init__.py b/torchrl/objectives/multiagent/__init__.py new file mode 100644 index 00000000000..7340cffd841 --- /dev/null +++ b/torchrl/objectives/multiagent/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .qmixer import QMixerLoss diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py new file mode 100644 index 00000000000..e9eca7ce293 --- /dev/null +++ b/torchrl/objectives/multiagent/qmixer.py @@ -0,0 +1,377 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import warnings +from copy import deepcopy +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import dispatch, make_functional, repopulate_module, TensorDictModule +from tensordict.utils import NestedKey +from torch import nn + +from torchrl.data.tensor_specs import TensorSpec + +from torchrl.modules import SafeSequential +from torchrl.modules.tensordict_module.actors import QValueActor +from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible + +from torchrl.modules.utils.utils import _find_action_space + +from torchrl.objectives.common import LossModule + +from torchrl.objectives.utils import ( + _cache_values, + _GAMMA_LMBDA_DEPREC_WARNING, + default_value_kwargs, + distance_loss, + ValueEstimators, +) +from torchrl.objectives.value import TDLambdaEstimator +from torchrl.objectives.value.advantages import TD0Estimator, TD1Estimator + + +class QMixerLoss(LossModule): + """The QMixer loss class. + + Mixes local agent q values into a global q value according to a mixing network and then + uses DQN updates on the global value. + This loss is for multi-agent applications. + Therefore, it expects the 'local_value', 'action_value' and 'action' keys + to have an agent dimension (this is visible in the dafault AcceptedKeys). + This dimension will be mixed by the mixer which will compute a 'global_value' key, used for a DQN objective. + The premade mixers of type :class:`torchrl.modules.models.multiagent.Mixer` will expect the multi-agent + dimension to be the penultimate one. + + Args: + local_value_network (QValueActor or nn.Module): a local Q value operator. + mixer_network (TensorDictModule or nn.Module): a mixer network mapping the agents' local Q values + and an optional state to the global Q value. It is suggested to provide a TensorDictModule + wrapping a mixer from :class:`torchrl.modules.models.multiagent.Mixer`. + + Keyword Args: + loss_function (str, optional): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". + Defaults to "l2". + delay_value (bool, optional): whether to duplicate the value network + into a new target value network to + create a double DQN. Default is ``False``. + action_space (str or TensorSpec, optional): Action space. Must be one of + ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, + or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, + :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, + :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + If not provided, an attempt to retrieve it from the value network + will be made. + priority_key (NestedKey, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] + The key at which priority is assumed to be stored within TensorDicts added + to this ReplayBuffer. This is to be used when the sampler is of type + :class:`~torchrl.data.PrioritizedSampler`. Defaults to ``"td_error"``. + + Examples: + >>> import torch + >>> from torch import nn + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.modules import QValueModule, SafeSequential + >>> from torchrl.modules.models.multiagent import QMixer + >>> from torchrl.objectives.multiagent import QMixerLoss + >>> n_agents = 4 + >>> module = TensorDictModule( + ... nn.Linear(10,3), in_keys=[("agents", "observation")], out_keys=[("agents", "action_value")] + ... ) + >>> value_module = QValueModule( + ... action_value_key=("agents", "action_value"), + ... out_keys=[ + ... ("agents", "action"), + ... ("agents", "action_value"), + ... ("agents", "chosen_action_value"), + ... ], + ... action_space="categorical", + ... ) + >>> qnet = SafeSequential(module, value_module) + >>> qmixer = TensorDictModule( + ... module=QMixer( + ... state_shape=(64, 64, 3), + ... mixing_embed_dim=32, + ... n_agents=n_agents, + ... device="cpu", + ... ), + ... in_keys=[("agents", "chosen_action_value"), "state"], + ... out_keys=["chosen_action_value"], + ... ) + >>> loss = QMixerLoss(qnet, qmixer, action_space="categorical") + >>> td = TensorDict( + ... { + ... "agents": TensorDict( + ... {"observation": torch.zeros(32, n_agents, 10)}, [32, n_agents] + ... ), + ... "state": torch.zeros(32, 64, 64, 3), + ... "next": TensorDict( + ... { + ... "agents": TensorDict( + ... {"observation": torch.zeros(32, n_agents, 10)}, [32, n_agents] + ... ), + ... "state": torch.zeros(32, 64, 64, 3), + ... "reward": torch.zeros(32, 1), + ... "done": torch.zeros(32, 1, dtype=torch.bool), + ... }, + ... [32], + ... ), + ... }, + ... [32], + ... ) + >>> loss(qnet(td)) + TensorDict( + fields={ + loss: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + advantage (NestedKey): The input tensordict key where the advantage is expected. + Will be used for the underlying value estimator. Defaults to ``"advantage"``. + value_target (NestedKey): The input tensordict key where the target state value is expected. + Will be used for the underlying value estimator Defaults to ``"value_target"``. + local_value (NestedKey): The input tensordict key where the local chosen action value is expected. + Will be used for the underlying value estimator. Defaults to ``("agents", "chosen_action_value")``. + global_value (NestedKey): The input tensordict key where the global chosen action value is expected. + Will be used for the underlying value estimator. Defaults to ``"chosen_action_value"``. + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``("agents", "action")``. + action_value (NestedKey): The input tensordict key where the action value is expected. + Defaults to ``("agents", "action_value")``. + priority (NestedKey): The input tensordict key where the target priority is written to. + Defaults to ``"td_error"``. + reward (NestedKey): The input tensordict key where the reward is expected. + Will be used for the underlying value estimator. Defaults to ``"reward"``. + done (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is done. Will be used for the underlying value estimator. + Defaults to ``"done"``. + """ + + advantage: NestedKey = "advantage" + value_target: NestedKey = "value_target" + local_value: NestedKey = ("agents", "chosen_action_value") + global_value: NestedKey = "chosen_action_value" + action_value: NestedKey = ("agents", "action_value") + action: NestedKey = ("agents", "action") + priority: NestedKey = "td_error" + reward: NestedKey = "reward" + done: NestedKey = "done" + + default_keys = _AcceptedKeys() + default_value_estimator = ValueEstimators.TD0 + out_keys = ["loss"] + + def __init__( + self, + local_value_network: Union[QValueActor, nn.Module], + mixer_network: Union[TensorDictModule, nn.Module], + *, + loss_function: Optional[str] = "l2", + delay_value: bool = False, + gamma: float = None, + action_space: Union[str, TensorSpec] = None, + priority_key: str = None, + ) -> None: + super().__init__() + self._in_keys = None + self._set_deprecated_ctor_keys(priority=priority_key) + self.delay_value = delay_value + local_value_network = ensure_tensordict_compatible( + module=local_value_network, + wrapper_type=QValueActor, + action_space=action_space, + ) + if not isinstance(mixer_network, TensorDictModule): + # If it is not a TensorDictModule we make it one with default keys + mixer_network = ensure_tensordict_compatible( + module=mixer_network, + in_keys=[self.tensor_keys.local_value], + out_keys=[self.tensor_keys.global_value], + ) + + global_value_network = SafeSequential(local_value_network, mixer_network) + params = make_functional(global_value_network) + self.global_value_network = deepcopy(global_value_network) + repopulate_module(local_value_network, params["module", "0"]) + repopulate_module(mixer_network, params["module", "1"]) + + self.convert_to_functional( + local_value_network, + "local_value_network", + create_target_params=self.delay_value, + ) + self.convert_to_functional( + mixer_network, + "mixer_network", + create_target_params=self.delay_value, + ) + self.global_value_network.module[0] = self.local_value_network + self.global_value_network.module[1] = self.mixer_network + + self.global_value_network_in_keys = global_value_network.in_keys + + self.loss_function = loss_function + if action_space is None: + # infer from value net + try: + action_space = local_value_network.spec + except AttributeError: + # let's try with action_space then + try: + action_space = local_value_network.action_space + except AttributeError: + raise ValueError(self.ACTION_SPEC_ERROR) + if action_space is None: + warnings.warn( + "action_space was not specified. QMixerLoss will default to 'one-hot'." + "This behaviour will be deprecated soon and a space will have to be passed." + "Check the QMixerLoss documentation to see how to pass the action space. " + ) + action_space = "one-hot" + + self.action_space = _find_action_space(action_space) + + if gamma is not None: + warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) + self.gamma = gamma + + def _forward_value_estimator_keys(self, **kwargs) -> None: + if self._value_estimator is not None: + self._value_estimator.set_keys( + advantage=self.tensor_keys.advantage, + value_target=self.tensor_keys.value_target, + value=self.tensor_keys.global_value, + reward=self.tensor_keys.reward, + done=self.tensor_keys.done, + ) + self._set_in_keys() + + def _set_in_keys(self): + keys = [ + self.tensor_keys.action, + ("next", self.tensor_keys.reward), + ("next", self.tensor_keys.done), + *self.global_value_network.in_keys, + *[("next", key) for key in self.global_value_network.in_keys], + ] + self._in_keys = list(set(keys)) + + @property + def in_keys(self): + if self._in_keys is None: + self._set_in_keys() + return self._in_keys + + def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): + if value_type is None: + value_type = self.default_value_estimator + self.value_type = value_type + hp = dict(default_value_kwargs(value_type)) + if hasattr(self, "gamma"): + hp["gamma"] = self.gamma + hp.update(hyperparams) + if value_type is ValueEstimators.TD1: + self._value_estimator = TD1Estimator( + **hp, value_network=self.global_value_network + ) + elif value_type is ValueEstimators.TD0: + self._value_estimator = TD0Estimator( + **hp, value_network=self.global_value_network + ) + elif value_type is ValueEstimators.GAE: + raise NotImplementedError( + f"Value type {value_type} it not implemented for loss {type(self)}." + ) + elif value_type is ValueEstimators.TDLambda: + self._value_estimator = TDLambdaEstimator( + **hp, value_network=self.global_value_network + ) + else: + raise NotImplementedError(f"Unknown value type {value_type}") + + tensor_keys = { + "advantage": self.tensor_keys.advantage, + "value_target": self.tensor_keys.value_target, + "value": self.tensor_keys.global_value, + "reward": self.tensor_keys.reward, + "done": self.tensor_keys.done, + } + self._value_estimator.set_keys(**tensor_keys) + + @dispatch + def forward(self, tensordict: TensorDictBase) -> TensorDict: + td_copy = tensordict.clone(False) + self.local_value_network( + td_copy, + params=self.local_value_network_params, + ) + + action = tensordict.get(self.tensor_keys.action) + pred_val = td_copy.get( + self.tensor_keys.action_value + ) # [*B, n_agents, n_actions] + + if self.action_space == "categorical": + if action.shape != pred_val.shape: + # unsqueeze the action if it lacks on trailing singleton dim + action = action.unsqueeze(-1) + pred_val_index = torch.gather(pred_val, -1, index=action) + else: + action = action.to(torch.float) + pred_val_index = (pred_val * action).sum(-1, keepdim=True) + + td_copy.set(self.tensor_keys.local_value, pred_val_index) # [*B, n_agents, 1] + self.mixer_network(td_copy, params=self.mixer_network_params) + pred_val_index = td_copy.get(self.tensor_keys.global_value).squeeze(-1) + # [*B] this is global and shared among the agents as will be the target + + target_value = self.value_estimator.value_estimate( + td_copy, + target_params=self._cached_target_params, + ).squeeze(-1) + + with torch.no_grad(): + priority_tensor = (pred_val_index - target_value).pow(2) + priority_tensor = priority_tensor.unsqueeze(-1) + if tensordict.device is not None: + priority_tensor = priority_tensor.to(tensordict.device) + + tensordict.set( + self.tensor_keys.priority, + priority_tensor, + inplace=True, + ) + loss = distance_loss(pred_val_index, target_value, self.loss_function) + return TensorDict({"loss": loss.mean()}, []) + + @property + @_cache_values + def _cached_target_params(self): + target_params = TensorDict( + { + "module": { + "0": self.target_local_value_network_params, + "1": self.target_mixer_network_params, + } + }, + batch_size=self.target_local_value_network_params.batch_size, + device=self.target_local_value_network_params.device, + ) + return target_params diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 30342910d4e..31c8c291c5b 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -389,7 +389,7 @@ def is_stateless(self): return self.value_network._is_stateless def _next_value(self, tensordict, target_params, kwargs): - step_td = step_mdp(tensordict) + step_td = step_mdp(tensordict, keep_other=False) if self.value_network is not None: if target_params is not None: kwargs["params"] = target_params