Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Better advantage API for higher order derivatives #744

Merged
merged 5 commits into from
Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Dec 14, 2022
commit 925b601d8cf2ade62900cf47a51a1391df7653e5
3 changes: 1 addition & 2 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,7 @@ def main(cfg: "DictConfig"): # noqa: F821
cfg.gamma,
cfg.lmbda,
value_network=critic_model,
average_rewards=True,
gradient_mode=False,
average_gae=True,
)
trainer.register_op(
"process_optim_batch",
Expand Down
107 changes: 89 additions & 18 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import argparse
import re
from copy import deepcopy

_has_functorch = True
Expand All @@ -21,7 +22,7 @@
import torch
from _utils_internal import dtype_fixture, get_available_devices # noqa
from mocking_classes import ContinuousActionConvMockEnv
from tensordict.nn import get_functional
from tensordict.nn import get_functional, TensorDictModule

# from torchrl.data.postprocs.utils import expand_as_right
from tensordict.tensordict import assert_allclose_td, TensorDict
Expand Down Expand Up @@ -1597,23 +1598,25 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage):
value = self._create_mock_value(device=device)
if advantage == "gae":
advantage = GAE(
gamma=0.9, lmbda=0.9, value_network=value, gradient_mode=gradient_mode
gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode
)
elif advantage == "td":
advantage = TDEstimate(
gamma=0.9, value_network=value, gradient_mode=gradient_mode
gamma=0.9, value_network=value, differentiable=gradient_mode
)
elif advantage == "td_lambda":
advantage = TDLambdaEstimate(
gamma=0.9, lmbda=0.9, value_network=value, gradient_mode=gradient_mode
gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode
)
else:
raise NotImplementedError

loss_fn = loss_class(
actor, value, advantage_module=advantage, gamma=0.9, loss_critic_type="l2"
)

loss_fn = loss_class(actor, value, gamma=0.9, loss_critic_type="l2")
with pytest.raises(
KeyError, match=re.escape('key "advantage" not found in TensorDict with')
):
_ = loss_fn(td)
advantage(td)
loss = loss_fn(td)
loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
Expand Down Expand Up @@ -1659,20 +1662,17 @@ def test_ppo_shared(self, loss_class, device, advantage):
gamma=0.9,
lmbda=0.9,
value_network=value,
gradient_mode=False,
)
elif advantage == "td":
advantage = TDEstimate(
gamma=0.9,
value_network=value,
gradient_mode=False,
)
elif advantage == "td_lambda":
advantage = TDLambdaEstimate(
gamma=0.9,
lmbda=0.9,
value_network=value,
gradient_mode=False,
)
else:
raise NotImplementedError
Expand All @@ -1681,9 +1681,13 @@ def test_ppo_shared(self, loss_class, device, advantage):
value,
gamma=0.9,
loss_critic_type="l2",
advantage_module=advantage,
)

with pytest.raises(
KeyError, match=re.escape('key "advantage" not found in TensorDict with')
):
_ = loss_fn(td)
advantage(td)
loss = loss_fn(td)
loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
Expand Down Expand Up @@ -1731,29 +1735,33 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage):
value = self._create_mock_value(device=device)
if advantage == "gae":
advantage = GAE(
gamma=0.9, lmbda=0.9, value_network=value, gradient_mode=gradient_mode
gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode
)
elif advantage == "td":
advantage = TDEstimate(
gamma=0.9, value_network=value, gradient_mode=gradient_mode
gamma=0.9, value_network=value, differentiable=gradient_mode
)
elif advantage == "td_lambda":
advantage = TDLambdaEstimate(
gamma=0.9, lmbda=0.9, value_network=value, gradient_mode=gradient_mode
gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode
)
else:
raise NotImplementedError

loss_fn = loss_class(
actor, value, advantage_module=advantage, gamma=0.9, loss_critic_type="l2"
)
loss_fn = loss_class(actor, value, gamma=0.9, loss_critic_type="l2")

floss_fn, params, buffers = make_functional_with_buffers(loss_fn)
# fill params with zero
for p in params:
p.data.zero_()
# assert len(list(floss_fn.parameters())) == 0
with pytest.raises(
KeyError, match=re.escape('key "advantage" not found in TensorDict with')
):
_ = floss_fn(params, buffers, td)
advantage(td)
loss = floss_fn(params, buffers, td)

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -2948,6 +2956,69 @@ def __init__(self, actor_network, qvalue_network):
break


class TestAdv:
@pytest.mark.parametrize(
"adv,kwargs",
[[GAE, {"lmbda": 0.95}], [TDEstimate, {}], [TDLambdaEstimate, {"lmbda": 0.95}]],
)
def test_diff_reward(
self,
adv,
kwargs,
):
value_net = TensorDictModule(
nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
)
module = adv(
gamma=0.98,
value_network=value_net,
differentiable=True,
**kwargs,
)
td = TensorDict(
{
"obs": torch.randn(1, 10, 3),
"reward": torch.randn(1, 10, 1, requires_grad=True),
"done": torch.zeros(1, 10, 1, dtype=torch.bool),
"next": {"obs": torch.randn(1, 10, 3)},
},
[1, 10],
)
td = module(td.clone(False))
# check that the advantage can't backprop to the value params
td["advantage"].sum().backward()
for p in value_net.parameters():
assert p.grad is None or (p.grad == 0).all()
# check that rewards have a grad
assert td["reward"].grad.norm() > 0

@pytest.mark.parametrize(
"adv,kwargs",
[[GAE, {"lmbda": 0.95}], [TDEstimate, {}], [TDLambdaEstimate, {"lmbda": 0.95}]],
)
def test_non_differentiable(self, adv, kwargs):
value_net = TensorDictModule(
nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
)
module = adv(
gamma=0.98,
value_network=value_net,
differentiable=False,
**kwargs,
)
td = TensorDict(
{
"obs": torch.randn(1, 10, 3),
"reward": torch.randn(1, 10, 1, requires_grad=True),
"done": torch.zeros(1, 10, 1, dtype=torch.bool),
"next": {"obs": torch.randn(1, 10, 3)},
},
[1, 10],
)
td = module(td.clone(False))
assert td["advantage"].is_leaf


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
48 changes: 17 additions & 31 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

import math
from typing import Callable, Optional, Tuple
from typing import Tuple

import torch
from tensordict.tensordict import TensorDict, TensorDictBase
Expand Down Expand Up @@ -55,14 +55,13 @@ def __init__(
actor: SafeProbabilisticSequential,
critic: SafeModule,
advantage_key: str = "advantage",
advantage_diff_key: str = "value_error",
value_target_key: str = "value_target",
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coef: float = 0.01,
critic_coef: float = 1.0,
gamma: float = 0.99,
loss_critic_type: str = "smooth_l1",
advantage_module: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
):
super().__init__()
self.convert_to_functional(
Expand All @@ -72,7 +71,7 @@ def __init__(
# params of critic must be refs to actor if they're shared
self.convert_to_functional(critic, "critic", compare_against=self.actor_params)
self.advantage_key = advantage_key
self.advantage_diff_key = advantage_diff_key
self.value_target_key = value_target_key
self.samples_mc_entropy = samples_mc_entropy
self.entropy_bonus = entropy_bonus and entropy_coef
self.register_buffer(
Expand All @@ -83,9 +82,6 @@ def __init__(
)
self.register_buffer("gamma", torch.tensor(gamma, device=self.device))
self.loss_critic_type = loss_critic_type
self.advantage_module = advantage_module
if self.advantage_module is not None:
self.advantage_module = advantage_module.to(self.device)

def reset(self) -> None:
pass
Expand Down Expand Up @@ -119,35 +115,29 @@ def _log_weight(
return log_weight, dist

def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
if self.advantage_diff_key in tensordict.keys():
advantage_diff = tensordict.get(self.advantage_diff_key)
if not advantage_diff.requires_grad:
raise RuntimeError(
"value_target retrieved from tensordict does not requires grad."
)
loss_value = distance_loss(
advantage_diff,
torch.zeros_like(advantage_diff),
loss_function=self.loss_critic_type,
)
else:
advantage = tensordict.get(self.advantage_key)
try:
target_return = tensordict.get(self.value_target_key)
tensordict_select = tensordict.select(*self.critic.in_keys)
value = self.critic(
state_value = self.critic(
tensordict_select,
params=self.critic_params,
).get("state_value")
value_target = advantage + value.detach()
loss_value = distance_loss(
value, value_target, loss_function=self.loss_critic_type
target_return,
state_value,
loss_function=self.loss_critic_type,
)
except KeyError:
raise KeyError(
f"the key {self.value_target_key} was not found in the input tensordict. "
f"Make sure you provided the right key and the value_target (i.e. the target "
f"return) has been retrieved accordingly. Advantage classes such as GAE, "
f"TDLambdaEstimate and TDEstimate all return a 'value_target' entry that "
f"can be used for the value loss."
)
return self.critic_coef * loss_value

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.advantage_module is not None:
tensordict = self.advantage_module(
tensordict,
)
tensordict = tensordict.clone()
advantage = tensordict.get(self.advantage_key)
log_weight, dist = self._log_weight(tensordict)
Expand Down Expand Up @@ -226,8 +216,6 @@ def _clip_bounds(self):
)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.advantage_module is not None:
tensordict = self.advantage_module(tensordict)
tensordict = tensordict.clone()
advantage = tensordict.get(self.advantage_key)
log_weight, dist = self._log_weight(tensordict)
Expand Down Expand Up @@ -349,8 +337,6 @@ def __init__(
self.samples_mc_kl = samples_mc_kl

def forward(self, tensordict: TensorDictBase) -> TensorDict:
if self.advantage_module is not None:
tensordict = self.advantage_module(tensordict)
tensordict = tensordict.clone()
advantage = tensordict.get(self.advantage_key)
log_weight, dist = self._log_weight(tensordict)
Expand Down
Loading