Skip to content

Commit

Permalink
[Feature] Batched actions wrapper (#2018)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 19, 2024
1 parent f6fbc44 commit 76b296d
Show file tree
Hide file tree
Showing 9 changed files with 435 additions and 7 deletions.
5 changes: 3 additions & 2 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ The main features are:
and/or return the distribution of interest;
- Custom containers for Q-Value learning, model-based agents and others.

SafeModules
~~~~~~~~~~~
TensorDictModules and SafeModules
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

TorchRL :class:`~torchrl.modules.tensordict_module.SafeModule` allows you to
check the you model output matches what is to be expected for the environment.
Expand All @@ -52,6 +52,7 @@ projected (in a L1-manner) into the desired domain.
:template: rl_template_noinherit.rst

Actor
MultiStepActorWrapper
SafeModule
SafeSequential
TanhModule
Expand Down
62 changes: 62 additions & 0 deletions examples/agents/multi-step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Example of a dummy multi-step agent.
A multi-step actor predicts a macro (or an action sequence) and executes it regardless of the observations
coming in the meantime.
The core component of this example is the `MultiStepActorWrapper` class.
`MultiStepActorWrapper` handles the calls to the actor when the macro has run out of actions or
when the environment has been reset (which is indicated by the InitTracker transform).
"""

import torch.nn
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from torchrl.envs import (
CatFrames,
Compose,
GymEnv,
InitTracker,
SerialEnv,
TransformedEnv,
)
from torchrl.modules.tensordict_module.actors import MultiStepActorWrapper

time_steps = 6
n_obs = 4
n_action = 2
batch = 5


# Transforms a CatFrames in a stack of frames
def reshape_cat(data: torch.Tensor):
return data.unflatten(-1, (time_steps, n_obs))


# an actor that reads `time_steps` frames and outputs one action per frame
# (actions are conditioned on the observation of `time_steps` in the past)
actor_base = Seq(
Mod(reshape_cat, in_keys=["obs_cat"], out_keys=["obs_cat_reshape"]),
Mod(
torch.nn.Linear(n_obs, n_action),
in_keys=["obs_cat_reshape"],
out_keys=["action"],
),
)
# Wrap the actor to dispatch the actions
actor = MultiStepActorWrapper(actor_base, n_steps=time_steps)

env = TransformedEnv(
SerialEnv(batch, lambda: GymEnv("CartPole-v1")),
Compose(
InitTracker(),
CatFrames(N=time_steps, in_keys=["observation"], out_keys=["obs_cat"], dim=-1),
),
)

print(env.rollout(100, policy=actor, break_when_any_done=False))
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ per-file-ignores =
torchrl/objectives/value/advantages.py: TOR101
tutorials/*/**.py: T201
build_tools/setup_helpers/extension.py: T201
examples/torchrl_features/*.py: T201
examples/*.py: T201
test/opengl_rendering.py: T201
*/**/run-clang-format.py: T201

Expand Down
110 changes: 108 additions & 2 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest
import torch
from mocking_classes import DiscreteActionVecMockEnv
from mocking_classes import CountingEnv, DiscreteActionVecMockEnv
from tensordict import pad, TensorDict, unravel_key_list
from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential
from torch import nn
Expand All @@ -16,7 +16,14 @@
CompositeSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import EnvCreator, SerialEnv
from torchrl.envs import (
CatFrames,
Compose,
EnvCreator,
InitTracker,
SerialEnv,
TransformedEnv,
)
from torchrl.envs.utils import set_exploration_type, step_mdp
from torchrl.modules import (
AdditiveGaussianWrapper,
Expand All @@ -25,6 +32,7 @@
GRUModule,
LSTMModule,
MLP,
MultiStepActorWrapper,
NormalParamWrapper,
OnlineDTActor,
ProbabilisticActor,
Expand Down Expand Up @@ -1443,6 +1451,104 @@ def test_dt_inference_wrapper(self, online):
) - set(inference_actor.in_keys)


class TestBatchedActor:
def test_batched_actor_exceptions(self):
time_steps = 5
actor_base = TensorDictModule(
lambda x: torch.ones(
x.shape[0], time_steps, 1, device=x.device, dtype=x.dtype
),
in_keys=["observation_cat"],
out_keys=["action"],
)
with pytest.raises(ValueError, match="Only a single init_key can be passed"):
MultiStepActorWrapper(actor_base, n_steps=time_steps, init_key=["init_key"])

n_obs = 1
n_action = 1
batch = 2

# The second env has frequent resets, the first none
base_env = SerialEnv(
batch,
[lambda: CountingEnv(max_steps=5000), lambda: CountingEnv(max_steps=5)],
)
env = TransformedEnv(
base_env,
CatFrames(
N=time_steps,
in_keys=["observation"],
out_keys=["observation_cat"],
dim=-1,
),
)
actor = MultiStepActorWrapper(actor_base, n_steps=time_steps)
with pytest.raises(KeyError, match="No init key was passed"):
env.rollout(2, actor)

env = TransformedEnv(
base_env,
Compose(
InitTracker(),
CatFrames(
N=time_steps,
in_keys=["observation"],
out_keys=["observation_cat"],
dim=-1,
),
),
)
td = env.rollout(10)[..., -1]["next"]
actor = MultiStepActorWrapper(actor_base, n_steps=time_steps)
with pytest.raises(RuntimeError, match="Cannot initialize the wrapper"):
env.rollout(10, actor, tensordict=td, auto_reset=False)

actor = MultiStepActorWrapper(actor_base, n_steps=time_steps - 1)
with pytest.raises(RuntimeError, match="The action's time dimension"):
env.rollout(10, actor)

@pytest.mark.parametrize("time_steps", [3, 5])
def test_batched_actor_simple(self, time_steps):

batch = 2

# The second env has frequent resets, the first none
base_env = SerialEnv(
batch,
[lambda: CountingEnv(max_steps=5000), lambda: CountingEnv(max_steps=5)],
)
env = TransformedEnv(
base_env,
Compose(
InitTracker(),
CatFrames(
N=time_steps,
in_keys=["observation"],
out_keys=["observation_cat"],
dim=-1,
),
),
)

actor_base = TensorDictModule(
lambda x: torch.ones(
x.shape[0], time_steps, 1, device=x.device, dtype=x.dtype
),
in_keys=["observation_cat"],
out_keys=["action"],
)
actor = MultiStepActorWrapper(actor_base, n_steps=time_steps)
# rollout = env.rollout(100, break_when_any_done=False)
rollout = env.rollout(50, actor, break_when_any_done=False)
unique = rollout[0]["observation"].unique()
predicted = torch.arange(unique.numel())
assert (unique == predicted).all()
assert (
rollout[1]["observation"]
== (torch.arange(50) % 6).reshape_as(rollout[1]["observation"])
).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
1 change: 1 addition & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
LSTM,
LSTMCell,
LSTMModule,
MultiStepActorWrapper,
OrnsteinUhlenbeckProcessWrapper,
ProbabilisticActor,
QValueActor,
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/tensordict_module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DistributionalQValueHook,
DistributionalQValueModule,
LMHeadActorValueOperator,
MultiStepActorWrapper,
ProbabilisticActor,
QValueActor,
QValueHook,
Expand Down
Loading

1 comment on commit 76b296d

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 76b296d Previous: f6fbc44 Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 332.38019103358624 iter/sec (stddev: 0.011652557506936787) 755.4771793311247 iter/sec (stddev: 0.00007278786761598838) 2.27

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.