Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Graduate Replay Buffer prototype #794

Merged
merged 8 commits into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py)
scratch_dir="/tmp/"
)
buffer = TensorDictPrioritizedReplayBuffer(
buffer_size=10000,
alpha=0.7,
beta=0.5,
collate_fn=lambda x: x,
Expand Down
8 changes: 4 additions & 4 deletions benchmarks/storage/benchmark_sample_latency_over_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
import torch.distributed.rpc as rpc
from tensordict import TensorDict
from torchrl.data.replay_buffers.rb_prototype import RemoteTensorDictReplayBuffer
from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.data.replay_buffers.storages import (
LazyMemmapStorage,
Expand Down Expand Up @@ -92,10 +92,10 @@ def train(self, batch_size: int) -> None:
if self._ret is None:
self._ret = ret
else:
self._ret[0].update_(ret[0])
self._ret.update_(ret)
# make sure the content is read
self._ret[0]["observation"] + 1
self._ret[0]["next_observation"] + 1
self._ret["observation"] + 1
self._ret["next_observation"] + 1
return timeit.default_timer() - start_time

def _create_replay_buffer(self) -> rpc.RRef:
Expand Down
7 changes: 2 additions & 5 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,17 @@ widely used replay buffers:
TensorDictReplayBuffer
TensorDictPrioritizedReplayBuffer

Composable Replay Buffers (Prototype)
Composable Replay Buffers
-------------------------------------

We also provide a prototyped composable replay buffer.
We also give users the ability to compose a replay buffer using the following components:

.. autosummary::
:toctree: generated/
:template: rl_template.rst

.. currentmodule:: torchrl.data.replay_buffers

torchrl.data.replay_buffers.rb_prototype.ReplayBuffer
torchrl.data.replay_buffers.rb_prototype.TensorDictReplayBuffer
torchrl.data.replay_buffers.rb_prototype.RemoteTensorDictReplayBuffer
torchrl.data.replay_buffers.samplers.Sampler
torchrl.data.replay_buffers.samplers.RandomSampler
torchrl.data.replay_buffers.samplers.PrioritizedSampler
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/distributed_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
import torch.distributed.rpc as rpc
from tensordict import TensorDict
from torchrl.data.replay_buffers.rb_prototype import RemoteTensorDictReplayBuffer
from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.data.replay_buffers.utils import accept_remote_rref_invocation
Expand Down
111 changes: 44 additions & 67 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,20 @@
from _utils_internal import get_available_devices
from tensordict.prototype import is_tensorclass, tensorclass
from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase
from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer, TensorDictReplayBuffer
from torchrl.data.replay_buffers import (
rb_prototype,
samplers,
from torchrl.data import (
PrioritizedReplayBuffer,
RemoteTensorDictReplayBuffer,
ReplayBuffer,
TensorDictPrioritizedReplayBuffer,
writers,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers import samplers, writers
from torchrl.data.replay_buffers.samplers import (
PrioritizedSampler,
RandomSampler,
SamplerWithoutReplacement,
)

from torchrl.data.replay_buffers.storages import (
LazyMemmapStorage,
LazyTensorStorage,
Expand Down Expand Up @@ -60,9 +62,9 @@
@pytest.mark.parametrize(
"rb_type",
[
rb_prototype.ReplayBuffer,
rb_prototype.TensorDictReplayBuffer,
rb_prototype.RemoteTensorDictReplayBuffer,
ReplayBuffer,
TensorDictReplayBuffer,
RemoteTensorDictReplayBuffer,
],
)
@pytest.mark.parametrize(
Expand All @@ -87,23 +89,21 @@ def _get_rb(self, rb_type, size, sampler, writer, storage):
return rb

def _get_datum(self, rb_type):
if rb_type is rb_prototype.ReplayBuffer:
if rb_type is ReplayBuffer:
data = torch.randint(100, (1,))
elif (
rb_type is rb_prototype.TensorDictReplayBuffer
or rb_type is rb_prototype.RemoteTensorDictReplayBuffer
rb_type is TensorDictReplayBuffer or rb_type is RemoteTensorDictReplayBuffer
):
data = TensorDict({"a": torch.randint(100, (1,))}, [])
else:
raise NotImplementedError(rb_type)
return data

def _get_data(self, rb_type, size):
if rb_type is rb_prototype.ReplayBuffer:
if rb_type is ReplayBuffer:
data = torch.randint(100, (size, 1))
elif (
rb_type is rb_prototype.TensorDictReplayBuffer
or rb_type is rb_prototype.RemoteTensorDictReplayBuffer
rb_type is TensorDictReplayBuffer or rb_type is RemoteTensorDictReplayBuffer
):
data = TensorDict(
{
Expand Down Expand Up @@ -298,7 +298,7 @@ def test_set_tensorclass(self, max_size, shape, storage):
def test_prototype_prb(priority_key, contiguous, device):
torch.manual_seed(0)
np.random.seed(0)
rb = rb_prototype.TensorDictReplayBuffer(
rb = TensorDictReplayBuffer(
sampler=samplers.PrioritizedSampler(5, alpha=0.7, beta=0.9),
priority_key=priority_key,
)
Expand All @@ -311,7 +311,7 @@ def test_prototype_prb(priority_key, contiguous, device):
batch_size=[3],
).to(device)
rb.extend(td1)
s, _ = rb.sample(2)
s = rb.sample(2)
assert s.batch_size == torch.Size(
[
2,
Expand All @@ -330,7 +330,7 @@ def test_prototype_prb(priority_key, contiguous, device):
batch_size=[5],
).to(device)
rb.extend(td2)
s, _ = rb.sample(5)
s = rb.sample(5)
assert s.batch_size == torch.Size([5])
assert (td2[s.get("_idx").squeeze()].get("a") == s.get("a")).all()
assert_allclose_td(td2[s.get("_idx").squeeze()].select("a"), s.select("a"))
Expand All @@ -353,26 +353,26 @@ def test_prototype_prb(priority_key, contiguous, device):

idx0 = s.get("_idx")[0]
rb.update_tensordict_priority(s)
s, _ = rb.sample(5)
s = rb.sample(5)
assert (val == s.get("a")).sum() >= 1
torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1))

# test updating values of original td
td2.set_("a", torch.ones_like(td2.get("a")))
s, _ = rb.sample(5)
s = rb.sample(5)
torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1))


@pytest.mark.parametrize("stack", [False, True])
def test_rb_prototype_trajectories(stack):
def test_replay_buffer_trajectories(stack):
traj_td = TensorDict(
{"obs": torch.randn(3, 4, 5), "actions": torch.randn(3, 4, 2)},
batch_size=[3, 4],
)
if stack:
traj_td = torch.stack([td.to_tensordict() for td in traj_td], 0)

rb = rb_prototype.TensorDictReplayBuffer(
rb = TensorDictReplayBuffer(
sampler=samplers.PrioritizedSampler(
5,
alpha=0.7,
Expand All @@ -381,10 +381,10 @@ def test_rb_prototype_trajectories(stack):
priority_key="td_error",
)
rb.extend(traj_td)
sampled_td, _ = rb.sample(3)
sampled_td = rb.sample(3)
sampled_td.set("td_error", torch.rand(3))
rb.update_tensordict_priority(sampled_td)
sampled_td, _ = rb.sample(3, include_info=True)
sampled_td = rb.sample(3, include_info=True)
assert (sampled_td.get("_weight") > 0).all()
assert sampled_td.batch_size == torch.Size([3])

Expand Down Expand Up @@ -433,7 +433,7 @@ def _get_rb(self, rbtype, size, storage, prefetch):
params = self._default_params_td_prb
else:
raise NotImplementedError(rbtype)
rb = rbtype(size=size, storage=storage, prefetch=prefetch, **params)
rb = rbtype(storage=storage, prefetch=prefetch, **params)
return rb

def _get_datum(self, rbtype):
Expand Down Expand Up @@ -481,17 +481,17 @@ def test_cursor_position2(self, rbtype, storage, size, prefetch):
rb.extend(batch1)

# Added less data than storage max size
if size > 5:
assert rb._cursor == 5
if size > 5 or storage is None:
assert rb._writer._cursor == 5
# Added more data than storage max size
elif size < 5:
assert rb._cursor == 5 - size
assert rb._writer._cursor == 5 - size
# Added as data as storage max size
else:
assert rb._cursor == 0
assert rb._writer._cursor == 0
batch2 = self._get_data(rbtype, size=size - 1)
rb.extend(batch2)
assert rb._cursor == size - 1
assert rb._writer._cursor == size - 1

def test_add(self, rbtype, storage, size, prefetch):
torch.manual_seed(0)
Expand Down Expand Up @@ -575,10 +575,10 @@ def test_prb(priority_key, contiguous, device):
torch.manual_seed(0)
np.random.seed(0)
rb = TensorDictPrioritizedReplayBuffer(
5,
alpha=0.7,
beta=0.9,
priority_key=priority_key,
storage=ListStorage(5),
)
td1 = TensorDict(
source={
Expand Down Expand Up @@ -630,7 +630,7 @@ def test_prb(priority_key, contiguous, device):
val = s.get("a")[0]

idx0 = s.get("_idx")[0]
rb.update_priority(s)
rb.update_tensordict_priority(s)
s = rb.sample(5)
assert (val == s.get("a")).sum() >= 1
torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1))
Expand All @@ -651,16 +651,16 @@ def test_rb_trajectories(stack):
traj_td = torch.stack([td.to_tensordict() for td in traj_td], 0)

rb = TensorDictPrioritizedReplayBuffer(
5,
alpha=0.7,
beta=0.9,
priority_key="td_error",
storage=ListStorage(5),
)
rb.extend(traj_td)
sampled_td = rb.sample(3)
sampled_td.set("td_error", torch.rand(3))
rb.update_priority(sampled_td)
sampled_td = rb.sample(3, return_weight=True)
rb.update_tensordict_priority(sampled_td)
sampled_td = rb.sample(3, include_info=True)
assert (sampled_td.get("_weight") > 0).all()
assert sampled_td.batch_size == torch.Size([3])

Expand All @@ -680,12 +680,12 @@ def test_shared_storage_prioritized_sampler():
sampler0 = RandomSampler()
sampler1 = PrioritizedSampler(max_capacity=n, alpha=0.7, beta=1.1)

rb0 = rb_prototype.ReplayBuffer(
rb0 = ReplayBuffer(
storage=storage,
writer=writer,
sampler=sampler0,
)
rb1 = rb_prototype.ReplayBuffer(
rb1 = ReplayBuffer(
storage=storage,
writer=writer,
sampler=sampler1,
Expand All @@ -708,25 +708,8 @@ def test_shared_storage_prioritized_sampler():
assert rb1._sampler._sum_tree.query(0, 70) == 50


def test_legacy_rb_does_not_attach():
n = 10
storage = LazyMemmapStorage(n)
writer = RoundRobinWriter()
sampler = RandomSampler()
rb = ReplayBuffer(storage=storage, size=n, prefetch=0)
prb = rb_prototype.ReplayBuffer(
storage=storage,
writer=writer,
sampler=sampler,
)

assert len(storage._attached_entities) == 1
assert prb in storage._attached_entities
assert rb not in storage._attached_entities


def test_append_transform():
rb = rb_prototype.ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0))
rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0))
td = TensorDict(
{
"observation": torch.randn(2, 4, 3, 16),
Expand All @@ -741,7 +724,7 @@ def test_append_transform():

rb.append_transform(flatten)

sampled, _ = rb.sample(1)
sampled = rb.sample(1)
assert sampled.get("observation_cat").shape[-1] == 32


Expand All @@ -750,29 +733,25 @@ def test_init_transform():
-2, -1, in_keys=["observation"], out_keys=["flattened"]
)

rb = rb_prototype.ReplayBuffer(
collate_fn=lambda x: torch.stack(x, 0), transform=flatten
)
rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0), transform=flatten)

td = TensorDict({"observation": torch.randn(2, 4, 3, 16)}, [])
rb.add(td)
sampled, _ = rb.sample(1)
sampled = rb.sample(1)
assert sampled.get("flattened").shape[-1] == 48


def test_insert_transform():
flatten = FlattenObservation(
-2, -1, in_keys=["observation"], out_keys=["flattened"]
)
rb = rb_prototype.ReplayBuffer(
collate_fn=lambda x: torch.stack(x, 0), transform=flatten
)
rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0), transform=flatten)
td = TensorDict({"observation": torch.randn(2, 4, 3, 16, 1)}, [])
rb.add(td)

rb.insert_transform(0, SqueezeTransform(-1, in_keys=["observation"]))

sampled, _ = rb.sample(1)
sampled = rb.sample(1)
assert sampled.get("flattened").shape[-1] == 48

with pytest.raises(ValueError):
Expand Down Expand Up @@ -810,7 +789,7 @@ def test_insert_transform():

@pytest.mark.parametrize("transform", transforms)
def test_smoke_replay_buffer_transform(transform):
rb = rb_prototype.ReplayBuffer(
rb = ReplayBuffer(
transform=transform(in_keys="observation"),
)

Expand All @@ -833,9 +812,7 @@ def test_smoke_replay_buffer_transform(transform):

@pytest.mark.parametrize("transform", transforms)
def test_smoke_replay_buffer_transform_no_inkeys(transform):
rb = rb_prototype.ReplayBuffer(
collate_fn=lambda x: torch.stack(x, 0), transform=transform()
)
rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0), transform=transform())

td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1)}, [])
rb.add(td)
Expand Down
4 changes: 2 additions & 2 deletions test/test_rb_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers.rb_prototype import RemoteTensorDictReplayBuffer
from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.data.replay_buffers.writers import RoundRobinWriter
Expand Down Expand Up @@ -50,7 +50,7 @@ def sample_from_buffer_remotely_returns_correct_tensordict_test(rank, name, worl
if name == "TRAINER":
buffer = _construct_buffer("BUFFER")
_, inserted = _add_random_tensor_dict_to_buffer(buffer)
sampled, _ = _sample_from_buffer(buffer, 1)
sampled = _sample_from_buffer(buffer, 1)
assert type(sampled) is type(inserted) is TensorDict
assert (sampled == inserted)["a"].item()

Expand Down
Loading