Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Stack transform #2567

Merged
merged 1 commit into from
Dec 4, 2024
Merged

Conversation

kurtamohler
Copy link
Collaborator

Description

Adds a transform that stacks tensors and specs from different keys of a tensordict into a common key.

Motivation and Context

close #2566

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • New feature (non-breaking change which adds core functionality)
  • Documentation (update in the documentation)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@kurtamohler kurtamohler requested a review from vmoens November 14, 2024 21:13
Copy link

pytorch-bot bot commented Nov 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2567

Note: Links to docs will display an error until the docs builds have been completed.

❌ 5 New Failures, 9 Unrelated Failures

As of commit 8e044a1 with merge base 1cffffe (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 14, 2024
@kurtamohler
Copy link
Collaborator Author

Looks like there is a minor bug if I try to use this on UnityMLAgentsEnv and then do a rollout. I'll fix that and add a test

@vmoens vmoens added the enhancement New feature or request label Nov 15, 2024
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this, long awaited feature!
Just left a couple of comments on the default dim and test set

torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
test/test_transforms.py Show resolved Hide resolved
@kurtamohler kurtamohler force-pushed the Stack-Transform-0 branch 3 times, most recently from 23f7e1b to f443812 Compare November 19, 2024 05:37
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks this is superb
I'd like to discuss the inverse transform:
Would it make sense in the inverse to get an entry (from the input_spec) and unbind it?
Like: you have a single action with leading dim of 2, and map it to ("agent0", "action"), ("agent1", "action"). The spec seen from the outside is the stack of the 2 specs (as it is for stuff processed in forward).
Would that make sense?

torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Nov 20, 2024

Would it make sense in the inverse to get an entry (from the input_spec) and unbind it?
Like: you have a single action with leading dim of 2, and map it to ("agent0", "action"), ("agent1", "action"). The spec seen from the outside is the stack of the 2 specs (as it is for stuff processed in forward).
Would that make sense?

Yes, I think that does make sense, and that is exactly what happens. For instance, if I add a line to print the output of Stack._inv_call at the end of the function and then run the following script:

Click to expand
from torchrl.envs import Stack, TransformedEnv
from torchrl.envs import UnityMLAgentsEnv

base_env = UnityMLAgentsEnv(registered_name='3DBall')

try:
    t = Stack(
        in_keys=[("group_0", f"agent_{idx}") for idx in range(12)],
        out_key="agents",
    )   
    env = TransformedEnv(base_env, t)
    action = env.full_action_spec.rand()
    print("-------------------------")
    print(action)
    print("-------------------------")
    env.step(action)

finally:
    base_env.close()

I get the following output:

Click to expand
-------------------------
TensorDict(
    fields={
        agents: TensorDict(
            fields={
                continuous_action: Tensor(shape=torch.Size([12, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([12]),
            device=None,
            is_shared=False),
        group_0: TensorDict(
            fields={
            },
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
-------------------------
TensorDict(
    fields={
        group_0: TensorDict(
            fields={
                agent_0: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_10: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_11: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_1: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_2: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_3: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_4: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_5: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_6: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_7: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_8: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_9: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Which shows that the inverse of Stack is able to correctly unbind the stacked actions into the format that the unity env expects. BTW, this functionality is being tested in the environment unit test I added.

But one thing from the above example that I'd like to fix is that env.full_action_spec still contains the "group_0" key, with an empty Composite spec, which is left over from exclude-ing all the keys under that group. I'm not sure what's the most efficient way to prune empty specs/tensordicts like those.

EDIT: I found a solution that I think is good--during __init__, build up a list of all the parent keys of in_keys. In the above example, that list would just be [("group_0",)]. Then in _transform_spec and _call, after deleting the in_keys from a spec/tensordict, check if we need to delete any of the parent keys as well. I'll push an update when I have it working and tested.

@kurtamohler kurtamohler added the Environments Adds or modifies an environment wrapper label Nov 20, 2024
@kurtamohler kurtamohler force-pushed the Stack-Transform-0 branch 2 times, most recently from a6e56be to cec32d2 Compare November 21, 2024 04:18
@vmoens
Copy link
Contributor

vmoens commented Nov 21, 2024

Ah got it, sorry I was expecting something like in_keys or such. Maybe let's make clear in the doc strings that in_keys can be part of the input (usually it's reserved to output keys)

Maybe this? If that's what you need you can mention it in the doc of the transform.

https://pytorch.org/rl/main/reference/generated/torchrl.envs.transforms.RemoveEmptySpecs.html#torchrl.envs.transforms.RemoveEmptySpecs

@kurtamohler
Copy link
Collaborator Author

Ah got it, sorry I was expecting something like in_keys or such. Maybe let's make clear in the doc strings that in_keys can be part of the input (usually it's reserved to output keys)

I'm sorry, I'm not sure what you mean by this. in_keys is always supposed to be part of the input, isn't it? Are you talking about the inverse operation here, or the leftover empty specs/tensordicts, or something else?

RemoveEmptySpecs would do what I want, but it's not very efficient since it has to check through the entire spec/tensordict. Wouldn't it be better if Stack has the responsibility of cleaning up any empty keys left over? I don't think the way I implemented it is overly complicated, and it should give significantly better performance than RemoveEmptySpecs, since it knows which keys to check ahead of time. But I suppose one could make the argument that in many cases, it isn't really necessary to remove the empty specs/tensordicts, and not removing them would be more efficient anyway

@vmoens
Copy link
Contributor

vmoens commented Nov 22, 2024

I'm sorry, I'm not sure what you mean by this. in_keys is always supposed to be part of the input, isn't it? Are you talking about the inverse operation here, or the leftover empty specs/tensordicts, or something else?

Yeah usually anything you iterate over during inverse pass is passed through the inv_keys and anything you process during forward is in the in_keys.
It's useful to keep them separated because it might as well be the case that there is a key in the input named as an output but you just want to process it one way and not the other.

@vmoens
Copy link
Contributor

vmoens commented Nov 22, 2024

I don't think the way I implemented it is overly complicated, and it should give significantly better performance than RemoveEmptySpecs, since it knows which keys to check ahead of time.

Ok that sounds good!

@kurtamohler
Copy link
Collaborator Author

Yeah usually anything you iterate over during inverse pass is passed through the inv_keys and anything you process during forward is in the in_keys.

What I think you're saying is that most of the existing transforms that have an inverse allow the user to explicitly set the inverse keys in __init__. But with Stack, the inverse keys are automatically assigned to in_key_inv = out_key and out_keys_inv = in_keys. So we should document that difference. Is that what you're saying?

I guess we also might as well allow the user to override the default in_key_inv and out_keys_inv.

@vmoens
Copy link
Contributor

vmoens commented Nov 23, 2024

Yep that's what I'm saying.
But also I'm a bit worried about this behaviour: in some cases, the observations are part of the input tensordict but they're not input to the env (not registered in the input spec). With the current state of things, they will be split before being fed to the base env when it's not necessary for them to be (it would only be the case if they were part of the input spec), inducing some overhead (with the unbind function).

So I'd rather ask users to indicate the ˋin_keys and ˋin_keys_inv separately.

@kurtamohler
Copy link
Collaborator Author

So I'd rather ask users to indicate the ˋin_keys and ˋin_keys_inv separately.

Got it, that makes sense. I'll make that change

@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Nov 25, 2024

Actually, I guess there are still some decisions to make about the specific behavior of in_keys_inv which we should discuss. What do you think about this?

        in_keys_inv (NestedKey or sequence of NestedKey): keys to be unstacked
            during :meth:`~.inv`. To fully invert the output of
            :meth:`~forward`, set equal to ``out_key``. To only invert part of
            the :meth:`~forward` output, specify subkeys of ``out_key``. For
            instance, if ``out_key=("agents",)`` and ``in_keys_inv=[("agents",
            "action"), ("agents", "state")]``, then :meth:`~.inv` will only
            unstack the "action" and "state" keys under "agents", leaving
            everything else under "agents" unmodified.

We'll probably have to enforce that all of the keys specified in in_keys_inv start with the prefix out_key. For instance, if out_key=("agents",), then each of the in_keys_inv need to be ("agents",) + <some nested key>.

We can still infer the out_keys_inv automatically. But would there be a reason to allow the user to set out_keys_inv as well? If we did allow the user to set out_keys_inv, then looking at the following example:

in_keys=[("agent0",), ("agent1",), ...]
out_key=("agents",)
in_keys_inv=[("agents", "action"), ("agents", "state")]

out_keys_inv would need to be able to indicate a mapping where both the "action" and "state" gets split into multiple agents, like this:

("agents", "action") --> [("agent0", "action"), ("agent1", "action"), ...]
("agents", "state") --> [("agent0", "state"), ("agent1", "state"), ...]

To me, it seems that a natural representation of out_keys_inv would be a list of lists of nested keys:

out_keys_inv = [
    [("agent0", "action"), ("agent1", "action"), ...],
    [("agent0", "state"), ("agent1", "state"), ...]
]

where out_keys_inv[i] is the list of keys to split the key in_keys_inv[i] into. But (I think) none of the other transforms specify keys in the format of a list of lists of nested keys, so it would be out of the ordinary. Plus, maybe it's a bit more verbose than it really needs to be.

Instead, maybe out_keys_inv could just be:

out_keys_inv = [
    ("agent0",), ("agent1",), ...,
]

(Which is just equal to in_keys in this case)

But the docs would need to make it clear that out_keys_inv specifies the prefixes of the actual keys to which the inverse op assigns all the unstacked elements, and that the mapping between in_keys_inv and the output keys of the inverse operation really goes like:

in_keys_inv[i] --> [prefix + in_keys_inv[i][len(out_key) : ] for prefix in out_keys_inv ]

Where i iterates over in_keys_inv. Note that out_key here is the output key of the forward op.

What do you think? Is there a better decision for the behavior of in_keys_inv? Can you think of a reason why users would want to specify out_keys_inv? If not, then I suppose we can leave it out for now until someone requests that feature

@vmoens
Copy link
Contributor

vmoens commented Nov 27, 2024

in_keys_inv=[("agents", "action"), ("agents", "state")]

@kurtamohler what about we simply restrict the transform to stack X keys into 1 during forward and split 1 entry into Y during inverse?

If users want to unbind multiple keys they should just have multiple transforms.

@kurtamohler
Copy link
Collaborator Author

Ok that sounds good to me!

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very minor comments but otherwise LGTM

Comment on lines 4329 to 4340
This transform is useful for environments that have multiple agents with
identical specs under different keys. The specs and tensordicts for the
agents can be stacked together under a shared key, in order to run MARL
algorithms that expect the tensors for observations, rewards, etc. to
contain batched data for all the agents.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather have a general statement about the transform first - some will use that just to stack tensors together, not in MARL settings (eg multiple images in one)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a note saying that multiple stacks will require multiple transforms?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, let me know if that's good

@vmoens vmoens merged commit 594462d into pytorch:main Dec 4, 2024
63 of 76 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request Environments Adds or modifies an environment wrapper
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Transform that stacks data for agents with identical specs
3 participants