Skip to content

Commit

Permalink
[Feature] Max Value Writer (#1622)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
albertbou92 and vmoens authored Oct 18, 2023
1 parent 38dfc21 commit 55d667e
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ We also give users the ability to compose a replay buffer using the following co
Writer
RoundRobinWriter
TensorDictRoundRobinWriter
TensorDictMaxValueWriter

Storage choice is very influential on replay buffer sampling latency, especially in distributed reinforcement learning settings with larger data volumes.
:class:`LazyMemmapStorage` is highly advised in distributed settings with shared storage due to the lower serialisation cost of MemmapTensors as well as the ability to specify file storage locations for improved node failure recovery.
Expand Down
64 changes: 63 additions & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@
ListStorage,
TensorStorage,
)
from torchrl.data.replay_buffers.writers import RoundRobinWriter
from torchrl.data.replay_buffers.writers import (
RoundRobinWriter,
TensorDictMaxValueWriter,
)
from torchrl.envs.transforms.transforms import (
BinarizeReward,
CatFrames,
Expand Down Expand Up @@ -1209,6 +1212,65 @@ def test_load_state_dict(self, storage_in, storage_out, init_out):
assert (s.exclude("index") == 1).all()


@pytest.mark.parametrize("size", [20, 25, 30])
@pytest.mark.parametrize("batch_size", [1, 10, 15])
@pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)])
def test_max_value_writer(size, batch_size, reward_ranges):
rb = TensorDictReplayBuffer(
storage=LazyTensorStorage(size),
sampler=SamplerWithoutReplacement(),
batch_size=batch_size,
writer=TensorDictMaxValueWriter(rank_key="key"),
)

max_reward1, max_reward2, max_reward3 = reward_ranges

td = TensorDict(
{
"key": torch.clamp_max(torch.rand(size), max=max_reward1),
"obs": torch.tensor(torch.rand(size)),
},
batch_size=size,
device="cpu",
)
rb.extend(td)
sample = rb.sample()
assert (sample.get("key") <= max_reward1).all()
assert (0 <= sample.get("key")).all()
assert len(sample.get("index").unique()) == len(sample.get("index"))

td = TensorDict(
{
"key": torch.clamp(torch.rand(size), min=max_reward1, max=max_reward2),
"obs": torch.tensor(torch.rand(size)),
},
batch_size=size,
device="cpu",
)
rb.extend(td)
sample = rb.sample()
assert (sample.get("key") <= max_reward2).all()
assert (max_reward1 <= sample.get("key")).all()
assert len(sample.get("index").unique()) == len(sample.get("index"))

td = TensorDict(
{
"key": torch.clamp(torch.rand(size), min=max_reward2, max=max_reward3),
"obs": torch.tensor(torch.rand(size)),
},
batch_size=size,
device="cpu",
)

for sample in td:
rb.add(sample)

sample = rb.sample()
assert (sample.get("key") <= max_reward3).all()
assert (max_reward2 <= sample.get("key")).all()
assert len(sample.get("index").unique()) == len(sample.get("index"))


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
1 change: 1 addition & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ReplayBuffer,
RoundRobinWriter,
Storage,
TensorDictMaxValueWriter,
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
TensorDictRoundRobinWriter,
Expand Down
7 changes: 6 additions & 1 deletion torchrl/data/replay_buffers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,9 @@
Storage,
TensorStorage,
)
from .writers import RoundRobinWriter, TensorDictRoundRobinWriter, Writer
from .writers import (
RoundRobinWriter,
TensorDictMaxValueWriter,
TensorDictRoundRobinWriter,
Writer,
)
11 changes: 6 additions & 5 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,12 +718,13 @@ def add(self, data: TensorDictBase) -> int:
data_add = data

index = super()._add(data_add)
if is_tensor_collection(data_add):
data_add.set("index", index)
if index is not None:
if is_tensor_collection(data_add):
data_add.set("index", index)

# priority = self._get_priority(data)
# if priority:
self.update_tensordict_priority(data_add)
# priority = self._get_priority(data)
# if priority:
self.update_tensordict_priority(data_add)
return index

def extend(self, tensordicts: TensorDictBase) -> torch.Tensor:
Expand Down
126 changes: 126 additions & 0 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import heapq
from abc import ABC, abstractmethod
from typing import Any, Dict, Sequence

Expand Down Expand Up @@ -92,3 +93,128 @@ def extend(self, data: Sequence) -> torch.Tensor:
data["index"] = index
self._storage[index] = data
return index


class TensorDictMaxValueWriter(Writer):
"""A Writer class for composable replay buffers that keeps the top elements based on some ranking key.
If rank_key is not provided, the key will be ``("next", "reward")``.
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer, TensorDictMaxValueWriter
>>> from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
>>> rb = TensorDictReplayBuffer(
... storage=LazyTensorStorage(1),
... sampler=SamplerWithoutReplacement(),
... batch_size=1,
... writer=TensorDictMaxValueWriter(rank_key="key"),
... )
>>> td = TensorDict({
... "key": torch.tensor(range(10)),
... "obs": torch.tensor(range(10))
... }, batch_size=10)
>>> rb.extend(td)
>>> print(rb.sample().get("obs").item())
9
>>> td = TensorDict({
... "key": torch.tensor(range(10, 20)),
... "obs": torch.tensor(range(10, 20))
... }, batch_size=10)
>>> rb.extend(td)
>>> print(rb.sample().get("obs").item())
19
>>> td = TensorDict({
... "key": torch.tensor(range(10)),
... "obs": torch.tensor(range(10))
... }, batch_size=10)
>>> rb.extend(td)
>>> print(rb.sample().get("obs").item())
19
"""

def __init__(self, rank_key=None, **kwargs) -> None:
super().__init__(**kwargs)
self._cursor = 0
self._current_top_values = []
self._rank_key = rank_key
if self._rank_key is None:
self._rank_key = ("next", "reward")

def get_insert_index(self, data: Any) -> int:
"""Returns the index where the data should be inserted, or ``None`` if it should not be inserted."""
if data.batch_dims > 1:
raise RuntimeError(
"Expected input tensordict to have no more than 1 dimension, got"
f"tensordict.batch_size = {data.batch_size}"
)

ret = None
rank_data = data.get(("_data", self._rank_key))

# If time dimension, sum along it.
rank_data = rank_data.sum(-1).item()

if rank_data is None:
raise KeyError(f"Rank key {self._rank_key} not found in data.")

# If the buffer is not full, add the data
if len(self._current_top_values) < self._storage.max_size:

ret = self._cursor
self._cursor = (self._cursor + 1) % self._storage.max_size

# Add new reward to the heap
heapq.heappush(self._current_top_values, (rank_data, ret))

# If the buffer is full, check if the new data is better than the worst data in the buffer
elif rank_data > self._current_top_values[0][0]:

# retrieve position of the smallest value
min_sample = heapq.heappop(self._current_top_values)
ret = min_sample[1]

# Add new reward to the heap
heapq.heappush(self._current_top_values, (rank_data, ret))

return ret

def add(self, data: Any) -> int:
"""Inserts a single element of data at an appropriate index, and returns that index.
The data passed to this module should be structured as :obj:`[]` or :obj:`[T]` where
:obj:`T` the time dimension. If the data is a trajectory, the rank key will be summed
over the time dimension.
"""
index = self.get_insert_index(data)
if index is not None:
data.set("index", index)
self._storage[index] = data
return index

def extend(self, data: Sequence) -> None:
"""Inserts a series of data points at appropriate indices.
The data passed to this module should be structured as :obj:`[B]` or :obj:`[B, T]` where :obj:`B` is
the batch size, :obj:`T` the time dimension. If the data is a trajectory, the rank key will be summed over the
time dimension.
"""
data_to_replace = {}
for i, sample in enumerate(data):
index = self.get_insert_index(sample)
if index is not None:
data_to_replace[index] = i

# Replace the data in the storage all at once
keys, values = zip(*data_to_replace.items())
if len(keys) > 0:
index = data.get("index")
values = list(values)
keys = index[values] = torch.tensor(keys, dtype=index.dtype)
data.set("index", index)
self._storage[keys] = data[values]

def _empty(self) -> None:
self._cursor = 0
self._current_top_values = []

0 comments on commit 55d667e

Please sign in to comment.