Skip to content

Commit

Permalink
[Feature] Add reduction parameter to Off-Policy losses. (pytorch#1956)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 authored Feb 26, 2024
1 parent 1eb6305 commit 6274b27
Show file tree
Hide file tree
Showing 7 changed files with 379 additions and 66 deletions.
231 changes: 229 additions & 2 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,10 @@ def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est):
action_spec_type=action_spec_type, device=device
)
loss_fn = DQNLoss(
actor, loss_function="l2", delay_value=delay_value, double_dqn=double_dqn
actor,
loss_function="l2",
delay_value=delay_value,
double_dqn=double_dqn,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -699,7 +702,11 @@ def test_distributional_dqn(
td = self._create_mock_data_dqn(
action_spec_type=action_spec_type, atoms=atoms
).to(device)
loss_fn = DistributionalDQNLoss(actor, gamma=gamma, delay_value=delay_value)
loss_fn = DistributionalDQNLoss(
actor,
gamma=gamma,
delay_value=delay_value,
)

if td_est not in (None, ValueEstimators.TD0):
with pytest.raises(NotImplementedError):
Expand All @@ -717,6 +724,7 @@ def test_distributional_dqn(
else contextlib.nullcontext()
):
loss = loss_fn(td)

assert loss_fn.tensor_keys.priority in td.keys()

sum([item for _, item in loss.items()]).backward()
Expand Down Expand Up @@ -843,6 +851,58 @@ def test_distributional_dqn_tensordict_run(self, action_spec_type, td_est):
_ = loss_fn(td)
assert loss_fn.tensor_keys.priority in td.keys()

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_dqn_reduction(self, reduction):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
actor = self._create_mock_actor(action_spec_type="categorical", device=device)
td = self._create_mock_data_dqn(action_spec_type="categorical", device=device)
loss_fn = DQNLoss(
actor,
loss_function="l2",
delay_value=False,
reduction=reduction,
)
loss_fn.make_value_estimator()
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss"):
assert loss[key].shape == td.shape
else:
for key in loss.keys():
assert loss[key].shape == torch.Size([])

@pytest.mark.parametrize("atoms", range(4, 10))
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_distributional_dqn_reduction(self, reduction, atoms):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
actor = self._create_mock_distributional_actor(
action_spec_type="categorical", atoms=atoms
).to(device)
td = self._create_mock_data_dqn(action_spec_type="categorical", device=device)
loss_fn = DistributionalDQNLoss(
actor, gamma=0.9, delay_value=False, reduction=reduction
)
loss_fn.make_value_estimator()
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss"):
assert loss[key].shape == td.shape
else:
for key in loss.keys():
assert loss[key].shape == torch.Size([])


class TestQMixer(LossModuleTestBase):
seed = 0
Expand Down Expand Up @@ -1884,6 +1944,35 @@ def test_ddpg_notensordict(self):
assert loss_actor == loss_val_td["loss_actor"]
assert (target_value == loss_val_td["target_value"]).all()

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_ddpg_reduction(self, reduction):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
actor = self._create_mock_actor(device=device)
value = self._create_mock_value(device=device)
td = self._create_mock_data_ddpg(device=device)
loss_fn = DDPGLoss(
actor,
value,
loss_function="l2",
delay_actor=False,
delay_value=False,
reduction=reduction,
)
loss_fn.make_value_estimator()
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss"):
assert loss[key].shape == td.shape
else:
for key in loss.keys():
assert loss[key].shape == torch.Size([])


@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
Expand Down Expand Up @@ -2553,6 +2642,39 @@ def test_td3_notensordict(
assert loss_actor == loss_val_td["loss_actor"]
assert loss_qvalue == loss_val_td["loss_qvalue"]

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_td3_reduction(self, reduction):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
actor = self._create_mock_actor(device=device)
value = self._create_mock_value(device=device)
td = self._create_mock_data_td3(device=device)
action_spec = actor.spec
bounds = None
loss_fn = TD3Loss(
actor,
value,
action_spec=action_spec,
bounds=bounds,
loss_function="l2",
delay_qvalue=False,
delay_actor=False,
reduction=reduction,
)
loss_fn.make_value_estimator()
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss"):
assert loss[key].shape == td.shape
else:
for key in loss.keys():
assert loss[key].shape == torch.Size([])


@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
Expand Down Expand Up @@ -2820,6 +2942,7 @@ def test_sac(
UserWarning, match="No target network updater"
):
loss = loss_fn(td)

assert loss_fn.tensor_keys.priority in td.keys()

# check that losses are independent
Expand Down Expand Up @@ -3420,6 +3543,41 @@ def test_state_dict(self, version):
)
loss.load_state_dict(state)

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_sac_reduction(self, reduction, version):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_mock_data_sac(device=device)
actor = self._create_mock_actor(device=device)
qvalue = self._create_mock_qvalue(device=device)
if version == 1:
value = self._create_mock_value(device=device)
else:
value = None
loss_fn = SACLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
loss_function="l2",
delay_qvalue=False,
delay_actor=False,
delay_value=False,
reduction=reduction,
)
loss_fn.make_value_estimator()
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss"):
assert loss[key].shape == td.shape
else:
for key in loss.keys():
assert loss[key].shape == torch.Size([])


@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
Expand Down Expand Up @@ -3609,6 +3767,7 @@ def test_discrete_sac(
UserWarning, match="No target network updater"
):
loss = loss_fn(td)

assert loss_fn.tensor_keys.priority in td.keys()

# check that losses are independent
Expand Down Expand Up @@ -3968,6 +4127,36 @@ def test_discrete_sac_notensordict(
assert loss_actor == loss_val_td["loss_actor"]
assert loss_alpha == loss_val_td["loss_alpha"]

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_discrete_sac_reduction(self, reduction):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_mock_data_sac(device=device)
actor = self._create_mock_actor(device=device)
qvalue = self._create_mock_qvalue(device=device)
loss_fn = DiscreteSACLoss(
actor_network=actor,
qvalue_network=qvalue,
num_actions=actor.spec["action"].space.n,
loss_function="l2",
action_space="one-hot",
delay_qvalue=False,
reduction=reduction,
)
loss_fn.make_value_estimator()
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss"):
assert loss[key].shape == td.shape
else:
for key in loss.keys():
assert loss[key].shape == torch.Size([])


@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
Expand Down Expand Up @@ -4874,6 +5063,44 @@ def test_redq_notensordict(
assert loss_actor == loss_val_td["loss_actor"]
assert loss_alpha == loss_val_td["loss_alpha"]

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
@pytest.mark.parametrize("deprecated_loss", [True, False])
def test_redq_reduction(self, reduction, deprecated_loss):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_mock_data_redq(device=device)
actor = self._create_mock_actor(device=device)
qvalue = self._create_mock_qvalue(device=device)
if deprecated_loss:
loss_fn = REDQLoss_deprecated(
actor_network=actor,
qvalue_network=qvalue,
loss_function="l2",
delay_qvalue=False,
reduction=reduction,
)
else:
loss_fn = REDQLoss(
actor_network=actor,
qvalue_network=qvalue,
loss_function="l2",
delay_qvalue=False,
reduction=reduction,
)
loss_fn.make_value_estimator()
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss"):
assert loss[key].shape[-1] == td.shape[0]
else:
for key in loss.keys():
assert loss[key].shape == torch.Size([])


class TestCQL(LossModuleTestBase):
seed = 0
Expand Down
Loading

0 comments on commit 6274b27

Please sign in to comment.