Skip to content

Commit

Permalink
[BugFix] Multiagent "auto" entropy fix in SAC (pytorch#1494)
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <matbet@meta.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
matteobettini and vmoens committed Oct 10, 2023
1 parent e28ef5c commit 3cc870a
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 11 deletions.
5 changes: 4 additions & 1 deletion examples/multiagent/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,10 @@ def train(cfg: "DictConfig"): # noqa: F821

if cfg.env.continuous_actions:
loss_module = SACLoss(
actor_network=policy, qvalue_network=value_module, delay_qvalue=True
actor_network=policy,
qvalue_network=value_module,
delay_qvalue=True,
action_spec=env.unbatched_action_spec,
)
loss_module.set_keys(
state_action_value=("agents", "state_action_value"),
Expand Down
13 changes: 12 additions & 1 deletion torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,19 @@ def target_entropy(self):
)
if not isinstance(action_spec, CompositeSpec):
action_spec = CompositeSpec({self.tensor_keys.action: action_spec})
if (
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
):
action_container_shape = action_spec[
self.tensor_keys.action[:-1]
].shape
else:
action_container_shape = action_spec.shape
target_entropy = -float(
np.prod(action_spec[self.tensor_keys.action].shape)
action_spec[self.tensor_keys.action]
.shape[len(action_container_shape) :]
.numel()
)
self.register_buffer(
"target_entropy_buffer", torch.tensor(target_entropy, device=device)
Expand Down
17 changes: 14 additions & 3 deletions torchrl/objectives/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from dataclasses import dataclass
from typing import Union

import numpy as np

import torch
from tensordict.nn import dispatch
from tensordict.tensordict import TensorDict, TensorDictBase
Expand Down Expand Up @@ -127,7 +125,20 @@ def __init__(
"the target entropy explicitely or provide the spec of the "
"action tensor in the actor network."
)
target_entropy = -float(np.prod(actor_network.spec["action"].shape))
if (
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
):
action_container_shape = actor_network.spec[
self.tensor_keys.action[:-1]
].shape
else:
action_container_shape = actor_network.spec.shape
target_entropy = -float(
actor_network.spec[self.tensor_keys.action]
.shape[len(action_container_shape) :]
.numel()
)
self.register_buffer(
"target_entropy", torch.tensor(target_entropy, device=device)
)
Expand Down
14 changes: 12 additions & 2 deletions torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from numbers import Number
from typing import Union

import numpy as np
import torch

from tensordict.nn import dispatch, TensorDictModule, TensorDictSequential
Expand Down Expand Up @@ -352,8 +351,19 @@ def target_entropy(self):
)
if not isinstance(action_spec, CompositeSpec):
action_spec = CompositeSpec({self.tensor_keys.action: action_spec})
if (
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
):
action_container_shape = action_spec[
self.tensor_keys.action[:-1]
].shape
else:
action_container_shape = action_spec.shape
target_entropy = -float(
np.prod(action_spec[self.tensor_keys.action].shape)
action_spec[self.tensor_keys.action]
.shape[len(action_container_shape) :]
.numel()
)
self.register_buffer(
"target_entropy_buffer", torch.tensor(target_entropy, device=device)
Expand Down
16 changes: 13 additions & 3 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@

import numpy as np
import torch

from tensordict.nn import dispatch, make_functional, TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey
from torch import Tensor

from torchrl.data import CompositeSpec, TensorSpec
from torchrl.data.utils import _find_action_space
from torchrl.envs.utils import ExplorationType, set_exploration_type

from torchrl.modules import ProbabilisticActor
from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
from torchrl.objectives.common import LossModule
Expand Down Expand Up @@ -401,8 +400,19 @@ def target_entropy(self):
)
if not isinstance(action_spec, CompositeSpec):
action_spec = CompositeSpec({self.tensor_keys.action: action_spec})
if (
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
):
action_container_shape = action_spec[
self.tensor_keys.action[:-1]
].shape
else:
action_container_shape = action_spec.shape
target_entropy = -float(
np.prod(action_spec[self.tensor_keys.action].shape)
action_spec[self.tensor_keys.action].shape[
len(action_container_shape) :
].numel()
)
self.register_buffer(
"target_entropy_buffer", torch.tensor(target_entropy, device=device)
Expand Down
13 changes: 12 additions & 1 deletion torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,18 @@ def __init__(
)
elif action_spec is not None:
if isinstance(action_spec, CompositeSpec):
action_spec = action_spec[self.tensor_keys.action]
if (
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
):
action_container_shape = action_spec[
self.tensor_keys.action[:-1]
].shape
else:
action_container_shape = action_spec.shape
action_spec = action_spec[self.tensor_keys.action][
(0,) * len(action_container_shape)
]
if not isinstance(action_spec, BoundedTensorSpec):
raise ValueError(
f"action_spec is not of type BoundedTensorSpec but {type(action_spec)}."
Expand Down

0 comments on commit 3cc870a

Please sign in to comment.