Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Allow for composite action distributions in PPO/A2C losses #2391

Merged
merged 19 commits into from
Sep 4, 2024
Next Next commit
account for composite distribution
  • Loading branch information
albertbou92 committed Aug 10, 2024
commit 627b673637bebe9d9fe9f30642979ba63bf93def
19 changes: 16 additions & 3 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@

import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.nn import (
CompositeDistribution,
dispatch,
ProbabilisticTensorDictSequential,
TensorDictModule,
)
from tensordict.utils import NestedKey
from torch import distributions as d

Expand Down Expand Up @@ -383,7 +388,11 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
entropy = dist.entropy()
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
entropy = -dist.log_prob(x).mean(0)
if isinstance(dist, CompositeDistribution):
log_prob = dist.log_prob(x).get(self.tensor_keys.sample_log_prob)
else:
log_prob = dist.log_prob(x)
entropy = -log_prob.mean(0)
vmoens marked this conversation as resolved.
Show resolved Hide resolved
return entropy.unsqueeze(-1)

def _log_probs(
Expand All @@ -402,7 +411,11 @@ def _log_probs(
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict_clone)
log_prob = dist.log_prob(action)
if isinstance(dist, CompositeDistribution):
tensordict = dist.log_prob(tensordict)
log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
else:
log_prob = dist.log_prob(action)
vmoens marked this conversation as resolved.
Show resolved Hide resolved
log_prob = log_prob.unsqueeze(-1)
return log_prob, dist

Expand Down
26 changes: 23 additions & 3 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams
from tensordict.nn import (
CompositeDistribution,
dispatch,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
Expand Down Expand Up @@ -449,7 +450,11 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
entropy = dist.entropy()
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
entropy = -dist.log_prob(x).mean(0)
if isinstance(dist, CompositeDistribution):
log_prob = dist.log_prob(x).get(self.tensor_keys.sample_log_prob)
else:
log_prob = dist.log_prob(x)
entropy = -log_prob.mean(0)
vmoens marked this conversation as resolved.
Show resolved Hide resolved
return entropy.unsqueeze(-1)

def _log_weight(
Expand All @@ -466,12 +471,17 @@ def _log_weight(
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict)
log_prob = dist.log_prob(action)

prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
if prev_log_prob.requires_grad:
raise RuntimeError("tensordict prev_log_prob requires grad.")

if isinstance(dist, CompositeDistribution):
vmoens marked this conversation as resolved.
Show resolved Hide resolved
tensordict = dist.log_prob(tensordict)
log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
else:
log_prob = dist.log_prob(action)

log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)

Expand Down Expand Up @@ -1107,7 +1117,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist)
except NotImplementedError:
x = previous_dist.sample((self.samples_mc_kl,))
kl = (previous_dist.log_prob(x) - current_dist.log_prob(x)).mean(0)
if isinstance(current_dist, CompositeDistribution):
vmoens marked this conversation as resolved.
Show resolved Hide resolved
previous_log_prob = previous_dist.log_prob(x).get(
self.tensor_keys.sample_log_prob
)
current_log_prob = current_dist.log_prob(x).get(
self.tensor_keys.sample_log_prob
)
else:
previous_log_prob = previous_dist.log_prob(x)
current_log_prob = current_dist.log_prob(x)
kl = (previous_log_prob - current_log_prob).mean(0)
kl = kl.unsqueeze(-1)
neg_loss = neg_loss - self.beta * kl
if kl.mean() > self.dtarg * 1.5:
Expand Down