Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Graduate Replay Buffer prototype #794

Merged
merged 8 commits into from
Jan 7, 2023

Conversation

KamilPiechowiak
Copy link
Contributor

Description

This change replaces ReplayBuffers implementations in replay_buffers.py by composable implementations from rb_prototype.py.
This is a breaking change. ReplayBuffer constructor no longer accepts size parameter. Now it is only provided to the underlying Storage. This change also makes sample return type consistent across ReplayBuffer and its derived classes. It is also a breaking change. That is why many files needed to be updated.

Motivation and Context

Why is this change required? What problem does it solve?
This change allows to create new types of replay buffers by composing samplers and writers.

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 4, 2023
@vmoens vmoens added enhancement New feature or request bc breaking backward compatibility breaking change labels Jan 5, 2023
README.md Outdated
@@ -335,7 +334,7 @@ The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py)
```python
from torchrl.objectives import DQNLoss
loss_module = DQNLoss(value_network=value_network, gamma=0.99)
tensordict = replay_buffer.sample(batch_size)
tensordict = replay_buffer.sample(batch_size)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's your view on this?
I feel that the [0] index is bothersome. With TensorDictReplayBuffer we read the second output from sample (info) and write the content in the tensordict.
Maybe we should just have a def sample(self, ..., return_info=False) flag that can be turned on when needed.
Wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't like the [0] index. However, the version of TensorDictReplayBuffer proposed in rb_prototype.py also returns 2 values from sample(). While migrating I decided to keep this logic. If we change it, we break consistency with parent class ReplayBuffer. I don't know if there are cases when users would like to use both ReplayBuffer and TensorDictReplayBuffer interchangeably in a single piece of code. If this is not the case, we can stop returning info (or return it only if the flag is set).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see two ways of going about this

  1. branch out your branch and have a sample() function with a single output, merge it after review in your branch and ship the whole thing. Advantage: we don't bc-break nightly twice
  2. ship this PR as it is and solve it in a second time.
    Any pref?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if there are cases when users would like to use both ReplayBuffer and TensorDictReplayBuffer

It's better if they're consistent, even if they're not interchangeable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply (one needs to refresh the page to see new comments).
So which solution do we choose? You said that they should be consistent (returning 2 values by default from both TensorDictReplayBuffer and ReplayBuffer) but should change the sample() function in TensorDictReplayBuffer to return 1 value (which breaks consistency). Or shall we remove info from returned value of sample() also in the ReplayBuffer and return it conditionally there too? In this option we would return always one value and two values only if return_info=True.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the last approach (1 value by default in base and derived classes) is best.
If you agree, I'll make a new PR with this change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, one value by default unless return_info = True
Thanks for taking care of this!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated this PR with the required changes.

@vmoens vmoens changed the title Graduate Replay Buffer prototype [Refactor] Graduate Replay Buffer prototype Jan 5, 2023
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing! Can you look at the 2 comments I left?
Brilliant work I love it

@vmoens
Copy link
Contributor

vmoens commented Jan 6, 2023

the lint seems to be failing, can you fix that?

@KamilPiechowiak
Copy link
Contributor Author

the lint seems to be failing, can you fix that?

Linting fails because of two things:

  • there are type annotations 'Transform' as strings and linter doesn't like it. I did it this way to avoid circular references from envs.transforms. It is possible to remove this linter error by making a global variable Transform = 'Transform' and use it in annotations but it also doesn't look good. Maybe you have other ideas on how to solve it?
  • ReplayBuffer is imported in trainers.py. It is not used there, but it is later imported somewhere else. I'll find this import and import from torchrl.data.

@vmoens
Copy link
Contributor

vmoens commented Jan 6, 2023

  • there are type annotations 'Transform' as strings and linter doesn't like it. I did it this way to avoid circular references from envs.transforms. It is possible to remove this linter error by making a global variable Transform = 'Transform' and use it in annotations but it also doesn't look good. Maybe you have other ideas on how to solve it?

if that's the problem you can just put a # noqa-F821 comment at the end of the lines where the lint complains

@KamilPiechowiak
Copy link
Contributor Author

KamilPiechowiak commented Jan 6, 2023

Reopening. Removed merge conflicts.

@vmoens vmoens merged commit 569161e into pytorch:main Jan 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bc breaking backward compatibility breaking change CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants