Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Immutable writer for datasets #1781

Merged
merged 16 commits into from
Jan 9, 2024
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Jan 9, 2024
commit 4f320458b530f63d9b73c088b76c4ef147824c7f
16 changes: 15 additions & 1 deletion examples/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from tensordict.nn import TensorDictModule

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data import (
LazyMemmapStorage,
RoundRobinWriter,
TensorDictReplayBuffer,
TensorStorage,
)
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
from torchrl.data.replay_buffers import RandomSampler
from torchrl.envs import (
Expand Down Expand Up @@ -242,7 +247,16 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling):
use_truncated_as_done=True,
direct_download=True,
prefetch=4,
writer=RoundRobinWriter(),
)

# since we're not extending the data, adding keys can only be done via
# the creation of a new storage
data_memmap = data[:]
with data_memmap.unlock_():
data_memmap = r2g.inv(data_memmap)
data._storage = TensorStorage(data_memmap)

loc = data[:]["observation"].flatten(0, -2).mean(axis=0).float()
std = data[:]["observation"].flatten(0, -2).std(axis=0).float()

Expand Down
Loading