Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] the usage of tensordict keys in loss modules #1175

Merged
merged 43 commits into from
May 31, 2023

Conversation

Blonck
Copy link
Contributor

@Blonck Blonck commented May 22, 2023

Description

We have various loss modules in RL.
They work as

loss_module = LossModule(network, …)
loss_module(data)

These loss modules access the actual data by keys. Some keys are configurable via ctor,

advantage_key: str = "advantage",

others are hardcoded,

return -td_copy.get("state_action_value")

This is refactored such that all relevant keys can be set via

loss_module.set_keys(sample_log_prob=“some_other_key”)

The same is done for the advantage modules.

advantage_module.set_keys(advantage="other_advantage")

Closes #1174

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • New feature (non-breaking change which adds core functionality)
  • Documentation (update in the documentation)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 22, 2023
@Blonck Blonck force-pushed the refactor_loss_keys branch from e1b2350 to c6186fc Compare May 23, 2023 13:26
@Blonck Blonck added the Refactoring Refactoring of an existing feature label May 23, 2023
@Blonck Blonck self-assigned this May 23, 2023
@Blonck Blonck requested a review from vmoens May 23, 2023 13:52
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a mil for this!

Is there a way for tensordict_keys to be defined at the class level (and not instance)? Such that one can do PPOLoss.tensordict_keys without instantiating an object?

Also, what about making tensordict_keys a property using @abc.abstractmethod in LossModule parent class, such that we force all new losses to have that attribute?

torchrl/objectives/ddpg.py Outdated Show resolved Hide resolved
torchrl/objectives/ddpg.py Outdated Show resolved Hide resolved
torchrl/objectives/ddpg.py Outdated Show resolved Hide resolved
torchrl/objectives/ddpg.py Outdated Show resolved Hide resolved
test/test_cost.py Outdated Show resolved Hide resolved
if key not in self.tensordict_keys.keys():
raise ValueError(f"{key} not a valid tensordict key")
set_value = value if value is not None else self.tensordict_keys[key]
setattr(self, key, set_value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather have these keys contained in a separate container, like a dictionary or else.
I'm afraid that as the number of keys increase, we'll end up with a class with many attributes and no easy access to the list of them all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the container is already there. However, using the keys would be involve always one extra step.

tensordict.get(self.action_key) -> tensordict.get(self.tensordict_keys["action_key"])

From the data organisation it would be the better solution, but the code will be a little more noisy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's a bit clunky.
@matteobettini do you have an opinion on the matter?

Copy link
Contributor

@matteobettini matteobettini May 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer tensordict.get(self.tensordict_keys["action_key"]).

Also, would loss_keys make sense instead of tensordict_keys?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about an Enum instead of a dict (to make sure only a finite set of keys is present)?

tensordict.get(self.loss_keys.action_key)

Copy link
Contributor Author

@Blonck Blonck May 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about an Enum instead of a dict (to make sure only a finite set of keys is present)?

This would mean constructing the enum in the base class from data provided by the child class. I guess this is possible, since almost everything is possible in python. However, the syntax would not look like

tensordict.get(self.loss_keys.action_key)

since the key must be converted to str (or pair of strings). StrEnum would be a solution, but are introduced in python 3.11.

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a mil for this!

Is there a way for tensordict_keys to be defined at the class level (and not instance)? Such that one can do PPOLoss.tensordict_keys without instantiating an object?

Also, what about making tensordict_keys a property using @abc.abstractmethod in LossModule parent class, such that we force all new losses to have that attribute?

@Blonck
Copy link
Contributor Author

Blonck commented May 23, 2023

Is there a way for tensordict_keys to be defined at the class level (and not instance)? Such that one can do PPOLoss.tensordict_keys without instantiating an object?

That would lead that the behavior of one instance of a loss module would change if .set_keys() is called somewhere else. I think that could surprise the user of torchrl. (Although, usually only one loss module is used.)

Also, what about making tensordict_keys a property using @abc.abstractmethod in LossModule parent class, such that we force all new losses to have that attribute?

That is a good idea, I will try to implement this. It should lead to an easier interface.

Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
@vmoens
Copy link
Contributor

vmoens commented May 23, 2023

That would lead that the behavior of one instance of a loss module would change if .set_keys() is called somewhere else. I think that could surprise the user of torchrl. (Although, usually only one loss module is used.)

Not necessarily.
Here you clearly separate the default and non-default. The default should be defined at the class level if they're kept separated.

@Blonck
Copy link
Contributor Author

Blonck commented May 23, 2023

Not necessarily.
Here you clearly separate the default and non-default. The default should be defined at the class level if they're kept separated.

Got it. I've implemented this idea and it looks good. Although, I had to use a abstract staticmethod. Afaik, there is no static property decorator.

@matteobettini
Copy link
Contributor

Also one thing to mention is the the new keys have to be transparently reflected in the value_estimators

loss = loss()
loss.set_keys()
loss.make_value_estim()

and

loss = loss()
loss.make_value_estim()
loss.set_keys()

Should both work with the keys reflected in the value estimators

@vmoens
Copy link
Contributor

vmoens commented May 23, 2023

Also one thing to mention is the the new keys have to be transparently reflected in the value_estimators

loss = loss()
loss.set_keys()
loss.make_value_estim()

and

loss = loss()
loss.make_value_estim()
loss.set_keys()

Should both work with the keys reflected in the value estimators

ideally keys of the value estimator should only be set through the value estimator no?

@matteobettini
Copy link
Contributor

matteobettini commented May 23, 2023

But the value estimator is created by the loss.

do we want to make users to call set_keys twice?

already calling it once is a added complexity for MARL users

@Blonck
Copy link
Contributor Author

Blonck commented May 23, 2023

But the value estimator is created by the loss.

do we want to make users to call set_keys twice?

already calling it once is a added complexity for MARL users

That is a good point. This problem is introduced by allowing to configure the tensordict keys after constructing the losses.
I could solve this by extending .set_keys(...) so that it also sets the keys for the value estimator.
Another solution would be to remove .set_keys(...) and add a generic constructor argument for all loss modules, e.g., tensordict_keys.
If there is no need to configure the keys during runtime, I would prefer the later one, although the code becomes a bit clunky.

@Blonck
Copy link
Contributor Author

Blonck commented May 23, 2023

The solution with the ctor would roughly look like:

class MyLoss(LossModule):
    def __init__(self, ...., tensordict_keys = {}):
         super().__init__(tensordict_keys=tensordict_keys)
         ...

class LossModule(nn.Module, ABC):
    def __init__(self, tensordict_keys={}):
        #merge default tensordict keys with provided keys
       default_keys = self.default_tensordict_keys()
       for key, value in tensordict_keys.items():
          if key not in default_keys:
              raise ValueError(...)
          default_keys[key] = value
      // create attributes
      ...

@Blonck
Copy link
Contributor Author

Blonck commented May 24, 2023

Forwarding the new keys via .set_keys() to the value estimator would also require to maintain a mapping from key_name_loss to key_name_value_estimator. In the default case both key names are identical, see ppo.py

# from ppo.py
def make_value_estimator(...):
    ...
    value_key = self.value_key
    if value_type == ValueEstimators.TD1:
        self._value_estimator = TD1Estimator(
            value_network=self.critic, value_key=value_key, **hp
        )

However, there is also this case ddpg.py:

# from ddpg.py
def make_value_estimator(...):
    ...
    value_key = "state_action_value" # <- would correspond to self.state_action_value_key
    if value_type == ValueEstimators.TD1:
        self._value_estimator = TD1Estimator(
            value_network=self.actor_critic, value_key=value_key, **hp
        )

@Blonck
Copy link
Contributor Author

Blonck commented May 26, 2023

Btw, this code will raise an exception in case the value key name has a non-default value

loss = loss()
loss.make_value_estim()
loss.set_keys()

because

actor = # construct tensordict with non-default value key
loss = loss() # constructing loss with default key names works because the ctor doesn't use any key names
loss.make_value_estim() # will raise an exception because

The value estimator checks that the value key is in the value_network.out_keys.
Since the actor uses a non-default value key, but loss and hence the value estimator are instantiated with the default values this will raise an exception.

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Must have been quite a headache to come about!
See my comments in the code

torchrl/objectives/ppo.py Outdated Show resolved Hide resolved
torchrl/objectives/a2c.py Outdated Show resolved Hide resolved
torchrl/objectives/common.py Outdated Show resolved Hide resolved
torchrl/objectives/common.py Outdated Show resolved Hide resolved
torchrl/objectives/ddpg.py Outdated Show resolved Hide resolved
torchrl/objectives/value/advantages.py Outdated Show resolved Hide resolved
torchrl/objectives/value/advantages.py Outdated Show resolved Hide resolved
torchrl/objectives/value/advantages.py Outdated Show resolved Hide resolved
torchrl/objectives/value/advantages.py Outdated Show resolved Hide resolved
torchrl/objectives/value/advantages.py Outdated Show resolved Hide resolved
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Must have been quite a headache to come about!
See my comments in the code

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!
I'm happy to merge this but I have one more question: we have a tutorial about losses and this feature seems pretty advanced for a regular user that would like to code up a new loss with hard coded keys.
Is it mandatory to code _AcceptedKeys?
It's a great tool but maybe we can reserve it for internal usage.

Have you checked that this tutorial is running under this PR?

def __new__(cls, *args, **kwargs):
cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward)
cls._tensor_keys = cls._AcceptedKeys()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe make this optional (only if _AcceptedKeys is present)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation does not prevent users from crafting a loss module that lacks configurable keys, as _AcceptedKeys is defined as an empty set in such cases. However the abstract method prevents users from doing so:

@abstractmethod
def _forward_value_estimator_keys(self, **kwargs) -> None:
    """Passes updated tensordict keys to the underlying value estimator."""
    ...

In this case, the set_keys method will not function if supplied with any arguments, a behavior that aligns with my expectations.

We can remove the @AbstractMethod decorator and introducing an error condition if the .set_keys method is invoked while _forward_value_estimator_keys() remains undefined by the loss module. This adjustment would ensure an exception is triggered when .set_keys() is called from the custom loss module.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it
Up to you for the exception. In a way, if someone writes a loss module then calls set_keys without having written a set of keys they're probably way off the road...

>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
>>> dqn_loss.set_keys(priority_key="td_error", action_value_key="action_value")
"""
for key, value in kwargs.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we make _AcceptedKeys optional, we can raise an exception if it is not present?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Refactoring Refactoring of an existing feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Refactor key usage of loss modules
4 participants