Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Immutable writer for datasets #1781

Merged
merged 16 commits into from
Jan 9, 2024
Merged
Next Next commit
init
  • Loading branch information
vmoens committed Jan 8, 2024
commit ea892cc7f7b4a8821e013234829f0f162e79269a
4 changes: 3 additions & 1 deletion torchrl/data/datasets/d4rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torchrl.data.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage, TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer


class D4RLExperienceReplay(TensorDictReplayBuffer):
Expand Down Expand Up @@ -234,6 +234,8 @@ def __init__(
f"The dataset could not be found in {Path(self.root) / name}."
)

if writer is None:
writer = ImmutableDatasetWriter()
super().__init__(
batch_size=batch_size,
storage=storage,
Expand Down
11 changes: 10 additions & 1 deletion torchrl/data/datasets/minari_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from torchrl.data.tensor_specs import (
BoundedTensorSpec,
CompositeSpec,
Expand All @@ -34,6 +34,7 @@
from torchrl.envs.utils import _classproperty

_has_tqdm = importlib.util.find_spec("tqdm", None) is not None
_has_minari = importlib.util.find_spec("minari", None) is not None

_NAME_MATCH = KeyDependentDefaultDict(lambda key: key)
_NAME_MATCH["observations"] = "observation"
Expand Down Expand Up @@ -193,6 +194,10 @@ def __init__(
else:
storage = self._load()
storage = TensorStorage(storage)

if writer is None:
writer = ImmutableDatasetWriter()

super().__init__(
storage=storage,
sampler=sampler,
Expand All @@ -206,6 +211,8 @@ def __init__(

@_classproperty
def available_datasets(self):
if not _has_minari:
raise ImportError("minari library not found.")
import minari

return minari.list_remote_datasets().keys()
Expand All @@ -228,6 +235,8 @@ def metadata_path(self):
return Path(self.root) / self.dataset_id / "env_metadata.json"

def _download_and_preproc(self):
if not _has_minari:
raise ImportError("minari library not found.")
import minari

if _has_tqdm:
Expand Down
5 changes: 5 additions & 0 deletions torchrl/data/datasets/openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from tensordict.tensordict import TensorDict

from torchrl.data import ImmutableDatasetWriter
from torchrl.data.replay_buffers import (
LazyMemmapStorage,
Sampler,
Expand Down Expand Up @@ -72,6 +73,10 @@ def __init__(
self.max_outcome_val = dataset["y"].max().item()

storage = LazyMemmapStorage(dataset.shape[0])

if writer is None:
writer = ImmutableDatasetWriter()

super().__init__(
batch_size=batch_size,
storage=storage,
Expand Down
6 changes: 5 additions & 1 deletion torchrl/data/datasets/roboset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from torchrl.data.replay_buffers.writers import Writer, ImmutableDatasetWriter

_has_tqdm = importlib.util.find_spec("tqdm", None) is not None
_has_h5py = importlib.util.find_spec("h5py", None) is not None
Expand Down Expand Up @@ -190,6 +190,10 @@ def __init__(
else:
storage = self._load()
storage = TensorStorage(storage)

if writer is None:
writer = ImmutableDatasetWriter()

super().__init__(
storage=storage,
sampler=sampler,
Expand Down
6 changes: 5 additions & 1 deletion torchrl/data/datasets/vd4rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer

from torchrl.envs.transforms import Compose, Resize, ToTensorImage
from torchrl.envs.utils import _classproperty
Expand Down Expand Up @@ -220,6 +220,10 @@ def __init__(
transform, Resize(image_size, in_keys=["pixels", ("next", "pixels")])
)
storage = TensorStorage(storage)

if writer is None:
writer = ImmutableDatasetWriter()

super().__init__(
storage=storage,
sampler=sampler,
Expand Down
Loading