-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] step_mdp
nested keys
#1339
Conversation
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests look good, but I suspect the implementation of step_mdp will be drastically slower than the one we have now.
Given the number of issues opened by users lately regarding environments overhead compared to gym, I'm reluctant to any change in step_mdp that makes it slower than what it already is.
Let's wait until we merge the 2 benchmark PRs (the one with step_mdp and the other to run benchmarks in PRs), then it'll be easier to iterate over this
torchrl/envs/utils.py
Outdated
if isinstance(done_key, tuple) and len(done_key) == 1: | ||
done_key = done_key[0] | ||
if isinstance(reward_key, tuple) and len(reward_key) == 1: | ||
reward_key = reward_key[0] | ||
if isinstance(action_key, tuple) and len(action_key) == 1: | ||
action_key = action_key[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is expensive let's not do that if we can avoid it
We could blend this in unravel_keys in tensordict (which is coded in c++) if it's there for efficiency purposes
torchrl/envs/utils.py
Outdated
if not exclude_action: | ||
out._set(action_key, tensordict.get(action_key)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here if action has to be kepts we're removing it and adding it back, which is expensive, let's avoid that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are not removing it, the exclusion before is on "next"
torchrl/envs/utils.py
Outdated
out = tensordict.get("next").clone(False) | ||
excluded = set() | ||
excluded = {action_key} | ||
if exclude_done: | ||
excluded.add(done_key) | ||
if exclude_reward: | ||
excluded.add(reward_key) | ||
if len(excluded): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that will never happen anymore, so we'll always be calling exclude, which is expensive. Let's avoid that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
torchrl/envs/utils.py
Outdated
# out.update(tensordict.select(*td_keys)) | ||
for key in td_keys: | ||
out._set(key, tensordict.get(key)) | ||
excluded = set.union(excluded, set(out.keys(True, True))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set on all the keys is super expensive, i'd avoid it if we can
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
torchrl/envs/utils.py
Outdated
td_next = tensordict.get("next") | ||
|
||
td_keys = td.keys(True, True) | ||
td_next_keys = td_next.keys(True, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are is the only time we traverse.
Basically we just visit every key in tensordict
torchrl/envs/utils.py
Outdated
|
||
# Set the keys from root | ||
if not exclude_action: | ||
_set_key(dest=out, source=tensordict, key=action_key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here we do the action separate just to not have another if later, perf is the same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see a version of that working but the way we set the batch size requires multiple access to nested tensordicts that will be time consuming
If you merge main into this branch you'll get the time measure of your solution
torchrl/envs/utils.py
Outdated
td = tensordict.exclude("next") | ||
td_next = tensordict.get("next") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use pop
torchrl/envs/utils.py
Outdated
excluded = { | ||
done_key if exclude_done else None, | ||
reward_key if exclude_reward else None, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't weird to have None in the set?
Should we not build it iteratively?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i can do it, i thought preformance was similar and it was more readable, but i ll change it
torchrl/envs/utils.py
Outdated
if isinstance(key, tuple) and len(key) > 1: # Setting the batch_sizes | ||
for k in range(1, len(key)): | ||
dest[key[:k]].batch_size = source[key[:k]].batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if two keys share the same root, we'll set the batch-time multiple times, which is time consuming
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I know, i ll improve it
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
No description provided.