Skip to content

Commit

Permalink
[Refactor] Remove loss device (#1395)
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <matbet@meta.com>
  • Loading branch information
matteobettini authored Jul 26, 2023
1 parent d41e960 commit c06ed70
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 38 deletions.
14 changes: 8 additions & 6 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,14 @@ def __init__(
self.convert_to_functional(critic, "critic", compare_against=policy_params)
self.samples_mc_entropy = samples_mc_entropy
self.entropy_bonus = entropy_bonus and entropy_coef
self.register_buffer(
"entropy_coef", torch.tensor(entropy_coef, device=self.device)
)
self.register_buffer(
"critic_coef", torch.tensor(critic_coef, device=self.device)
)

try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")

self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
self.register_buffer("critic_coef", torch.tensor(critic_coef, device=device))
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
Expand Down
15 changes: 4 additions & 11 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Iterator, List, Optional, Tuple, Union
from typing import Iterator, List, Optional, Tuple

import torch

Expand All @@ -24,6 +24,7 @@
from torch.nn import Parameter

from torchrl._utils import RL_WARNINGS
from torchrl.data import DEVICE_TYPING
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules.utils import Buffer
from torchrl.objectives.utils import _cache_values, ValueEstimators
Expand Down Expand Up @@ -252,7 +253,6 @@ def convert_to_functional(
# as tensor = nn.Parameter(tensor) keeps its identity when moved to another device

def create_buffers(tensor):

if isinstance(tensor, torch.Tensor) and not isinstance(
tensor, (Buffer, nn.Parameter)
):
Expand Down Expand Up @@ -285,7 +285,6 @@ def create_buffers(tensor):
# For buffers, a cloned expansion (or equivalently a repeat) is returned.

def _compare_and_expand(param):

if param in compare_against:
expanded_param = param.data.expand(expand_dim, *param.shape)
# the expanded parameter must be sent to device when to()
Expand Down Expand Up @@ -376,7 +375,7 @@ def _compare_and_expand(param):
target_params = params_and_buffers.apply(_make_target_param(clone=True))
target_params_items = target_params.items(True, True)
target_params_list = []
for (key, val) in target_params_items:
for key, val in target_params_items:
if not isinstance(key, tuple):
key = (key,)
name = sep.join([name_params_target, *key])
Expand Down Expand Up @@ -493,12 +492,6 @@ def _networks(self) -> Iterator[nn.Module]:
if isinstance(item, nn.Module):
yield item

@property
def device(self) -> torch.device:
for p in self.parameters():
return p.device
return torch.device("cpu")

def register_buffer(
self, name: str, tensor: Optional[Tensor], persistent: bool = True
) -> None:
Expand Down Expand Up @@ -530,7 +523,7 @@ def to(self, *args, **kwargs):
out._cache = {}
return out

def cuda(self, device: Optional[Union[int, device]] = None) -> LossModule:
def cuda(self, device: Optional[DEVICE_TYPING] = None) -> LossModule:
if device is None:
return self.to("cuda")
else:
Expand Down
19 changes: 4 additions & 15 deletions torchrl/objectives/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,23 +278,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
a tensor containing the DQN loss.
"""
if self.device is not None:
warnings.warn(
"The use of a device for the objective function will soon be deprecated",
category=DeprecationWarning,
)
device = self.device
else:
device = tensordict.device
tddevice = tensordict.to(device)

td_copy = tddevice.clone(False)
td_copy = tensordict.clone(False)
self.value_network(
td_copy,
params=self.value_network_params,
)

action = tddevice.get(self.tensor_keys.action)
action = tensordict.get(self.tensor_keys.action)
pred_val = td_copy.get(self.tensor_keys.action_value)

if self.action_space == "categorical":
Expand Down Expand Up @@ -433,10 +423,9 @@ def _log_ps_a_categorical(action, action_log_softmax):

def forward(self, input_tensordict: TensorDictBase) -> TensorDict:
# from https://github.com/Kaixhin/Rainbow/blob/9ff5567ad1234ae0ed30d8471e8f13ae07119395/agent.py
device = self.device
tensordict = TensorDict(
source=input_tensordict, batch_size=input_tensordict.batch_size
).to(device)
)

if tensordict.batch_dims != 1:
raise RuntimeError(
Expand Down Expand Up @@ -543,7 +532,7 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict:
m.view(-1).index_add_(0, index, tensor)

# Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
loss = -torch.sum(m.to(device) * log_ps_a, 1)
loss = -torch.sum(m.to(input_tensordict.device) * log_ps_a, 1)
input_tensordict.set(
self.tensor_keys.priority,
loss.detach().unsqueeze(1).to(input_tensordict.device),
Expand Down
14 changes: 8 additions & 6 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,14 @@ def __init__(
self.samples_mc_entropy = samples_mc_entropy
self.entropy_bonus = entropy_bonus
self.separate_losses = separate_losses
self.register_buffer(
"entropy_coef", torch.tensor(entropy_coef, device=self.device)
)
self.register_buffer(
"critic_coef", torch.tensor(critic_coef, device=self.device)
)

try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")

self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
self.register_buffer("critic_coef", torch.tensor(critic_coef, device=device))
self.loss_critic_type = loss_critic_type
self.normalize_advantage = normalize_advantage
if gamma is not None:
Expand Down

0 comments on commit c06ed70

Please sign in to comment.