Skip to content

Commit

Permalink
[BugFix] Fix KLPENPPOLoss KL computation (#1922)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 17, 2024
1 parent 31bea14 commit e538fdc
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 10 deletions.
24 changes: 23 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -5824,6 +5824,10 @@ def _create_mock_data_ppo(
reward = torch.randn(batch, 1, device=device)
done = torch.zeros(batch, 1, dtype=torch.bool, device=device)
terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device)
loc_key = "loc"
scale_key = "scale"
loc = torch.randn(batch, 4, device=device)
scale = torch.rand(batch, 4, device=device)
td = TensorDict(
batch_size=(batch,),
source={
Expand All @@ -5836,6 +5840,8 @@ def _create_mock_data_ppo(
},
action_key: action,
sample_log_prob_key: torch.randn_like(action[..., 1]) / 10,
loc_key: loc,
scale_key: scale,
},
device=device,
)
Expand Down Expand Up @@ -5951,6 +5957,10 @@ def test_ppo(
loss_fn.make_value_estimator(td_est)

loss = loss_fn(td)
if isinstance(loss_fn, KLPENPPOLoss):
kl = loss.pop("kl")
assert (kl != 0).any()

if reduction == "none":

def func(x):
Expand Down Expand Up @@ -6457,20 +6467,32 @@ def test_ppo_notensordict(
f"next_{terminated_key}": td.get(("next", terminated_key)),
f"next_{observation_key}": td.get(("next", observation_key)),
}
if loss_class is KLPENPPOLoss:
kwargs.update({"loc": td.get("loc"), "scale": td.get("scale")})

td = TensorDict(kwargs, td.batch_size, names=["time"]).unflatten_keys("_")

# setting the seed for each loss so that drawing the random samples from
# value network leads to same numbers for both runs
torch.manual_seed(self.seed)
beta = getattr(loss, "beta", None)
if beta is not None:
beta = beta.clone()
loss_val = loss(**kwargs)
torch.manual_seed(self.seed)
if beta is not None:
loss.beta = beta.clone()
loss_val_td = loss(td)

for i, out_key in enumerate(loss.out_keys):
torch.testing.assert_close(loss_val_td.get(out_key), loss_val[i])
torch.testing.assert_close(
loss_val_td.get(out_key), loss_val[i], msg=out_key
)

# test select
torch.manual_seed(self.seed)
if beta is not None:
loss.beta = beta.clone()
loss.select_out_keys("loss_objective", "loss_critic")
if torch.__version__ >= "2.0.0":
loss_obj, loss_crit = loss(**kwargs)
Expand Down
58 changes: 49 additions & 9 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.nn import (
dispatch,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
TensorDictModule,
)
from tensordict.utils import NestedKey
from torch import distributions as d

Expand Down Expand Up @@ -962,6 +967,35 @@ def __init__(
self.decrement = decrement
self.samples_mc_kl = samples_mc_kl

def _set_in_keys(self):
keys = [
self.tensor_keys.action,
self.tensor_keys.sample_log_prob,
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
("next", self.tensor_keys.terminated),
*self.actor_network.in_keys,
*[("next", key) for key in self.actor_network.in_keys],
*self.critic_network.in_keys,
]
# Get the parameter keys from the actor dist
actor_dist_module = None
for module in self.actor_network.modules():
# Ideally we should combine them if there is more than one
if isinstance(module, ProbabilisticTensorDictModule):
if actor_dist_module is not None:
raise RuntimeError(
"Actors with one and only one distribution are currently supported "
f"in {type(self).__name__}. If you need to use more than one "
f"distribtuion over the action space please submit an issue "
f"on github."
)
actor_dist_module = module
if actor_dist_module is None:
raise RuntimeError("Could not find the probabilistic module in the actor.")
keys += list(actor_dist_module.in_keys)
self._in_keys = list(set(keys))

@property
def out_keys(self):
if self._out_keys is None:
Expand All @@ -979,27 +1013,33 @@ def out_keys(self, values):

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDict:
tensordict = tensordict.clone(False)
advantage = tensordict.get(self.tensor_keys.advantage, None)
tensordict_copy = tensordict.copy()
try:
previous_dist = self.actor_network.build_dist_from_params(tensordict)
except KeyError as err:
raise KeyError(
"The parameters of the distribution were not found. "
f"Make sure they are provided to {type(self).__name__}."
) from err
advantage = tensordict_copy.get(self.tensor_keys.advantage, None)
if advantage is None:
self.value_estimator(
tensordict,
tensordict_copy,
params=self._cached_critic_network_params_detached,
target_params=self.target_critic_network_params,
)
advantage = tensordict.get(self.tensor_keys.advantage)
advantage = tensordict_copy.get(self.tensor_keys.advantage)
if self.normalize_advantage and advantage.numel() > 1:
loc = advantage.mean()
scale = advantage.std().clamp_min(1e-6)
advantage = (advantage - loc) / scale
log_weight, dist = self._log_weight(tensordict)
log_weight, dist = self._log_weight(tensordict_copy)
neg_loss = log_weight.exp() * advantage

previous_dist = self.actor_network.build_dist_from_params(tensordict)
with self.actor_network_params.to_module(
self.actor_network
) if self.functional else contextlib.nullcontext():
current_dist = self.actor_network.get_dist(tensordict)
current_dist = self.actor_network.get_dist(tensordict_copy)
try:
kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist)
except NotImplementedError:
Expand All @@ -1024,7 +1064,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
td_out.set("entropy", entropy.detach()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
if self.critic_coef:
loss_critic = self.loss_critic(tensordict)
loss_critic = self.loss_critic(tensordict_copy)
td_out.set("loss_critic", loss_critic)
td_out = td_out.apply(
functools.partial(_reduce, reduction=self.reduction), batch_size=[]
Expand Down

0 comments on commit e538fdc

Please sign in to comment.