Skip to content

Commit

Permalink
[BugFix] Set exploration mode to MODE in all losses by default (#1123)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 4, 2023
1 parent 257f152 commit 714d645
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 19 deletions.
8 changes: 8 additions & 0 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.nn import Parameter

from torchrl._utils import RL_WARNINGS
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules.utils import Buffer
from torchrl.objectives.utils import ValueEstimators
from torchrl.objectives.value import ValueEstimatorBase
Expand Down Expand Up @@ -53,11 +54,18 @@ class LossModule(nn.Module):
pointer. This class attribute indicates which value estimator will be
used if none other is specified.
The value estimator can be changed using the :meth:`~.make_value_estimator` method.
By default, the forward method is always decorated with a
gh :class:`torchrl.envs.ExplorationType.MODE`
"""

default_value_estimator: ValueEstimators = None
SEP = "_sep_"

def __new__(cls, *args, **kwargs):
cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward)
return super().__new__(cls)

def __init__(self):
super().__init__()
self._param_maps = {}
Expand Down
9 changes: 3 additions & 6 deletions torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from tensordict.nn import make_functional, repopulate_module, TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase

from torchrl.envs.utils import ExplorationType, set_exploration_type

from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
Expand Down Expand Up @@ -162,10 +160,9 @@ def _loss_value(
batch_size=self.target_actor_network_params.batch_size,
device=self.target_actor_network_params.device,
)
with set_exploration_type(ExplorationType.MODE):
target_value = self.value_estimator.value_estimate(
tensordict, target_params=target_params
).squeeze(-1)
target_value = self.value_estimator.value_estimate(
tensordict, target_params=target_params
).squeeze(-1)

# td_error = pred_val - target_value
loss_value = distance_loss(
Expand Down
11 changes: 4 additions & 7 deletions torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from tensordict.tensordict import TensorDict, TensorDictBase
from torch import Tensor

from torchrl.envs.utils import ExplorationType, set_exploration_type

from torchrl.modules import ProbabilisticActor
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
Expand Down Expand Up @@ -170,11 +168,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:

def _loss_actor(self, tensordict: TensorDictBase) -> Tensor:
# KL loss
with set_exploration_type(ExplorationType.MODE):
dist = self.actor_network.get_dist(
tensordict,
params=self.actor_network_params,
)
dist = self.actor_network.get_dist(
tensordict,
params=self.actor_network_params,
)

log_prob = dist.log_prob(tensordict["action"])

Expand Down
11 changes: 5 additions & 6 deletions torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from tensordict.tensordict import TensorDict, TensorDictBase

from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.envs.utils import step_mdp
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
Expand Down Expand Up @@ -127,11 +127,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0)
tensordict_actor = tensordict_actor.contiguous()

with set_exploration_type(ExplorationType.MODE):
actor_output_td = vmap(self.actor_network)(
tensordict_actor,
actor_params,
)
actor_output_td = vmap(self.actor_network)(
tensordict_actor,
actor_params,
)
# add noise to target policy
noise = torch.normal(
mean=torch.zeros(actor_output_td[1]["action"].shape),
Expand Down

0 comments on commit 714d645

Please sign in to comment.