Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Batched actions wrapper #2018

Merged
merged 8 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Mar 19, 2024
commit 504a55f3812a4f7e75f154398ff32165cb0bc468
4 changes: 4 additions & 0 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,6 +1503,10 @@ def test_batched_actor_exceptions(self):
with pytest.raises(RuntimeError, match="Cannot initialize the wrapper"):
env.rollout(10, actor, tensordict=td, auto_reset=False)

actor = BatchedActionWrapper(actor_base, n_steps=time_steps - 1)
with pytest.raises(RuntimeError, match="The action's time dimension"):
env.rollout(10, actor)

@pytest.mark.parametrize("time_steps", [3, 5])
def test_batched_actor_simple(self, time_steps):

Expand Down
19 changes: 18 additions & 1 deletion torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2127,12 +2127,24 @@ class BatchedActionWrapper(TensorDictModuleBase):
because a "done" state has been encountered. Unlike ``action_keys``,
this key must be unique.

Args:
actor (TensorDictModuleBase): An actor.
n_steps (int): the number of actions the actor outputs at once
(lookahead window).

Keyword Args:
action_keys (list of NestedKeys, optional): the action keys from
the environment. Can be retrieved from ``env.action_keys``.
Defaults to all ``out_keys`` of the ``actor`` which end
with the ``"action"`` string.
init_key (NestedK
"""

def __init__(
self,
actor: TensorDictModule,
actor: TensorDictModuleBase,
n_steps: int,
*,
action_keys: List[NestedKey] | None = None,
init_key: List[NestedKey] | None = None,
):
Expand Down Expand Up @@ -2210,6 +2222,11 @@ def forward(
action_entry = parent_td.get(action_key_orig[-1], None)
if action_entry is None:
raise self._NO_INIT_ERR
if action_entry.shape[parent_td.ndim] != self.n_steps:
raise RuntimeError(
f"The action's time dimension (dim={parent_td.ndim}) doesn't match the n_steps argument ({self.n_steps}). "
f"The action shape was {action_entry.shape}."
)
base_idx = (
slice(
None,
Expand Down
Loading