Skip to content

Commit

Permalink
[BugFix] Fix SAC (pytorch#1189)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 23, 2023
1 parent 4ece06c commit ae10bb8
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 16 deletions.
11 changes: 0 additions & 11 deletions examples/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# LICENSE file in the root directory of this source tree.

import dataclasses
import uuid
from datetime import datetime

import hydra
import torch.cuda
Expand Down Expand Up @@ -77,15 +75,6 @@ def main(cfg: "DictConfig"): # noqa: F821
else torch.device("cuda:0")
)

exp_name = "_".join(
[
"SAC",
cfg.exp_name,
str(uuid.uuid4())[:8],
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
]
)

exp_name = generate_exp_name("SAC", cfg.exp_name)
logger = get_logger(
logger_type=cfg.logger, logger_name="sac_logging", experiment_name=exp_name
Expand Down
8 changes: 4 additions & 4 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ class SACLoss(LossModule):
Default is ``False``.
delay_qvalue (bool, optional): Whether to separate the target Q value
networks from the Q value networks used for data collection.
Default is ``False``.
Default is ``True``.
delay_value (bool, optional): Whether to separate the target value
networks from the value networks used for data collection.
Default is ``False``.
Default is ``True``.
"""

default_value_estimator = ValueEstimators.TD0
Expand All @@ -105,8 +105,8 @@ def __init__(
fixed_alpha: bool = False,
target_entropy: Union[str, float] = "auto",
delay_actor: bool = False,
delay_qvalue: bool = False,
delay_value: bool = False,
delay_qvalue: bool = True,
delay_value: bool = True,
gamma: float = None,
) -> None:
if not _has_functorch:
Expand Down
5 changes: 4 additions & 1 deletion torchrl/trainers/helpers/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ def make_sac_loss(model, cfg) -> Tuple[SACLoss, Optional[TargetNetUpdater]]:
**loss_kwargs,
)
loss_module.make_value_estimator(gamma=cfg.gamma)
target_net_updater = make_target_updater(cfg, loss_module)
if cfg.loss == "double":
target_net_updater = make_target_updater(cfg, loss_module)
else:
target_net_updater = None
return loss_module, target_net_updater


Expand Down

0 comments on commit ae10bb8

Please sign in to comment.