Skip to content

Commit

Permalink
[Feature] Preproc for datasets (pytorch#1989)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 18, 2024
1 parent 87f3437 commit 29d9a5b
Show file tree
Hide file tree
Showing 14 changed files with 824 additions and 60 deletions.
41 changes: 41 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -658,12 +658,53 @@ Here's an example:
the latest wheels are not published on PyPI. For OpenML, `scikit-learn <https://pypi.org/project/scikit-learn/>`_ and
`pandas <https://pypi.org/project/pandas>`_ are required.

Transforming datasets
~~~~~~~~~~~~~~~~~~~~~

In many instances, the raw data isn't going to be used as-is.
The natural solution could be to pass a :class:`~torchrl.envs.transforms.Transform`
instance to the dataset constructor and modify the sample on-the-fly. This will
work but it will incur an extra runtime for the transform.
If the transformations can be (at least a part) pre-applied to the dataset,
a conisderable disk space and some incurred overhead at sampling time can be
saved. To do this, the
:meth:`~torchrl.data.datasets.BaseDatasetExperienceReplay.preprocess` can be
used. This method will run a per-sample preprocessing pipeline on each element
of the dataset, and replace the existing dataset by its transformed version.

Once transformed, re-creating the same dataset will produce another object with
the same transformed storage (unless ``download="force"`` is being used):

>>> dataset = RobosetExperienceReplay(
... "FK1-v4(expert)/FK1_MicroOpenRandom_v2d-v4", batch_size=32, download="force"
... )
>>>
>>> def func(data):
... return data.set("obs_norm", data.get("observation").norm(dim=-1))
...
>>> dataset.preprocess(
... func,
... num_workers=max(1, os.cpu_count() - 2),
... num_chunks=1000,
... mp_start_method="fork",
... )
>>> sample = dataset.sample()
>>> assert "obs_norm" in sample.keys()
>>> # re-recreating the dataset gives us the transformed version back.
>>> dataset = RobosetExperienceReplay(
... "FK1-v4(expert)/FK1_MicroOpenRandom_v2d-v4", batch_size=32
... )
>>> sample = dataset.sample()
>>> assert "obs_norm" in sample.keys()


.. currentmodule:: torchrl.data.datasets

.. autosummary::
:toctree: generated/
:template: rl_template.rst

BaseDatasetExperienceReplay
AtariDQNExperienceReplay
D4RLExperienceReplay
GenDGRLExperienceReplay
Expand Down
Loading

0 comments on commit 29d9a5b

Please sign in to comment.