-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Refactor] Graduate Replay Buffer prototype #794
Conversation
README.md
Outdated
@@ -335,7 +334,7 @@ The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py) | |||
```python | |||
from torchrl.objectives import DQNLoss | |||
loss_module = DQNLoss(value_network=value_network, gamma=0.99) | |||
tensordict = replay_buffer.sample(batch_size) | |||
tensordict = replay_buffer.sample(batch_size)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's your view on this?
I feel that the [0]
index is bothersome. With TensorDictReplayBuffer
we read the second output from sample
(info
) and write the content in the tensordict.
Maybe we should just have a def sample(self, ..., return_info=False)
flag that can be turned on when needed.
Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also don't like the [0]
index. However, the version of TensorDictReplayBuffer
proposed in rb_prototype.py
also returns 2 values from sample()
. While migrating I decided to keep this logic. If we change it, we break consistency with parent class ReplayBuffer
. I don't know if there are cases when users would like to use both ReplayBuffer
and TensorDictReplayBuffer
interchangeably in a single piece of code. If this is not the case, we can stop returning info (or return it only if the flag is set).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I see two ways of going about this
- branch out your branch and have a
sample()
function with a single output, merge it after review in your branch and ship the whole thing. Advantage: we don't bc-break nightly twice - ship this PR as it is and solve it in a second time.
Any pref?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if there are cases when users would like to use both ReplayBuffer and TensorDictReplayBuffer
It's better if they're consistent, even if they're not interchangeable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late reply (one needs to refresh the page to see new comments).
So which solution do we choose? You said that they should be consistent (returning 2 values by default from both TensorDictReplayBuffer
and ReplayBuffer
) but should change the sample()
function in TensorDictReplayBuffer
to return 1 value (which breaks consistency). Or shall we remove info
from returned value of sample()
also in the ReplayBuffer
and return it conditionally there too? In this option we would return always one value and two values only if return_info=True
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that the last approach (1 value by default in base and derived classes) is best.
If you agree, I'll make a new PR with this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, one value by default unless return_info = True
Thanks for taking care of this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated this PR with the required changes.
Change returned values of ReplayBuffer sample method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing! Can you look at the 2 comments I left?
Brilliant work I love it
the lint seems to be failing, can you fix that? |
Linting fails because of two things:
|
if that's the problem you can just put a |
Reopening. Removed merge conflicts. |
Description
This change replaces ReplayBuffers implementations in replay_buffers.py by composable implementations from rb_prototype.py.
This is a breaking change. ReplayBuffer constructor no longer accepts size parameter. Now it is only provided to the underlying Storage. This change also makes sample return type consistent across ReplayBuffer and its derived classes. It is also a breaking change. That is why many files needed to be updated.
Motivation and Context
Why is this change required? What problem does it solve?
This change allows to create new types of replay buffers by composing samplers and writers.
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!