Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] step_mdp nested keys #1339

Merged
merged 17 commits into from
Jun 30, 2023
Merged

Conversation

matteobettini
Copy link
Contributor

No description provided.

Signed-off-by: Matteo Bettini <matbet@meta.com>
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 30, 2023
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests look good, but I suspect the implementation of step_mdp will be drastically slower than the one we have now.
Given the number of issues opened by users lately regarding environments overhead compared to gym, I'm reluctant to any change in step_mdp that makes it slower than what it already is.
Let's wait until we merge the 2 benchmark PRs (the one with step_mdp and the other to run benchmarks in PRs), then it'll be easier to iterate over this

Comment on lines 158 to 163
if isinstance(done_key, tuple) and len(done_key) == 1:
done_key = done_key[0]
if isinstance(reward_key, tuple) and len(reward_key) == 1:
reward_key = reward_key[0]
if isinstance(action_key, tuple) and len(action_key) == 1:
action_key = action_key[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is expensive let's not do that if we can avoid it
We could blend this in unravel_keys in tensordict (which is coded in c++) if it's there for efficiency purposes

Comment on lines 201 to 202
if not exclude_action:
out._set(action_key, tensordict.get(action_key))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here if action has to be kepts we're removing it and adding it back, which is expensive, let's avoid that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are not removing it, the exclusion before is on "next"

out = tensordict.get("next").clone(False)
excluded = set()
excluded = {action_key}
if exclude_done:
excluded.add(done_key)
if exclude_reward:
excluded.add(reward_key)
if len(excluded):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that will never happen anymore, so we'll always be calling exclude, which is expensive. Let's avoid that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

# out.update(tensordict.select(*td_keys))
for key in td_keys:
out._set(key, tensordict.get(key))
excluded = set.union(excluded, set(out.keys(True, True)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set on all the keys is super expensive, i'd avoid it if we can

Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
td_next = tensordict.get("next")

td_keys = td.keys(True, True)
td_next_keys = td_next.keys(True, True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are is the only time we traverse.
Basically we just visit every key in tensordict


# Set the keys from root
if not exclude_action:
_set_key(dest=out, source=tensordict, key=action_key)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we do the action separate just to not have another if later, perf is the same

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see a version of that working but the way we set the batch size requires multiple access to nested tensordicts that will be time consuming
If you merge main into this branch you'll get the time measure of your solution

Comment on lines 191 to 192
td = tensordict.exclude("next")
td_next = tensordict.get("next")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use pop

Comment on lines 200 to 203
excluded = {
done_key if exclude_done else None,
reward_key if exclude_reward else None,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't weird to have None in the set?
Should we not build it iteratively?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i can do it, i thought preformance was similar and it was more readable, but i ll change it

Comment on lines 225 to 227
if isinstance(key, tuple) and len(key) > 1: # Setting the batch_sizes
for k in range(1, len(key)):
dest[key[:k]].batch_size = source[key[:k]].batch_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if two keys share the same root, we'll set the batch-time multiple times, which is time consuming

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I know, i ll improve it

Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
@vmoens vmoens added the bug Something isn't working label Jun 30, 2023
@vmoens vmoens merged commit 12cbe72 into pytorch:main Jun 30, 2023
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@matteobettini matteobettini deleted the step_mdp_nested branch July 3, 2023 07:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants