Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Exclude "_" out_keys in tensordictmodel #589

Merged
merged 8 commits into from
Oct 20, 2022

Conversation

jlesuffleur
Copy link
Contributor

Description

Added the possibility to avoid writing some tensors in the output tensordict in TensorDictModule by setting the output key to "_".

  • tests added
  • note added in the tutorial

Motivation and Context

New feature, closes #564

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 19, 2022
@vmoens vmoens changed the title [Feature] Exclude out_keys in tensordictmodel [Feature] Exclude "_" out_keys in tensordictmodel Oct 19, 2022
@vmoens vmoens added the enhancement New feature or request label Oct 19, 2022
@codecov
Copy link

codecov bot commented Oct 19, 2022

Codecov Report

Merging #589 (9ce26a3) into main (b584d78) will increase coverage by 0.01%.
The diff coverage is 94.11%.

❗ Current head 9ce26a3 differs from pull request most recent head 9f225df. Consider uploading reports for the commit 9f225df to get more accurate results

@@            Coverage Diff             @@
##             main     #589      +/-   ##
==========================================
+ Coverage   86.96%   86.98%   +0.01%     
==========================================
  Files         120      120              
  Lines       21812    21844      +32     
==========================================
+ Hits        18968    19000      +32     
  Misses       2844     2844              
Flag Coverage Δ
linux-cpu 85.30% <94.11%> (+0.01%) ⬆️
linux-gpu 86.75% <94.11%> (+0.01%) ⬆️
linux-outdeps-gpu 75.42% <94.11%> (+0.02%) ⬆️
linux-stable-cpu 85.28% <94.11%> (+0.01%) ⬆️
linux-stable-gpu 86.75% <94.11%> (+0.01%) ⬆️
macos-cpu 85.08% <94.11%> (+0.02%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
torchrl/modules/distributions/continuous.py 85.65% <ø> (ø)
torchrl/modules/tensordict_module/common.py 79.18% <75.00%> (-0.26%) ⬇️
test/test_tensordictmodules.py 98.46% <100.00%> (+0.04%) ⬆️
torchrl/objectives/utils.py 85.10% <0.00%> (-0.11%) ⬇️
torchrl/objectives/__init__.py 100.00% <0.00%> (ø)
test/test_trainer.py 98.93% <0.00%> (+1.06%) ⬆️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a couple of tests for the two missing bits?

@@ -104,7 +104,7 @@ def _inverse(self, y: torch.Tensor) -> torch.Tensor:


class NormalParamWrapper(nn.Module):
"""A wrapper for normal distirbution parameters.
"""A wrapper for normal distribution parameters.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

torchrl/modules/tensordict_module/common.py Show resolved Hide resolved
torchrl/modules/tensordict_module/common.py Show resolved Hide resolved
@jlesuffleur
Copy link
Contributor Author

@vmoens I added the two tests as suggested 👍

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also check that the warnings we capture are the ones we expect

}

# warning due to "_" in spec keys
with pytest.warns():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we check the warning message too (just to make sure we capture the right warning)

@@ -773,6 +836,16 @@ def test_vmap_probabilistic_laterconstruct(self, safe, spec_type):


class TestTDSequence:
def test_in_key_warning(self):
with pytest.warns():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing here

tensordict_module = TensorDictModule(
nn.Linear(3, 4), in_keys=["_"], out_keys=["out1"]
)
with pytest.warns():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing here

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work thanks for this

@vmoens vmoens merged commit 3e36522 into pytorch:main Oct 20, 2022
@jlesuffleur jlesuffleur deleted the exclude_keys_tensordictmodel branch October 20, 2022 08:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Exclude "_" keys from the output tensordict in TensorDictModule
3 participants