forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlosses.py
131 lines (113 loc) · 5.13 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Any, Optional, Tuple
from torchrl.objectives import DistributionalDQNLoss, DQNLoss, HardUpdate, SoftUpdate
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import TargetNetUpdater
def make_target_updater(
cfg: "DictConfig", loss_module: LossModule # noqa: F821
) -> Optional[TargetNetUpdater]:
"""Builds a target network weight update object."""
if cfg.loss == "double":
if not cfg.hard_update:
target_net_updater = SoftUpdate(
loss_module, eps=1 - 1 / cfg.value_network_update_interval
)
else:
target_net_updater = HardUpdate(
loss_module,
value_network_update_interval=cfg.value_network_update_interval,
)
else:
if cfg.hard_update:
raise RuntimeError(
"hard/soft-update are supposed to be used with double SAC loss. "
"Consider using --loss=double or discarding the hard_update flag."
)
target_net_updater = None
return target_net_updater
def make_dqn_loss(model, cfg) -> Tuple[DQNLoss, Optional[TargetNetUpdater]]:
"""Builds the DQN loss module."""
loss_kwargs = {}
if cfg.distributional:
loss_class = DistributionalDQNLoss
else:
loss_kwargs.update({"loss_function": cfg.loss_function})
loss_class = DQNLoss
if cfg.loss not in ("single", "double"):
raise NotImplementedError
loss_kwargs.update({"delay_value": cfg.loss == "double"})
loss_module = loss_class(model, **loss_kwargs)
loss_module.make_value_estimator(gamma=cfg.gamma)
target_net_updater = make_target_updater(cfg, loss_module)
return loss_module, target_net_updater
@dataclass
class LossConfig:
"""Generic Loss config struct."""
loss: str = "double"
# whether double or single SAC loss should be used. Default=double
hard_update: bool = False
# whether soft-update should be used with double SAC loss (default) or hard updates.
loss_function: str = "smooth_l1"
# loss function for the value network. Either one of l1, l2 or smooth_l1 (default).
value_network_update_interval: int = 1000
# how often the target value network weights are updated (in number of updates).
# If soft-updates are used, the value is translated into a moving average decay by using
# the formula decay=1-1/cfg.value_network_update_interval. Default=1000
gamma: float = 0.99
# Decay factor for return computation. Default=0.99.
num_q_values: int = 2
# As suggested in the original SAC paper and in https://arxiv.org/abs/1802.09477, we can
# use two (or more!) different qvalue networks trained independently and choose the lowest value
# predicted to predict the state action value. This can be disabled by using this flag.
# REDQ uses an arbitrary number of Q-value functions to speed up learning in MF contexts.
target_entropy: Any = None
# Target entropy for the policy distribution. Default is None (auto calculated as the `target_entropy = -action_dim`)
@dataclass
class A2CLossConfig:
"""A2C Loss config struct."""
gamma: float = 0.99
# Decay factor for return computation. Default=0.99.
entropy_coef: float = 1e-3
# Entropy factor for the A2C loss
critic_coef: float = 1.0
# Critic factor for the A2C loss
critic_loss_function: str = "smooth_l1"
# loss function for the value network. Either one of l1, l2 or smooth_l1 (default).
@dataclass
class PPOLossConfig:
"""PPO Loss config struct."""
loss: str = "clip"
# PPO loss class, either clip or kl or base/<empty>. Default=clip
# PPOLoss base parameters:
gamma: float = 0.99
# Decay factor for return computation. Default=0.99.
lmbda: float = 0.95
# lambda factor in GAE (using 'lambda' as attribute is prohibited in python, hence the misspelling)
entropy_bonus: bool = True
# whether to add an entropy term to the PPO loss.
entropy_coef: float = 1e-3
# Entropy factor for the PPO loss
samples_mc_entropy: int = 1
# Number of samples to use for a Monte-Carlo estimate if the policy distribution has not closed formula.
loss_function: str = "smooth_l1"
# loss function for the value network. Either one of l1, l2 or smooth_l1 (default).
critic_coef: float = 1.0
# Critic loss multiplier when computing the total loss.
# ClipPPOLoss parameters:
clip_epsilon: float = 0.2
# weight clipping threshold in the clipped PPO loss equation.
# KLPENPPOLoss parameters:
dtarg: float = 0.01
# target KL divergence.
beta: float = 1.0
# initial KL divergence multiplier.
increment: float = 2
# how much beta should be incremented if KL > dtarg. Valid range: increment >= 1.0
decrement: float = 0.5
# how much beta should be decremented if KL < dtarg. Valid range: decrement <= 1.0
samples_mc_kl: int = 1
# Number of samples to use for a Monte-Carlo estimate of KL if necessary