-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Extend TensorDictPrimer default_value options #2071
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2071
Note: Links to docs will display an error until the docs builds have been completed. ❌ 6 New Failures, 7 Unrelated FailuresAs of commit 1625514 with merge base acf168e (): NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, left a bunch of comments.
Thanks a million!
try: | ||
expanded_spec = self._try_expand_shape(spec) | ||
except AttributeError: | ||
raise RuntimeError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When will this be reached?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if for any reason self.parent is None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when would transform_observation_spec
be called when parent
is None
?
self.random = random | ||
if isinstance(default_value, dict): | ||
primer_keys = {unravel_key(key) for key in self.primers.keys(True, True)} | ||
default_value_keys = {unravel_key(key) for key in default_value.keys()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about passing through a tensordict to represent this?
For instance, this format will be messy with nested keys
default_values = {("a", "b"): 1, ("c", "d"): lambda: torch.randn(()), "e": {"f": lambda: torch.zeros(())}}
but if you use tensordict nightly you get a nice representation:
default_values = TensorDict(default_values, []).to_dict()
default_values
which prints
{'a': {'b': tensor(1)},
'c': {'d': <function __main__.<lambda>()>},
'e': {'f': <function __main__.<lambda>()>}}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since tensordict accepts whatever value now, we could even not transform it back to a dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That way the default value structure will be 100% identical with the CompositeSpec that we use to represent the specs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah very cool option @vmoens !
atm I need to transform it back to dict, because I get the following behaviour for non-tensor data:
import torch
from tensordict import TensorDict
default_value = {
"mykey1": lambda: torch.ones(3),
"mykey2": lambda: torch.tensor(1, dtype=torch.int64),
}
default_value = TensorDict(default_value, [])
keys = default_value.keys(True, True)
print(keys)
output:
_TensorDictKeysView([],
include_nested=True,
leaves_only=True)
So non-tensor data ara not considered leafs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you using the latest nightly?
You can always define your own is_leaf for keys:
import tensordict
import torch
from tensordict import TensorDict
default_value = {
"mykey1": lambda: torch.ones(3),
"mykey2": lambda: torch.tensor(1, dtype=torch.int64),
}
default_value = TensorDict(default_value, [])
print(default_value)
keys = list(default_value.keys(True, True, is_leaf=lambda x: issubclass(x, (tensordict.NonTensorData, torch.Tensor))))
print(keys)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was using the latest commit, yes. But this solution works fine :)
try: | ||
expanded_spec = self._try_expand_shape(spec) | ||
except AttributeError: | ||
raise RuntimeError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when would transform_observation_spec
be called when parent
is None
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
TestgSDE is failing because we patched the behaviour for wrong primers, can you fix that? |
done! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
This PR aims to extend the possible values taken by the tensors added by the TensorDictPrimer transform, allowing to use callable to create them.
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213
if this solves the issue #15213Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!