From 90c8e40f64bb76601d93a9416fa8723cd607ffe2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 2 Dec 2024 16:24:13 +0000 Subject: [PATCH] [BugFix] Better account of composite distributions in PPO ghstack-source-id: 3d86f99bc5b20a53e4092d786e96a5f7e83405ac Pull Request resolved: https://github.com/pytorch/rl/pull/2622 --- torchrl/objectives/ppo.py | 53 +++++++++++++++++--------- torchrl/objectives/utils.py | 5 +++ torchrl/objectives/value/advantages.py | 22 ++++++++--- 3 files changed, 58 insertions(+), 22 deletions(-) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 8c64c1ba539..eb9a916dfc1 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -18,6 +18,7 @@ TensorDictParams, ) from tensordict.nn import ( + CompositeDistribution, dispatch, ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, @@ -33,6 +34,7 @@ _clip_value_loss, _GAMMA_LMBDA_DEPREC_ERROR, _reduce, + _sum_td_features, default_value_kwargs, distance_loss, ValueEstimators, @@ -462,9 +464,13 @@ def reset(self) -> None: def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: try: - entropy = dist.entropy() + if isinstance(dist, CompositeDistribution): + kwargs = {"aggregate_probabilities": False, "include_sum": False} + else: + kwargs = {} + entropy = dist.entropy(**kwargs) if is_tensor_collection(entropy): - entropy = entropy.get(dist.entropy_key) + entropy = _sum_td_features(entropy) except NotImplementedError: x = dist.rsample((self.samples_mc_entropy,)) log_prob = dist.log_prob(x) @@ -497,13 +503,20 @@ def _log_weight( if isinstance(action, torch.Tensor): log_prob = dist.log_prob(action) else: - maybe_log_prob = dist.log_prob(tensordict) - if not isinstance(maybe_log_prob, torch.Tensor): - # In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not - # be a tensor - log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob) + if isinstance(dist, CompositeDistribution): + is_composite = True + kwargs = { + "inplace": False, + "aggregate_probabilities": False, + "include_sum": False, + } else: - log_prob = maybe_log_prob + is_composite = False + kwargs = {} + log_prob = dist.log_prob(tensordict, **kwargs) + if is_composite and not isinstance(prev_log_prob, TensorDict): + log_prob = _sum_td_features(log_prob) + log_prob.view_as(prev_log_prob) log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) @@ -598,6 +611,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: advantage = (advantage - loc) / scale log_weight, dist, kl_approx = self._log_weight(tensordict) + if is_tensor_collection(log_weight): + log_weight = _sum_td_features(log_weight) + log_weight = log_weight.view(advantage.shape) neg_loss = log_weight.exp() * advantage td_out = TensorDict({"loss_objective": -neg_loss}, batch_size=[]) if self.entropy_bonus: @@ -1149,16 +1165,19 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) except NotImplementedError: x = previous_dist.sample((self.samples_mc_kl,)) - previous_log_prob = previous_dist.log_prob(x) - current_log_prob = current_dist.log_prob(x) + if isinstance(previous_dist, CompositeDistribution): + kwargs = { + "aggregate_probabilities": False, + "inplace": False, + "include_sum": False, + } + else: + kwargs = {} + previous_log_prob = previous_dist.log_prob(x, **kwargs) + current_log_prob = current_dist.log_prob(x, **kwargs) if is_tensor_collection(current_log_prob): - previous_log_prob = previous_log_prob.get( - self.tensor_keys.sample_log_prob - ) - current_log_prob = current_log_prob.get( - self.tensor_keys.sample_log_prob - ) - + previous_log_prob = _sum_td_features(previous_log_prob) + current_log_prob = _sum_td_features(current_log_prob) kl = (previous_log_prob - current_log_prob).mean(0) kl = kl.unsqueeze(-1) neg_loss = neg_loss - self.beta * kl diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 4dfed60e5a9..9c46fc98262 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -615,3 +615,8 @@ def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimize raise ValueError("Cannot group optimizers of different type.") params.extend(optimizer.param_groups) return cls(params) + + +def _sum_td_features(data: TensorDictBase) -> torch.Tensor: + # Sum all features and return a tensor + return data.sum(dim="feature", reduce=True) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index fadfe932c50..bbd6a23bfdd 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -15,11 +15,14 @@ import torch from tensordict import TensorDictBase from tensordict.nn import ( + CompositeDistribution, dispatch, + ProbabilisticTensorDictModule, set_skip_existing, TensorDictModule, TensorDictModuleBase, ) +from tensordict.nn.probabilistic import interaction_type from tensordict.utils import NestedKey from torch import Tensor @@ -74,14 +77,22 @@ def new_func(self, *args, **kwargs): def _call_actor_net( - actor_net: TensorDictModuleBase, + actor_net: ProbabilisticTensorDictModule, data: TensorDictBase, params: TensorDictBase, log_prob_key: NestedKey, ): - # TODO: extend to handle time dimension (and vmap?) - log_pi = actor_net(data.select(*actor_net.in_keys, strict=False)).get(log_prob_key) - return log_pi + dist = actor_net.get_dist(data.select(*actor_net.in_keys, strict=False)) + if isinstance(dist, CompositeDistribution): + kwargs = { + "aggregate_probabilities": True, + "inplace": False, + "include_sum": False, + } + else: + kwargs = {} + s = actor_net._dist_sample(dist, interaction_type=interaction_type()) + return dist.log_prob(s, **kwargs) class ValueEstimatorBase(TensorDictModuleBase): @@ -1771,7 +1782,8 @@ def forward( data=tensordict, params=None, log_prob_key=self.tensor_keys.sample_log_prob, - ).view_as(value) + ) + log_pi = log_pi.view_as(value) # Compute the V-Trace correction done = tensordict.get(("next", self.tensor_keys.done))