Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Allow for composite action distributions in PPO/A2C losses #2391

Merged
merged 19 commits into from
Sep 4, 2024
Prev Previous commit
Next Next commit
a2c tests
  • Loading branch information
albertbou92 committed Aug 14, 2024
commit 8bc8e03c911aa6f9ec6a5c9f89fa0e3307c1bf4f
20 changes: 7 additions & 13 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,7 @@

import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams
from tensordict.nn import (
CompositeDistribution,
dispatch,
ProbabilisticTensorDictSequential,
TensorDictModule,
)
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.utils import NestedKey
from torch import distributions as d

Expand Down Expand Up @@ -388,10 +383,9 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
entropy = dist.entropy()
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
if isinstance(dist, CompositeDistribution):
log_prob = dist.log_prob(x).get(self.tensor_keys.sample_log_prob)
else:
log_prob = dist.log_prob(x)
log_prob = dist.log_prob(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, was this a bug or did we sum the log-probs automatically?

Copy link
Contributor Author

@albertbou92 albertbou92 Aug 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is simply because the log_prob() method for a composite dist will return a TD instead of a Tensor, so we compute the entropy in 2 steps.

This is the old version:

    def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
        try:
            entropy = dist.entropy()
        except NotImplementedError:
            x = dist.rsample((self.samples_mc_entropy,))
            entropy = -dist.log_prob(x).mean(0)
        return entropy.unsqueeze(-1)

This is the new version. It simply retrieves the log tensor before computing the entropy.

    def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
        try:
            entropy = dist.entropy()
        except NotImplementedError:
            x = dist.rsample((self.samples_mc_entropy,))
            log_prob = dist.log_prob(x)
            if is_tensor_collection(log_prob):
                log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
            entropy = -log_prob.mean(0)
        return entropy.unsqueeze(-1)

if isinstance(x, log_prob):
log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
entropy = -log_prob.mean(0)
return entropy.unsqueeze(-1)

Expand All @@ -411,11 +405,11 @@ def _log_probs(
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict_clone)
if isinstance(dist, CompositeDistribution):
if isinstance(action, torch.Tensor):
log_prob = dist.log_prob(action)
else:
tensordict = dist.log_prob(tensordict)
log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
else:
log_prob = dist.log_prob(action)
log_prob = log_prob.unsqueeze(-1)
return log_prob, dist

Expand Down
35 changes: 13 additions & 22 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams
from tensordict.nn import (
CompositeDistribution,
dispatch,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
Expand Down Expand Up @@ -450,10 +449,9 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
entropy = dist.entropy()
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
if isinstance(dist, CompositeDistribution):
log_prob = dist.log_prob(x).get(self.tensor_keys.sample_log_prob)
else:
log_prob = dist.log_prob(x)
log_prob = dist.log_prob(x)
if isinstance(log_prob, TensorDict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
entropy = -log_prob.mean(0)
return entropy.unsqueeze(-1)

Expand All @@ -471,24 +469,17 @@ def _log_weight(
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict)
# dist = TransformedDistribution(dist, ExpTransform())

prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
if prev_log_prob.requires_grad:
raise RuntimeError("tensordict prev_log_prob requires grad.")

if isinstance(dist, CompositeDistribution):
if (
tensordict.get(self.tensor_keys.action).batch_size
!= tensordict.batch_size
):
# This condition can be True in notensordict usage
tensordict.get(
self.tensor_keys.action
).batch_size = tensordict.batch_size
if isinstance(action, torch.Tensor):
log_prob = dist.log_prob(action)
else:
tensordict = dist.log_prob(tensordict)
log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
else:
log_prob = dist.log_prob(action)

log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
Expand Down Expand Up @@ -1125,16 +1116,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist)
except NotImplementedError:
x = previous_dist.sample((self.samples_mc_kl,))
if isinstance(current_dist, CompositeDistribution):
previous_log_prob = previous_dist.log_prob(x).get(
previous_log_prob = previous_dist.log_prob(x)
current_log_prob = current_dist.log_prob(x)
if isinstance(x, TensorDict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

previous_log_prob = previous_log_prob.get(
self.tensor_keys.sample_log_prob
)
current_log_prob = current_dist.log_prob(x).get(
current_log_prob = current_log_prob.get(
self.tensor_keys.sample_log_prob
)
else:
previous_log_prob = previous_dist.log_prob(x)
current_log_prob = current_dist.log_prob(x)

kl = (previous_log_prob - current_log_prob).mean(0)
kl = kl.unsqueeze(-1)
neg_loss = neg_loss - self.beta * kl
Expand Down