Skip to content

Commit

Permalink
[Refactor] Refactor functional calls in losses (pytorch#1707)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 24, 2023
1 parent fa149e4 commit bc7595f
Show file tree
Hide file tree
Showing 24 changed files with 273 additions and 449 deletions.
10 changes: 8 additions & 2 deletions test/assets/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,20 @@ def get_minibatch():
batch_size=16,
block_size=33,
tensorclass_type=PromptData,
dataset_name="test/datasets_mini/openai_summarize_tldr",
dataset_name="CarperAI/openai_summarize_tldr",
device="cpu",
infinite=False,
prefetch=0,
split="train",
from_disk=True,
from_disk=False,
root_dir=tmpdir,
)
for data in dl:
data = data.clone().memmap_("test/datasets_mini/tldr_batch/")
break
print("done")


if __name__ == "__main__":
# generate_small_dataset()
get_minibatch()
25 changes: 18 additions & 7 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
get_default_devices,
)
from mocking_classes import ContinuousActionConvMockEnv
from tensordict.nn import get_functional, NormalParamExtractor, TensorDictModule
from tensordict.nn import NormalParamExtractor, TensorDictModule
from tensordict.nn.utils import Buffer

# from torchrl.data.postprocs.utils import expand_as_right
Expand Down Expand Up @@ -4967,6 +4967,18 @@ def test_cql(
else:
raise NotImplementedError(k)
loss_fn.zero_grad()
assert all(
(p.grad is None) or (p.grad == 0).all()
for p in loss_fn.actor_network_params.values(
include_nested=True, leaves_only=True
)
)
assert all(
(p.grad is None) or (p.grad == 0).all()
for p in loss_fn.qvalue_network_params.values(
include_nested=True, leaves_only=True
)
)

sum([item for _, item in loss.items()]).backward()
named_parameters = list(loss_fn.named_parameters())
Expand Down Expand Up @@ -6500,6 +6512,8 @@ def test_a2c(self, device, gradient_mode, advantage, td_est):
assert ("critic" not in name) or ("target_" in name)

value.zero_grad()
for n, p in loss_fn.named_parameters():
assert p.grad is None or p.grad.norm() == 0, n
loss_objective.backward()
named_parameters = loss_fn.named_parameters()
for name, p in named_parameters:
Expand Down Expand Up @@ -6900,20 +6914,20 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est
advantage = GAE(
gamma=gamma,
lmbda=0.9,
value_network=get_functional(value_net),
value_network=value_net,
differentiable=gradient_mode,
)
elif advantage == "td":
advantage = TD1Estimator(
gamma=gamma,
value_network=get_functional(value_net),
value_network=value_net,
differentiable=gradient_mode,
)
elif advantage == "td_lambda":
advantage = TDLambdaEstimator(
gamma=0.9,
lmbda=0.9,
value_network=get_functional(value_net),
value_network=value_net,
differentiable=gradient_mode,
)
elif advantage is None:
Expand Down Expand Up @@ -9829,9 +9843,6 @@ def test_tdlambda_tensor_gamma(self, device, gamma, lmbda, N, T, has_done):
next_state_value = torch.randn(*N, T, 1, device=device)

gamma_tensor = torch.full((*N, T, 1), gamma, device=device)
# if len(N) == 2:
# print(terminated[4, 0, -10:])
# print(done[4, 0, -10:])
v1 = vec_td_lambda_advantage_estimate(
gamma,
lmbda,
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/rlhf/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def load(self):
data_dir = root_dir / str(Path(self.dataset_name).name).split("-")[0]
data_dir_total = data_dir / split / str(max_length)
# search for data
print(data_dir_total)
if os.path.exists(data_dir_total):
dataset = TensorDict.load_memmap(data_dir_total)
return dataset
Expand Down
25 changes: 9 additions & 16 deletions torchrl/envs/transforms/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,13 @@
from copy import copy, deepcopy

import torch
from tensordict import TensorDictBase, unravel_key
from tensordict.nn import (
make_functional,
ProbabilisticTensorDictModule,
repopulate_module,
TensorDictParams,
)
from tensordict import TensorDict, TensorDictBase, unravel_key
from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams
from tensordict.utils import is_seq_of_nested_key
from torch import nn
from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs.transforms.transforms import Transform
from torchrl.envs.transforms.utils import _set_missing_tolerance
from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param


class KLRewardTransform(Transform):
Expand Down Expand Up @@ -116,11 +111,10 @@ def __init__(
self.in_keys = self.in_keys + actor.in_keys

# check that the model has parameters
params = make_functional(
actor, keep_params=False, funs_to_decorate=["forward", "get_dist"]
)
self.functional_actor = deepcopy(actor)
repopulate_module(actor, params)
params = TensorDict.from_module(actor)
with params.apply(_stateless_param).to_module(actor):
# copy a stateless actor
self.__dict__["functional_actor"] = deepcopy(actor)
# we need to register these params as buffer to have `to` and similar
# methods work properly

Expand Down Expand Up @@ -170,9 +164,8 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.out_keys[0] != ("reward",) and self.parent is not None:
tensordict.set(self.out_keys[0], self.parent.reward_spec.zero())
return tensordict
dist = self.functional_actor.get_dist(
tensordict.clone(False), params=self.frozen_params
)
with self.frozen_params.to_module(self.functional_actor):
dist = self.functional_actor.get_dist(tensordict.clone(False))
# get the log_prob given the original model
log_prob = dist.log_prob(action)
reward_key = self.in_keys[0]
Expand Down
9 changes: 9 additions & 0 deletions torchrl/envs/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


import torch
from torch import nn


def check_finite(tensor: torch.Tensor):
Expand Down Expand Up @@ -59,3 +60,11 @@ def _get_reset(reset_key, tensordict):
if _reset.ndim > parent_td.ndim:
_reset = _reset.flatten(parent_td.ndim, -1).any(-1)
return _reset


def _stateless_param(param):
is_param = isinstance(param, nn.Parameter)
param = param.data.to("meta")
if is_param:
return nn.Parameter(param, requires_grad=False)
return param
21 changes: 10 additions & 11 deletions torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, make_functional
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules import ProbabilisticActor, NormalParamWrapper, TanhNormal
>>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
Expand All @@ -197,8 +197,9 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
... in_keys=["loc", "scale"],
... distribution_class=TanhNormal,
... )
>>> params = make_functional(td_module)
>>> td = td_module(td, params=params)
>>> params = TensorDict.from_module(td_module)
>>> with params.to_module(td_module):
... td = td_module(td)
>>> td
TensorDict(
fields={
Expand Down Expand Up @@ -319,7 +320,6 @@ class ValueOperator(TensorDictModule):
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> from torch import nn
>>> from torchrl.data import UnboundedContinuousTensorSpec
>>> from torchrl.modules import ValueOperator
Expand All @@ -334,8 +334,9 @@ class ValueOperator(TensorDictModule):
>>> td_module = ValueOperator(
... in_keys=["observation", "action"], module=module
... )
>>> params = make_functional(td_module)
>>> td = td_module(td, params=params)
>>> params = TensorDict.from_module(td_module)
>>> with params.to_module(td_module):
... td = td_module(td)
>>> print(td)
TensorDict(
fields={
Expand Down Expand Up @@ -792,7 +793,6 @@ class QValueHook:
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch import nn
>>> from torchrl.data import OneHotDiscreteTensorSpec
>>> from torchrl.modules.tensordict_module.actors import QValueHook, Actor
Expand Down Expand Up @@ -878,7 +878,6 @@ class DistributionalQValueHook(QValueHook):
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch import nn
>>> from torchrl.data import OneHotDiscreteTensorSpec
>>> from torchrl.modules.tensordict_module.actors import DistributionalQValueHook, Actor
Expand All @@ -893,12 +892,13 @@ class DistributionalQValueHook(QValueHook):
... return self.linear(x).view(-1, nbins, 4).log_softmax(-2)
...
>>> module = CustomDistributionalQval()
>>> params = make_functional(module)
>>> params = TensorDict.from_module(module)
>>> action_spec = OneHotDiscreteTensorSpec(4)
>>> hook = DistributionalQValueHook("one_hot", support = torch.arange(nbins))
>>> module.register_forward_hook(hook)
>>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"])
>>> qvalue_actor(td, params=params)
>>> with params.to_module(module):
... qvalue_actor(td)
>>> print(td)
TensorDict(
fields={
Expand Down Expand Up @@ -992,7 +992,6 @@ class QValueActor(SafeSequential):
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch import nn
>>> from torchrl.data import OneHotDiscreteTensorSpec
>>> from torchrl.modules.tensordict_module.actors import QValueActor
Expand Down
6 changes: 3 additions & 3 deletions torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ class SafeModule(TensorDictModule):
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn.functional_modules import make_functional
>>> from torchrl.data import UnboundedContinuousTensorSpec
>>> from torchrl.modules import TensorDictModule
>>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,])
Expand All @@ -150,8 +149,9 @@ class SafeModule(TensorDictModule):
... in_keys=["input", "hidden"],
... out_keys=["output"],
... )
>>> params = make_functional(td_fmodule)
>>> td_functional = td_fmodule(td.clone(), params=params)
>>> params = TensorDict.from_module(td_fmodule)
>>> with params.to_module(td_module):
... td_functional = td_fmodule(td.clone())
>>> print(td_functional)
TensorDict(
fields={
Expand Down
6 changes: 3 additions & 3 deletions torchrl/modules/tensordict_module/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class SafeSequential(TensorDictSequential, SafeModule):
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn.functional_modules import make_functional
>>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
>>> from torchrl.modules import TanhNormal, SafeSequential, TensorDictModule, NormalParamWrapper
>>> from torchrl.modules.tensordict_module import SafeProbabilisticModule
Expand All @@ -58,8 +57,9 @@ class SafeSequential(TensorDictSequential, SafeModule):
... out_keys=["output"],
... )
>>> td_module = SafeSequential(td_module1, td_module2)
>>> params = make_functional(td_module)
>>> td_module(td, params=params)
>>> params = TensorDict.from_module(td_module)
>>> with params.to_module(td_module):
... td_module(td)
>>> print(td)
TensorDict(
fields={
Expand Down
14 changes: 8 additions & 6 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ def _log_probs(
f"tensordict stored {self.tensor_keys.action} require grad."
)
tensordict_clone = tensordict.select(*self.actor.in_keys).clone()

dist = self.actor.get_dist(tensordict_clone, params=self.actor_params)
with self.actor_params.to_module(self.actor):
dist = self.actor.get_dist(tensordict_clone)
log_prob = dist.log_prob(action)
log_prob = log_prob.unsqueeze(-1)
return log_prob, dist
Expand All @@ -339,10 +339,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
# overhead that we could easily reduce.
target_return = tensordict.get(self.tensor_keys.value_target)
tensordict_select = tensordict.select(*self.critic.in_keys)
state_value = self.critic(
tensordict_select,
params=self.critic_params,
).get(self.tensor_keys.value)
with self.critic_params.to_module(self.critic):
state_value = self.critic(
tensordict_select,
).get(self.tensor_keys.value)
loss_value = distance_loss(
target_return,
state_value,
Expand Down Expand Up @@ -374,6 +374,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
target_params=self.target_critic_params,
)
advantage = tensordict.get(self.tensor_keys.advantage)
assert not advantage.requires_grad
log_probs, dist = self._log_probs(tensordict)
loss = -(log_probs * advantage)
td_out = TensorDict({"loss_objective": loss.mean()}, [])
Expand All @@ -392,6 +393,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
self.value_type = value_type
hp = dict(default_value_kwargs(value_type))
hp.update(hyperparams)

if hasattr(self, "gamma"):
hp["gamma"] = self.gamma
if value_type == ValueEstimators.TD1:
Expand Down
Loading

0 comments on commit bc7595f

Please sign in to comment.