Skip to content

Commit

Permalink
[Feature] Offline objectives reduction parameter (#1984)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 authored Mar 1, 2024
1 parent 3a41f40 commit 6abc9bf
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 48 deletions.
170 changes: 170 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -5619,6 +5619,39 @@ def test_cql_batcher(
(p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())
)

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_cql_reduction(self, reduction):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_mock_data_cql(device=device)

actor = self._create_mock_actor(device=device)
qvalue = self._create_mock_qvalue(device=device)

loss_fn = CQLLoss(
actor_network=actor,
qvalue_network=qvalue,
loss_function="l2",
delay_actor=False,
delay_qvalue=False,
reduction=reduction,
)
loss_fn.make_value_estimator()
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss"):
assert loss[key].shape == td.shape
else:
for key in loss.keys():
if not key.startswith("loss"):
continue
assert loss[key].shape == torch.Size([])


class TestDiscreteCQL(LossModuleTestBase):
seed = 0
Expand Down Expand Up @@ -5991,6 +6024,31 @@ def test_dcql_notensordict(
torch.testing.assert_close(loss_val_td.get(loss.out_keys[0]), loss_val[0])
torch.testing.assert_close(loss_val_td.get(loss.out_keys[1]), loss_val[1])

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_dcql_reduction(self, reduction):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
actor = self._create_mock_actor(action_spec_type="one_hot", device=device)
td = self._create_mock_data_dcql(action_spec_type="one_hot", device=device)
loss_fn = DiscreteCQLLoss(
actor, loss_function="l2", delay_value=False, reduction=reduction
)
loss_fn.make_value_estimator()
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss"):
assert loss[key].shape == td.shape
else:
for key in loss.keys():
if not key.startswith("loss"):
continue
assert loss[key].shape == torch.Size([])


class TestPPO(LossModuleTestBase):
seed = 0
Expand Down Expand Up @@ -8547,6 +8605,28 @@ def test_onlinedt_notensordict(self, device):
return
assert loss_entropy == loss_val_td["loss_entropy"]

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_onlinedt_reduction(self, reduction):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_mock_data_odt(device=device)
actor = self._create_mock_actor(device=device)
loss_fn = OnlineDTLoss(actor, reduction=reduction)
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss"):
assert loss[key].shape == td.shape
else:
for key in loss.keys():
if not key.startswith("loss"):
continue
assert loss[key].shape == torch.Size([])


class TestDT(LossModuleTestBase):
seed = 0
Expand Down Expand Up @@ -8713,6 +8793,23 @@ def test_seq_dt(self, device):
for name, p in named_parameters:
assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient"

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_dt_reduction(self, reduction):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_mock_data_dt(device=device)
actor = self._create_mock_actor(device=device)
loss_fn = DTLoss(actor, reduction=reduction)
loss = loss_fn(td)
if reduction == "none":
assert loss["loss"].shape == td["action"].shape
else:
assert loss["loss"].shape == torch.Size([])


@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
Expand Down Expand Up @@ -9487,6 +9584,42 @@ def test_iql_notensordict(
assert loss_actor == loss_val_td["loss_actor"]
assert loss_value == loss_val_td["loss_value"]

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_iql_reduction(self, reduction):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_mock_data_iql(device=device)

actor = self._create_mock_actor(device=device)
qvalue = self._create_mock_qvalue(device=device)
value = self._create_mock_value(device=device)

loss_fn = IQLLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
loss_function="l2",
reduction=reduction,
)
loss_fn.make_value_estimator()
with _check_td_steady(td), pytest.warns(
UserWarning, match="No target network updater"
):
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss"):
assert loss[key].shape == td.shape
else:
for key in loss.keys():
if not key.startswith("loss"):
continue
assert loss[key].shape == torch.Size([])


@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
Expand Down Expand Up @@ -10269,6 +10402,43 @@ def test_discrete_iql_notensordict(
assert loss_actor == loss_val_td["loss_actor"]
assert loss_value == loss_val_td["loss_value"]

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_discrete_iql_reduction(self, reduction):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_mock_data_discrete_iql(device=device)

actor = self._create_mock_actor(device=device)
qvalue = self._create_mock_qvalue(device=device)
value = self._create_mock_value(device=device)

loss_fn = DiscreteIQLLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
loss_function="l2",
action_space="one-hot",
reduction=reduction,
)
loss_fn.make_value_estimator()
with _check_td_steady(td), pytest.warns(
UserWarning, match="No target network updater"
):
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss"):
assert loss[key].shape == td.shape
else:
for key in loss.keys():
if not key.startswith("loss"):
continue
assert loss[key].shape == torch.Size([])


@pytest.mark.parametrize("create_target_params", [True, False])
@pytest.mark.parametrize(
Expand Down
54 changes: 39 additions & 15 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_ERROR,
_reduce,
_vmap_func,
default_value_kwargs,
distance_loss,
Expand Down Expand Up @@ -82,6 +83,10 @@ class CQLLoss(LossModule):
with_lagrange (bool, optional): Whether to use the Lagrange multiplier.
Default is ``False``.
lagrange_thresh (float, optional): Lagrange threshold. Default is 0.0.
reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
Examples:
>>> import torch
Expand Down Expand Up @@ -271,8 +276,11 @@ def __init__(
num_random: int = 10,
with_lagrange: bool = False,
lagrange_thresh: float = 0.0,
reduction: str = None,
) -> None:
self._out_keys = None
if reduction is None:
reduction = "mean"
super().__init__()

# Actor
Expand Down Expand Up @@ -356,6 +364,7 @@ def __init__(
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self.reduction = reduction

@property
def target_entropy(self):
Expand Down Expand Up @@ -514,11 +523,11 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if shape:
tensordict.update(tensordict_reshape.view(shape))
out = {
"loss_actor": loss_actor.mean(),
"loss_actor_bc": loss_actor_bc.mean(),
"loss_qvalue": q_loss.mean(),
"loss_cql": cql_loss.mean(),
"loss_alpha": loss_alpha.mean(),
"loss_actor": loss_actor,
"loss_actor_bc": loss_actor_bc,
"loss_qvalue": q_loss,
"loss_cql": cql_loss,
"loss_alpha": loss_alpha,
"alpha": self._alpha,
"entropy": -td_device.get(self.tensor_keys.log_prob).mean().detach(),
}
Expand All @@ -543,6 +552,7 @@ def actor_bc_loss(self, tensordict: TensorDictBase) -> Tensor:
bc_log_prob = dist.log_prob(tensordict.get(self.tensor_keys.action))

bc_actor_loss = self._alpha * log_prob - bc_log_prob
bc_actor_loss = _reduce(bc_actor_loss, reduction=self.reduction)
metadata = {"bc_log_prob": bc_log_prob.mean().detach()}
return bc_actor_loss, metadata

Expand Down Expand Up @@ -574,6 +584,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor:
# write log_prob in tensordict for alpha loss
tensordict.set(self.tensor_keys.log_prob, log_prob.detach())
actor_loss = self._alpha * log_prob - min_q_logprob
actor_loss = _reduce(actor_loss, reduction=self.reduction)

return actor_loss, {}

Expand Down Expand Up @@ -683,11 +694,11 @@ def q_loss(self, tensordict: TensorDictBase) -> Tensor:
q_pred,
target_value.expand_as(q_pred),
loss_function=self.loss_function,
)
).sum(0)
loss_qval = _reduce(loss_qval, reduction=self.reduction)
td_error = (q_pred - target_value).pow(2)
metadata = {"td_error": td_error.detach()}

return loss_qval.sum(0).mean(), metadata
return loss_qval, metadata

def cql_loss(self, tensordict: TensorDictBase) -> Tensor:
pred_q1 = tensordict.get(self.tensor_keys.pred_q1)
Expand Down Expand Up @@ -826,7 +837,10 @@ def filter_and_repeat(name, x):
tensordict.set(self.tensor_keys.cql_q1_loss, cql_q1_loss)
tensordict.set(self.tensor_keys.cql_q2_loss, cql_q2_loss)

return (cql_q1_loss + cql_q2_loss).mean(), {}
cql_q_loss = (cql_q1_loss + cql_q2_loss).mean(-1)
cql_q_loss = _reduce(cql_q_loss, reduction=self.reduction)

return cql_q_loss, {}

def alpha_prime_loss(self, tensordict: TensorDictBase) -> Tensor:
cql_q1_loss = tensordict.get(self.tensor_keys.cql_q1_loss)
Expand All @@ -848,6 +862,7 @@ def alpha_prime_loss(self, tensordict: TensorDictBase) -> Tensor:
min_qf2_loss = alpha_prime * (cql_q2_loss.mean() - self.target_action_gap)

alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5
alpha_prime_loss = _reduce(alpha_prime_loss, reduction=self.reduction)
return alpha_prime_loss, {}

def alpha_loss(self, tensordict: TensorDictBase) -> Tensor:
Expand All @@ -858,6 +873,7 @@ def alpha_loss(self, tensordict: TensorDictBase) -> Tensor:
else:
# placeholder
alpha_loss = torch.zeros_like(log_pi)
alpha_loss = _reduce(alpha_loss, reduction=self.reduction)
return alpha_loss, {}

@property
Expand Down Expand Up @@ -886,7 +902,10 @@ class DiscreteCQLLoss(LossModule):
gamma (float, optional): Discount factor. Default is ``None``.
action_space: The action space of the environment. If None, it is inferred from the value network.
Defaults to None.
reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
Examples:
>>> from torchrl.modules import MLP, QValueActor
Expand Down Expand Up @@ -1007,9 +1026,12 @@ def __init__(
delay_value: bool = True,
gamma: float = None,
action_space=None,
reduction: str = None,
) -> None:
super().__init__()
self._in_keys = None
if reduction is None:
reduction = "mean"
super().__init__()
self.delay_value = delay_value
value_network = ensure_tensordict_compatible(
module=value_network,
Expand Down Expand Up @@ -1044,6 +1066,7 @@ def __init__(
)
action_space = "one-hot"
self.action_space = _find_action_space(action_space)
self.reduction = reduction

if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
Expand Down Expand Up @@ -1165,9 +1188,8 @@ def value_loss(
pred_val,
inplace=True,
)
loss = (
0.5 * distance_loss(pred_val_index, target_value, self.loss_function).mean()
)
loss = 0.5 * distance_loss(pred_val_index, target_value, self.loss_function)
loss = _reduce(loss, reduction=self.reduction)

metadata = {
"td_error": td_error.mean(0).detach(),
Expand Down Expand Up @@ -1230,4 +1252,6 @@ def cql_loss(self, tensordict):
else:
q_a = (qvalues * current_action).sum(dim=-1, keepdim=True)

return (logsumexp - q_a).mean(), {}
loss_cql = (logsumexp - q_a).squeeze(-1)
loss_cql = _reduce(loss_cql, reduction=self.reduction)
return loss_cql, {}
Loading

0 comments on commit 6abc9bf

Please sign in to comment.