Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] the usage of tensordict keys in loss modules #1175

Merged
merged 43 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
83dc591
[Refactor] the usage of tensordict keys in loss modules
Blonck May 22, 2023
09ced18
Add more loss modules
Blonck May 22, 2023
bc04cae
Add more loss modules
Blonck May 23, 2023
75c8ea1
Refactor remaining loss modules
Blonck May 23, 2023
5a74a16
Remove unnecessary tests
Blonck May 23, 2023
32725b4
tensordict_keys dict is not longer overwritten from child classes
Blonck May 23, 2023
ab94848
Merge branch 'main' into refactor_loss_keys
Blonck May 23, 2023
802fe48
Harmonize key name for "state_value"
Blonck May 23, 2023
c6186fc
Polish refactoring
Blonck May 23, 2023
b694e8c
Merge branch 'main' into refactor_loss_keys
Blonck May 23, 2023
9150b74
Apply suggestions from code review
Blonck May 23, 2023
bcd8a28
Use abstract staticmethod to provide default values
Blonck May 23, 2023
6f10920
Merge branch 'main' into refactor_loss_keys
Blonck May 23, 2023
67941df
Merge branch 'main' and rename tensordict_keys to loss_keys
Blonck May 24, 2023
7f3e129
Use simple set_keys on all loss modules
Blonck May 24, 2023
427c1e8
Implement tensor_keys via _AcceptedKeys dataclass
Blonck May 24, 2023
66fb949
Extended _AcceptedKeys to all loss modules
Blonck May 25, 2023
526ab36
Refactor unit test for tensordict keys
Blonck May 25, 2023
08e20da
Merge branch 'main' into refactor_loss_key_advanced
Blonck May 25, 2023
0d476ca
WIP
Blonck May 25, 2023
9bb616a
Fix .in_keys of ValueEstimatorBase
Blonck May 25, 2023
5d00ca0
Move tensordict key logig to base class
Blonck May 25, 2023
4db47e5
Fix make_value_estimator of a2c.py
Blonck May 25, 2023
6b422f9
Remvove '_key' from keynames in ppo.py + polish
Blonck May 26, 2023
317755d
Remvove '_key' from keynames in ddpg.py + polish
Blonck May 26, 2023
fe9fba0
Fix documentation in advantages.py
Blonck May 26, 2023
34091e0
Remvove '_key' from keynames in dqn.py + polish
Blonck May 26, 2023
4baa5dc
Remvove '_key' from keynames in dreamer.py + polish
Blonck May 26, 2023
4595546
Remvove '_key' from keynames in iql.py and redq.py + polish
Blonck May 26, 2023
8ae6ad9
Remove tensor_keys from advantage ctor
Blonck May 26, 2023
a15e220
Add documentation to a2c.py
Blonck May 26, 2023
f1187f3
Change documentation of loss modules
Blonck May 26, 2023
3e09c58
Add unit test for advantages tensordict keys
Blonck May 26, 2023
e52a3f2
Merge branch 'main' into refactor_loss_key_advanced
Blonck May 26, 2023
2dc81c9
Improve wording of docstrings
Blonck May 26, 2023
655c28d
Apply suggestions from code review
Blonck May 28, 2023
226d4d3
Merge branch 'pytorch:main' into refactor_loss_keys
Blonck May 28, 2023
75d33c6
Apply code review changes
Blonck May 28, 2023
4320db6
Merge branch 'main' into refactor_loss_keys_github
Blonck May 30, 2023
cf4cd09
Change line breaking in docstrings for _AcceptedKeys
Blonck May 30, 2023
81c0413
LossModule is not longer an abstract base class.
Blonck May 31, 2023
6e753a4
Merge branch 'main' into refactor_loss_keys_github
Blonck May 31, 2023
cc784a1
Merge branch 'main' into refactor_loss_keys
vmoens May 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
LossModule is not longer an abstract base class.
  • Loading branch information
Blonck committed May 31, 2023
commit 81c041336d2f1916dff0a3e128dc9a60c6c9dfaf
35 changes: 31 additions & 4 deletions test/test_cost.py
Blonck marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import operator
import warnings
from copy import deepcopy
from dataclasses import asdict
from dataclasses import asdict, dataclass

from packaging import version as pack_version
from tensordict.nn import InteractionType
Expand Down Expand Up @@ -6080,9 +6080,6 @@ def __init__(self, compare_against, expand_dim):
expand_dim=expand_dim,
)

def _forward_value_estimator_keys(self, **kwargs) -> None:
pass

loss_module = MyLoss(compare_against=compare_against, expand_dim=expand_dim)

for key in ["module.0.bias", "module.0.weight"]:
Expand All @@ -6104,6 +6101,36 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
for key in ["module.1.bias", "module.1.weight"]:
loss_module.module_b_params.flatten_keys()[key].requires_grad

def test_tensordict_keys(self):
"""Test configurable tensordict key behavior with derived classes."""

class MyLoss(LossModule):
def __init__(self):
super().__init__()

loss_module = MyLoss()
with pytest.raises(AttributeError):
loss_module.set_keys()

class MyLoss2(MyLoss):
def _forward_value_estimator_keys(self, **kwargs) -> None:
pass

loss_module = MyLoss2()
assert loss_module.set_keys() is None
with pytest.raises(ValueError):
loss_module.set_keys(some_key="test")

class MyLoss3(MyLoss2):
@dataclass
class _AcceptedKeys:
some_key = "some_value"

loss_module = MyLoss3()
assert loss_module.tensor_keys.some_key == "some_value"
loss_module.set_keys(some_key="test")
assert loss_module.tensor_keys.some_key == "test"


class TestUtils:
@pytest.mark.parametrize("B", [None, (1, ), (4, ), (2, 2, ), (1, 2, 8, )]) # fmt: skip
Expand Down
37 changes: 29 additions & 8 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from typing import Iterator, List, Optional, Tuple, Union
Expand Down Expand Up @@ -40,7 +39,7 @@
FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality."


class LossModule(nn.Module, ABC):
class LossModule(nn.Module):
"""A parent class for RL losses.

LossModule inherits from nn.Module. It is designed to read an input
Expand All @@ -59,6 +58,25 @@ class LossModule(nn.Module, ABC):

By default, the forward method is always decorated with a
gh :class:`torchrl.envs.ExplorationType.MODE`

To utilize the ability configuring the tensordict keys via
:meth:`~.set_keys()` a subclass must define an _AcceptedKeys dataclass.
This dataclass should include all keys that are intended to be configurable.
In addition, the subclass must implement the
:meth:._forward_value_estimator_keys() method. This function is crucial for
forwarding any altered tensordict keys to the underlying value_estimator.

Examples:
>>> class MyLoss(LossModule):
>>> @dataclass
>>> class _AcceptedKeys:
>>> action = "action"
>>>
>>> def _forward_value_estimator_keys(self, **kwargs) -> None:
>>> pass
>>>
>>> loss = MyLoss()
>>> loss.set_keys(action="action2")
"""

@dataclass
Expand Down Expand Up @@ -91,11 +109,6 @@ def __init__(self):
self.value_type = self.default_value_estimator
# self.register_forward_pre_hook(_parameters_to_tensordict)

@abstractmethod
def _forward_value_estimator_keys(self, **kwargs) -> None:
"""Passes updated tensordict keys to the underlying value estimator."""
...

def _set_deprecated_ctor_keys(self, **kwargs) -> None:
"""Helper function to set a tensordict key from a constructor and raise a warning simultaneously."""
for key, value in kwargs.items():
Expand Down Expand Up @@ -124,7 +137,15 @@ def set_keys(self, **kwargs) -> None:
else:
setattr(self.tensor_keys, key, self.default_keys.key)

self._forward_value_estimator_keys(**kwargs)
try:
self._forward_value_estimator_keys(**kwargs)
except AttributeError:
raise AttributeError(
"To utilize `.set_keys(...)` for tensordict key configuration, the subclassed loss module "
"must define an _AcceptedKeys dataclass containing all keys intended for configuration. "
"Moreover, the subclass needs to implement `._forward_value_estimator_keys()` method to "
"facilitate forwarding of any modified tensordict keys to the underlying value_estimator."
)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""It is designed to read an input TensorDict and return another tensordict with loss keys named "loss*".
Expand Down