Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Update all instances of exploration *Wrapper to *Module #2298

Merged
merged 2 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Update all instances of exploration *Wrapper to *Module
  • Loading branch information
kurtamohler committed Jul 19, 2024
commit b40430b24dd9d5456132df0c071cf6bcabc924b9
2 changes: 1 addition & 1 deletion sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def main(cfg: "DictConfig"): # noqa: F821
for _, tensordict in enumerate(collector):
sampling_time = time.time() - sampling_start
# Update exploration policy
exploration_policy.step(tensordict.numel())
exploration_policy[1].step(tensordict.numel())
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not completely sure if this is the best way to do this. Does there happen to be some alternative to TensorDictSequential which does essentially the same thing but also provides a step function?

Copy link
Collaborator Author

@kurtamohler kurtamohler Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it looks like the same thing was done when EGreedyWrapper was updated to EGreedyModule, so I guess it's alright:

policy = TensorDictSequential(
actor,
EGreedyModule(

policy[1].step()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep either that or

def update_exploration(module):
    if isinstance(module, ExplorationModule):
        module.set()
policy.apply(update_exploration)

We could make sure that all exploration modules have the same parent class and use that update function across examples.


# Update weights of the inference policy
collector.update_policy_weights_()
Expand Down
30 changes: 19 additions & 11 deletions sota-implementations/ddpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import torch

from tensordict.nn import TensorDictSequential

from torch import nn, optim
from torchrl.collectors import SyncDataCollector
from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
Expand All @@ -25,9 +27,9 @@
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
AdditiveGaussianWrapper,
AdditiveGaussianModule,
MLP,
OrnsteinUhlenbeckProcessWrapper,
OrnsteinUhlenbeckProcessModule,
SafeModule,
SafeSequential,
TanhModule,
Expand Down Expand Up @@ -227,18 +229,24 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):

# Exploration wrappers:
if cfg.network.noise_type == "ou":
actor_model_explore = OrnsteinUhlenbeckProcessWrapper(
actor_model_explore = TensorDictSequential(
model[0],
annealing_num_steps=1_000_000,
).to(device)
OrnsteinUhlenbeckProcessModule(
spec=action_spec,
annealing_num_steps=1_000_000,
).to(device),
)
elif cfg.network.noise_type == "gaussian":
actor_model_explore = AdditiveGaussianWrapper(
actor_model_explore = TensorDictSequential(
model[0],
sigma_end=1.0,
sigma_init=1.0,
mean=0.0,
std=0.1,
).to(device)
AdditiveGaussianModule(
spec=action_spec,
sigma_end=1.0,
sigma_init=1.0,
mean=0.0,
std=0.1,
).to(device),
)
else:
raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def compile_rssms(module):
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)

policy.step(current_frames)
policy[1].step(current_frames)
collector.update_policy_weights_()
# Evaluation
if (i % eval_iter) == 0:
Expand Down
17 changes: 10 additions & 7 deletions sota-implementations/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
)
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import (
AdditiveGaussianWrapper,
AdditiveGaussianModule,
DreamerActor,
IndependentNormal,
MLP,
Expand Down Expand Up @@ -266,13 +266,16 @@ def make_dreamer(
test_env=test_env,
)
# Exploration noise to be added to the actor_realworld
actor_realworld = AdditiveGaussianWrapper(
actor_realworld = TensorDictSequential(
actor_realworld,
sigma_init=1.0,
sigma_end=1.0,
annealing_num_steps=1,
mean=0.0,
std=cfg.networks.exploration_noise,
AdditiveGaussianModule(
spec=test_env.action_spec,
sigma_init=1.0,
sigma_end=1.0,
annealing_num_steps=1,
mean=0.0,
std=cfg.networks.exploration_noise,
),
)

# Make Critic
Expand Down
15 changes: 9 additions & 6 deletions sota-implementations/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import hydra
import torch

from tensordict.nn import TensorDictModule
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn
from torchrl._utils import logger as torchrl_logger
from torchrl.collectors import SyncDataCollector
Expand All @@ -18,7 +18,7 @@
from torchrl.envs.libs.vmas import VmasEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
AdditiveGaussianWrapper,
AdditiveGaussianModule,
ProbabilisticActor,
TanhDelta,
ValueOperator,
Expand Down Expand Up @@ -102,10 +102,13 @@ def train(cfg: "DictConfig"): # noqa: F821
return_log_prob=False,
)

policy_explore = AdditiveGaussianWrapper(
policy_explore = TensorDictSequential(
policy,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
AdditiveGaussianModule(
spec=env.unbatched_action_spec,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
),
)

# Critic
Expand Down Expand Up @@ -200,7 +203,7 @@ def train(cfg: "DictConfig"): # noqa: F821
optim.zero_grad()
target_net_updater.step()

policy_explore.step(frames=current_frames) # Update exploration annealing
policy_explore[1].step(frames=current_frames) # Update exploration annealing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the example I gave above, the update_exploration or step_exploration should be turned into a class to allow us to pass the current_frames

collector.update_policy_weights_()

training_time = time.time() - training_start
Expand Down
16 changes: 10 additions & 6 deletions sota-implementations/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

import hydra
import torch.cuda
from tensordict.nn import TensorDictSequential
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
from torchrl.modules import OrnsteinUhlenbeckProcessModule
from torchrl.record import VideoRecorder
from torchrl.record.loggers import get_logger
from utils import (
Expand Down Expand Up @@ -111,12 +112,15 @@ def main(cfg: "DictConfig"): # noqa: F821
if cfg.exploration.ou_exploration:
if cfg.exploration.gSDE:
raise RuntimeError("gSDE and ou_exploration are incompatible")
actor_model_explore = OrnsteinUhlenbeckProcessWrapper(
actor_model_explore = TensorDictSequential(
actor_model_explore,
annealing_num_steps=cfg.exploration.annealing_frames,
sigma=cfg.exploration.ou_sigma,
theta=cfg.exploration.ou_theta,
).to(device)
OrnsteinUhlenbeckProcessModule(
spec=actor_model_explore.spec,
annealing_num_steps=cfg.exploration.annealing_frames,
sigma=cfg.exploration.ou_sigma,
theta=cfg.exploration.ou_theta,
).to(device),
)
if device == torch.device("cpu"):
# mostly for debugging
actor_model_explore.share_memory()
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()
for tensordict in collector:
sampling_time = time.time() - sampling_start
exploration_policy.step(tensordict.numel())
exploration_policy[1].step(tensordict.numel())

# Update weights of the inference policy
collector.update_policy_weights_()
Expand Down
19 changes: 11 additions & 8 deletions sota-implementations/td3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from contextlib import nullcontext

import torch
from tensordict.nn import TensorDictSequential

from torch import nn, optim
from torchrl.collectors import SyncDataCollector
Expand All @@ -27,7 +28,7 @@
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
AdditiveGaussianWrapper,
AdditiveGaussianModule,
MLP,
SafeModule,
SafeSequential,
Expand Down Expand Up @@ -233,14 +234,16 @@ def make_td3_agent(cfg, train_env, eval_env, device):
eval_env.close()

# Exploration wrappers:
actor_model_explore = AdditiveGaussianWrapper(
actor_model_explore = TensorDictSequential(
model[0],
sigma_init=1,
sigma_end=1,
mean=0,
std=0.1,
spec=action_spec,
).to(device)
AdditiveGaussianModule(
sigma_init=1,
sigma_end=1,
mean=0,
std=0.1,
spec=action_spec,
).to(device),
)
return model, actor_model_explore


Expand Down
19 changes: 11 additions & 8 deletions sota-implementations/td3_bc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import functools

import torch
from tensordict.nn import TensorDictSequential

from torch import nn, optim
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
Expand All @@ -24,7 +25,7 @@
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
AdditiveGaussianWrapper,
AdditiveGaussianModule,
MLP,
SafeModule,
SafeSequential,
Expand Down Expand Up @@ -174,14 +175,16 @@ def make_td3_agent(cfg, train_env, device):
del td

# Exploration wrappers:
actor_model_explore = AdditiveGaussianWrapper(
actor_model_explore = TensorDictSequential(
model[0],
sigma_init=1,
sigma_end=1,
mean=0,
std=0.1,
spec=action_spec,
).to(device)
AdditiveGaussianModule(
sigma_init=1,
sigma_end=1,
mean=0,
std=0.1,
spec=action_spec,
).to(device),
)
return model, actor_model_explore


Expand Down
15 changes: 11 additions & 4 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
PARTIAL_MISSING_ERR,
RandomPolicy,
)
from torchrl.modules import Actor, OrnsteinUhlenbeckProcessWrapper, SafeModule
from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule

# torch.set_default_dtype(torch.double)
IS_WINDOWS = sys.platform == "win32"
Expand Down Expand Up @@ -1291,8 +1291,13 @@ def make_env():
policy_module, in_keys=["observation"], out_keys=["action"]
)
copier = TensorDictModule(lambda x: x, in_keys=["observation"], out_keys=[out_key])
policy = TensorDictSequential(policy, copier)
policy_explore = OrnsteinUhlenbeckProcessWrapper(policy)
policy_explore = TensorDictSequential(
policy,
copier,
OrnsteinUhlenbeckProcessModule(
spec=CompositeSpec({key: None for key in policy.out_keys})
),
)

collector_kwargs = {
"create_env_fn": make_env,
Expand Down Expand Up @@ -2472,7 +2477,9 @@ def make_env():
obs_spec = dummy_env.observation_spec["observation"]
policy_module = nn.Linear(obs_spec.shape[-1], dummy_env.action_spec.shape[-1])
policy = Actor(policy_module, spec=dummy_env.action_spec)
policy_explore = OrnsteinUhlenbeckProcessWrapper(policy)
policy_explore = TensorDictSequential(
policy, OrnsteinUhlenbeckProcessModule(spec=policy.spec)
)

collector_kwargs = {
"create_env_fn": make_env,
Expand Down
10 changes: 6 additions & 4 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from torchrl.envs.utils import set_exploration_type, step_mdp
from torchrl.modules import (
AdditiveGaussianWrapper,
AdditiveGaussianModule,
DecisionTransformerInferenceWrapper,
DTActor,
GRUModule,
Expand Down Expand Up @@ -1363,17 +1363,19 @@ def test_actor_critic_specs():
out_keys=[action_key],
)
original_spec = spec.clone()
module = AdditiveGaussianWrapper(policy_module, spec=spec, action_key=action_key)
module = TensorDictSequential(
policy_module, AdditiveGaussianModule(spec=spec, action_key=action_key)
)
value_module = ValueOperator(
module=module,
in_keys=[("agents", "observation"), action_key],
out_keys=[("agents", "state_action_value")],
)
assert original_spec == spec
assert module.spec == spec
assert module[1].spec == spec
DDPGLoss(actor_network=module, value_network=value_module)
assert original_spec == spec
assert module.spec == spec
assert module[1].spec == spec


def test_vmapmodule():
Expand Down
17 changes: 10 additions & 7 deletions tutorials/sphinx-tutorials/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@
# Later, we will see how the target parameters should be updated in TorchRL.
#

from tensordict.nn import TensorDictModule
from tensordict.nn import TensorDictModule, TensorDictSequential


def _init(
Expand Down Expand Up @@ -722,7 +722,7 @@ def get_env_stats():
ActorCriticWrapper,
DdpgMlpActor,
DdpgMlpQNet,
OrnsteinUhlenbeckProcessWrapper,
OrnsteinUhlenbeckProcessModule,
ProbabilisticActor,
TanhDelta,
ValueOperator,
Expand Down Expand Up @@ -781,15 +781,18 @@ def make_ddpg_actor(
# Exploration
# ~~~~~~~~~~~
#
# The policy is wrapped in a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper`
# The policy is passed into a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessModule`
# exploration module, as suggested in the original paper.
# Let's define the number of frames before OU noise reaches its minimum value
annealing_frames = 1_000_000

actor_model_explore = OrnsteinUhlenbeckProcessWrapper(
actor_model_explore = TensorDictSequential(
actor,
annealing_num_steps=annealing_frames,
).to(device)
OrnsteinUhlenbeckProcessModule(
spec=actor.spec.clone(),
annealing_num_steps=annealing_frames,
).to(device),
)
if device == torch.device("cpu"):
actor_model_explore.share_memory()

Expand Down Expand Up @@ -1173,7 +1176,7 @@ def ceil_div(x, y):
)

# update the exploration strategy
actor_model_explore.step(current_frames)
actor_model_explore[1].step(current_frames)

collector.shutdown()
del collector
Expand Down
Loading
Loading