Skip to content

Commit

Permalink
[BugFix] Dedicated tests for on policy losses reduction parameter (#1974
Browse files Browse the repository at this point in the history
)

Co-authored-by: vmoens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vmoens@meta.com>
  • Loading branch information
3 people authored Feb 27, 2024
1 parent db4ad23 commit 3d65083
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 41 deletions.
124 changes: 91 additions & 33 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6201,7 +6201,6 @@ def _create_seq_mock_data_ppo(
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", [True, False])
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_ppo(
self,
loss_class,
Expand All @@ -6210,7 +6209,6 @@ def test_ppo(
advantage,
td_est,
functional,
reduction,
):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)
Expand Down Expand Up @@ -6246,7 +6244,6 @@ def test_ppo(
value,
loss_critic_type="l2",
functional=functional,
reduction=reduction,
)
if advantage is not None:
advantage(td)
Expand All @@ -6259,15 +6256,6 @@ def test_ppo(
kl = loss.pop("kl")
assert (kl != 0).any()

if reduction == "none":

def func(x):
if x.dtype != torch.float:
return
return x.mean()

loss = loss.apply(func, batch_size=[])

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -6804,6 +6792,41 @@ def test_ppo_notensordict(
assert loss_obj == loss_val_td.get("loss_objective")
assert loss_crit == loss_val_td.get("loss_critic")

@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_ppo_reduction(self, reduction, loss_class):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_seq_mock_data_ppo(device=device)
actor = self._create_mock_actor(device=device)
value = self._create_mock_value(device=device)
advantage = GAE(
gamma=0.9,
lmbda=0.9,
value_network=value,
)
loss_fn = loss_class(
actor,
value,
loss_critic_type="l2",
reduction=reduction,
)
advantage(td)
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss_"):
assert loss[key].shape == td.shape
else:
for key in loss.keys():
if not key.startswith("loss_"):
continue
assert loss[key].shape == torch.Size([])


class TestA2C(LossModuleTestBase):
seed = 0
Expand Down Expand Up @@ -6969,8 +6992,7 @@ def _create_seq_mock_data_a2c(
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", (True, False))
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reduction):
def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_a2c(device=device)

Expand Down Expand Up @@ -7005,7 +7027,6 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reducti
value,
loss_critic_type="l2",
functional=functional,
reduction=reduction,
)

# Check error is raised when actions require grads
Expand All @@ -7023,14 +7044,7 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reducti
elif td_est is not None:
loss_fn.make_value_estimator(td_est)
loss = loss_fn(td)
if reduction == "none":

def func(x):
if x.dtype != torch.float:
return
return x.mean()

loss = loss.apply(func, batch_size=[])
loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -7413,6 +7427,40 @@ def test_a2c_notensordict(
assert loss_objective == loss_val_td["loss_objective"]
assert loss_critic == loss_val_td["loss_critic"]

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_a2c_reduction(self, reduction):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_seq_mock_data_a2c(device=device)
actor = self._create_mock_actor(device=device)
value = self._create_mock_value(device=device)
advantage = GAE(
gamma=0.9,
lmbda=0.9,
value_network=value,
)
loss_fn = A2CLoss(
actor,
value,
loss_critic_type="l2",
reduction=reduction,
)
advantage(td)
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss_"):
assert loss[key].shape == td.shape
else:
for key in loss.keys():
if not key.startswith("loss_"):
continue
assert loss[key].shape == torch.Size([])


class TestReinforce(LossModuleTestBase):
seed = 0
Expand Down Expand Up @@ -7659,26 +7707,16 @@ def _create_mock_common_layer_setup(
return actor, critic, common, td

@pytest.mark.parametrize("separate_losses", [False, True])
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_reinforce_tensordict_separate_losses(self, separate_losses, reduction):
def test_reinforce_tensordict_separate_losses(self, separate_losses):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
loss_fn = ReinforceLoss(
actor_network=actor,
critic_network=critic,
separate_losses=separate_losses,
reduction=reduction,
)

loss = loss_fn(td)
if reduction == "none":

def func(x):
if x.dtype != torch.float:
return
return x.mean()

loss = loss.apply(func, batch_size=[])

assert all(
(p.grad is None) or (p.grad == 0).all()
Expand Down Expand Up @@ -7807,6 +7845,26 @@ def test_reinforce_notensordict(
return
assert loss_actor == loss_val_td["loss_actor"]

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_reinforce_reduction(self, reduction):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
loss_fn = ReinforceLoss(
actor_network=actor,
critic_network=critic,
reduction=reduction,
)
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss_"):
assert loss[key].shape == td.shape
else:
for key in loss.keys():
if not key.startswith("loss_"):
continue
assert loss[key].shape == torch.Size([])


@pytest.mark.parametrize("device", get_default_devices())
class TestDreamer(LossModuleTestBase):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
loss_critic = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction)
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
if name.startswith("loss_")
else value,
batch_size=[],
Expand Down
13 changes: 7 additions & 6 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import annotations

import contextlib
import functools

import math
import warnings
Expand Down Expand Up @@ -560,10 +559,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.critic_coef:
loss_critic = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
td_out = td_out.apply(
functools.partial(_reduce, reduction=self.reduction), batch_size=[]
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
if name.startswith("loss_")
else value,
batch_size=[],
)

return td_out

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
Expand Down Expand Up @@ -807,7 +808,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:

td_out.set("ESS", _reduce(ess, self.reduction) / batch)
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction)
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
if name.startswith("loss_")
else value,
batch_size=[],
Expand Down Expand Up @@ -1070,7 +1071,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
loss_critic = self.loss_critic(tensordict_copy)
td_out.set("loss_critic", loss_critic)
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction)
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
if name.startswith("loss_")
else value,
batch_size=[],
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:

td_out.set("loss_value", self.loss_critic(tensordict))
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction)
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
if name.startswith("loss_")
else value,
batch_size=[],
Expand Down

0 comments on commit 3d65083

Please sign in to comment.