-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add PrioritizedSliceSampler #1875
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1875
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (16 Unrelated Failures)As of commit 6e68587 with merge base b34e2d2 (): FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Amazing! Will do a proper review shortly. Any chance we can get the bugfix in a separate PR to put it in the minor release? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonderful! Thanks it's a super useful feature
I left a couple of comments, most are just suggestions so don't feel like you need to address them all.
If we can just move the bugfix to a separate PR, I'll merge that one straight away and then we can move to this.
starts = torch.from_numpy(starts).to(device=lengths.device) | ||
index = self._tensor_slices_from_startend(seq_length, starts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note to self: we should avoid returning numpy arrays and stick to torch
cc @albertbou92
terminated_key: terminated, | ||
} | ||
) | ||
return index.to(torch.long), info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should check if there isn't a way to enforce this dtype earlier and avoid casting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI I borrowed this logic from SliceSampler
rl/torchrl/data/replay_buffers/samplers.py
Line 932 in 144f547
return index.to(torch.long), {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked on my side, and index is already a torch.long, so the cast is a no-op
…ioritized_slice_sampler
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Description
Add PrioritizedSliceSampler, a subclass of both PrioritizedSampler and SliceSampler. It allows to sample slices following some priority weights.
In contrast to the PrioritizedSampler, which selects steps based on individual priority weights, the PrioritizedSliceSampler selects slices with a corresponding priority weight for each. In its implementation, it focuses solely on sampling the start index of each slice while disregarding indices at the end of each episode, as they cannot form slices of sufficient length.
Differing from the SliceSampler, which initially samples trajectories and subsequently selects slices randomly within each trajectory, the PrioritizedSliceSampler directly targets slice sampling. This approach simplifies implementation, but it also means that, under the default uniform priority weights, the PrioritizedSliceSampler may tend to sample slices more frequently from longer trajectories than shorter ones. Nonetheless, the initial priority weights can be adjusted manually to reflect any prior in the sampling. Also, as priority weights are updated during training, the sampler should adjust, mitigating any oversampling of slices from longer trajectories over time.
Motivation and Context
This type of sampling is used in the literature:
https://github.com/fyhMer/fowm/blob/main/src/algorithm/helper.py#L504-L510
https://github.com/fyhMer/fowm/blob/main/src/algorithm/tdmpc.py#L334-L337
Addresses this feature request: #1876
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!