Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Max Value Writer #1622

Merged
merged 22 commits into from
Oct 18, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix
  • Loading branch information
albertbou92 committed Oct 11, 2023
commit bd32b3c4232d261ef46e7d0f6580ac214ac9223d
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(self, rank_key=None, **kw) -> None:
def add(self, data: Any) -> int:

ret = None
rank_data = data.get("_data", self._rank_key)
rank_data = data.get("_data").get(self._rank_key)
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved

# Sum the rank key, in case it is a whole trajectory
rank_data = rank_data.sum().item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe?
Maybe we should document what are the expected shapes for this class, eg

[B, T]

but not

[B1, B2, T]

Another option is to check the number of dimensions of the ranking key OR the name of the last dim of the input tensordict (which should be "time").

Not raising any exception and just doing a plain sum could lead to surprising results I think

Copy link
Contributor Author

@albertbou92 albertbou92 Oct 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the first option. Since the ranking value has to be a single float we only allow data of the shape [] and [T] for the add method and [B] and [B, T] for the extend method. If data has a time dimension, we sum along it. If too many dimensions are provided, an error is raised.

I did not go for checking the dimension names because it seemed to restrictive. I don't think time dimension is always labelled

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not always but mostly
if you get your data from env.rollout or collector, it will.
If from there you store the data in a rb, it will keep the tag.
But if you reshape or do other stuff it could go away.

Expand Down