Skip to content

[Feature Request] Keep intermediate keys when calling advantage modules and loss modules. #1299

Open
@skandermoalla

Description

Motivation

An advantage of TensorDict modules, and in particular the TensorDictSequential is that they can write intermediate keys during their forward pass, which can then be reused.
Say for example you divide your actor (or critic) into two parts: a feature extractor and an action (or value) head, and combine these two parts with a TensorDictSequential. You do this to get the feature vectors and the actions (values) with a single forward pass on an input tensordict, and then use the actions to interact with the environment/compute gradients and the feature vectors for some analysis.

Unfortunately, the forward passes performed with the actor/critic by a loss module or the critic by an advantage module do not preserve these intermediate keys. So you wouldn't get the feature vectors when you compute the loss or the advantage.

Current State

Advantage modules

Advantage estimators used to write the intermediate keys in the input tensordict as all they were doing was calling the value estimator on the input tensordict. However, after #1263, this behavior has changed with the introduction of _call_value_nets() which precisely extracts only the value key from its output and writes it to the input tensordict.

It would be great to have the old behavior back, where all the new keys generated by the value estimator are written to the input tensordict.

Loss modules

Most of the losses seem to clone the input tensordict before calling the networks, so no intermediate key will be written to it after calling the loss module.

Proposed solution

The desired behavior is to not waste the forward pass performed by composite modules (actor/critic) and recover the intermediate keys they compute whenever a loss or advantage function calls them.
After all, the default expectation when calling modules is that they edit the input tensordict in place.
Now for losses that shouldn't overwrite the tensordict like PPO which needs to keep the old log probs to compute the ratio again later, they could output the tensordict they newly computed instead of the old one, which will be kept by the user.

I think that following the newly introduced notion of _AcceptedKeys in the refactor of tensordict keys in loss modules #1175, we could have a key that tracks which (additional) keys should be written back to the input tensordict or returned in the output tensordict.
Or when the loss literally clones the input tensordict like PPO, just output the newly computed tensordict, as this wouldn't have more storage footprint.

That way, back to the motivating example, you could pick your feature_vector key and have it written in the input tensordict or returned in the output when calling an advantage or the loss module.

Happy to discuss this.

Limitations

I'm mainly thinking of straightforward actor-critic algorithms when describing these features so I may miss some limitations or undesired behaviors.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions