Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] buffer __iter__ for samplers without replacement + prefetch #2185

Merged
merged 5 commits into from
Jun 11, 2024

Conversation

JulianKu
Copy link
Contributor

Reopen #2178 (see #2182)

Description

Adds loop termination condition in replay buffer __iter__ method for yielding samples in prefetch queue and another condition in its sample method for only adding new jobs to the prefetch queue if the sampler has not yet run out.

Motivation and Context

For replay buffers with prefetch enabled while having a sampler without replacement (that can "run out"), the buffer's __iter__ method will not yield all samples. This is because its while loop terminates on sampler ran_out which is set as soon as a prefetch job is executed that triggers running out. However, there might still be prefetch jobs in the queue left that could yield samples, therefore their samples are simply skipped.

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

For replay buffers with prefetch enabled while having a sampler without replacement (that can "run out"), the buffer's `__iter__` method will not yield all samples. This is because its while loop terminates on sampler `ran_out` which is set as soon as a prefetch job is executed that triggers running out. However, there might still be prefetch jobs in the queue that could yield samples. Therefore, new flags are added for adding jobs to the prefetch queue and for terminating the yield loop only if available prefetch jobs are completed
Copy link

pytorch-bot bot commented May 31, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2185

Note: Links to docs will display an error until the docs builds have been completed.

❌ 5 New Failures, 14 Unrelated Failures

As of commit 46ddf76 with merge base 672b50e (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 31, 2024
@JulianKu
Copy link
Contributor Author

JulianKu commented May 31, 2024

@vmoens Somehow, for me the example you provided in #2182 seems to run fine (at least no errors apparent).
I set-up a simple Github codespace for that...

@JulianKu ➜ /workspaces/torchrl (fix/iter-replacement-prefetch) $ python sota-implementations/iql/iql_offline.py optim.gradient_steps=55 logger.backend=

/workspaces/torchrl/sota-implementations/iql/iql_offline.py:35: UserWarning: 
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path="", config_name="offline_config")
/usr/local/python/3.10.13/lib/python3.10/site-packages/hydra/_internal/hydra.py:119: UserWarning: Future Hydra versions will no longer change working directory at job runtime by default.
See https://hydra.cc/docs/1.2/upgrades/1.1_to_1.2/changes_to_job_working_dir/ for more information.
  ret = run_job(
/workspaces/torchrl/sota-implementations/iql/iql_offline.py:35: UserWarning: 
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path="", config_name="offline_config")
/workspaces/torchrl/sota-implementations/iql/iql_offline.py:35: UserWarning: 
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path="", config_name="offline_config")
/workspaces/torchrl/sota-implementations/iql/iql_offline.py:35: UserWarning: 
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path="", config_name="offline_config")
/workspaces/torchrl/sota-implementations/iql/iql_offline.py:35: UserWarning: 
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path="", config_name="offline_config")
/workspaces/torchrl/sota-implementations/iql/iql_offline.py:35: UserWarning: 
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path="", config_name="offline_config")
2024-05-31 08:05:06,428 [torchrl][INFO] Downloading dataset: http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_medium-v2.hdf5 to /tmp/tmp9rza7qdi/halfcheetah_medium-v2.hdf5
/home/codespace/.local/lib/python3.10/site-packages/torch/nn/modules/lazy.py:181: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 55/55 [00:02<00:00, 18.35it/s]

@vmoens
Copy link
Contributor

vmoens commented May 31, 2024

I still get the error.

Besides, I think that this won't really solve it because self._sampler.ran_out will only be True when the sample is gathered, which in the case of prefetch>0 will only occur once the sample is accessible (so we still can have samples coming in after the sampler has ran out).

Making a step back, what we want to avoid is that _sample is reached when the sampler has ran out, not because it raises an exception but because it will just reset the ran_out attribute and restart sampling.
So, what I would suggest is this:

  • We create an auto_reset attribute in the sampler that is checked whenever a sample is drawn. If auto_reset=False and ran_out=True, the sampler returns a special value (e.g. StopIteration)
  • Whenever we get a StopIteration we tell the __iter__ to stop
  • Next time we call __iter__ we reset the sampler

I gave an attempt here, we still have the issue that there can be run conditions between the workers so we need to guard the accessed to ran_out and auto_reset with a lock I think.

@vmoens
Copy link
Contributor

vmoens commented Jun 11, 2024

I believe this PR now solves this problem

@vmoens vmoens merged commit 1029f10 into pytorch:main Jun 11, 2024
28 of 43 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants