From 23bf3151a5bb1377d6ba65c5edb5659618af0da3 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Wed, 21 Feb 2024 15:03:51 +0000 Subject: [PATCH] [BugFix] Fix multiple context syntax in multiagent examples (#1943) --- examples/multiagent/iql.py | 2 +- examples/multiagent/maddpg_iddpg.py | 2 +- examples/multiagent/mappo_ippo.py | 3 ++- examples/multiagent/qmix_vdn.py | 2 +- examples/multiagent/sac.py | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/multiagent/iql.py b/examples/multiagent/iql.py index 1408e47e915..bb374b99941 100644 --- a/examples/multiagent/iql.py +++ b/examples/multiagent/iql.py @@ -206,7 +206,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad() and set_exploration_type(ExplorationType.MEAN): + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/examples/multiagent/maddpg_iddpg.py b/examples/multiagent/maddpg_iddpg.py index d4ed03ad3c6..bed6240d244 100644 --- a/examples/multiagent/maddpg_iddpg.py +++ b/examples/multiagent/maddpg_iddpg.py @@ -230,7 +230,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad() and set_exploration_type(ExplorationType.MEAN): + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/examples/multiagent/mappo_ippo.py b/examples/multiagent/mappo_ippo.py index 8f4a2356c35..908d0bdc106 100644 --- a/examples/multiagent/mappo_ippo.py +++ b/examples/multiagent/mappo_ippo.py @@ -17,6 +17,7 @@ from torchrl.data.replay_buffers.storages import LazyTensorStorage from torchrl.envs import RewardSum, TransformedEnv from torchrl.envs.libs.vmas import VmasEnv +from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator from torchrl.modules.models.multiagent import MultiAgentMLP from torchrl.objectives import ClipPPOLoss, ValueEstimators @@ -235,7 +236,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(): + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/examples/multiagent/qmix_vdn.py b/examples/multiagent/qmix_vdn.py index e814ce8f79f..008e01b28b9 100644 --- a/examples/multiagent/qmix_vdn.py +++ b/examples/multiagent/qmix_vdn.py @@ -241,7 +241,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad() and set_exploration_type(ExplorationType.MEAN): + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/examples/multiagent/sac.py b/examples/multiagent/sac.py index 28317dba728..d76ddd1f913 100644 --- a/examples/multiagent/sac.py +++ b/examples/multiagent/sac.py @@ -300,7 +300,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad() and set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps,