Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add PrioritizedSliceSampler #1875

Merged
merged 11 commits into from
Feb 7, 2024

Conversation

Cadene
Copy link
Contributor

@Cadene Cadene commented Feb 6, 2024

Description

Add PrioritizedSliceSampler, a subclass of both PrioritizedSampler and SliceSampler. It allows to sample slices following some priority weights.

In contrast to the PrioritizedSampler, which selects steps based on individual priority weights, the PrioritizedSliceSampler selects slices with a corresponding priority weight for each. In its implementation, it focuses solely on sampling the start index of each slice while disregarding indices at the end of each episode, as they cannot form slices of sufficient length.

Differing from the SliceSampler, which initially samples trajectories and subsequently selects slices randomly within each trajectory, the PrioritizedSliceSampler directly targets slice sampling. This approach simplifies implementation, but it also means that, under the default uniform priority weights, the PrioritizedSliceSampler may tend to sample slices more frequently from longer trajectories than shorter ones. Nonetheless, the initial priority weights can be adjusted manually to reflect any prior in the sampling. Also, as priority weights are updated during training, the sampler should adjust, mitigating any oversampling of slices from longer trajectories over time.

Motivation and Context

This type of sampling is used in the literature:
https://github.com/fyhMer/fowm/blob/main/src/algorithm/helper.py#L504-L510
https://github.com/fyhMer/fowm/blob/main/src/algorithm/tdmpc.py#L334-L337

Addresses this feature request: #1876

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Documentation (update in the documentation)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Copy link

pytorch-bot bot commented Feb 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1875

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (16 Unrelated Failures)

As of commit 6e68587 with merge base b34e2d2 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 6, 2024
test/test_rb.py Show resolved Hide resolved
@vmoens
Copy link
Contributor

vmoens commented Feb 6, 2024

Amazing! Will do a proper review shortly. Any chance we can get the bugfix in a separate PR to put it in the minor release?

@vmoens vmoens added the enhancement New feature or request label Feb 6, 2024
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonderful! Thanks it's a super useful feature
I left a couple of comments, most are just suggestions so don't feel like you need to address them all.

If we can just move the bugfix to a separate PR, I'll merge that one straight away and then we can move to this.

torchrl/data/replay_buffers/samplers.py Outdated Show resolved Hide resolved
Comment on lines 1220 to 1221
starts = torch.from_numpy(starts).to(device=lengths.device)
index = self._tensor_slices_from_startend(seq_length, starts)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self: we should avoid returning numpy arrays and stick to torch

cc @albertbou92

torchrl/data/replay_buffers/samplers.py Outdated Show resolved Hide resolved
terminated_key: terminated,
}
)
return index.to(torch.long), info
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should check if there isn't a way to enforce this dtype earlier and avoid casting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I borrowed this logic from SliceSampler

return index.to(torch.long), {}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked on my side, and index is already a torch.long, so the cast is a no-op

torchrl/data/replay_buffers/samplers.py Outdated Show resolved Hide resolved
@Cadene Cadene changed the title Add PrioritizedSliceSampler + Small fix "traj_terminated" Add PrioritizedSliceSampler Feb 7, 2024
@Cadene
Copy link
Contributor Author

Cadene commented Feb 7, 2024

Amazing! Will do a proper review shortly. Any chance we can get the bugfix in a separate PR to put it in the minor release?

@vmoens Done ;) #1884

@vmoens vmoens changed the title Add PrioritizedSliceSampler [Feature] Add PrioritizedSliceSampler Feb 7, 2024
@vmoens vmoens merged commit 4d52d5f into pytorch:main Feb 7, 2024
52 of 68 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants