-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Refactor] the usage of tensordict keys in loss modules #1175
Conversation
e1b2350
to
c6186fc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a mil for this!
Is there a way for tensordict_keys
to be defined at the class level (and not instance)? Such that one can do PPOLoss.tensordict_keys
without instantiating an object?
Also, what about making tensordict_keys
a property using @abc.abstractmethod in LossModule parent class, such that we force all new losses to have that attribute?
torchrl/objectives/common.py
Outdated
if key not in self.tensordict_keys.keys(): | ||
raise ValueError(f"{key} not a valid tensordict key") | ||
set_value = value if value is not None else self.tensordict_keys[key] | ||
setattr(self, key, set_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather have these keys contained in a separate container, like a dictionary or else.
I'm afraid that as the number of keys increase, we'll end up with a class with many attributes and no easy access to the list of them all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the container is already there. However, using the keys would be involve always one extra step.
tensordict.get(self.action_key) -> tensordict.get(self.tensordict_keys["action_key"])
From the data organisation it would be the better solution, but the code will be a little more noisy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it's a bit clunky.
@matteobettini do you have an opinion on the matter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer tensordict.get(self.tensordict_keys["action_key"])
.
Also, would loss_keys
make sense instead of tensordict_keys
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about an Enum instead of a dict (to make sure only a finite set of keys is present)?
tensordict.get(self.loss_keys.action_key)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about an Enum instead of a dict (to make sure only a finite set of keys is present)?
This would mean constructing the enum in the base class from data provided by the child class. I guess this is possible, since almost everything is possible in python. However, the syntax would not look like
tensordict.get(self.loss_keys.action_key)
since the key must be converted to str (or pair of strings). StrEnum would be a solution, but are introduced in python 3.11.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a mil for this!
Is there a way for tensordict_keys
to be defined at the class level (and not instance)? Such that one can do PPOLoss.tensordict_keys
without instantiating an object?
Also, what about making tensordict_keys
a property using @abc.abstractmethod in LossModule parent class, such that we force all new losses to have that attribute?
That would lead that the behavior of one instance of a loss module would change if
That is a good idea, I will try to implement this. It should lead to an easier interface. |
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Not necessarily. |
Got it. I've implemented this idea and it looks good. Although, I had to use a abstract staticmethod. Afaik, there is no static property decorator. |
Also one thing to mention is the the new keys have to be transparently reflected in the value_estimators
and
Should both work with the keys reflected in the value estimators |
ideally keys of the value estimator should only be set through the value estimator no? |
But the value estimator is created by the loss. do we want to make users to call set_keys twice? already calling it once is a added complexity for MARL users |
That is a good point. This problem is introduced by allowing to configure the tensordict keys after constructing the losses. |
The solution with the ctor would roughly look like: class MyLoss(LossModule):
def __init__(self, ...., tensordict_keys = {}):
super().__init__(tensordict_keys=tensordict_keys)
...
class LossModule(nn.Module, ABC):
def __init__(self, tensordict_keys={}):
#merge default tensordict keys with provided keys
default_keys = self.default_tensordict_keys()
for key, value in tensordict_keys.items():
if key not in default_keys:
raise ValueError(...)
default_keys[key] = value
// create attributes
... |
Forwarding the new keys via # from ppo.py
def make_value_estimator(...):
...
value_key = self.value_key
if value_type == ValueEstimators.TD1:
self._value_estimator = TD1Estimator(
value_network=self.critic, value_key=value_key, **hp
) However, there is also this case # from ddpg.py
def make_value_estimator(...):
...
value_key = "state_action_value" # <- would correspond to self.state_action_value_key
if value_type == ValueEstimators.TD1:
self._value_estimator = TD1Estimator(
value_network=self.actor_critic, value_key=value_key, **hp
) |
Btw, this code will raise an exception in case the value key name has a non-default value
because
The value estimator checks that the value key is in the value_network.out_keys. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Must have been quite a headache to come about!
See my comments in the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Must have been quite a headache to come about!
See my comments in the code
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
I'm happy to merge this but I have one more question: we have a tutorial about losses and this feature seems pretty advanced for a regular user that would like to code up a new loss with hard coded keys.
Is it mandatory to code _AcceptedKeys
?
It's a great tool but maybe we can reserve it for internal usage.
Have you checked that this tutorial is running under this PR?
def __new__(cls, *args, **kwargs): | ||
cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward) | ||
cls._tensor_keys = cls._AcceptedKeys() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe make this optional (only if _AcceptedKeys is present)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation does not prevent users from crafting a loss module that lacks configurable keys, as _AcceptedKeys is defined as an empty set in such cases. However the abstract method prevents users from doing so:
@abstractmethod
def _forward_value_estimator_keys(self, **kwargs) -> None:
"""Passes updated tensordict keys to the underlying value estimator."""
...
In this case, the set_keys method will not function if supplied with any arguments, a behavior that aligns with my expectations.
We can remove the @AbstractMethod decorator and introducing an error condition if the .set_keys
method is invoked while _forward_value_estimator_keys() remains undefined by the loss module. This adjustment would ensure an exception is triggered when .set_keys() is called from the custom loss module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it
Up to you for the exception. In a way, if someone writes a loss module then calls set_keys without having written a set of keys they're probably way off the road...
>>> dqn_loss = DQNLoss(actor, action_space="one-hot") | ||
>>> dqn_loss.set_keys(priority_key="td_error", action_value_key="action_value") | ||
""" | ||
for key, value in kwargs.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we make _AcceptedKeys
optional, we can raise an exception if it is not present?
Description
We have various loss modules in RL.
They work as
These loss modules access the actual data by keys. Some keys are configurable via ctor,
rl/torchrl/objectives/ppo.py
Line 125 in 3c8197b
others are hardcoded,
rl/torchrl/objectives/ddpg.py
Line 139 in 714d645
This is refactored such that all relevant keys can be set via
The same is done for the advantage modules.
Closes #1174
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!