Skip to content

Commit

Permalink
[Refactor] Deprecate interaction_mode (pytorch#1067)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 18, 2023
1 parent cbb0c2f commit 3390db0
Show file tree
Hide file tree
Showing 44 changed files with 278 additions and 209 deletions.
2 changes: 1 addition & 1 deletion docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Besides those compute parameters, users may choose to configure the following pa
- reset_at_each_iter: if :obj:`True`, the environment(s) will be reset after each batch collection
- split_trajs: if :obj:`True`, the trajectories will be split and delivered in a padded tensordict
along with a :obj:`"mask"` key that will point to a boolean mask representing the valid values.
- exploration_mode: the exploration strategy to be used with the policy.
- exploration_type: the exploration strategy to be used with the policy.
- reset_when_done: whether environments should be reset when reaching a done state.


Expand Down
4 changes: 2 additions & 2 deletions examples/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.cuda
from hydra.core.config_store import ConfigStore
from torchrl.envs.transforms import RewardScaling
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives.value import TD0Estimator
from torchrl.record.loggers import generate_exp_name, get_logger
from torchrl.trainers.helpers.collectors import (
Expand Down Expand Up @@ -95,7 +95,7 @@ def main(cfg: "DictConfig"): # noqa: F821

loss_module = make_a2c_loss(model, cfg)
if cfg.gSDE:
with torch.no_grad(), set_exploration_mode("random"):
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
# get dimensions to build the parallel env
proof_td = model(proof_env.reset().to(device))
action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:]
Expand Down
4 changes: 2 additions & 2 deletions examples/bandits/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import nn

from torchrl.envs.libs.openml import OpenMLEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import DistributionalQValueActor, EGreedyWrapper, MLP, QValueActor
from torchrl.objectives import DistributionalDQNLoss, DQNLoss

Expand Down Expand Up @@ -94,7 +94,7 @@
init_r = None
init_loss = None
for i in pbar:
with set_exploration_mode("random"):
with set_exploration_type(ExplorationType.RANDOM):
data = env.step(policy(env.reset()))
loss_vals = loss(data)
loss_val = sum(
Expand Down
4 changes: 2 additions & 2 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from hydra.core.config_store import ConfigStore
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -122,7 +122,7 @@ def main(cfg: "DictConfig"): # noqa: F821
actor_model_explore.share_memory()

if cfg.gSDE:
with torch.no_grad(), set_exploration_mode("random"):
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
# get dimensions to build the parallel env
proof_td = actor_model_explore(proof_env.reset().to(device))
action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:]
Expand Down
15 changes: 8 additions & 7 deletions examples/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.cuda
import tqdm
from tensordict.nn import InteractionType

from torch import nn, optim
from torchrl.collectors import SyncDataCollector
Expand All @@ -22,7 +23,7 @@
from torchrl.envs import EnvCreator, ParallelEnv

from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import MLP, SafeModule
from torchrl.modules.distributions import OneHotCategorical

Expand Down Expand Up @@ -134,7 +135,7 @@ def env_factory(num_workers):
out_keys=["action"],
distribution_class=OneHotCategorical,
distribution_kwargs={},
default_interaction_mode="random",
default_interaction_type=InteractionType.RANDOM,
return_log_prob=False,
).to(device)

Expand Down Expand Up @@ -224,7 +225,7 @@ def env_factory(num_workers):
new_collected_epochs = len(np.unique(tensordict["collector"]["traj_ids"]))
if r0 is None:
r0 = (
tensordict["reward"].sum().item()
tensordict["next", "reward"].sum().item()
/ new_collected_epochs
/ cfg.env_per_collector
)
Expand Down Expand Up @@ -284,7 +285,7 @@ def env_factory(num_workers):
rewards.append(
(
i,
tensordict["reward"].sum().item()
tensordict["next", "reward"].sum().item()
/ cfg.env_per_collector
/ new_collected_epochs,
)
Expand All @@ -307,16 +308,16 @@ def env_factory(num_workers):
}
)

with set_exploration_mode(
"random"
with set_exploration_type(
ExplorationType.RANDOM
), torch.no_grad(): # TODO: exploration mode to mean causes nans

eval_rollout = test_env.rollout(
max_steps=cfg.max_frames_per_traj,
policy=actor,
auto_cast_to_device=True,
).clone()
eval_reward = eval_rollout["reward"].sum(-2).mean().item()
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
rewards_eval.append((i, eval_reward))
eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})"
metrics.update({"test_reward": rewards_eval[-1][1]})
Expand Down
8 changes: 5 additions & 3 deletions examples/distributed/collectors/multi_nodes/ray_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@
optim.step()
optim.zero_grad()

logs["reward"].append(tensordict_data["reward"].mean().item())
logs["reward"].append(tensordict_data["next", "reward"].mean().item())
pbar.update(tensordict_data.numel() * frame_skip)
cum_reward_str = f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})"
logs["step_count"].append(tensordict_data["step_count"].max().item())
Expand All @@ -206,8 +206,10 @@
with set_exploration_mode("mean"), torch.no_grad():
# execute a rollout with the trained policy
eval_rollout = env.rollout(1000, policy_module)
logs["eval reward"].append(eval_rollout["reward"].mean().item())
logs["eval reward (sum)"].append(eval_rollout["reward"].sum().item())
logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
logs["eval reward (sum)"].append(
eval_rollout["next", "reward"].sum().item()
)
logs["eval step_count"].append(eval_rollout["step_count"].max().item())
eval_str = f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} (init: {logs['eval reward (sum)'][0]: 4.4f}), eval step-count: {logs['eval step_count'][-1]}"
del eval_rollout
Expand Down
2 changes: 1 addition & 1 deletion examples/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def main(cfg: "DictConfig"): # noqa: F821
replay_buffer.extend(tensordict.cpu())
logger.log_scalar(
"r_training",
tensordict["reward"].mean().detach().item(),
tensordict["next", "reward"].mean().detach().item(),
step=collected_frames,
)

Expand Down
14 changes: 8 additions & 6 deletions examples/iql/iql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import MLP, ProbabilisticActor, ValueOperator
from torchrl.modules.distributions import TanhNormal

Expand Down Expand Up @@ -147,7 +147,7 @@ def env_factory(num_workers):
module=actor_module,
distribution_class=dist_class,
distribution_kwargs=dist_kwargs,
default_interaction_mode="random",
default_interaction_type=ExplorationType.RANDOM,
return_log_prob=False,
)

Expand Down Expand Up @@ -247,7 +247,7 @@ def env_factory(num_workers):
collector.update_policy_weights_()

if r0 is None:
r0 = tensordict["reward"].sum(-1).mean().item()
r0 = tensordict["next", "reward"].sum(-1).mean().item()
pbar.update(tensordict.numel())

if "mask" in tensordict.keys():
Expand Down Expand Up @@ -293,7 +293,9 @@ def env_factory(num_workers):
if cfg.prb:
replay_buffer.update_priority(sampled_tensordict)

rewards.append((i, tensordict["reward"].sum().item() / cfg.env_per_collector))
rewards.append(
(i, tensordict["next", "reward"].sum().item() / cfg.env_per_collector)
)
train_log = {
"train_reward": rewards[-1][1],
"collected_frames": collected_frames,
Expand All @@ -309,13 +311,13 @@ def env_factory(num_workers):
for key, value in train_log.items():
logger.log_scalar(key, value, step=collected_frames)

with set_exploration_mode("mean"), torch.no_grad():
with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
eval_rollout = test_env.rollout(
max_steps=cfg.max_frames_per_traj,
policy=model[0],
auto_cast_to_device=True,
).clone()
eval_reward = eval_rollout["reward"].sum(-2).mean().item()
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
rewards_eval.append((i, eval_reward))
eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})"
logger.log_scalar("test_reward", rewards_eval[-1][1], step=collected_frames)
Expand Down
4 changes: 2 additions & 2 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from hydra.core.config_store import ConfigStore
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives.value import GAE
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -98,7 +98,7 @@ def main(cfg: "DictConfig"): # noqa: F821

loss_module = make_ppo_loss(model, cfg)
if cfg.gSDE:
with torch.no_grad(), set_exploration_mode("random"):
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
# get dimensions to build the parallel env
proof_td = model(proof_env.reset().to(device))
action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:]
Expand Down
4 changes: 2 additions & 2 deletions examples/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from hydra.core.config_store import ConfigStore
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -134,7 +134,7 @@ def main(cfg: "DictConfig"): # noqa: F821
actor_model_explore.share_memory()

if cfg.gSDE:
with torch.no_grad(), set_exploration_mode("random"):
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
# get dimensions to build the parallel env
proof_td = actor_model_explore(proof_env.reset().to(device))
action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:]
Expand Down
4 changes: 2 additions & 2 deletions examples/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from hydra.core.config_store import ConfigStore
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -132,7 +132,7 @@ def main(cfg: "DictConfig"): # noqa: F821
actor_model_explore.share_memory()

if cfg.gSDE:
with torch.no_grad(), set_exploration_mode("random"):
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
# get dimensions to build the parallel env
proof_td = actor_model_explore(proof_env.reset().to(device))
action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:]
Expand Down
17 changes: 10 additions & 7 deletions examples/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.cuda
import tqdm
from tensordict.nn import InteractionType

from torch import nn, optim
from torchrl.collectors import MultiSyncDataCollector
Expand All @@ -26,7 +27,7 @@
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import RewardScaling
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
AdditiveGaussianWrapper,
MLP,
Expand Down Expand Up @@ -168,7 +169,7 @@ def main(cfg: "DictConfig"): # noqa: F821
module=actor_module,
distribution_class=dist_class,
distribution_kwargs=dist_kwargs,
default_interaction_mode="random",
default_interaction_type=InteractionType.RANDOM,
return_log_prob=False,
)

Expand All @@ -191,7 +192,7 @@ def main(cfg: "DictConfig"): # noqa: F821
model = nn.ModuleList([actor, qvalue]).to(device)

# init nets
with torch.no_grad(), set_exploration_mode("random"):
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
td = eval_env.reset()
td = td.to(device)
for net in model:
Expand Down Expand Up @@ -272,7 +273,7 @@ def main(cfg: "DictConfig"): # noqa: F821
collector.update_policy_weights_()

if r0 is None:
r0 = tensordict["reward"].sum(-1).mean().item()
r0 = tensordict["next", "reward"].sum(-1).mean().item()
pbar.update(tensordict.numel())

# extend the replay buffer with the new data
Expand Down Expand Up @@ -321,7 +322,9 @@ def main(cfg: "DictConfig"): # noqa: F821
if cfg.prb:
replay_buffer.update_priority(sampled_tensordict)

rewards.append((i, tensordict["reward"].sum().item() / cfg.env_per_collector))
rewards.append(
(i, tensordict["next", "reward"].sum().item() / cfg.env_per_collector)
)
train_log = {
"train_reward": rewards[-1][1],
"collected_frames": collected_frames,
Expand All @@ -336,13 +339,13 @@ def main(cfg: "DictConfig"): # noqa: F821
for key, value in train_log.items():
logger.log_scalar(key, value, step=collected_frames)

with set_exploration_mode("mean"), torch.no_grad():
with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
eval_rollout = eval_env.rollout(
cfg.max_frames_per_traj // cfg.frame_skip,
actor_model_explore,
auto_cast_to_device=True,
)
eval_reward = eval_rollout["reward"].sum(-2).mean().item()
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
rewards_eval.append((i, eval_reward))
eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})"
logger.log_scalar("test_reward", rewards_eval[-1][1], step=collected_frames)
Expand Down
3 changes: 2 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from copy import deepcopy

from packaging import version as pack_version
from tensordict.nn import InteractionType

_has_functorch = True
try:
Expand Down Expand Up @@ -3112,7 +3113,7 @@ def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
SafeProbabilisticModule(
in_keys=["loc", "scale"],
out_keys="action",
default_interaction_mode="random",
default_interaction_type=InteractionType.RANDOM,
distribution_class=TanhNormal,
),
)
Expand Down
Loading

0 comments on commit 3390db0

Please sign in to comment.