From 49d7f74a852b117afa527e4b251807c32f72f225 Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Wed, 4 Sep 2024 07:13:05 -0700 Subject: [PATCH] [BugFix] Allow for composite action distributions in PPO/A2C losses (#2391) --- test/test_cost.py | 479 +++++++++++++++++++++++++++++++------- torchrl/objectives/a2c.py | 33 ++- torchrl/objectives/ppo.py | 49 +++- 3 files changed, 460 insertions(+), 101 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 2af5a88f9fa..ab95c55ef83 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -13,8 +13,8 @@ from packaging import version as pack_version from tensordict._C import unravel_keys - from tensordict.nn import ( + CompositeDistribution, InteractionType, ProbabilisticTensorDictModule, ProbabilisticTensorDictModule as ProbMod, @@ -25,7 +25,6 @@ TensorDictSequential as Seq, ) from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type - from torchrl.modules.models import QMixer _has_functorch = True @@ -7544,21 +7543,45 @@ def _create_mock_actor( obs_dim=3, action_dim=4, device="cpu", + action_key="action", observation_key="observation", sample_log_prob_key="sample_log_prob", + composite_action_dist=False, ): # Actor action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) + if composite_action_dist: + action_spec = Composite({action_key: {"action1": action_spec}}) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + name_map={ + "action1": (action_key, "action1"), + }, + log_prob_key=sample_log_prob_key, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] module = TensorDictModule( - net, in_keys=[observation_key], out_keys=["loc", "scale"] + net, in_keys=[observation_key], out_keys=module_out_keys ) actor = ProbabilisticActor( module=module, - distribution_class=TanhNormal, - in_keys=["loc", "scale"], + distribution_class=distribution_class, + in_keys=actor_in_keys, + out_keys=[action_key], spec=action_spec, return_log_prob=True, log_prob_key=sample_log_prob_key, @@ -7582,22 +7605,51 @@ def _create_mock_value( ) return value.to(device) - def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + def _create_mock_actor_value( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + composite_action_dist=False, + sample_log_prob_key="sample_log_prob", + ): # Actor action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) + if composite_action_dist: + action_spec = Composite({"action": {"action1": action_spec}}) base_layer = nn.Linear(obs_dim, 5) net = nn.Sequential( base_layer, nn.Linear(5, 2 * action_dim), NormalParamExtractor() ) + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + name_map={ + "action1": ("action", "action1"), + }, + log_prob_key=sample_log_prob_key, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] module = TensorDictModule( - net, in_keys=["observation"], out_keys=["loc", "scale"] + net, in_keys=["observation"], out_keys=module_out_keys ) actor = ProbabilisticActor( module=module, - distribution_class=TanhNormal, - in_keys=["loc", "scale"], + distribution_class=distribution_class, + in_keys=actor_in_keys, spec=action_spec, return_log_prob=True, ) @@ -7609,22 +7661,49 @@ def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu return actor.to(device), value.to(device) def _create_mock_actor_value_shared( - self, batch=2, obs_dim=3, action_dim=4, device="cpu" + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + composite_action_dist=False, + sample_log_prob_key="sample_log_prob", ): # Actor action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) + if composite_action_dist: + action_spec = Composite({"action": {"action1": action_spec}}) base_layer = nn.Linear(obs_dim, 5) common = TensorDictModule( base_layer, in_keys=["observation"], out_keys=["hidden"] ) net = nn.Sequential(nn.Linear(5, 2 * action_dim), NormalParamExtractor()) - module = TensorDictModule(net, in_keys=["hidden"], out_keys=["loc", "scale"]) + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + name_map={ + "action1": ("action", "action1"), + }, + log_prob_key=sample_log_prob_key, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] + module = TensorDictModule(net, in_keys=["hidden"], out_keys=module_out_keys) actor_head = ProbabilisticActor( module=module, - distribution_class=TanhNormal, - in_keys=["loc", "scale"], + distribution_class=distribution_class, + in_keys=actor_in_keys, spec=action_spec, return_log_prob=True, ) @@ -7654,6 +7733,7 @@ def _create_mock_data_ppo( done_key="done", terminated_key="terminated", sample_log_prob_key="sample_log_prob", + composite_action_dist=False, ): # create a tensordict obs = torch.randn(batch, obs_dim, device=device) @@ -7679,13 +7759,17 @@ def _create_mock_data_ppo( terminated_key: terminated, reward_key: reward, }, - action_key: action, + action_key: {"action1": action} if composite_action_dist else action, sample_log_prob_key: torch.randn_like(action[..., 1]) / 10, - loc_key: loc, - scale_key: scale, }, device=device, ) + if composite_action_dist: + td[("params", "action1", loc_key)] = loc + td[("params", "action1", scale_key)] = scale + else: + td[loc_key] = loc + td[scale_key] = scale return td def _create_seq_mock_data_ppo( @@ -7698,6 +7782,7 @@ def _create_seq_mock_data_ppo( device="cpu", sample_log_prob_key="sample_log_prob", action_key="action", + composite_action_dist=False, ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) @@ -7713,8 +7798,11 @@ def _create_seq_mock_data_ppo( done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) + action = action.masked_fill_(~mask.unsqueeze(-1), 0.0) params_mean = torch.randn_like(action) / 10 params_scale = torch.rand_like(action) / 10 + loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0) + scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0) td = TensorDict( batch_size=(batch, T), source={ @@ -7726,16 +7814,21 @@ def _create_seq_mock_data_ppo( "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, - action_key: action.masked_fill_(~mask.unsqueeze(-1), 0.0), + action_key: {"action1": action} if composite_action_dist else action, sample_log_prob_key: ( torch.randn_like(action[..., 1]) / 10 ).masked_fill_(~mask, 0.0), - "loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0), - "scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0), }, device=device, names=[None, "time"], ) + if composite_action_dist: + td[("params", "action1", "loc")] = loc + td[("params", "action1", "scale")] = scale + else: + td["loc"] = loc + td["scale"] = scale + return td @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @@ -7744,6 +7837,7 @@ def _create_seq_mock_data_ppo( @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) @pytest.mark.parametrize("functional", [True, False]) + @pytest.mark.parametrize("composite_action_dist", [True, False]) def test_ppo( self, loss_class, @@ -7752,11 +7846,16 @@ def test_ppo( advantage, td_est, functional, + composite_action_dist, ): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_ppo(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) if advantage == "gae": advantage = GAE( @@ -7796,7 +7895,10 @@ def test_ppo( loss = loss_fn(td) if isinstance(loss_fn, KLPENPPOLoss): - kl = loss.pop("kl") + if composite_action_dist: + kl = loss.pop("kl_approx") + else: + kl = loss.pop("kl") assert (kl != 0).any() loss_critic = loss["loss_critic"] @@ -7833,10 +7935,15 @@ def test_ppo( @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True,)) @pytest.mark.parametrize("device", get_default_devices()) - def test_ppo_state_dict(self, loss_class, device, gradient_mode): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_state_dict( + self, loss_class, device, gradient_mode, composite_action_dist + ): torch.manual_seed(self.seed) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) loss_fn = loss_class(actor, value, loss_critic_type="l2") sd = loss_fn.state_dict() @@ -7846,11 +7953,16 @@ def test_ppo_state_dict(self, loss_class, device, gradient_mode): @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) - def test_ppo_shared(self, loss_class, device, advantage): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_ppo(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) - actor, value = self._create_mock_actor_value(device=device) + actor, value = self._create_mock_actor_value( + device=device, composite_action_dist=composite_action_dist + ) if advantage == "gae": advantage = GAE( gamma=0.9, @@ -7932,18 +8044,24 @@ def test_ppo_shared(self, loss_class, device, advantage): ) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("separate_losses", [True, False]) + @pytest.mark.parametrize("composite_action_dist", [True, False]) def test_ppo_shared_seq( self, loss_class, device, advantage, separate_losses, + composite_action_dist, ): """Tests PPO with shared module with and without passing twice across the common module.""" torch.manual_seed(self.seed) - td = self._create_seq_mock_data_ppo(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) - model, actor, value = self._create_mock_actor_value_shared(device=device) + model, actor, value = self._create_mock_actor_value_shared( + device=device, composite_action_dist=composite_action_dist + ) value2 = value[-1] # prune the common module if advantage == "gae": advantage = GAE( @@ -8001,8 +8119,20 @@ def test_ppo_shared_seq( grad2 = TensorDict(dict(model.named_parameters()), []).apply( lambda x: x.grad.clone() ) - assert_allclose_td(loss, loss2) - assert_allclose_td(grad, grad2) + if composite_action_dist and loss_class is KLPENPPOLoss: + # KL computation for composite dist is based on randomly + # sampled data, thus will not be the same. + # Similarly, objective loss depends on the KL, so ir will + # not be the same either. + # Finally, gradients will be different too. + loss.pop("kl", None) + loss2.pop("kl", None) + loss.pop("loss_objective", None) + loss2.pop("loss_objective", None) + assert_allclose_td(loss, loss2) + else: + assert_allclose_td(loss, loss2) + assert_allclose_td(grad, grad2) model.zero_grad() @pytest.mark.skipif( @@ -8012,11 +8142,18 @@ def test_ppo_shared_seq( @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) - def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_diff( + self, loss_class, device, gradient_mode, advantage, composite_action_dist + ): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_ppo(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) if advantage == "gae": advantage = GAE( @@ -8105,8 +8242,9 @@ def zero_param(p): ValueEstimators.TDLambda, ], ) - def test_ppo_tensordict_keys(self, loss_class, td_est): - actor = self._create_mock_actor() + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist): + actor = self._create_mock_actor(composite_action_dist=composite_action_dist) value = self._create_mock_value() loss_fn = loss_class(actor, value, loss_critic_type="l2") @@ -8145,7 +8283,10 @@ def test_ppo_tensordict_keys(self, loss_class, td_est): @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) - def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_tensordict_keys_run( + self, loss_class, advantage, td_est, composite_action_dist + ): """Test PPO loss module with non-default tensordict keys.""" torch.manual_seed(self.seed) gradient_mode = True @@ -8160,9 +8301,12 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): td = self._create_seq_mock_data_ppo( sample_log_prob_key=tensor_keys["sample_log_prob"], action_key=tensor_keys["action"], + composite_action_dist=composite_action_dist, ) actor = self._create_mock_actor( - sample_log_prob_key=tensor_keys["sample_log_prob"] + sample_log_prob_key=tensor_keys["sample_log_prob"], + composite_action_dist=composite_action_dist, + action_key=tensor_keys["action"], ) value = self._create_mock_value(out_keys=[tensor_keys["value"]]) @@ -8253,6 +8397,12 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + @pytest.mark.parametrize( + "composite_action_dist", + [ + False, + ], + ) def test_ppo_notensordict( self, loss_class, @@ -8262,6 +8412,7 @@ def test_ppo_notensordict( reward_key, done_key, terminated_key, + composite_action_dist, ): torch.manual_seed(self.seed) td = self._create_mock_data_ppo( @@ -8271,10 +8422,14 @@ def test_ppo_notensordict( reward_key=reward_key, done_key=done_key, terminated_key=terminated_key, + composite_action_dist=composite_action_dist, ) actor = self._create_mock_actor( - observation_key=observation_key, sample_log_prob_key=sample_log_prob_key + observation_key=observation_key, + sample_log_prob_key=sample_log_prob_key, + composite_action_dist=composite_action_dist, + action_key=action_key, ) value = self._create_mock_value(observation_key=observation_key) @@ -8297,7 +8452,9 @@ def test_ppo_notensordict( f"next_{observation_key}": td.get(("next", observation_key)), } if loss_class is KLPENPPOLoss: - kwargs.update({"loc": td.get("loc"), "scale": td.get("scale")}) + loc_key = "params" if composite_action_dist else "loc" + scale_key = "params" if composite_action_dist else "scale" + kwargs.update({loc_key: td.get(loc_key), scale_key: td.get(scale_key)}) td = TensorDict(kwargs, td.batch_size, names=["time"]).unflatten_keys("_") @@ -8310,6 +8467,7 @@ def test_ppo_notensordict( loss_val = loss(**kwargs) torch.manual_seed(self.seed) if beta is not None: + loss.beta = beta.clone() loss_val_td = loss(td) @@ -8337,15 +8495,20 @@ def test_ppo_notensordict( @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_ppo_reduction(self, reduction, loss_class): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_reduction(self, reduction, loss_class, composite_action_dist): torch.manual_seed(self.seed) device = ( torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda") ) - td = self._create_seq_mock_data_ppo(device=device) - actor = self._create_mock_actor(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) advantage = GAE( gamma=0.9, @@ -8373,10 +8536,17 @@ def test_ppo_reduction(self, reduction, loss_class): @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("clip_value", [True, False, None, 0.5, torch.tensor(0.5)]) - def test_ppo_value_clipping(self, clip_value, loss_class, device): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_value_clipping( + self, clip_value, loss_class, device, composite_action_dist + ): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_ppo(device=device) - actor = self._create_mock_actor(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) advantage = GAE( gamma=0.9, @@ -8435,22 +8605,46 @@ def _create_mock_actor( obs_dim=3, action_dim=4, device="cpu", + action_key="action", observation_key="observation", sample_log_prob_key="sample_log_prob", + composite_action_dist=False, ): # Actor action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) + if composite_action_dist: + action_spec = Composite({action_key: {"action1": action_spec}}) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + name_map={ + "action1": (action_key, "action1"), + }, + log_prob_key=sample_log_prob_key, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] module = TensorDictModule( - net, in_keys=[observation_key], out_keys=["loc", "scale"] + net, in_keys=[observation_key], out_keys=module_out_keys ) actor = ProbabilisticActor( module=module, - in_keys=["loc", "scale"], + in_keys=actor_in_keys, + out_keys=[action_key], spec=action_spec, - distribution_class=TanhNormal, + distribution_class=distribution_class, return_log_prob=True, log_prob_key=sample_log_prob_key, ) @@ -8474,7 +8668,15 @@ def _create_mock_value( return value.to(device) def _create_mock_common_layer_setup( - self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2, T=10 + self, + n_obs=3, + n_act=4, + ncells=4, + batch=2, + n_hidden=2, + T=10, + composite_action_dist=False, + sample_log_prob_key="sample_log_prob", ): common_net = MLP( num_cells=ncells, @@ -8495,10 +8697,11 @@ def _create_mock_common_layer_setup( out_features=1, ) batch = [batch, T] + action = torch.randn(*batch, n_act) td = TensorDict( { "obs": torch.randn(*batch, n_obs), - "action": torch.randn(*batch, n_act), + "action": {"action1": action} if composite_action_dist else action, "sample_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), "terminated": torch.zeros(*batch, 1, dtype=torch.bool), @@ -8513,14 +8716,35 @@ def _create_mock_common_layer_setup( names=[None, "time"], ) common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"]) + + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + name_map={ + "action1": ("action", "action1"), + }, + log_prob_key=sample_log_prob_key, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] + actor = ProbSeq( common, Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), - Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), + Mod(NormalParamExtractor(), in_keys=["param"], out_keys=module_out_keys), ProbMod( - in_keys=["loc", "scale"], + in_keys=actor_in_keys, out_keys=["action"], - distribution_class=TanhNormal, + distribution_class=distribution_class, ), ) critic = Seq( @@ -8544,6 +8768,7 @@ def _create_seq_mock_data_a2c( done_key="done", terminated_key="terminated", sample_log_prob_key="sample_log_prob", + composite_action_dist=False, ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) @@ -8559,8 +8784,11 @@ def _create_seq_mock_data_a2c( done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) + action = action.masked_fill_(~mask.unsqueeze(-1), 0.0) params_mean = torch.randn_like(action) / 10 params_scale = torch.rand_like(action) / 10 + loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0) + scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0) td = TensorDict( batch_size=(batch, T), source={ @@ -8572,17 +8800,21 @@ def _create_seq_mock_data_a2c( reward_key: reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, - action_key: action.masked_fill_(~mask.unsqueeze(-1), 0.0), + action_key: {"action1": action} if composite_action_dist else action, sample_log_prob_key: torch.randn_like(action[..., 1]).masked_fill_( ~mask, 0.0 ) / 10, - "loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0), - "scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0), }, device=device, names=[None, "time"], ) + if composite_action_dist: + td[("params", "action1", "loc")] = loc + td[("params", "action1", "scale")] = scale + else: + td["loc"] = loc + td["scale"] = scale return td @pytest.mark.parametrize("gradient_mode", (True, False)) @@ -8590,11 +8822,24 @@ def _create_seq_mock_data_a2c( @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) @pytest.mark.parametrize("functional", (True, False)) - def test_a2c(self, device, gradient_mode, advantage, td_est, functional): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c( + self, + device, + gradient_mode, + advantage, + td_est, + functional, + composite_action_dist, + ): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_a2c(device=device) + td = self._create_seq_mock_data_a2c( + device=device, composite_action_dist=composite_action_dist + ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) if advantage == "gae": advantage = GAE( @@ -8627,14 +8872,24 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional): functional=functional, ) + def set_requires_grad(tensor, requires_grad): + tensor.requires_grad = requires_grad + return tensor + # Check error is raised when actions require grads - td["action"].requires_grad = True + if composite_action_dist: + td["action"].apply_(lambda x: set_requires_grad(x, True)) + else: + td["action"].requires_grad = True with pytest.raises( RuntimeError, - match="tensordict stored action require grad.", + match="tensordict stored action requires grad.", ): _ = loss_fn._log_probs(td) - td["action"].requires_grad = False + if composite_action_dist: + td["action"].apply_(lambda x: set_requires_grad(x, False)) + else: + td["action"].requires_grad = False td = td.exclude(loss_fn.tensor_keys.value_target) if advantage is not None: @@ -8675,9 +8930,12 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional): @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("device", get_default_devices()) - def test_a2c_state_dict(self, device, gradient_mode): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_state_dict(self, device, gradient_mode, composite_action_dist): torch.manual_seed(self.seed) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) loss_fn = A2CLoss(actor, value, loss_critic_type="l2") sd = loss_fn.state_dict() @@ -8685,23 +8943,36 @@ def test_a2c_state_dict(self, device, gradient_mode): loss_fn2.load_state_dict(sd) @pytest.mark.parametrize("separate_losses", [False, True]) - def test_a2c_separate_losses(self, separate_losses): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_separate_losses(self, separate_losses, composite_action_dist): torch.manual_seed(self.seed) - actor, critic, common, td = self._create_mock_common_layer_setup() + actor, critic, common, td = self._create_mock_common_layer_setup( + composite_action_dist=composite_action_dist + ) loss_fn = A2CLoss( actor_network=actor, critic_network=critic, separate_losses=separate_losses, ) + def set_requires_grad(tensor, requires_grad): + tensor.requires_grad = requires_grad + return tensor + # Check error is raised when actions require grads - td["action"].requires_grad = True + if composite_action_dist: + td["action"].apply_(lambda x: set_requires_grad(x, True)) + else: + td["action"].requires_grad = True with pytest.raises( RuntimeError, - match="tensordict stored action require grad.", + match="tensordict stored action requires grad.", ): _ = loss_fn._log_probs(td) - td["action"].requires_grad = False + if composite_action_dist: + td["action"].apply_(lambda x: set_requires_grad(x, False)) + else: + td["action"].requires_grad = False td = td.exclude(loss_fn.tensor_keys.value_target) loss = loss_fn(td) @@ -8745,13 +9016,18 @@ def test_a2c_separate_losses(self, separate_losses): @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) - def test_a2c_diff(self, device, gradient_mode, advantage): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_diff(self, device, gradient_mode, advantage, composite_action_dist): if pack_version.parse(torch.__version__) > pack_version.parse("1.14"): raise pytest.skip("make_functional_with_buffers needs to be changed") torch.manual_seed(self.seed) - td = self._create_seq_mock_data_a2c(device=device) + td = self._create_seq_mock_data_a2c( + device=device, composite_action_dist=composite_action_dist + ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) if advantage == "gae": advantage = GAE( @@ -8821,8 +9097,9 @@ def test_a2c_diff(self, device, gradient_mode, advantage): ValueEstimators.TDLambda, ], ) - def test_a2c_tensordict_keys(self, td_est): - actor = self._create_mock_actor() + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_tensordict_keys(self, td_est, composite_action_dist): + actor = self._create_mock_actor(composite_action_dist=composite_action_dist) value = self._create_mock_value() loss_fn = A2CLoss(actor, value, loss_critic_type="l2") @@ -8867,7 +9144,10 @@ def test_a2c_tensordict_keys(self, td_est): ) @pytest.mark.parametrize("advantage", ("gae", "vtrace", None)) @pytest.mark.parametrize("device", get_default_devices()) - def test_a2c_tensordict_keys_run(self, device, advantage, td_est): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_tensordict_keys_run( + self, device, advantage, td_est, composite_action_dist + ): """Test A2C loss module with non-default tensordict keys.""" torch.manual_seed(self.seed) gradient_mode = True @@ -8887,10 +9167,14 @@ def test_a2c_tensordict_keys_run(self, device, advantage, td_est): done_key=done_key, terminated_key=terminated_key, sample_log_prob_key=sample_log_prob_key, + composite_action_dist=composite_action_dist, ) actor = self._create_mock_actor( - device=device, sample_log_prob_key=sample_log_prob_key + device=device, + sample_log_prob_key=sample_log_prob_key, + composite_action_dist=composite_action_dist, + action_key=action_key, ) value = self._create_mock_value(device=device, out_keys=[value_key]) if advantage == "gae": @@ -8969,12 +9253,26 @@ def test_a2c_tensordict_keys_run(self, device, advantage, td_est): @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + @pytest.mark.parametrize( + "composite_action_dist", + [ + False, + ], + ) def test_a2c_notensordict( - self, action_key, observation_key, reward_key, done_key, terminated_key + self, + action_key, + observation_key, + reward_key, + done_key, + terminated_key, + composite_action_dist, ): torch.manual_seed(self.seed) - actor = self._create_mock_actor(observation_key=observation_key) + actor = self._create_mock_actor( + observation_key=observation_key, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(observation_key=observation_key) td = self._create_seq_mock_data_a2c( action_key=action_key, @@ -8982,6 +9280,7 @@ def test_a2c_notensordict( reward_key=reward_key, done_key=done_key, terminated_key=terminated_key, + composite_action_dist=composite_action_dist, ) loss = A2CLoss(actor, value) @@ -9026,15 +9325,20 @@ def test_a2c_notensordict( assert loss_critic == loss_val_td["loss_critic"] @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_a2c_reduction(self, reduction): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_reduction(self, reduction, composite_action_dist): torch.manual_seed(self.seed) device = ( torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda") ) - td = self._create_seq_mock_data_a2c(device=device) - actor = self._create_mock_actor(device=device) + td = self._create_seq_mock_data_a2c( + device=device, composite_action_dist=composite_action_dist + ) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) advantage = GAE( gamma=0.9, @@ -9061,10 +9365,15 @@ def test_a2c_reduction(self, reduction): @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("clip_value", [True, None, 0.5, torch.tensor(0.5)]) - def test_a2c_value_clipping(self, clip_value, device): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_value_clipping(self, clip_value, device, composite_action_dist): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_a2c(device=device) - actor = self._create_mock_actor(device=device) + td = self._create_seq_mock_data_a2c( + device=device, composite_action_dist=composite_action_dist + ) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) advantage = GAE( gamma=0.9, diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index d3b2b4d2ac2..ff9b5f3883e 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -10,7 +10,12 @@ from typing import Tuple import torch -from tensordict import TensorDict, TensorDictBase, TensorDictParams +from tensordict import ( + is_tensor_collection, + TensorDict, + TensorDictBase, + TensorDictParams, +) from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule from tensordict.utils import NestedKey from torch import distributions as d @@ -190,6 +195,13 @@ class A2CLoss(LossModule): ... next_reward = torch.randn(*batch, 1), ... next_observation = torch.randn(*batch, n_obs)) >>> loss_obj.backward() + + .. note:: + There is an exception regarding compatibility with non-tensordict-based modules. + If the actor network is probabilistic and uses a :class:`~tensordict.nn.distributions.CompositeDistribution`, + this class must be used with tensordicts and cannot function as a tensordict-independent module. + This is because composite action spaces inherently rely on the structured representation of data provided by + tensordicts to handle their actions. """ @dataclass @@ -383,7 +395,10 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: entropy = dist.entropy() except NotImplementedError: x = dist.rsample((self.samples_mc_entropy,)) - entropy = -dist.log_prob(x).mean(0) + log_prob = dist.log_prob(x) + if is_tensor_collection(log_prob): + log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + entropy = -log_prob.mean(0) return entropy.unsqueeze(-1) def _log_probs( @@ -391,10 +406,6 @@ def _log_probs( ) -> Tuple[torch.Tensor, d.Distribution]: # current log_prob of actions action = tensordict.get(self.tensor_keys.action) - if action.requires_grad: - raise RuntimeError( - f"tensordict stored {self.tensor_keys.action} require grad." - ) tensordict_clone = tensordict.select( *self.actor_network.in_keys, strict=False ).clone() @@ -402,7 +413,15 @@ def _log_probs( self.actor_network ) if self.functional else contextlib.nullcontext(): dist = self.actor_network.get_dist(tensordict_clone) - log_prob = dist.log_prob(action) + if action.requires_grad: + raise RuntimeError( + f"tensordict stored {self.tensor_keys.action} requires grad." + ) + if isinstance(action, torch.Tensor): + log_prob = dist.log_prob(action) + else: + tensordict = dist.log_prob(tensordict) + log_prob = tensordict.get(self.tensor_keys.sample_log_prob) log_prob = log_prob.unsqueeze(-1) return log_prob, dist diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index b10ed5df98a..b4779a90663 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -12,7 +12,12 @@ from typing import Tuple import torch -from tensordict import TensorDict, TensorDictBase, TensorDictParams +from tensordict import ( + is_tensor_collection, + TensorDict, + TensorDictBase, + TensorDictParams, +) from tensordict.nn import ( dispatch, ProbabilisticTensorDictModule, @@ -238,6 +243,12 @@ class PPOLoss(LossModule): ... next_observation=torch.randn(*batch, n_obs)) >>> loss_objective.backward() + .. note:: + There is an exception regarding compatibility with non-tensordict-based modules. + If the actor network is probabilistic and uses a :class:`~tensordict.nn.distributions.CompositeDistribution`, + this class must be used with tensordicts and cannot function as a tensordict-independent module. + This is because composite action spaces inherently rely on the structured representation of data provided by + tensordicts to handle their actions. """ @dataclass @@ -449,7 +460,10 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: entropy = dist.entropy() except NotImplementedError: x = dist.rsample((self.samples_mc_entropy,)) - entropy = -dist.log_prob(x).mean(0) + log_prob = dist.log_prob(x) + if is_tensor_collection(log_prob): + log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + entropy = -log_prob.mean(0) return entropy.unsqueeze(-1) def _log_weight( @@ -457,20 +471,27 @@ def _log_weight( ) -> Tuple[torch.Tensor, d.Distribution]: # current log_prob of actions action = tensordict.get(self.tensor_keys.action) - if action.requires_grad: - raise RuntimeError( - f"tensordict stored {self.tensor_keys.action} requires grad." - ) with self.actor_network_params.to_module( self.actor_network ) if self.functional else contextlib.nullcontext(): dist = self.actor_network.get_dist(tensordict) - log_prob = dist.log_prob(action) prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob) if prev_log_prob.requires_grad: - raise RuntimeError("tensordict prev_log_prob requires grad.") + raise RuntimeError( + f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad." + ) + + if action.requires_grad: + raise RuntimeError( + f"tensordict stored {self.tensor_keys.action} requires grad." + ) + if isinstance(action, torch.Tensor): + log_prob = dist.log_prob(action) + else: + tensordict = dist.log_prob(tensordict) + log_prob = tensordict.get(self.tensor_keys.sample_log_prob) log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) @@ -1107,7 +1128,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) except NotImplementedError: x = previous_dist.sample((self.samples_mc_kl,)) - kl = (previous_dist.log_prob(x) - current_dist.log_prob(x)).mean(0) + previous_log_prob = previous_dist.log_prob(x) + current_log_prob = current_dist.log_prob(x) + if is_tensor_collection(x): + previous_log_prob = previous_log_prob.get( + self.tensor_keys.sample_log_prob + ) + current_log_prob = current_log_prob.get( + self.tensor_keys.sample_log_prob + ) + + kl = (previous_log_prob - current_log_prob).mean(0) kl = kl.unsqueeze(-1) neg_loss = neg_loss - self.beta * kl if kl.mean() > self.dtarg * 1.5: