-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Exclude "_"
out_keys in tensordictmodel
#589
[Feature] Exclude "_"
out_keys in tensordictmodel
#589
Conversation
"_"
out_keys in tensordictmodel
Codecov Report
@@ Coverage Diff @@
## main #589 +/- ##
==========================================
+ Coverage 86.96% 86.98% +0.01%
==========================================
Files 120 120
Lines 21812 21844 +32
==========================================
+ Hits 18968 19000 +32
Misses 2844 2844
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a couple of tests for the two missing bits?
@@ -104,7 +104,7 @@ def _inverse(self, y: torch.Tensor) -> torch.Tensor: | |||
|
|||
|
|||
class NormalParamWrapper(nn.Module): | |||
"""A wrapper for normal distirbution parameters. | |||
"""A wrapper for normal distribution parameters. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch!
@vmoens I added the two tests as suggested 👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also check that the warnings we capture are the ones we expect
test/test_tensordictmodules.py
Outdated
} | ||
|
||
# warning due to "_" in spec keys | ||
with pytest.warns(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we check the warning message too (just to make sure we capture the right warning)
test/test_tensordictmodules.py
Outdated
@@ -773,6 +836,16 @@ def test_vmap_probabilistic_laterconstruct(self, safe, spec_type): | |||
|
|||
|
|||
class TestTDSequence: | |||
def test_in_key_warning(self): | |||
with pytest.warns(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same thing here
test/test_tensordictmodules.py
Outdated
tensordict_module = TensorDictModule( | ||
nn.Linear(3, 4), in_keys=["_"], out_keys=["out1"] | ||
) | ||
with pytest.warns(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same thing here
…suffleur/rl into exclude_keys_tensordictmodel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work thanks for this
Description
Added the possibility to avoid writing some tensors in the output tensordict in
TensorDictModule
by setting the output key to"_"
.Motivation and Context
New feature, closes #564
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!