From 0eabb789739a5e9a2a9f244076c7a9bf8bc7b48e Mon Sep 17 00:00:00 2001 From: Pau Riba Date: Wed, 6 Nov 2024 14:05:21 +0100 Subject: [PATCH] [BugFix] Support for tensor collection in the `PPOLoss` (#2543) Co-authored-by: Pau Riba --- test/test_cost.py | 93 ++++++++++++++++++++++++++++++++++++++- torchrl/objectives/ppo.py | 2 + 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/test/test_cost.py b/test/test_cost.py index 0066c024776..0b36f5b8961 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -7650,6 +7650,7 @@ def _create_mock_actor( observation_key="observation", sample_log_prob_key="sample_log_prob", composite_action_dist=False, + aggregate_probabilities=True, ): # Actor action_spec = Bounded( @@ -7668,7 +7669,7 @@ def _create_mock_actor( "action1": (action_key, "action1"), }, log_prob_key=sample_log_prob_key, - aggregate_probabilities=True, + aggregate_probabilities=aggregate_probabilities, ) module_out_keys = [ ("params", "action1", "loc"), @@ -8038,6 +8039,96 @@ def test_ppo( assert counter == 2 actor.zero_grad() + @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) + @pytest.mark.parametrize("gradient_mode", (True, False)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + @pytest.mark.parametrize("functional", [True, False]) + def test_ppo_composite_no_aggregate( + self, loss_class, device, gradient_mode, advantage, td_est, functional + ): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_ppo(device=device, composite_action_dist=True) + + actor = self._create_mock_actor( + device=device, + composite_action_dist=True, + aggregate_probabilities=False, + ) + value = self._create_mock_value(device=device) + if advantage == "gae": + advantage = GAE( + gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode + ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) + elif advantage == "td": + advantage = TD1Estimator( + gamma=0.9, value_network=value, differentiable=gradient_mode + ) + elif advantage == "td_lambda": + advantage = TDLambdaEstimator( + gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode + ) + elif advantage is None: + pass + else: + raise NotImplementedError + + loss_fn = loss_class( + actor, + value, + loss_critic_type="l2", + functional=functional, + ) + if advantage is not None: + advantage(td) + else: + if td_est is not None: + loss_fn.make_value_estimator(td_est) + + loss = loss_fn(td) + if isinstance(loss_fn, KLPENPPOLoss): + kl = loss.pop("kl_approx") + assert (kl != 0).any() + + loss_critic = loss["loss_critic"] + loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) + loss_critic.backward(retain_graph=True) + # check that grads are independent and non null + named_parameters = loss_fn.named_parameters() + counter = 0 + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + counter += 1 + assert "actor" not in name + assert "critic" in name + if p.grad is None: + assert ("actor" in name) or ("target_" in name) + assert ("critic" not in name) or ("target_" in name) + assert counter == 2 + + value.zero_grad() + loss_objective.backward() + counter = 0 + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + counter += 1 + assert "actor" in name + assert "critic" not in name + if p.grad is None: + assert ("actor" not in name) or ("target_" in name) + assert ("critic" in name) or ("target_" in name) + assert counter == 2 + actor.zero_grad() + @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True,)) @pytest.mark.parametrize("device", get_default_devices()) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index efc951b3999..ef78bc4bb0f 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -463,6 +463,8 @@ def reset(self) -> None: def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: try: entropy = dist.entropy() + if is_tensor_collection(entropy): + entropy = entropy.get(dist.entropy_key) except NotImplementedError: x = dist.rsample((self.samples_mc_entropy,)) log_prob = dist.log_prob(x)