Skip to content

Commit

Permalink
[Feature] RLHF networks (pytorch#1319)
Browse files Browse the repository at this point in the history
Co-authored-by: vmoens <vincentmoens@gmail.com>
  • Loading branch information
apbard and vmoens authored Jul 3, 2023
1 parent fa4fe1d commit a6d76e2
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 1 deletion.
64 changes: 64 additions & 0 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
MultiOneHotDiscreteTensorSpec,
OneHotDiscreteTensorSpec,
)
from torchrl.data.rlhf.dataset import _has_transformers
from torchrl.modules import MLP, SafeModule
from torchrl.modules.tensordict_module.actors import (
_process_action_space_spec,
ActorValueOperator,
DistributionalQValueActor,
DistributionalQValueHook,
DistributionalQValueModule,
LMHeadActorValueOperator,
ProbabilisticActor,
QValueActor,
QValueHook,
Expand Down Expand Up @@ -561,6 +563,68 @@ def test_actorcritic(device):
) == len(policy_params)


@pytest.mark.skipif(not _has_transformers, reason="missing dependencies")
@pytest.mark.parametrize("device", get_default_devices())
def test_lmhead_actorvalueoperator(device):
from transformers import AutoModelForCausalLM

base_model = AutoModelForCausalLM.from_pretrained("gpt2", return_dict=False)
aco = LMHeadActorValueOperator(base_model)

# check common
assert aco.module[0][0].module is base_model.transformer
assert aco.module[0][1].in_keys == ["x"]
assert aco.module[0][1].out_keys == ["x"]

# check actor
assert aco.module[1].in_keys == ["x"]
assert aco.module[1].out_keys == ["logits", "action", "sample_log_prob"]
assert aco.module[1][0].module is base_model.lm_head

# check critic
assert aco.module[2].in_keys == ["x"]
assert aco.module[2].out_keys == ["state_value"]
assert isinstance(aco.module[2].module, nn.Linear)
assert aco.module[2].module.in_features == base_model.transformer.embed_dim
assert aco.module[2].module.out_features == 1

td = TensorDict(
source={
"input_ids": torch.randint(50257, (4, 3)),
"attention_mask": torch.ones((4, 3)),
},
batch_size=[
4,
],
).to(device)
td_total = aco(td.clone())
policy_op = aco.get_policy_operator()
td_policy = policy_op(td.clone())
value_op = aco.get_value_operator()
td_value = value_op(td)
torch.testing.assert_close(td_total.get("action"), td_policy.get("action"))
torch.testing.assert_close(
td_total.get("sample_log_prob"), td_policy.get("sample_log_prob")
)
torch.testing.assert_close(td_total.get("state_value"), td_value.get("state_value"))

value_params = set(
list(aco.get_value_operator().parameters()) + list(aco.module[0].parameters())
)
value_params2 = set(value_op.parameters())
assert len(value_params.difference(value_params2)) == 0 and len(
value_params.intersection(value_params2)
) == len(value_params)

policy_params = set(
list(aco.get_policy_operator().parameters()) + list(aco.module[0].parameters())
)
policy_params2 = set(policy_op.parameters())
assert len(policy_params.difference(policy_params2)) == 0 and len(
policy_params.intersection(policy_params2)
) == len(policy_params)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
12 changes: 12 additions & 0 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
from torchrl.modules.tensordict_module.common import (
ensure_tensordict_compatible,
is_tensordict_compatible,
VmapModule,
)
from torchrl.modules.tensordict_module.probabilistic import (
SafeProbabilisticModule,
SafeProbabilisticTensorDictSequential,
)
from torchrl.modules.tensordict_module.sequence import SafeSequential


_has_functorch = False
try:
try:
Expand Down Expand Up @@ -1727,6 +1729,16 @@ def test_multi_consecutive(self, shape):
)


def test_vmapmodule():
lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"])
sample_in = torch.ones((10, 3, 2))
sample_in_td = TensorDict({"x": sample_in}, batch_size=[10])
lam(sample_in)
vm = VmapModule(lam, 0)
vm(sample_in_td)
assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
47 changes: 47 additions & 0 deletions torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
from typing import List, Optional, Sequence, Tuple, Union

import torch

from tensordict import TensorDictBase
from tensordict.nn import (
dispatch,
TensorDictModule,
TensorDictModuleBase,
TensorDictModuleWrapper,
TensorDictSequential,
)
from torch import nn
from torch.distributions import Categorical

from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
from torchrl.modules.models.models import DistributionalDQNnet
Expand Down Expand Up @@ -1748,3 +1751,47 @@ def forward(self, tensordict):
feature = low + (high - low) * (feature + 1) / 2
tensordict.set(out_key, feature)
return tensordict


class LMHeadActorValueOperator(ActorValueOperator):
"""Builds an Actor-Value operator from an huggingface-like *LMHeadModel.
This method:
- takes as input an huggingface-like *LMHeadModel
- extracts the final linear layer uses it as a base layer of the actor_head and
adds the sampling layer
- uses the common transformer as common model
- adds a linear critic
Args:
base_model (nn.Module): a torch model composed by a `.transformer` model and `.lm_head` linear layer
.. note:: For more details regarding the class construction, please refer to :class:`~.ActorValueOperator`.
"""

def __init__(self, base_model):
actor_head = base_model.lm_head
value_head = nn.Linear(actor_head.in_features, 1, bias=False)
common = TensorDictSequential(
TensorDictModule(
base_model.transformer,
in_keys={"input_ids": "input_ids", "attention_mask": "attention_mask"},
out_keys=["x"],
),
TensorDictModule(lambda x: x[:, -1, :], in_keys=["x"], out_keys=["x"]),
)
actor_head = TensorDictModule(actor_head, in_keys=["x"], out_keys=["logits"])
actor_head = SafeProbabilisticTensorDictSequential(
actor_head,
SafeProbabilisticModule(
in_keys=["logits"],
out_keys=["action"],
distribution_class=Categorical,
return_log_prob=True,
),
)
value_head = TensorDictModule(
value_head, in_keys=["x"], out_keys=["state_value"]
)

return super().__init__(common, actor_head, value_head)
49 changes: 48 additions & 1 deletion torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

import torch

from tensordict.nn import TensorDictModule
from tensordict.nn import TensorDictModule, TensorDictModuleBase
from tensordict.tensordict import TensorDictBase

from torch import nn

from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
Expand Down Expand Up @@ -401,3 +402,49 @@ def ensure_tensordict_compatible(
if out_keys is not None:
kwargs["out_keys"] = out_keys
return wrapper_type(module, **kwargs)


class VmapModule(TensorDictModuleBase):
"""A TensorDictModule wrapper to vmap over the input.
It is intended to be used with modules that accept data with one less batch
dimension than the one provided. By using this wrapper, one can hide a
batch dimension and satisfy the wrapped module.
Args:
module (TensorDictModuleBase): the module to vmap over.
vmap_dim (int, optional): the vmap input and output dim.
If none is provided, the last dimension of the tensordict is
assumed.
.. note::
Since vmap requires to have control over the batch size of the input
this module does not support dispatched arguments
Example:
>>> lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"])
>>> sample_in = torch.ones((10,3,2))
>>> sample_in_td = TensorDict({"x":sample_in}, batch_size=[10])
>>> lam(sample_in)
>>> vm = VmapModule(lam, 0)
>>> vm(sample_in_td)
>>> assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all()
"""

def __init__(self, module: TensorDictModuleBase, vmap_dim=None):
super().__init__()
self.in_keys = module.in_keys
self.out_keys = module.out_keys
self.module = module
self.vmap_dim = vmap_dim

def forward(self, tensordict):
# TODO: there is a risk of segfault if input is not a tensordict.
# We should investigate (possibly prevent it c++ side?)
vmap_dim = self.vmap_dim
if vmap_dim is None:
ndim = tensordict.ndim
vmap_dim = ndim - 1
td = torch.vmap(self.module, (vmap_dim,), (vmap_dim,))(tensordict)
return tensordict.update(td)

0 comments on commit a6d76e2

Please sign in to comment.