From 4c50f1ec04a52b01ac1acb53519b2c566a97d8a7 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Thu, 7 Sep 2023 15:46:22 +0100 Subject: [PATCH] [BugFix] Multiagent "auto" entropy fix in SAC (#1494) Signed-off-by: Matteo Bettini Co-authored-by: Vincent Moens --- examples/multiagent/sac.py | 5 ++++- torchrl/objectives/cql.py | 13 ++++++++++++- torchrl/objectives/decision_transformer.py | 17 ++++++++++++++--- torchrl/objectives/redq.py | 14 ++++++++++++-- torchrl/objectives/sac.py | 16 +++++++++++++--- torchrl/objectives/td3.py | 13 ++++++++++++- 6 files changed, 67 insertions(+), 11 deletions(-) diff --git a/examples/multiagent/sac.py b/examples/multiagent/sac.py index 72e45b05a1e..e9aea20e282 100644 --- a/examples/multiagent/sac.py +++ b/examples/multiagent/sac.py @@ -189,7 +189,10 @@ def train(cfg: "DictConfig"): # noqa: F821 if cfg.env.continuous_actions: loss_module = SACLoss( - actor_network=policy, qvalue_network=value_module, delay_qvalue=True + actor_network=policy, + qvalue_network=value_module, + delay_qvalue=True, + action_spec=env.unbatched_action_spec, ) loss_module.set_keys( state_action_value=("agents", "state_action_value"), diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 06e8de28bf7..b24d4498106 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -366,8 +366,19 @@ def target_entropy(self): ) if not isinstance(action_spec, CompositeSpec): action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if ( + isinstance(self.tensor_keys.action, tuple) + and len(self.tensor_keys.action) > 1 + ): + action_container_shape = action_spec[ + self.tensor_keys.action[:-1] + ].shape + else: + action_container_shape = action_spec.shape target_entropy = -float( - np.prod(action_spec[self.tensor_keys.action].shape) + action_spec[self.tensor_keys.action] + .shape[len(action_container_shape) :] + .numel() ) self.register_buffer( "target_entropy_buffer", torch.tensor(target_entropy, device=device) diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 768e05de9c6..24f6c184d7d 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -7,8 +7,6 @@ from dataclasses import dataclass from typing import Union -import numpy as np - import torch from tensordict.nn import dispatch from tensordict.tensordict import TensorDict, TensorDictBase @@ -127,7 +125,20 @@ def __init__( "the target entropy explicitely or provide the spec of the " "action tensor in the actor network." ) - target_entropy = -float(np.prod(actor_network.spec["action"].shape)) + if ( + isinstance(self.tensor_keys.action, tuple) + and len(self.tensor_keys.action) > 1 + ): + action_container_shape = actor_network.spec[ + self.tensor_keys.action[:-1] + ].shape + else: + action_container_shape = actor_network.spec.shape + target_entropy = -float( + actor_network.spec[self.tensor_keys.action] + .shape[len(action_container_shape) :] + .numel() + ) self.register_buffer( "target_entropy", torch.tensor(target_entropy, device=device) ) diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 039b5c65b9d..afafcbfd446 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -8,7 +8,6 @@ from numbers import Number from typing import Union -import numpy as np import torch from tensordict.nn import dispatch, TensorDictModule, TensorDictSequential @@ -352,8 +351,19 @@ def target_entropy(self): ) if not isinstance(action_spec, CompositeSpec): action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if ( + isinstance(self.tensor_keys.action, tuple) + and len(self.tensor_keys.action) > 1 + ): + action_container_shape = action_spec[ + self.tensor_keys.action[:-1] + ].shape + else: + action_container_shape = action_spec.shape target_entropy = -float( - np.prod(action_spec[self.tensor_keys.action].shape) + action_spec[self.tensor_keys.action] + .shape[len(action_container_shape) :] + .numel() ) self.register_buffer( "target_entropy_buffer", torch.tensor(target_entropy, device=device) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index a1ebf0e5873..a82795ab1bb 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -10,15 +10,14 @@ import numpy as np import torch + from tensordict.nn import dispatch, make_functional, TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import Tensor - from torchrl.data import CompositeSpec, TensorSpec from torchrl.data.utils import _find_action_space from torchrl.envs.utils import ExplorationType, set_exploration_type - from torchrl.modules import ProbabilisticActor from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.common import LossModule @@ -401,8 +400,19 @@ def target_entropy(self): ) if not isinstance(action_spec, CompositeSpec): action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if ( + isinstance(self.tensor_keys.action, tuple) + and len(self.tensor_keys.action) > 1 + ): + action_container_shape = action_spec[ + self.tensor_keys.action[:-1] + ].shape + else: + action_container_shape = action_spec.shape target_entropy = -float( - np.prod(action_spec[self.tensor_keys.action].shape) + action_spec[self.tensor_keys.action].shape[ + len(action_container_shape) : + ].numel() ) self.register_buffer( "target_entropy_buffer", torch.tensor(target_entropy, device=device) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index cebdeb2e669..68d63fbaa47 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -270,7 +270,18 @@ def __init__( ) elif action_spec is not None: if isinstance(action_spec, CompositeSpec): - action_spec = action_spec[self.tensor_keys.action] + if ( + isinstance(self.tensor_keys.action, tuple) + and len(self.tensor_keys.action) > 1 + ): + action_container_shape = action_spec[ + self.tensor_keys.action[:-1] + ].shape + else: + action_container_shape = action_spec.shape + action_spec = action_spec[self.tensor_keys.action][ + (0,) * len(action_container_shape) + ] if not isinstance(action_spec, BoundedTensorSpec): raise ValueError( f"action_spec is not of type BoundedTensorSpec but {type(action_spec)}."