[Feature Request] Keep intermediate keys when calling advantage modules and loss modules. #1299
Description
Motivation
An advantage of TensorDict
modules, and in particular the TensorDictSequential
is that they can write intermediate keys during their forward pass, which can then be reused.
Say for example you divide your actor (or critic) into two parts: a feature extractor and an action (or value) head, and combine these two parts with a TensorDictSequential
. You do this to get the feature vectors and the actions (values) with a single forward pass on an input tensordict, and then use the actions to interact with the environment/compute gradients and the feature vectors for some analysis.
Unfortunately, the forward passes performed with the actor/critic by a loss module or the critic by an advantage module do not preserve these intermediate keys. So you wouldn't get the feature vectors when you compute the loss or the advantage.
Current State
Advantage modules
Advantage estimators used to write the intermediate keys in the input tensordict as all they were doing was calling the value estimator on the input tensordict. However, after #1263, this behavior has changed with the introduction of _call_value_nets()
which precisely extracts only the value key from its output and writes it to the input tensordict.
It would be great to have the old behavior back, where all the new keys generated by the value estimator are written to the input tensordict.
Loss modules
Most of the losses seem to clone the input tensordict before calling the networks, so no intermediate key will be written to it after calling the loss module.
Proposed solution
The desired behavior is to not waste the forward pass performed by composite modules (actor/critic) and recover the intermediate keys they compute whenever a loss or advantage function calls them.
After all, the default expectation when calling modules is that they edit the input tensordict in place.
Now for losses that shouldn't overwrite the tensordict like PPO which needs to keep the old log probs to compute the ratio again later, they could output the tensordict they newly computed instead of the old one, which will be kept by the user.
I think that following the newly introduced notion of _AcceptedKeys
in the refactor of tensordict keys in loss modules #1175, we could have a key that tracks which (additional) keys should be written back to the input tensordict or returned in the output tensordict.
Or when the loss literally clones the input tensordict like PPO, just output the newly computed tensordict, as this wouldn't have more storage footprint.
That way, back to the motivating example, you could pick your feature_vector
key and have it written in the input tensordict or returned in the output when calling an advantage or the loss module.
Happy to discuss this.
Limitations
I'm mainly thinking of straightforward actor-critic algorithms when describing these features so I may miss some limitations or undesired behaviors.
Checklist
- I have checked that there is no similar issue in the repo (required)