From 81945657a0b8677f7258ff0cbe2f648617501a75 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jan 2024 20:51:26 +0000 Subject: [PATCH] [Feature] Immutable writer for datasets (#1781) --- examples/cql/utils.py | 2 +- examples/decision_transformer/utils.py | 39 +++++++++------ examples/iql/utils.py | 2 +- test/test_libs.py | 27 +++++----- torchrl/data/datasets/d4rl.py | 40 +++++++-------- torchrl/data/datasets/minari_data.py | 11 ++++- torchrl/data/datasets/openml.py | 54 ++++++++++++++------ torchrl/data/datasets/openx.py | 68 ++++++++++++++++++++++---- torchrl/data/datasets/roboset.py | 6 ++- torchrl/data/datasets/vd4rl.py | 6 ++- 10 files changed, 173 insertions(+), 82 deletions(-) diff --git a/examples/cql/utils.py b/examples/cql/utils.py index f14d3784577..828b370a559 100644 --- a/examples/cql/utils.py +++ b/examples/cql/utils.py @@ -152,7 +152,7 @@ def make_replay_buffer( def make_offline_replay_buffer(rb_cfg): data = D4RLExperienceReplay( - rb_cfg.dataset, + dataset_id=rb_cfg.dataset, split_trajs=False, batch_size=rb_cfg.batch_size, sampler=SamplerWithoutReplacement(drop_last=False), diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 5232901a114..7cb5b52b6ea 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -10,7 +10,12 @@ from tensordict.nn import TensorDictModule from torchrl.collectors import SyncDataCollector -from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data import ( + LazyMemmapStorage, + RoundRobinWriter, + TensorDictReplayBuffer, + TensorStorage, +) from torchrl.data.datasets.d4rl import D4RLExperienceReplay from torchrl.data.replay_buffers import RandomSampler from torchrl.envs import ( @@ -234,33 +239,35 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): exclude, ) data = D4RLExperienceReplay( - rb_cfg.dataset, + dataset_id=rb_cfg.dataset, split_trajs=True, batch_size=rb_cfg.batch_size, sampler=RandomSampler(), # SamplerWithoutReplacement(drop_last=False), - transform=transforms, + transform=None, use_truncated_as_done=True, direct_download=True, prefetch=4, + writer=RoundRobinWriter(), ) - loc = ( - data._storage._storage.get(("_data", "observation")) - .flatten(0, -2) - .mean(axis=0) - .float() - ) - std = ( - data._storage._storage.get(("_data", "observation")) - .flatten(0, -2) - .std(axis=0) - .float() - ) + + # since we're not extending the data, adding keys can only be done via + # the creation of a new storage + data_memmap = data[:] + with data_memmap.unlock_(): + data_memmap = r2g.inv(data_memmap) + data._storage = TensorStorage(data_memmap) + + loc = data[:]["observation"].flatten(0, -2).mean(axis=0).float() + std = data[:]["observation"].flatten(0, -2).std(axis=0).float() + obsnorm = ObservationNorm( loc=loc, scale=std, in_keys=["observation_cat", ("next", "observation_cat")], standard_normal=True, ) + for t in transforms: + data.append_transform(t) data.append_transform(obsnorm) return data, loc, std @@ -300,7 +307,7 @@ def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001): batch_size=rb_cfg.batch_size, ) # init buffer with offline data - offline_data = offline_buffer.sample(100000) + offline_data = offline_buffer[:100000] offline_data.del_("index") replay_buffer.extend(offline_data.clone().detach().to_tensordict()) # add transforms after offline data extension to not trigger reward-to-go calculation diff --git a/examples/iql/utils.py b/examples/iql/utils.py index 1dff1c7bd34..e69cf45a0cd 100644 --- a/examples/iql/utils.py +++ b/examples/iql/utils.py @@ -133,7 +133,7 @@ def make_replay_buffer( def make_offline_replay_buffer(rb_cfg): data = D4RLExperienceReplay( - name=rb_cfg.dataset, + dataset_id=rb_cfg.dataset, split_trajs=False, batch_size=rb_cfg.batch_size, sampler=SamplerWithoutReplacement(drop_last=False), diff --git a/test/test_libs.py b/test/test_libs.py index 4c6adf19180..51092ca6618 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1873,28 +1873,25 @@ def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs, tmpdir root=root3, ) if not use_truncated_as_done: - keys = set(data_from_env._storage._storage.keys(True, True)) - keys = keys.intersection(data_true._storage._storage.keys(True, True)) - assert ( - data_true._storage._storage.shape - == data_from_env._storage._storage.shape - ) + keys = set(data_from_env[:].keys(True, True)) + keys = keys.intersection(data_true[:].keys(True, True)) + assert data_true[:].shape == data_from_env[:].shape # for some reason, qlearning_dataset overwrites the next obs that is contained in the buffer, # resulting in tiny changes in the value contained for that key. Over 99.99% of the values # match, but the test still fails because of this. # We exclude that entry from the comparison. - keys.discard(("_data", "next", "observation")) + keys.discard(("next", "observation")) assert_allclose_td( - data_true._storage._storage.select(*keys), - data_from_env._storage._storage.select(*keys), + data_true[:].select(*keys), + data_from_env[:].select(*keys), ) else: - leaf_names = data_from_env._storage._storage.keys(True) + leaf_names = data_from_env[:].keys(True) leaf_names = [ name[-1] if isinstance(name, tuple) else name for name in leaf_names ] assert "truncated" in leaf_names - leaf_names = data_true._storage._storage.keys(True) + leaf_names = data_true[:].keys(True) leaf_names = [ name[-1] if isinstance(name, tuple) else name for name in leaf_names ] @@ -1925,12 +1922,12 @@ def test_direct_download(self, task, tmpdir): download="force", root=root2, ) - keys = set(data_direct._storage._storage.keys(True, True)) - keys = keys.intersection(data_d4rl._storage._storage.keys(True, True)) + keys = set(data_direct[:].keys(True, True)) + keys = keys.intersection(data_d4rl[:].keys(True, True)) assert len(keys) assert_allclose_td( - data_direct._storage._storage.select(*keys).apply(lambda t: t.float()), - data_d4rl._storage._storage.select(*keys).apply(lambda t: t.float()), + data_direct[:].select(*keys).apply(lambda t: t.float()), + data_d4rl[:].select(*keys).apply(lambda t: t.float()), ) @pytest.mark.parametrize( diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index 38fce4a6b7c..a7f4fb0b198 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -24,10 +24,10 @@ from torchrl.data.datasets.d4rl_infos import D4RL_DATASETS from torchrl.data.datasets.utils import _get_root_dir -from torchrl.data.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import Sampler -from torchrl.data.replay_buffers.storages import LazyMemmapStorage, TensorStorage -from torchrl.data.replay_buffers.writers import Writer +from torchrl.data.replay_buffers.storages import TensorStorage +from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer class D4RLExperienceReplay(TensorDictReplayBuffer): @@ -45,7 +45,7 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): the ``("next", "observation")`` of ``"done"`` states are zeroed. Args: - name (str): the name of the D4RL env to get the data from. + dataset_id (str): the dataset_id of the D4RL env to get the data from. batch_size (int): the batch size to use during sampling. sampler (Sampler, optional): the sampler to be used. If none is provided a default RandomSampler() will be used. @@ -135,7 +135,7 @@ def _import_d4rl(cls): def __init__( self, - name, + dataset_id, batch_size: int, sampler: Sampler | None = None, writer: Writer | None = None, @@ -155,9 +155,8 @@ def __init__( self.use_truncated_as_done = use_truncated_as_done if root is None: root = _get_root_dir("d4rl") - self.root = root - self.name = name - dataset = None + self.root = Path(root) + self.dataset_id = dataset_id if not from_env and direct_download is None: self._import_d4rl() @@ -200,7 +199,7 @@ def __init__( raise ImportError("Could not import d4rl") from self.D4RL_ERR if from_env: - dataset = self._get_dataset_from_env(name, env_kwargs) + dataset = self._get_dataset_from_env(dataset_id, env_kwargs) else: if self.use_truncated_as_done: warnings.warn( @@ -210,13 +209,13 @@ def __init__( "can be absent from the static dataset." ) env_kwargs.update({"terminate_on_end": terminate_on_end}) - dataset = self._get_dataset_direct(name, env_kwargs) + dataset = self._get_dataset_direct(dataset_id, env_kwargs) else: if terminate_on_end is False: raise ValueError( "Using terminate_on_end=False is not compatible with direct_download=True." ) - dataset = self._get_dataset_direct_download(name, env_kwargs) + dataset = self._get_dataset_direct_download(dataset_id, env_kwargs) # Fill unknown next states with 0 dataset["next", "observation"][dataset["next", "done"].squeeze()] = 0 @@ -224,16 +223,16 @@ def __init__( dataset = split_trajectories(dataset) dataset["next", "done"][:, -1] = True - storage = LazyMemmapStorage( - dataset.shape[0], scratch_dir=Path(self.root) / name - ) + storage = TensorStorage(dataset.memmap(self._dataset_path)) elif self._is_downloaded(): - storage = TensorStorage(TensorDict.load_memmap(Path(self.root) / name)) + storage = TensorStorage(TensorDict.load_memmap(self._dataset_path)) else: raise RuntimeError( - f"The dataset could not be found in {Path(self.root) / name}." + f"The dataset could not be found in {self._dataset_path}." ) + if writer is None: + writer = ImmutableDatasetWriter() super().__init__( batch_size=batch_size, storage=storage, @@ -244,12 +243,13 @@ def __init__( prefetch=prefetch, transform=transform, ) - if dataset is not None: - # if dataset has just been downloaded - self.extend(dataset) + + @property + def _dataset_path(self): + return Path(self.root) / self.dataset_id def _is_downloaded(self): - return os.path.exists(Path(self.root) / self.name) + return os.path.exists(self._dataset_path) def _get_dataset_direct_download(self, name, env_kwargs): """Directly download and use a D4RL dataset.""" diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 754d5da9865..9566c1eff10 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -24,7 +24,7 @@ from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import Sampler from torchrl.data.replay_buffers.storages import TensorStorage -from torchrl.data.replay_buffers.writers import Writer +from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer from torchrl.data.tensor_specs import ( BoundedTensorSpec, CompositeSpec, @@ -34,6 +34,7 @@ from torchrl.envs.utils import _classproperty _has_tqdm = importlib.util.find_spec("tqdm", None) is not None +_has_minari = importlib.util.find_spec("minari", None) is not None _NAME_MATCH = KeyDependentDefaultDict(lambda key: key) _NAME_MATCH["observations"] = "observation" @@ -193,6 +194,10 @@ def __init__( else: storage = self._load() storage = TensorStorage(storage) + + if writer is None: + writer = ImmutableDatasetWriter() + super().__init__( storage=storage, sampler=sampler, @@ -206,6 +211,8 @@ def __init__( @_classproperty def available_datasets(self): + if not _has_minari: + raise ImportError("minari library not found.") import minari return minari.list_remote_datasets().keys() @@ -228,6 +235,8 @@ def metadata_path(self): return Path(self.root) / self.dataset_id / "env_metadata.json" def _download_and_preproc(self): + if not _has_minari: + raise ImportError("minari library not found.") import minari if _has_tqdm: diff --git a/torchrl/data/datasets/openml.py b/torchrl/data/datasets/openml.py index cfb26b0c2f7..a9c359d9fe6 100644 --- a/torchrl/data/datasets/openml.py +++ b/torchrl/data/datasets/openml.py @@ -2,16 +2,23 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations -from typing import Callable, Optional +import os +from pathlib import Path +from typing import Callable import numpy as np from tensordict.tensordict import TensorDict -from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer -from torchrl.data.replay_buffers.samplers import Sampler, SamplerWithoutReplacement -from torchrl.data.replay_buffers.storages import LazyMemmapStorage -from torchrl.data.replay_buffers.writers import Writer +from torchrl.data.datasets.utils import _get_root_dir +from torchrl.data.replay_buffers import ( + Sampler, + SamplerWithoutReplacement, + TensorDictReplayBuffer, + TensorStorage, + Writer, +) class OpenMLExperienceReplay(TensorDictReplayBuffer): @@ -52,23 +59,32 @@ def __init__( self, name: str, batch_size: int, - sampler: Optional[Sampler] = None, - writer: Optional[Writer] = None, - collate_fn: Optional[Callable] = None, + root: Path | None = None, + sampler: Sampler | None = None, + writer: Writer | None = None, + collate_fn: Callable | None = None, pin_memory: bool = False, - prefetch: Optional[int] = None, - transform: Optional["Transform"] = None, # noqa-F821 + prefetch: int | None = None, + transform: "Transform" | None = None, # noqa-F821 ): if sampler is None: sampler = SamplerWithoutReplacement() + if root is None: + root = _get_root_dir("openml") + self.root = Path(root) + self.dataset_id = name + + if not self._is_downloaded(): + dataset = self._get_data( + name, + ) + storage = TensorStorage(dataset.memmap(self._dataset_path)) + else: + dataset = TensorDict.load_memmap(self._dataset_path) + storage = TensorStorage(dataset) - dataset = self._get_data( - name, - ) self.max_outcome_val = dataset["y"].max().item() - - storage = LazyMemmapStorage(dataset.shape[0]) super().__init__( batch_size=batch_size, storage=storage, @@ -79,7 +95,13 @@ def __init__( prefetch=prefetch, transform=transform, ) - self.extend(dataset) + + @property + def _dataset_path(self): + return self.root / self.dataset_id + + def _is_downloaded(self): + return os.path.exists(self._dataset_path) @classmethod def _get_data(cls, dataset_name): diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index aa78a92ff16..d964531e4e8 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -6,10 +6,11 @@ import importlib.util import io +import json import os import tempfile from pathlib import Path -from typing import Any, Callable, Tuple +from typing import Any, Callable, Dict, Tuple import torch @@ -35,6 +36,10 @@ class OpenXExperienceReplay(TensorDictReplayBuffer): spanning 22 robot embodiments, collected through a collaboration between 21 institutions, demonstrating 527 skills (160266 tasks). + Website: https://robotics-transformer-x.github.io/ + GitHub: https://github.com/google-deepmind/open_x_embodiment + Paper: https://arxiv.org/abs/2310.08864 + .. note:: Non-tensor data will be written in the tensordict data using the :class:`~tensordict.tensorclass.NonTensorData` primitive. @@ -66,7 +71,7 @@ class for more information on how to interact with non-tensor data sampler is set to ``False`` if they wish to enjoy the two different behaviours (shuffled and not) within the same code base. - num_slice (int, optional): the number of slices in a batch. This + num_slices (int, optional): the number of slices in a batch. This corresponds to the number of trajectories present in a batch. Once collected, the batch is presented as a concatenation of sub-trajectories that can be recovered through `batch.reshape(num_slices, -1)`. @@ -527,22 +532,31 @@ def __init__( slice_len=None, pad=None, ): + self.shuffle = shuffle + self.dataset_id = dataset_id + self.repo = repo + self.split = split + self._init() + self.base_path = base_path + self.truncate = truncate + self.num_slices = num_slices + self.slice_len = slice_len + self.pad = pad + + def _init(self): if not _has_datasets: raise ImportError( - f"the `datasets` library is required for the dataset {dataset_id}." + f"the `datasets` library is required for the dataset {self.dataset_id}." ) import datasets - dataset = datasets.load_dataset(repo, dataset_id, streaming=True, split=split) - if shuffle: + dataset = datasets.load_dataset( + self.repo, self.dataset_id, streaming=True, split=self.split + ) + if self.shuffle: dataset = dataset.shuffle() self.dataset = dataset self.dataset_iter = iter(dataset) - self.base_path = base_path - self.truncate = truncate - self.num_slices = num_slices - self.slice_len = slice_len - self.pad = pad def __iter__(self): episode = 0 @@ -607,6 +621,34 @@ def get(self, index: int) -> Any: return data[: index.stop] return data + def dumps(self, path): + path = Path(path) + state_dict = self.state_dict() + json.dump(state_dict, path / "state_dict.json") + + def state_dict(self) -> Dict[str, Any]: + return { + "repo": self.repo, + "split": self.split, + "dataset_id": self.dataset_id, + "shuffle": self.shuffle, + "base_path": self.base_path, + "truncated": self.truncate, + "num_slices": self.num_slices, + "slice_len": self.slice_len, + "pad": self.pad, + } + + def loads(self, path): + path = Path(path) + state_dict = json.load(path / "state_dict.json") + self.load_state_dict(state_dict) + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + for key, val in state_dict.items(): + setattr(self, key, val) + self._init() + def __len__(self): raise RuntimeError( f"{type(self)} does not have a length. Use a downloaded dataset to " @@ -662,6 +704,12 @@ def dumps(self, path): def loads(self, path): ... + def state_dict(self) -> Dict[str, Any]: + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + ... + OPENX_KEY_MAP = { "is_first": "is_init", diff --git a/torchrl/data/datasets/roboset.py b/torchrl/data/datasets/roboset.py index 6e9a9bb23f7..62e4e41e982 100644 --- a/torchrl/data/datasets/roboset.py +++ b/torchrl/data/datasets/roboset.py @@ -21,7 +21,7 @@ from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import Sampler from torchrl.data.replay_buffers.storages import TensorStorage -from torchrl.data.replay_buffers.writers import Writer +from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer _has_tqdm = importlib.util.find_spec("tqdm", None) is not None _has_h5py = importlib.util.find_spec("h5py", None) is not None @@ -190,6 +190,10 @@ def __init__( else: storage = self._load() storage = TensorStorage(storage) + + if writer is None: + writer = ImmutableDatasetWriter() + super().__init__( storage=storage, sampler=sampler, diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py index 815a00ca687..b3313ef8812 100644 --- a/torchrl/data/datasets/vd4rl.py +++ b/torchrl/data/datasets/vd4rl.py @@ -25,7 +25,7 @@ from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import Sampler from torchrl.data.replay_buffers.storages import TensorStorage -from torchrl.data.replay_buffers.writers import Writer +from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer from torchrl.envs.transforms import Compose, Resize, ToTensorImage from torchrl.envs.utils import _classproperty @@ -220,6 +220,10 @@ def __init__( transform, Resize(image_size, in_keys=["pixels", ("next", "pixels")]) ) storage = TensorStorage(storage) + + if writer is None: + writer = ImmutableDatasetWriter() + super().__init__( storage=storage, sampler=sampler,