Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into move-examples
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Mar 18, 2024
2 parents b9ad0e2 + 29d9a5b commit 9ffeea2
Show file tree
Hide file tree
Showing 16 changed files with 830 additions and 63 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
tests-gpu:
strategy:
matrix:
python_version: ["3.8"]
python_version: ["3.10"]
cuda_arch_version: ["12.1"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
Expand Down Expand Up @@ -119,7 +119,7 @@ jobs:
tests-optdeps:
strategy:
matrix:
python_version: ["3.9"] # "3.8", "3.9", "3.10", "3.11"
python_version: ["3.10"] # "3.8", "3.9", "3.10", "3.11"
cuda_arch_version: ["12.1"] # "11.6", "11.7"
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
Expand Down Expand Up @@ -156,7 +156,7 @@ jobs:
tests-stable-gpu:
strategy:
matrix:
python_version: ["3.8"] # "3.8", "3.9", "3.10", "3.11"
python_version: ["3.10"] # "3.8", "3.9", "3.10", "3.11"
cuda_arch_version: ["11.8"] # "11.6", "11.7"
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
Expand Down
41 changes: 41 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -658,12 +658,53 @@ Here's an example:
the latest wheels are not published on PyPI. For OpenML, `scikit-learn <https://pypi.org/project/scikit-learn/>`_ and
`pandas <https://pypi.org/project/pandas>`_ are required.

Transforming datasets
~~~~~~~~~~~~~~~~~~~~~

In many instances, the raw data isn't going to be used as-is.
The natural solution could be to pass a :class:`~torchrl.envs.transforms.Transform`
instance to the dataset constructor and modify the sample on-the-fly. This will
work but it will incur an extra runtime for the transform.
If the transformations can be (at least a part) pre-applied to the dataset,
a conisderable disk space and some incurred overhead at sampling time can be
saved. To do this, the
:meth:`~torchrl.data.datasets.BaseDatasetExperienceReplay.preprocess` can be
used. This method will run a per-sample preprocessing pipeline on each element
of the dataset, and replace the existing dataset by its transformed version.

Once transformed, re-creating the same dataset will produce another object with
the same transformed storage (unless ``download="force"`` is being used):

>>> dataset = RobosetExperienceReplay(
... "FK1-v4(expert)/FK1_MicroOpenRandom_v2d-v4", batch_size=32, download="force"
... )
>>>
>>> def func(data):
... return data.set("obs_norm", data.get("observation").norm(dim=-1))
...
>>> dataset.preprocess(
... func,
... num_workers=max(1, os.cpu_count() - 2),
... num_chunks=1000,
... mp_start_method="fork",
... )
>>> sample = dataset.sample()
>>> assert "obs_norm" in sample.keys()
>>> # re-recreating the dataset gives us the transformed version back.
>>> dataset = RobosetExperienceReplay(
... "FK1-v4(expert)/FK1_MicroOpenRandom_v2d-v4", batch_size=32
... )
>>> sample = dataset.sample()
>>> assert "obs_norm" in sample.keys()


.. currentmodule:: torchrl.data.datasets

.. autosummary::
:toctree: generated/
:template: rl_template.rst

BaseDatasetExperienceReplay
AtariDQNExperienceReplay
D4RLExperienceReplay
GenDGRLExperienceReplay
Expand Down
Loading

0 comments on commit 9ffeea2

Please sign in to comment.