Skip to content

Commit

Permalink
[Feature] Immutable writer for datasets (pytorch#1781)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 9, 2024
1 parent 58571f0 commit 8194565
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 82 deletions.
2 changes: 1 addition & 1 deletion examples/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def make_replay_buffer(

def make_offline_replay_buffer(rb_cfg):
data = D4RLExperienceReplay(
rb_cfg.dataset,
dataset_id=rb_cfg.dataset,
split_trajs=False,
batch_size=rb_cfg.batch_size,
sampler=SamplerWithoutReplacement(drop_last=False),
Expand Down
39 changes: 23 additions & 16 deletions examples/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from tensordict.nn import TensorDictModule

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data import (
LazyMemmapStorage,
RoundRobinWriter,
TensorDictReplayBuffer,
TensorStorage,
)
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
from torchrl.data.replay_buffers import RandomSampler
from torchrl.envs import (
Expand Down Expand Up @@ -234,33 +239,35 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling):
exclude,
)
data = D4RLExperienceReplay(
rb_cfg.dataset,
dataset_id=rb_cfg.dataset,
split_trajs=True,
batch_size=rb_cfg.batch_size,
sampler=RandomSampler(), # SamplerWithoutReplacement(drop_last=False),
transform=transforms,
transform=None,
use_truncated_as_done=True,
direct_download=True,
prefetch=4,
writer=RoundRobinWriter(),
)
loc = (
data._storage._storage.get(("_data", "observation"))
.flatten(0, -2)
.mean(axis=0)
.float()
)
std = (
data._storage._storage.get(("_data", "observation"))
.flatten(0, -2)
.std(axis=0)
.float()
)

# since we're not extending the data, adding keys can only be done via
# the creation of a new storage
data_memmap = data[:]
with data_memmap.unlock_():
data_memmap = r2g.inv(data_memmap)
data._storage = TensorStorage(data_memmap)

loc = data[:]["observation"].flatten(0, -2).mean(axis=0).float()
std = data[:]["observation"].flatten(0, -2).std(axis=0).float()

obsnorm = ObservationNorm(
loc=loc,
scale=std,
in_keys=["observation_cat", ("next", "observation_cat")],
standard_normal=True,
)
for t in transforms:
data.append_transform(t)
data.append_transform(obsnorm)
return data, loc, std

Expand Down Expand Up @@ -300,7 +307,7 @@ def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001):
batch_size=rb_cfg.batch_size,
)
# init buffer with offline data
offline_data = offline_buffer.sample(100000)
offline_data = offline_buffer[:100000]
offline_data.del_("index")
replay_buffer.extend(offline_data.clone().detach().to_tensordict())
# add transforms after offline data extension to not trigger reward-to-go calculation
Expand Down
2 changes: 1 addition & 1 deletion examples/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def make_replay_buffer(

def make_offline_replay_buffer(rb_cfg):
data = D4RLExperienceReplay(
name=rb_cfg.dataset,
dataset_id=rb_cfg.dataset,
split_trajs=False,
batch_size=rb_cfg.batch_size,
sampler=SamplerWithoutReplacement(drop_last=False),
Expand Down
27 changes: 12 additions & 15 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,28 +1873,25 @@ def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs, tmpdir
root=root3,
)
if not use_truncated_as_done:
keys = set(data_from_env._storage._storage.keys(True, True))
keys = keys.intersection(data_true._storage._storage.keys(True, True))
assert (
data_true._storage._storage.shape
== data_from_env._storage._storage.shape
)
keys = set(data_from_env[:].keys(True, True))
keys = keys.intersection(data_true[:].keys(True, True))
assert data_true[:].shape == data_from_env[:].shape
# for some reason, qlearning_dataset overwrites the next obs that is contained in the buffer,
# resulting in tiny changes in the value contained for that key. Over 99.99% of the values
# match, but the test still fails because of this.
# We exclude that entry from the comparison.
keys.discard(("_data", "next", "observation"))
keys.discard(("next", "observation"))
assert_allclose_td(
data_true._storage._storage.select(*keys),
data_from_env._storage._storage.select(*keys),
data_true[:].select(*keys),
data_from_env[:].select(*keys),
)
else:
leaf_names = data_from_env._storage._storage.keys(True)
leaf_names = data_from_env[:].keys(True)
leaf_names = [
name[-1] if isinstance(name, tuple) else name for name in leaf_names
]
assert "truncated" in leaf_names
leaf_names = data_true._storage._storage.keys(True)
leaf_names = data_true[:].keys(True)
leaf_names = [
name[-1] if isinstance(name, tuple) else name for name in leaf_names
]
Expand Down Expand Up @@ -1925,12 +1922,12 @@ def test_direct_download(self, task, tmpdir):
download="force",
root=root2,
)
keys = set(data_direct._storage._storage.keys(True, True))
keys = keys.intersection(data_d4rl._storage._storage.keys(True, True))
keys = set(data_direct[:].keys(True, True))
keys = keys.intersection(data_d4rl[:].keys(True, True))
assert len(keys)
assert_allclose_td(
data_direct._storage._storage.select(*keys).apply(lambda t: t.float()),
data_d4rl._storage._storage.select(*keys).apply(lambda t: t.float()),
data_direct[:].select(*keys).apply(lambda t: t.float()),
data_d4rl[:].select(*keys).apply(lambda t: t.float()),
)

@pytest.mark.parametrize(
Expand Down
40 changes: 20 additions & 20 deletions torchrl/data/datasets/d4rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from torchrl.data.datasets.d4rl_infos import D4RL_DATASETS

from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage, TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer


class D4RLExperienceReplay(TensorDictReplayBuffer):
Expand All @@ -45,7 +45,7 @@ class D4RLExperienceReplay(TensorDictReplayBuffer):
the ``("next", "observation")`` of ``"done"`` states are zeroed.
Args:
name (str): the name of the D4RL env to get the data from.
dataset_id (str): the dataset_id of the D4RL env to get the data from.
batch_size (int): the batch size to use during sampling.
sampler (Sampler, optional): the sampler to be used. If none is provided
a default RandomSampler() will be used.
Expand Down Expand Up @@ -135,7 +135,7 @@ def _import_d4rl(cls):

def __init__(
self,
name,
dataset_id,
batch_size: int,
sampler: Sampler | None = None,
writer: Writer | None = None,
Expand All @@ -155,9 +155,8 @@ def __init__(
self.use_truncated_as_done = use_truncated_as_done
if root is None:
root = _get_root_dir("d4rl")
self.root = root
self.name = name
dataset = None
self.root = Path(root)
self.dataset_id = dataset_id

if not from_env and direct_download is None:
self._import_d4rl()
Expand Down Expand Up @@ -200,7 +199,7 @@ def __init__(
raise ImportError("Could not import d4rl") from self.D4RL_ERR

if from_env:
dataset = self._get_dataset_from_env(name, env_kwargs)
dataset = self._get_dataset_from_env(dataset_id, env_kwargs)
else:
if self.use_truncated_as_done:
warnings.warn(
Expand All @@ -210,30 +209,30 @@ def __init__(
"can be absent from the static dataset."
)
env_kwargs.update({"terminate_on_end": terminate_on_end})
dataset = self._get_dataset_direct(name, env_kwargs)
dataset = self._get_dataset_direct(dataset_id, env_kwargs)
else:
if terminate_on_end is False:
raise ValueError(
"Using terminate_on_end=False is not compatible with direct_download=True."
)
dataset = self._get_dataset_direct_download(name, env_kwargs)
dataset = self._get_dataset_direct_download(dataset_id, env_kwargs)
# Fill unknown next states with 0
dataset["next", "observation"][dataset["next", "done"].squeeze()] = 0

if split_trajs:
dataset = split_trajectories(dataset)
dataset["next", "done"][:, -1] = True

storage = LazyMemmapStorage(
dataset.shape[0], scratch_dir=Path(self.root) / name
)
storage = TensorStorage(dataset.memmap(self._dataset_path))
elif self._is_downloaded():
storage = TensorStorage(TensorDict.load_memmap(Path(self.root) / name))
storage = TensorStorage(TensorDict.load_memmap(self._dataset_path))
else:
raise RuntimeError(
f"The dataset could not be found in {Path(self.root) / name}."
f"The dataset could not be found in {self._dataset_path}."
)

if writer is None:
writer = ImmutableDatasetWriter()
super().__init__(
batch_size=batch_size,
storage=storage,
Expand All @@ -244,12 +243,13 @@ def __init__(
prefetch=prefetch,
transform=transform,
)
if dataset is not None:
# if dataset has just been downloaded
self.extend(dataset)

@property
def _dataset_path(self):
return Path(self.root) / self.dataset_id

def _is_downloaded(self):
return os.path.exists(Path(self.root) / self.name)
return os.path.exists(self._dataset_path)

def _get_dataset_direct_download(self, name, env_kwargs):
"""Directly download and use a D4RL dataset."""
Expand Down
11 changes: 10 additions & 1 deletion torchrl/data/datasets/minari_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from torchrl.data.tensor_specs import (
BoundedTensorSpec,
CompositeSpec,
Expand All @@ -34,6 +34,7 @@
from torchrl.envs.utils import _classproperty

_has_tqdm = importlib.util.find_spec("tqdm", None) is not None
_has_minari = importlib.util.find_spec("minari", None) is not None

_NAME_MATCH = KeyDependentDefaultDict(lambda key: key)
_NAME_MATCH["observations"] = "observation"
Expand Down Expand Up @@ -193,6 +194,10 @@ def __init__(
else:
storage = self._load()
storage = TensorStorage(storage)

if writer is None:
writer = ImmutableDatasetWriter()

super().__init__(
storage=storage,
sampler=sampler,
Expand All @@ -206,6 +211,8 @@ def __init__(

@_classproperty
def available_datasets(self):
if not _has_minari:
raise ImportError("minari library not found.")
import minari

return minari.list_remote_datasets().keys()
Expand All @@ -228,6 +235,8 @@ def metadata_path(self):
return Path(self.root) / self.dataset_id / "env_metadata.json"

def _download_and_preproc(self):
if not _has_minari:
raise ImportError("minari library not found.")
import minari

if _has_tqdm:
Expand Down
54 changes: 38 additions & 16 deletions torchrl/data/datasets/openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from typing import Callable, Optional
import os
from pathlib import Path
from typing import Callable

import numpy as np
from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import Sampler, SamplerWithoutReplacement

from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.data.replay_buffers.writers import Writer
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers import (
Sampler,
SamplerWithoutReplacement,
TensorDictReplayBuffer,
TensorStorage,
Writer,
)


class OpenMLExperienceReplay(TensorDictReplayBuffer):
Expand Down Expand Up @@ -52,23 +59,32 @@ def __init__(
self,
name: str,
batch_size: int,
sampler: Optional[Sampler] = None,
writer: Optional[Writer] = None,
collate_fn: Optional[Callable] = None,
root: Path | None = None,
sampler: Sampler | None = None,
writer: Writer | None = None,
collate_fn: Callable | None = None,
pin_memory: bool = False,
prefetch: Optional[int] = None,
transform: Optional["Transform"] = None, # noqa-F821
prefetch: int | None = None,
transform: "Transform" | None = None, # noqa-F821
):

if sampler is None:
sampler = SamplerWithoutReplacement()
if root is None:
root = _get_root_dir("openml")
self.root = Path(root)
self.dataset_id = name

if not self._is_downloaded():
dataset = self._get_data(
name,
)
storage = TensorStorage(dataset.memmap(self._dataset_path))
else:
dataset = TensorDict.load_memmap(self._dataset_path)
storage = TensorStorage(dataset)

dataset = self._get_data(
name,
)
self.max_outcome_val = dataset["y"].max().item()

storage = LazyMemmapStorage(dataset.shape[0])
super().__init__(
batch_size=batch_size,
storage=storage,
Expand All @@ -79,7 +95,13 @@ def __init__(
prefetch=prefetch,
transform=transform,
)
self.extend(dataset)

@property
def _dataset_path(self):
return self.root / self.dataset_id

def _is_downloaded(self):
return os.path.exists(self._dataset_path)

@classmethod
def _get_data(cls, dataset_name):
Expand Down
Loading

0 comments on commit 8194565

Please sign in to comment.