Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix MARL-DDPG tutorial and other MODE usages #2373

Merged
merged 12 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Aug 6, 2024
commit 85ee3fee3c9ef316284c777c4e6be59aef9016c5
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ And it is `functorch` and `torch.compile` compatible!
policy_explore = EGreedyWrapper(policy)
with set_exploration_type(ExplorationType.RANDOM):
tensordict = policy_explore(tensordict) # will use eps-greedy
with set_exploration_type(ExplorationType.MODE):
with set_exploration_type(ExplorationType.DETERMINISTIC):
tensordict = policy_explore(tensordict) # will not use eps-greedy
```
</details>
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ Regular modules
Conv3dNet
SqueezeLayer
Squeeze2dLayer
BatchRenorm
BatchRenorm1d

Algorithm-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ CrossQ
:toctree: generated/
:template: rl_template_noinherit.rst

CrossQ
CrossQLoss

IQL
----
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/td3_bc/td3_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# evaluation
if i % evaluation_interval == 0:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
Expand Down
7 changes: 5 additions & 2 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ def test_no_spec_error(self, device):
@pytest.mark.parametrize("safe", [True, False])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize(
"exploration_type", [InteractionType.RANDOM, InteractionType.MODE]
"exploration_type", [InteractionType.RANDOM, InteractionType.DETERMINISTIC]
)
def test_gsde(
state_dim, action_dim, gSDE, device, safe, exploration_type, batch=16, bound=0.1
Expand Down Expand Up @@ -708,7 +708,10 @@ def test_gsde(
with set_exploration_type(exploration_type):
action1 = module(td).get("action")
action2 = actor(td.exclude("action")).get("action")
if gSDE or exploration_type == InteractionType.MODE:
if gSDE or exploration_type in (
InteractionType.DETERMINISTIC,
InteractionType.MODE,
):
torch.testing.assert_close(action1, action2)
else:
with pytest.raises(AssertionError):
Expand Down
2 changes: 1 addition & 1 deletion test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_stateful(self, safe, spec_type, lazy):
@pytest.mark.parametrize("out_keys", [["loc", "scale"], ["loc_1", "scale_1"]])
@pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize(
"exp_mode", [InteractionType.MODE, InteractionType.RANDOM, None]
"exp_mode", [InteractionType.DETERMINISTIC, InteractionType.RANDOM, None]
)
def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys):
torch.manual_seed(0)
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TruncatedNormal,
)
from .models import (
BatchRenorm1d,
Conv3dNet,
ConvNet,
DdpgCnnActor,
Expand Down
2 changes: 0 additions & 2 deletions torchrl/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,3 @@
SoftUpdate,
ValueEstimators,
)

# from .value import bellman_max, c_val, dv_val, vtrace, GAE, TDLambdaEstimate, TDEstimate
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def get_loss_module(actor, gamma):
frame_skip=1,
policy_exploration=actor_explore,
environment=test_env,
exploration_type=ExplorationType.MODE,
exploration_type=ExplorationType.DETERMINISTIC,
log_keys=[("next", "reward")],
out_keys={("next", "reward"): "rewards"},
log_pbar=True,
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/dqn_with_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@
exploration_module.step(data.numel())
updater.step()

with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
rollout = env.rollout(10000, stoch_policy)
traj_lens.append(rollout.get(("next", "step_count")).max().item())

Expand Down
4 changes: 2 additions & 2 deletions tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def process_batch(batch: TensorDictBase) -> TensorDictBase:
target_updaters[group].step()

# Exploration sigma anneal update
exploration_policies[group].step(current_frames)
exploration_policies[group][-1].step(current_frames)

# Stop training a certain group when a condition is met (e.g., number of training iterations)
if iteration == iteration_when_stop_training_evaders:
Expand Down Expand Up @@ -903,7 +903,7 @@ def process_batch(batch: TensorDictBase) -> TensorDictBase:
env_with_render = env_with_render.append_transform(
VideoRecorder(logger=video_logger, tag="vmas_rendered")
)
with set_exploration_type(ExplorationType.MODE):
with set_exploration_type(ExplorationType.DETERMINISTIC):
print("Rendering rollout...")
env_with_render.rollout(100, policy=agents_exploration_policy)
print("Saving the video...")
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/torchrl_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def exec_sequence(params, data):
td_module(td)
print("random:", td["action"])

with set_exploration_type(ExplorationType.MODE):
with set_exploration_type(ExplorationType.DETERMINISTIC):
td_module(td)
print("mode:", td["action"])

Expand Down
Loading