Skip to content

Commit

Permalink
[Feature] Make advantages compatible with Terminated, Truncated, Done (
Browse files Browse the repository at this point in the history
…#1581)

Co-authored-by: Skander Moalla <37197319+skandermoalla@users.noreply.github.com>
Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 2, 2023
1 parent 3785609 commit 106368f
Show file tree
Hide file tree
Showing 22 changed files with 1,203 additions and 337 deletions.
11 changes: 4 additions & 7 deletions examples/multiagent/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import DQNLoss, SoftUpdate, ValueEstimators
from utils.logging import init_logging, log_evaluation, log_training
from utils.utils import DoneTransform


def rendering_callback(env, td):
Expand Down Expand Up @@ -111,6 +112,7 @@ def train(cfg: "DictConfig"): # noqa: F821
storing_device=cfg.train.device,
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys),
)

replay_buffer = TensorDictReplayBuffer(
Expand All @@ -125,6 +127,8 @@ def train(cfg: "DictConfig"): # noqa: F821
action=env.action_key,
value=("agents", "chosen_action_value"),
reward=env.reward_key,
done=("agents", "done"),
terminated=("agents", "terminated"),
)
loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma)
target_net_updater = SoftUpdate(loss_module, eps=1 - cfg.loss.tau)
Expand All @@ -144,13 +148,6 @@ def train(cfg: "DictConfig"): # noqa: F821

sampling_time = time.time() - sampling_start

tensordict_data.set(
("next", "done"),
tensordict_data.get(("next", "done"))
.unsqueeze(-1)
.expand(tensordict_data.get(("next", env.reward_key)).shape),
) # We need to expand the done to match the reward shape

current_frames = tensordict_data.numel()
total_frames += current_frames
data_view = tensordict_data.reshape(-1)
Expand Down
11 changes: 4 additions & 7 deletions examples/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import DDPGLoss, SoftUpdate, ValueEstimators
from utils.logging import init_logging, log_evaluation, log_training
from utils.utils import DoneTransform


def rendering_callback(env, td):
Expand Down Expand Up @@ -133,6 +134,7 @@ def train(cfg: "DictConfig"): # noqa: F821
storing_device=cfg.train.device,
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys),
)

replay_buffer = TensorDictReplayBuffer(
Expand All @@ -147,6 +149,8 @@ def train(cfg: "DictConfig"): # noqa: F821
loss_module.set_keys(
state_action_value=("agents", "state_action_value"),
reward=env.reward_key,
done=("agents", "done"),
terminated=("agents", "terminated"),
)
loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma)
target_net_updater = SoftUpdate(loss_module, eps=1 - cfg.loss.tau)
Expand All @@ -170,13 +174,6 @@ def train(cfg: "DictConfig"): # noqa: F821

sampling_time = time.time() - sampling_start

tensordict_data.set(
("next", "done"),
tensordict_data.get(("next", "done"))
.unsqueeze(-1)
.expand(tensordict_data.get(("next", env.reward_key)).shape),
) # We need to expand the done to match the reward shape

current_frames = tensordict_data.numel()
total_frames += current_frames
data_view = tensordict_data.reshape(-1)
Expand Down
16 changes: 8 additions & 8 deletions examples/multiagent/mappo_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import ClipPPOLoss, ValueEstimators
from utils.logging import init_logging, log_evaluation, log_training
from utils.utils import DoneTransform


def rendering_callback(env, td):
Expand Down Expand Up @@ -126,6 +127,7 @@ def train(cfg: "DictConfig"): # noqa: F821
storing_device=cfg.train.device,
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys),
)

replay_buffer = TensorDictReplayBuffer(
Expand All @@ -142,7 +144,12 @@ def train(cfg: "DictConfig"): # noqa: F821
entropy_coef=cfg.loss.entropy_eps,
normalize_advantage=False,
)
loss_module.set_keys(reward=env.reward_key, action=env.action_key)
loss_module.set_keys(
reward=env.reward_key,
action=env.action_key,
done=("agents", "done"),
terminated=("agents", "terminated"),
)
loss_module.make_value_estimator(
ValueEstimators.GAE, gamma=cfg.loss.gamma, lmbda=cfg.loss.lmbda
)
Expand All @@ -165,13 +172,6 @@ def train(cfg: "DictConfig"): # noqa: F821

sampling_time = time.time() - sampling_start

tensordict_data.set(
("next", "done"),
tensordict_data.get(("next", "done"))
.unsqueeze(-1)
.expand(tensordict_data.get(("next", env.reward_key)).shape),
) # We need to expand the done to match the reward shape

with torch.no_grad():
loss_module.value_estimator(
tensordict_data,
Expand Down
13 changes: 6 additions & 7 deletions examples/multiagent/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import DiscreteSACLoss, SACLoss, SoftUpdate, ValueEstimators
from utils.logging import init_logging, log_evaluation, log_training
from utils.utils import DoneTransform


def rendering_callback(env, td):
Expand Down Expand Up @@ -179,6 +180,7 @@ def train(cfg: "DictConfig"): # noqa: F821
storing_device=cfg.train.device,
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys),
)

replay_buffer = TensorDictReplayBuffer(
Expand All @@ -198,6 +200,8 @@ def train(cfg: "DictConfig"): # noqa: F821
state_action_value=("agents", "state_action_value"),
action=env.action_key,
reward=env.reward_key,
done=("agents", "done"),
terminated=("agents", "terminated"),
)
else:
loss_module = DiscreteSACLoss(
Expand All @@ -211,6 +215,8 @@ def train(cfg: "DictConfig"): # noqa: F821
action_value=("agents", "action_value"),
action=env.action_key,
reward=env.reward_key,
done=("agents", "done"),
terminated=("agents", "terminated"),
)

loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma)
Expand All @@ -235,13 +241,6 @@ def train(cfg: "DictConfig"): # noqa: F821

sampling_time = time.time() - sampling_start

tensordict_data.set(
("next", "done"),
tensordict_data.get(("next", "done"))
.unsqueeze(-1)
.expand(tensordict_data.get(("next", env.reward_key)).shape),
) # We need to expand the done to match the reward shape

current_frames = tensordict_data.numel()
total_frames += current_frames
data_view = tensordict_data.reshape(-1)
Expand Down
41 changes: 41 additions & 0 deletions examples/multiagent/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from tensordict import unravel_key
from torchrl.envs import Transform


def swap_last(source, dest):
source = unravel_key(source)
dest = unravel_key(dest)
if isinstance(source, str):
if isinstance(dest, str):
return dest
return dest[-1]
if isinstance(dest, str):
return source[:-1] + (dest,)
return source[:-1] + (dest[-1],)


class DoneTransform(Transform):
"""Expands the 'done' entries (incl. terminated) to match the reward shape.
Can be appended to a replay buffer or a collector.
"""

def __init__(self, reward_key, done_keys):
super().__init__()
self.reward_key = reward_key
self.done_keys = done_keys

def forward(self, tensordict):
for done_key in self.done_keys:
new_name = swap_last(self.reward_key, done_key)
tensordict.set(
("next", new_name),
tensordict.get(("next", done_key))
.unsqueeze(-1)
.expand(tensordict.get(("next", self.reward_key)).shape),
)
return tensordict
Loading

0 comments on commit 106368f

Please sign in to comment.