Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Max Value Writer #1622

Merged
merged 22 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
comments feedback
  • Loading branch information
albertbou92 committed Oct 18, 2023
commit 8ec741bcc2e899a93eac90ad41c33903cf644ace
4 changes: 3 additions & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,9 @@ def test_max_value_writer(size, batch_size, reward_ranges):
},
batch_size=size,
)
rb.extend(td)

for sample in td:
rb.add(sample)
sample = rb.sample()
assert (sample.get("key") <= max_reward3).all()
assert (max_reward2 <= sample.get("key")).all()
Expand Down
22 changes: 19 additions & 3 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,17 @@ def __init__(self, rank_key=None, **kwargs) -> None:

def get_insert_index(self, data: Any) -> int:
"""Returns the index where the data should be inserted, or None if it should not be inserted."""
vmoens marked this conversation as resolved.
Show resolved Hide resolved
if data.batch_dims > 1:
raise RuntimeError(
"Expected input tensordict to have no more than 1 dimension, got"
f"tensordict.batch_size = {data.batch_size}"
)

ret = None
rank_data = data.get(("_data", self._rank_key))

# Sum the rank key, in case it is a whole trajectory
rank_data = rank_data.sum().item()
rank_data = rank_data.sum(-1).item()

if rank_data is None:
raise KeyError(f"Rank key {self._rank_key} not found in data.")
Expand All @@ -175,15 +181,25 @@ def get_insert_index(self, data: Any) -> int:
return ret

def add(self, data: Any) -> int:
"""Inserts a single element of data at an appropriate index, and returns that index."""
"""Inserts a single element of data at an appropriate index, and returns that index.

The data passed to this module should be structured as :obj:`[]` or :obj:`[T]` where
:obj:`T` the time dimension. If the data is a trajectory, the rank key will be summed
over the time dimension.
"""
index = self.get_insert_index(data)
if index is not None:
data.set("index", index)
self._storage[index] = data
return index

def extend(self, data: Sequence) -> None:
"""Inserts a series of data points at appropriate indices."""
"""Inserts a series of data points at appropriate indices.

The data passed to this module should be structured as :obj:`[B]` or :obj:`[B, T]` where :obj:`B` is
the batch size, :obj:`T` the time dimension. If the data is a trajectory, the rank key will be summed over the
time dimension.
"""
data_to_replace = {}
for i, sample in enumerate(data):
index = self.get_insert_index(sample)
Expand Down