From a41da21d8ab7bfb6a65f7ccb98f9cdd5771bedbe Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 6 Aug 2024 23:30:10 +0100 Subject: [PATCH] [BugFix] Fix MARL-DDPG tutorial and other MODE usages (#2373) --- README.md | 2 +- docs/source/reference/modules.rst | 2 +- docs/source/reference/objectives.rst | 2 +- sota-implementations/crossq/crossq.py | 2 +- sota-implementations/td3_bc/td3_bc.py | 2 +- test/test_exploration.py | 7 +++++-- test/test_tensordictmodules.py | 2 +- torchrl/modules/__init__.py | 1 + torchrl/objectives/__init__.py | 2 -- tutorials/sphinx-tutorials/coding_dqn.py | 2 +- tutorials/sphinx-tutorials/dqn_with_rnn.py | 2 +- tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py | 4 ++-- tutorials/sphinx-tutorials/torchrl_demo.py | 2 +- 13 files changed, 17 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index f82a8ff0c4c..9b812a21aa0 100644 --- a/README.md +++ b/README.md @@ -478,7 +478,7 @@ And it is `functorch` and `torch.compile` compatible! policy_explore = EGreedyWrapper(policy) with set_exploration_type(ExplorationType.RANDOM): tensordict = policy_explore(tensordict) # will use eps-greedy - with set_exploration_type(ExplorationType.MODE): + with set_exploration_type(ExplorationType.DETERMINISTIC): tensordict = policy_explore(tensordict) # will not use eps-greedy ``` diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index c73ed5083fd..5b05fc32194 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -319,7 +319,7 @@ Regular modules Conv3dNet SqueezeLayer Squeeze2dLayer - BatchRenorm + BatchRenorm1d Algorithm-specific modules ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index db0c58409e2..b3f8e242a9e 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -157,7 +157,7 @@ CrossQ :toctree: generated/ :template: rl_template_noinherit.rst - CrossQ + CrossQLoss IQL ---- diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index df34d4ae68d..c5a1b88eea3 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -203,7 +203,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/td3_bc/td3_bc.py b/sota-implementations/td3_bc/td3_bc.py index 7c43fdc1a12..b3e8ed3b880 100644 --- a/sota-implementations/td3_bc/td3_bc.py +++ b/sota-implementations/td3_bc/td3_bc.py @@ -128,7 +128,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # evaluation if i % evaluation_interval == 0: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) diff --git a/test/test_exploration.py b/test/test_exploration.py index 83ee4bc4220..b2fd97d986f 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -644,7 +644,7 @@ def test_no_spec_error(self, device): @pytest.mark.parametrize("safe", [True, False]) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize( - "exploration_type", [InteractionType.RANDOM, InteractionType.MODE] + "exploration_type", [InteractionType.RANDOM, InteractionType.DETERMINISTIC] ) def test_gsde( state_dim, action_dim, gSDE, device, safe, exploration_type, batch=16, bound=0.1 @@ -708,7 +708,10 @@ def test_gsde( with set_exploration_type(exploration_type): action1 = module(td).get("action") action2 = actor(td.exclude("action")).get("action") - if gSDE or exploration_type == InteractionType.MODE: + if gSDE or exploration_type in ( + InteractionType.DETERMINISTIC, + InteractionType.MODE, + ): torch.testing.assert_close(action1, action2) else: with pytest.raises(AssertionError): diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 38360a464e0..42e0880e6a4 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -189,7 +189,7 @@ def test_stateful(self, safe, spec_type, lazy): @pytest.mark.parametrize("out_keys", [["loc", "scale"], ["loc_1", "scale_1"]]) @pytest.mark.parametrize("lazy", [True, False]) @pytest.mark.parametrize( - "exp_mode", [InteractionType.MODE, InteractionType.RANDOM, None] + "exp_mode", [InteractionType.DETERMINISTIC, InteractionType.RANDOM, None] ) def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys): torch.manual_seed(0) diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 0a06e5844a0..c246b553e95 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -20,6 +20,7 @@ TruncatedNormal, ) from .models import ( + BatchRenorm1d, Conv3dNet, ConvNet, DdpgCnnActor, diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 60701cb0121..1ea9ebb5998 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -30,5 +30,3 @@ SoftUpdate, ValueEstimators, ) - -# from .value import bellman_max, c_val, dv_val, vtrace, GAE, TDLambdaEstimate, TDEstimate diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index e9f2085d3df..2da1967e5ad 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -672,7 +672,7 @@ def get_loss_module(actor, gamma): frame_skip=1, policy_exploration=actor_explore, environment=test_env, - exploration_type=ExplorationType.MODE, + exploration_type=ExplorationType.DETERMINISTIC, log_keys=[("next", "reward")], out_keys={("next", "reward"): "rewards"}, log_pbar=True, diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index 28a9638c6f6..8931f483384 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -440,7 +440,7 @@ exploration_module.step(data.numel()) updater.step() - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): rollout = env.rollout(10000, stoch_policy) traj_lens.append(rollout.get(("next", "step_count")).max().item()) diff --git a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py index 77574b765e7..fc1a22d50cf 100644 --- a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py +++ b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py @@ -817,7 +817,7 @@ def process_batch(batch: TensorDictBase) -> TensorDictBase: target_updaters[group].step() # Exploration sigma anneal update - exploration_policies[group].step(current_frames) + exploration_policies[group][-1].step(current_frames) # Stop training a certain group when a condition is met (e.g., number of training iterations) if iteration == iteration_when_stop_training_evaders: @@ -903,7 +903,7 @@ def process_batch(batch: TensorDictBase) -> TensorDictBase: env_with_render = env_with_render.append_transform( VideoRecorder(logger=video_logger, tag="vmas_rendered") ) - with set_exploration_type(ExplorationType.MODE): + with set_exploration_type(ExplorationType.DETERMINISTIC): print("Rendering rollout...") env_with_render.rollout(100, policy=agents_exploration_policy) print("Saving the video...") diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 9d25da0a4cd..29192d1c10e 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -652,7 +652,7 @@ def exec_sequence(params, data): td_module(td) print("random:", td["action"]) -with set_exploration_type(ExplorationType.MODE): +with set_exploration_type(ExplorationType.DETERMINISTIC): td_module(td) print("mode:", td["action"])