Skip to content

Commit

Permalink
[BugFix] Refactor reductions (#1968)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 27, 2024
1 parent b8ad113 commit db4ad23
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 84 deletions.
179 changes: 133 additions & 46 deletions test/test_cost.py

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import functools
import warnings
from copy import deepcopy
from dataclasses import dataclass
Expand Down Expand Up @@ -472,13 +471,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
td_out = TensorDict({"loss_objective": loss}, batch_size=[])
if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
td_out.set("entropy", entropy.detach()) # for logging
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
if self.critic_coef:
loss_critic = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
td_out = td_out.apply(
functools.partial(_reduce, reduction=self.reduction), batch_size=[]
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction)
if name.startswith("loss_")
else value,
batch_size=[],
)
return td_out

Expand Down
6 changes: 2 additions & 4 deletions torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from __future__ import annotations

import functools
from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple
Expand Down Expand Up @@ -296,9 +295,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata},
batch_size=[],
)
td_out = td_out.apply(
functools.partial(_reduce, reduction=self.reduction), batch_size=[]
)
return td_out

def loss_actor(
Expand All @@ -314,6 +310,7 @@ def loss_actor(
td_copy = self.value_network(td_copy)
loss_actor = -td_copy.get(self.tensor_keys.state_action_value).squeeze(-1)
metadata = {}
loss_actor = _reduce(loss_actor, self.reduction)
return loss_actor, metadata

def loss_value(
Expand Down Expand Up @@ -352,6 +349,7 @@ def loss_value(
"target_value_max": target_value.max(),
"pred_value_max": pred_val.max(),
}
loss_value = _reduce(loss_value, self.reduction)
return loss_value, metadata

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"loss_log_likelihood": -log_likelihood,
"loss_entropy": -entropy_bonus,
"loss_alpha": loss_alpha,
"entropy": entropy.detach(),
"entropy": entropy.detach().mean(),
"alpha": self.alpha.detach(),
}
return TensorDict(out, [])
Expand Down
10 changes: 6 additions & 4 deletions torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import math
from dataclasses import dataclass
from numbers import Number
Expand Down Expand Up @@ -312,12 +311,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"loss_qvalue": loss_qval,
"loss_alpha": loss_alpha,
"alpha": self.alpha,
"entropy": -sample_log_prob.detach(),
"entropy": -sample_log_prob.detach().mean(),
},
[],
)
td_out = td_out.apply(
functools.partial(_reduce, reduction=self.reduction), batch_size=[]
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction)
if name.startswith("loss_")
else value,
batch_size=[],
)
return td_out

Expand Down
20 changes: 13 additions & 7 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
td_out = TensorDict({"loss_objective": -neg_loss}, batch_size=[])
if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
td_out.set("entropy", entropy.detach()) # for logging
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
if self.critic_coef:
loss_critic = self.loss_critic(tensordict)
Expand Down Expand Up @@ -799,15 +799,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:

if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
td_out.set("entropy", entropy.detach()) # for logging
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
if self.critic_coef:
loss_critic = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)

td_out.set("ESS", _reduce(ess, self.reduction) / batch)
td_out = td_out.apply(
functools.partial(_reduce, reduction=self.reduction), batch_size=[]
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction)
if name.startswith("loss_")
else value,
batch_size=[],
)
return td_out

Expand Down Expand Up @@ -1061,13 +1064,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:

if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
td_out.set("entropy", entropy.detach()) # for logging
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
if self.critic_coef:
loss_critic = self.loss_critic(tensordict_copy)
td_out.set("loss_critic", loss_critic)
td_out = td_out.apply(
functools.partial(_reduce, reduction=self.reduction), batch_size=[]
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction)
if name.startswith("loss_")
else value,
batch_size=[],
)

return td_out
Expand Down
10 changes: 6 additions & 4 deletions torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import math
from dataclasses import dataclass
from numbers import Number
Expand Down Expand Up @@ -564,16 +563,19 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"loss_qvalue": loss_qval,
"loss_alpha": loss_alpha,
"alpha": self.alpha.detach(),
"entropy": -sample_log_prob.detach(),
"entropy": -sample_log_prob.detach().mean(),
"state_action_value_actor": state_action_value_actor.detach(),
"action_log_prob_actor": action_log_prob_actor.detach(),
"next.state_value": next_state_value.detach(),
"target_value": target_value.detach(),
},
[],
)
td_out = td_out.apply(
functools.partial(_reduce, reduction=self.reduction), batch_size=[]
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction)
if name.startswith("loss_")
else value,
batch_size=[],
)
return td_out

Expand Down
8 changes: 5 additions & 3 deletions torchrl/objectives/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import annotations

import contextlib
import functools
import warnings
from copy import deepcopy
from dataclasses import dataclass
Expand Down Expand Up @@ -402,8 +401,11 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
td_out = TensorDict({"loss_actor": loss_actor}, batch_size=[])

td_out.set("loss_value", self.loss_critic(tensordict))
td_out = td_out.apply(
functools.partial(_reduce, reduction=self.reduction), batch_size=[]
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction)
if name.startswith("loss_")
else value,
batch_size=[],
)

return td_out
Expand Down
19 changes: 12 additions & 7 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import math
import warnings
from dataclasses import dataclass
Expand Down Expand Up @@ -573,13 +572,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"loss_qvalue": loss_qvalue,
"loss_alpha": loss_alpha,
"alpha": self._alpha,
"entropy": entropy,
"entropy": entropy.detach().mean(),
}
if self._version == 1:
out["loss_value"] = loss_value
td_out = TensorDict(out, [])
td_out = td_out.apply(
functools.partial(_reduce, reduction=self.reduction), batch_size=[]
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction)
if name.startswith("loss_")
else value,
batch_size=[],
)
return td_out

Expand Down Expand Up @@ -1134,11 +1136,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"loss_qvalue": loss_value,
"loss_alpha": loss_alpha,
"alpha": self._alpha,
"entropy": entropy,
"entropy": entropy.detach().mean(),
}
td_out = TensorDict(out, [])
td_out = td_out.apply(
functools.partial(_reduce, reduction=self.reduction), batch_size=[]
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction)
if name.startswith("loss_")
else value,
batch_size=[],
)
return td_out

Expand Down
6 changes: 2 additions & 4 deletions torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
from dataclasses import dataclass
from typing import Optional, Tuple

Expand Down Expand Up @@ -375,6 +374,7 @@ def actor_loss(self, tensordict):
metadata = {
"state_action_value_actor": state_action_value_actor.detach(),
}
loss_actor = _reduce(loss_actor, reduction=self.reduction)
return loss_actor, metadata

def value_loss(self, tensordict):
Expand Down Expand Up @@ -449,6 +449,7 @@ def value_loss(self, tensordict):
"pred_value": current_qvalue.detach(),
"target_value": target_value.detach(),
}
loss_qval = _reduce(loss_qval, reduction=self.reduction)
return loss_qval, metadata

@dispatch
Expand All @@ -472,9 +473,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
},
batch_size=[],
)
td_out = td_out.apply(
functools.partial(_reduce, reduction=self.reduction), batch_size=[]
)
return td_out

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
Expand Down

0 comments on commit db4ad23

Please sign in to comment.