Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] the usage of tensordict keys in loss modules #1175

Merged
merged 43 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
83dc591
[Refactor] the usage of tensordict keys in loss modules
Blonck May 22, 2023
09ced18
Add more loss modules
Blonck May 22, 2023
bc04cae
Add more loss modules
Blonck May 23, 2023
75c8ea1
Refactor remaining loss modules
Blonck May 23, 2023
5a74a16
Remove unnecessary tests
Blonck May 23, 2023
32725b4
tensordict_keys dict is not longer overwritten from child classes
Blonck May 23, 2023
ab94848
Merge branch 'main' into refactor_loss_keys
Blonck May 23, 2023
802fe48
Harmonize key name for "state_value"
Blonck May 23, 2023
c6186fc
Polish refactoring
Blonck May 23, 2023
b694e8c
Merge branch 'main' into refactor_loss_keys
Blonck May 23, 2023
9150b74
Apply suggestions from code review
Blonck May 23, 2023
bcd8a28
Use abstract staticmethod to provide default values
Blonck May 23, 2023
6f10920
Merge branch 'main' into refactor_loss_keys
Blonck May 23, 2023
67941df
Merge branch 'main' and rename tensordict_keys to loss_keys
Blonck May 24, 2023
7f3e129
Use simple set_keys on all loss modules
Blonck May 24, 2023
427c1e8
Implement tensor_keys via _AcceptedKeys dataclass
Blonck May 24, 2023
66fb949
Extended _AcceptedKeys to all loss modules
Blonck May 25, 2023
526ab36
Refactor unit test for tensordict keys
Blonck May 25, 2023
08e20da
Merge branch 'main' into refactor_loss_key_advanced
Blonck May 25, 2023
0d476ca
WIP
Blonck May 25, 2023
9bb616a
Fix .in_keys of ValueEstimatorBase
Blonck May 25, 2023
5d00ca0
Move tensordict key logig to base class
Blonck May 25, 2023
4db47e5
Fix make_value_estimator of a2c.py
Blonck May 25, 2023
6b422f9
Remvove '_key' from keynames in ppo.py + polish
Blonck May 26, 2023
317755d
Remvove '_key' from keynames in ddpg.py + polish
Blonck May 26, 2023
fe9fba0
Fix documentation in advantages.py
Blonck May 26, 2023
34091e0
Remvove '_key' from keynames in dqn.py + polish
Blonck May 26, 2023
4baa5dc
Remvove '_key' from keynames in dreamer.py + polish
Blonck May 26, 2023
4595546
Remvove '_key' from keynames in iql.py and redq.py + polish
Blonck May 26, 2023
8ae6ad9
Remove tensor_keys from advantage ctor
Blonck May 26, 2023
a15e220
Add documentation to a2c.py
Blonck May 26, 2023
f1187f3
Change documentation of loss modules
Blonck May 26, 2023
3e09c58
Add unit test for advantages tensordict keys
Blonck May 26, 2023
e52a3f2
Merge branch 'main' into refactor_loss_key_advanced
Blonck May 26, 2023
2dc81c9
Improve wording of docstrings
Blonck May 26, 2023
655c28d
Apply suggestions from code review
Blonck May 28, 2023
226d4d3
Merge branch 'pytorch:main' into refactor_loss_keys
Blonck May 28, 2023
75d33c6
Apply code review changes
Blonck May 28, 2023
4320db6
Merge branch 'main' into refactor_loss_keys_github
Blonck May 30, 2023
cf4cd09
Change line breaking in docstrings for _AcceptedKeys
Blonck May 30, 2023
81c0413
LossModule is not longer an abstract base class.
Blonck May 31, 2023
6e753a4
Merge branch 'main' into refactor_loss_keys_github
Blonck May 31, 2023
cc784a1
Merge branch 'main' into refactor_loss_keys
vmoens May 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Remvove '_key' from keynames in ddpg.py + polish
  • Loading branch information
Blonck committed May 26, 2023
commit 317755d4321090aa358a2fc4d69bf5784e872496
126 changes: 82 additions & 44 deletions test/test_cost.py
Blonck marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,14 @@ def get_devices():

class LossModuleTestBase:
def tensordict_keys_test(
self, loss_fn, default_keys, loss_advantage_key_mapping=None
self, loss_fn, default_keys, td_est=None, loss_advantage_key_mapping=None
):
self.tensordict_keys_unknown_key_test(loss_fn)
self.tensordict_keys_default_values_test(loss_fn, default_keys)
self.tensordict_set_keys_test(loss_fn, default_keys)
if loss_advantage_key_mapping is not None:
self.set_advantage_keys_through_loss_test(
loss_fn, loss_advantage_key_mapping
loss_fn, td_est, loss_advantage_key_mapping
)

def tensordict_keys_unknown_key_test(self, loss_fn):
Expand Down Expand Up @@ -179,27 +179,19 @@ def tensordict_set_keys_test(self, loss_fn, default_keys):
for key, _ in default_keys.items():
assert getattr(test_fn.tensor_keys, key) == new_key

def set_advantage_keys_through_loss_test(self, loss_fn, loss_advantage_key_mapping):
def set_advantage_keys_through_loss_test(
self, loss_fn, td_est, loss_advantage_key_mapping
):
key_mapping = loss_advantage_key_mapping
test_fn = deepcopy(loss_fn)
test_fn.make_value_estimator(td_est)

key_mapping = loss_advantage_key_mapping
for loss_key, advantage_key in key_mapping.items():
test_fn.set_keys(**{loss_key: "test1"})
assert (
getattr(test_fn.value_estimator.tensor_keys, advantage_key) == "test1"
)

# TODO test
# loss = Loss(...)
# loss.set_keys()
# loss.make_value_estimator()
#
# vs
#
# loss = Loss(...)
# loss.make_value_estimator()
# loss.set_keys()


class TestDQN(LossModuleTestBase):
seed = 0
Expand Down Expand Up @@ -577,7 +569,9 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
)
return actor.to(device)

def _create_mock_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
def _create_mock_value(
self, batch=2, obs_dim=3, action_dim=4, device="cpu", out_keys=None
):
# Actor
class ValueClass(nn.Module):
def __init__(self):
Expand All @@ -589,8 +583,7 @@ def forward(self, obs, act):

module = ValueClass()
value = ValueOperator(
module=module,
in_keys=["observation", "action"],
module=module, in_keys=["observation", "action"], out_keys=out_keys
)
return value.to(device)

Expand Down Expand Up @@ -811,17 +804,46 @@ def test_ddpg_tensordict_keys(self, td_est):
loss_function="l2",
)

loss_fn.make_value_estimator(td_est)

default_keys = {
"state_action_value_key": "state_action_value",
"priority_key": "td_error",
"state_action_value": "state_action_value",
"priority": "td_error",
}
key_mapping = {"state_action_value_key": "value_key"}
key_mapping = {"state_action_value": "value_key"}

self.tensordict_keys_test(
loss_fn, default_keys=default_keys, loss_advantage_key_mapping=key_mapping
loss_fn,
default_keys=default_keys,
td_est=td_est,
loss_advantage_key_mapping=key_mapping,
)

@pytest.mark.parametrize(
"td_est",
[ValueEstimators.TD0, ValueEstimators.TD1, ValueEstimators.TDLambda, None],
)
def test_ddpg_tensordict_run(self, td_est):
"""Test DDPG loss module with non-default tensordict keys."""
torch.manual_seed(self.seed)
tensor_keys = {
"state_action_value": "state_action_value_test",
"priority": "td_error_test",
}

actor = self._create_mock_actor()
value = self._create_mock_value(out_keys=[tensor_keys["state_action_value"]])
td = self._create_mock_data_ddpg()
loss_fn = DDPGLoss(
actor,
value,
loss_function="l2",
)
loss_fn.set_keys(**tensor_keys)

if td_est is not None:
loss_fn.make_value_estimator(td_est)

with _check_td_steady(td):
_ = loss_fn(td)


class TestTD3(LossModuleTestBase):
Expand Down Expand Up @@ -1108,7 +1130,6 @@ def test_td3_tensordict_keys(self, td_est):
actor,
value,
)
loss_fn.make_value_estimator(td_est)

default_keys = {
"priority_key": "td_error",
Expand All @@ -1118,7 +1139,10 @@ def test_td3_tensordict_keys(self, td_est):
key_mapping = {"state_action_value_key": "value_key"}

self.tensordict_keys_test(
loss_fn, default_keys=default_keys, loss_advantage_key_mapping=key_mapping
loss_fn,
default_keys=default_keys,
td_est=td_est,
loss_advantage_key_mapping=key_mapping,
)


Expand Down Expand Up @@ -1557,8 +1581,6 @@ def test_sac_tensordict_keys(self, td_est, version):
loss_function="l2",
)

loss_fn.make_value_estimator(td_est)

default_keys = {
"priority_key": "td_error",
"value_key": "state_value",
Expand All @@ -1570,7 +1592,10 @@ def test_sac_tensordict_keys(self, td_est, version):
key_mapping = {"value_key": "value_key"}

self.tensordict_keys_test(
loss_fn, default_keys=default_keys, loss_advantage_key_mapping=key_mapping
loss_fn,
default_keys=default_keys,
td_est=td_est,
loss_advantage_key_mapping=key_mapping,
)


Expand Down Expand Up @@ -1922,16 +1947,17 @@ def test_discrete_sac_tensordict_keys(self, td_est):
loss_function="l2",
)

loss_fn.make_value_estimator(td_est)

default_keys = {
"priority_key": "td_error",
"value_key": "state_value",
"action_key": "action",
}
key_mapping = {"value_key": "value_key"}
self.tensordict_keys_test(
loss_fn, default_keys=default_keys, loss_advantage_key_mapping=key_mapping
loss_fn,
default_keys=default_keys,
td_est=td_est,
loss_advantage_key_mapping=key_mapping,
)


Expand Down Expand Up @@ -2420,7 +2446,6 @@ def test_redq_tensordict_keys(self, td_est):
qvalue_network=qvalue,
loss_function="l2",
)
loss_fn.make_value_estimator(td_est)

default_keys = {
"priority_key": "td_error",
Expand All @@ -2431,7 +2456,10 @@ def test_redq_tensordict_keys(self, td_est):
}
key_mapping = {"value_key": "value_key"}
self.tensordict_keys_test(
loss_fn, default_keys=default_keys, loss_advantage_key_mapping=key_mapping
loss_fn,
default_keys=default_keys,
td_est=td_est,
loss_advantage_key_mapping=key_mapping,
)


Expand Down Expand Up @@ -2890,7 +2918,6 @@ def test_ppo_tensordict_keys(self, loss_class, td_est):
value = self._create_mock_value()

loss_fn = loss_class(actor, value, loss_critic_type="l2")
loss_fn.make_value_estimator(td_est)

default_keys = {
"advantage": "advantage",
Expand All @@ -2905,7 +2932,10 @@ def test_ppo_tensordict_keys(self, loss_class, td_est):
"value": "value_key",
}
self.tensordict_keys_test(
loss_fn, default_keys=default_keys, loss_advantage_key_mapping=key_mapping
loss_fn,
default_keys=default_keys,
td_est=td_est,
loss_advantage_key_mapping=key_mapping,
)

@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
Expand Down Expand Up @@ -3233,7 +3263,6 @@ def test_a2c_tensordict_keys(self, td_est):
value = self._create_mock_value()

loss_fn = A2CLoss(actor, value, loss_critic_type="l2")
loss_fn.make_value_estimator(td_est)

default_keys = {
"advantage_key": "advantage",
Expand All @@ -3247,7 +3276,10 @@ def test_a2c_tensordict_keys(self, td_est):
"value_key": "value_key",
}
self.tensordict_keys_test(
loss_fn, default_keys=default_keys, loss_advantage_key_mapping=key_mapping
loss_fn,
default_keys=default_keys,
td_est=td_est,
loss_advantage_key_mapping=key_mapping,
)

@pytest.mark.parametrize("device", get_available_devices())
Expand Down Expand Up @@ -3435,7 +3467,6 @@ def test_reinforce_tensordict_keys(self, td_est):
actor_net,
critic=value_net,
)
loss_fn.make_value_estimator(td_est)

default_keys = {
"advantage_key": "advantage",
Expand All @@ -3449,7 +3480,10 @@ def test_reinforce_tensordict_keys(self, td_est):
"value_key": "value_key",
}
self.tensordict_keys_test(
loss_fn, default_keys=default_keys, loss_advantage_key_mapping=key_mapping
loss_fn,
default_keys=default_keys,
td_est=td_est,
loss_advantage_key_mapping=key_mapping,
)


Expand Down Expand Up @@ -3877,7 +3911,6 @@ def test_dreamer_actor_tensordict_keys(self, td_est, device):
value_model,
mb_env,
)
loss_fn.make_value_estimator(td_est)

default_keys = {
"belief_key": "belief",
Expand All @@ -3887,7 +3920,10 @@ def test_dreamer_actor_tensordict_keys(self, td_est, device):
}
key_mapping = {"value_key": "value_key"}
self.tensordict_keys_test(
loss_fn, default_keys=default_keys, loss_advantage_key_mapping=key_mapping
loss_fn,
default_keys=default_keys,
td_est=td_est,
loss_advantage_key_mapping=key_mapping,
)

def test_dreamer_value_tensordict_keys(self, device):
Expand Down Expand Up @@ -4234,7 +4270,6 @@ def test_iql_tensordict_keys(self, td_est):
value_network=value,
loss_function="l2",
)
loss_fn.make_value_estimator(td_est)

default_keys = {
"priority_key": "td_error",
Expand All @@ -4245,7 +4280,10 @@ def test_iql_tensordict_keys(self, td_est):
}
key_mapping = {"value_key": "value_key"}
self.tensordict_keys_test(
loss_fn, default_keys=default_keys, loss_advantage_key_mapping=key_mapping
loss_fn,
default_keys=default_keys,
td_est=td_est,
loss_advantage_key_mapping=key_mapping,
)


Expand Down
29 changes: 21 additions & 8 deletions torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,22 @@ class DDPGLoss(LossModule):

@dataclass
class _AcceptedKeys:
state_action_value_key: NestedKey = "state_action_value"
priority_key: NestedKey = "td_error"
"""Stores default values for all configurable tensordict keys.

This class is used to define and store which tensordict keys are configurable
via `.set_keys(key_name=key_value) and their default values.

Attributes:
------------
Blonck marked this conversation as resolved.
Show resolved Hide resolved
state_action_value : NestedKey
The input tensordict key where the state action value is expected.
Will be used for the underlying value estimator as value key. Defaults to ``"state_action_value"``.
priority : NestedKey
The input tensordict key where the target priority is written to. Defaults to ``"td_error"``.
"""

state_action_value: NestedKey = "state_action_value"
priority: NestedKey = "td_error"

default_keys = _AcceptedKeys()
default_value_estimator: ValueEstimators = ValueEstimators.TD0
Expand All @@ -59,7 +73,6 @@ def __init__(
gamma: float = None,
) -> None:
super().__init__()

self.delay_actor = delay_actor
self.delay_value = delay_value

Expand Down Expand Up @@ -94,7 +107,7 @@ def __init__(
def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
self._value_estimator.set_keys(
value_key=self._tensor_keys.state_action_value_key,
value_key=self._tensor_keys.state_action_value,
)

def forward(self, input_tensordict: TensorDictBase) -> TensorDict:
Expand All @@ -119,7 +132,7 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict:
if input_tensordict.device is not None:
td_error = td_error.to(input_tensordict.device)
input_tensordict.set(
self.tensor_keys.priority_key,
self.tensor_keys.priority,
td_error,
inplace=True,
)
Expand Down Expand Up @@ -150,7 +163,7 @@ def _loss_actor(
td_copy,
params=params,
)
return -td_copy.get(self.tensor_keys.state_action_value_key)
return -td_copy.get(self.tensor_keys.state_action_value)

def _loss_value(
self,
Expand All @@ -162,7 +175,7 @@ def _loss_value(
td_copy,
params=self.value_network_params,
)
pred_val = td_copy.get(self.tensor_keys.state_action_value_key).squeeze(-1)
pred_val = td_copy.get(self.tensor_keys.state_action_value).squeeze(-1)

target_params = TensorDict(
{
Expand Down Expand Up @@ -193,7 +206,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
if hasattr(self, "gamma"):
hp["gamma"] = self.gamma
hp.update(hyperparams)
tensor_keys = {"value_key": self.tensor_keys.state_action_value_key}
tensor_keys = {"value_key": self.tensor_keys.state_action_value}
if value_type == ValueEstimators.TD1:
self._value_estimator = TD1Estimator(
value_network=self.actor_critic, tensor_keys=tensor_keys, **hp
Expand Down