Skip to content

Commit

Permalink
[Refactor] Update all instances of exploration *Wrapper to *Module (
Browse files Browse the repository at this point in the history
pytorch#2298)

Co-authored-by: Vincent Moens <vmoens@meta.com>
  • Loading branch information
kurtamohler and vmoens authored Jul 22, 2024
1 parent bdc9784 commit 87f66e8
Show file tree
Hide file tree
Showing 14 changed files with 113 additions and 74 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def main(cfg: "DictConfig"): # noqa: F821
for _, tensordict in enumerate(collector):
sampling_time = time.time() - sampling_start
# Update exploration policy
exploration_policy.step(tensordict.numel())
exploration_policy[1].step(tensordict.numel())

# Update weights of the inference policy
collector.update_policy_weights_()
Expand Down
30 changes: 19 additions & 11 deletions sota-implementations/ddpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import torch

from tensordict.nn import TensorDictSequential

from torch import nn, optim
from torchrl.collectors import SyncDataCollector
from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
Expand All @@ -25,9 +27,9 @@
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
AdditiveGaussianWrapper,
AdditiveGaussianModule,
MLP,
OrnsteinUhlenbeckProcessWrapper,
OrnsteinUhlenbeckProcessModule,
SafeModule,
SafeSequential,
TanhModule,
Expand Down Expand Up @@ -227,18 +229,24 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):

# Exploration wrappers:
if cfg.network.noise_type == "ou":
actor_model_explore = OrnsteinUhlenbeckProcessWrapper(
actor_model_explore = TensorDictSequential(
model[0],
annealing_num_steps=1_000_000,
).to(device)
OrnsteinUhlenbeckProcessModule(
spec=action_spec,
annealing_num_steps=1_000_000,
).to(device),
)
elif cfg.network.noise_type == "gaussian":
actor_model_explore = AdditiveGaussianWrapper(
actor_model_explore = TensorDictSequential(
model[0],
sigma_end=1.0,
sigma_init=1.0,
mean=0.0,
std=0.1,
).to(device)
AdditiveGaussianModule(
spec=action_spec,
sigma_end=1.0,
sigma_init=1.0,
mean=0.0,
std=0.1,
).to(device),
)
else:
raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def compile_rssms(module):
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)

policy.step(current_frames)
policy[1].step(current_frames)
collector.update_policy_weights_()
# Evaluation
if (i % eval_iter) == 0:
Expand Down
17 changes: 10 additions & 7 deletions sota-implementations/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
)
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import (
AdditiveGaussianWrapper,
AdditiveGaussianModule,
DreamerActor,
IndependentNormal,
MLP,
Expand Down Expand Up @@ -266,13 +266,16 @@ def make_dreamer(
test_env=test_env,
)
# Exploration noise to be added to the actor_realworld
actor_realworld = AdditiveGaussianWrapper(
actor_realworld = TensorDictSequential(
actor_realworld,
sigma_init=1.0,
sigma_end=1.0,
annealing_num_steps=1,
mean=0.0,
std=cfg.networks.exploration_noise,
AdditiveGaussianModule(
spec=test_env.action_spec,
sigma_init=1.0,
sigma_end=1.0,
annealing_num_steps=1,
mean=0.0,
std=cfg.networks.exploration_noise,
),
)

# Make Critic
Expand Down
15 changes: 9 additions & 6 deletions sota-implementations/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import hydra
import torch

from tensordict.nn import TensorDictModule
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn
from torchrl._utils import logger as torchrl_logger
from torchrl.collectors import SyncDataCollector
Expand All @@ -18,7 +18,7 @@
from torchrl.envs.libs.vmas import VmasEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
AdditiveGaussianWrapper,
AdditiveGaussianModule,
ProbabilisticActor,
TanhDelta,
ValueOperator,
Expand Down Expand Up @@ -102,10 +102,13 @@ def train(cfg: "DictConfig"): # noqa: F821
return_log_prob=False,
)

policy_explore = AdditiveGaussianWrapper(
policy_explore = TensorDictSequential(
policy,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
AdditiveGaussianModule(
spec=env.unbatched_action_spec,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
),
)

# Critic
Expand Down Expand Up @@ -200,7 +203,7 @@ def train(cfg: "DictConfig"): # noqa: F821
optim.zero_grad()
target_net_updater.step()

policy_explore.step(frames=current_frames) # Update exploration annealing
policy_explore[1].step(frames=current_frames) # Update exploration annealing
collector.update_policy_weights_()

training_time = time.time() - training_start
Expand Down
16 changes: 10 additions & 6 deletions sota-implementations/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

import hydra
import torch.cuda
from tensordict.nn import TensorDictSequential
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
from torchrl.modules import OrnsteinUhlenbeckProcessModule
from torchrl.record import VideoRecorder
from torchrl.record.loggers import get_logger
from utils import (
Expand Down Expand Up @@ -111,12 +112,15 @@ def main(cfg: "DictConfig"): # noqa: F821
if cfg.exploration.ou_exploration:
if cfg.exploration.gSDE:
raise RuntimeError("gSDE and ou_exploration are incompatible")
actor_model_explore = OrnsteinUhlenbeckProcessWrapper(
actor_model_explore = TensorDictSequential(
actor_model_explore,
annealing_num_steps=cfg.exploration.annealing_frames,
sigma=cfg.exploration.ou_sigma,
theta=cfg.exploration.ou_theta,
).to(device)
OrnsteinUhlenbeckProcessModule(
spec=actor_model_explore.spec,
annealing_num_steps=cfg.exploration.annealing_frames,
sigma=cfg.exploration.ou_sigma,
theta=cfg.exploration.ou_theta,
).to(device),
)
if device == torch.device("cpu"):
# mostly for debugging
actor_model_explore.share_memory()
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()
for tensordict in collector:
sampling_time = time.time() - sampling_start
exploration_policy.step(tensordict.numel())
exploration_policy[1].step(tensordict.numel())

# Update weights of the inference policy
collector.update_policy_weights_()
Expand Down
19 changes: 11 additions & 8 deletions sota-implementations/td3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from contextlib import nullcontext

import torch
from tensordict.nn import TensorDictSequential

from torch import nn, optim
from torchrl.collectors import SyncDataCollector
Expand All @@ -27,7 +28,7 @@
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
AdditiveGaussianWrapper,
AdditiveGaussianModule,
MLP,
SafeModule,
SafeSequential,
Expand Down Expand Up @@ -233,14 +234,16 @@ def make_td3_agent(cfg, train_env, eval_env, device):
eval_env.close()

# Exploration wrappers:
actor_model_explore = AdditiveGaussianWrapper(
actor_model_explore = TensorDictSequential(
model[0],
sigma_init=1,
sigma_end=1,
mean=0,
std=0.1,
spec=action_spec,
).to(device)
AdditiveGaussianModule(
sigma_init=1,
sigma_end=1,
mean=0,
std=0.1,
spec=action_spec,
).to(device),
)
return model, actor_model_explore


Expand Down
19 changes: 11 additions & 8 deletions sota-implementations/td3_bc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import functools

import torch
from tensordict.nn import TensorDictSequential

from torch import nn, optim
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
Expand All @@ -24,7 +25,7 @@
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
AdditiveGaussianWrapper,
AdditiveGaussianModule,
MLP,
SafeModule,
SafeSequential,
Expand Down Expand Up @@ -174,14 +175,16 @@ def make_td3_agent(cfg, train_env, device):
del td

# Exploration wrappers:
actor_model_explore = AdditiveGaussianWrapper(
actor_model_explore = TensorDictSequential(
model[0],
sigma_init=1,
sigma_end=1,
mean=0,
std=0.1,
spec=action_spec,
).to(device)
AdditiveGaussianModule(
sigma_init=1,
sigma_end=1,
mean=0,
std=0.1,
spec=action_spec,
).to(device),
)
return model, actor_model_explore


Expand Down
15 changes: 11 additions & 4 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
PARTIAL_MISSING_ERR,
RandomPolicy,
)
from torchrl.modules import Actor, OrnsteinUhlenbeckProcessWrapper, SafeModule
from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule

# torch.set_default_dtype(torch.double)
IS_WINDOWS = sys.platform == "win32"
Expand Down Expand Up @@ -1291,8 +1291,13 @@ def make_env():
policy_module, in_keys=["observation"], out_keys=["action"]
)
copier = TensorDictModule(lambda x: x, in_keys=["observation"], out_keys=[out_key])
policy = TensorDictSequential(policy, copier)
policy_explore = OrnsteinUhlenbeckProcessWrapper(policy)
policy_explore = TensorDictSequential(
policy,
copier,
OrnsteinUhlenbeckProcessModule(
spec=CompositeSpec({key: None for key in policy.out_keys})
),
)

collector_kwargs = {
"create_env_fn": make_env,
Expand Down Expand Up @@ -2472,7 +2477,9 @@ def make_env():
obs_spec = dummy_env.observation_spec["observation"]
policy_module = nn.Linear(obs_spec.shape[-1], dummy_env.action_spec.shape[-1])
policy = Actor(policy_module, spec=dummy_env.action_spec)
policy_explore = OrnsteinUhlenbeckProcessWrapper(policy)
policy_explore = TensorDictSequential(
policy, OrnsteinUhlenbeckProcessModule(spec=policy.spec)
)

collector_kwargs = {
"create_env_fn": make_env,
Expand Down
10 changes: 6 additions & 4 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from torchrl.envs.utils import set_exploration_type, step_mdp
from torchrl.modules import (
AdditiveGaussianWrapper,
AdditiveGaussianModule,
DecisionTransformerInferenceWrapper,
DTActor,
GRUModule,
Expand Down Expand Up @@ -1363,17 +1363,19 @@ def test_actor_critic_specs():
out_keys=[action_key],
)
original_spec = spec.clone()
module = AdditiveGaussianWrapper(policy_module, spec=spec, action_key=action_key)
module = TensorDictSequential(
policy_module, AdditiveGaussianModule(spec=spec, action_key=action_key)
)
value_module = ValueOperator(
module=module,
in_keys=[("agents", "observation"), action_key],
out_keys=[("agents", "state_action_value")],
)
assert original_spec == spec
assert module.spec == spec
assert module[1].spec == spec
DDPGLoss(actor_network=module, value_network=value_module)
assert original_spec == spec
assert module.spec == spec
assert module[1].spec == spec


def test_vmapmodule():
Expand Down
17 changes: 10 additions & 7 deletions tutorials/sphinx-tutorials/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@
# Later, we will see how the target parameters should be updated in TorchRL.
#

from tensordict.nn import TensorDictModule
from tensordict.nn import TensorDictModule, TensorDictSequential


def _init(
Expand Down Expand Up @@ -722,7 +722,7 @@ def get_env_stats():
ActorCriticWrapper,
DdpgMlpActor,
DdpgMlpQNet,
OrnsteinUhlenbeckProcessWrapper,
OrnsteinUhlenbeckProcessModule,
ProbabilisticActor,
TanhDelta,
ValueOperator,
Expand Down Expand Up @@ -781,15 +781,18 @@ def make_ddpg_actor(
# Exploration
# ~~~~~~~~~~~
#
# The policy is wrapped in a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper`
# The policy is passed into a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessModule`
# exploration module, as suggested in the original paper.
# Let's define the number of frames before OU noise reaches its minimum value
annealing_frames = 1_000_000

actor_model_explore = OrnsteinUhlenbeckProcessWrapper(
actor_model_explore = TensorDictSequential(
actor,
annealing_num_steps=annealing_frames,
).to(device)
OrnsteinUhlenbeckProcessModule(
spec=actor.spec.clone(),
annealing_num_steps=annealing_frames,
).to(device),
)
if device == torch.device("cpu"):
actor_model_explore.share_memory()

Expand Down Expand Up @@ -1173,7 +1176,7 @@ def ceil_div(x, y):
)

# update the exploration strategy
actor_model_explore.step(current_frames)
actor_model_explore[1].step(current_frames)

collector.shutdown()
del collector
Expand Down
Loading

0 comments on commit 87f66e8

Please sign in to comment.