Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] RLHF end-to-end, clean #1597

Merged
merged 52 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
f24281c
RLHF end2end example
apbard Jun 27, 2023
ef3f76f
add VmapModule and from_lmhead_model method
apbard Jun 27, 2023
02a909b
Update examples/rlhf/train_rlhf.py
apbard Jun 28, 2023
953e4af
addressing comments
apbard Jun 28, 2023
ffb8661
Merge remote-tracking branch 'origin/main' into rlhf-networks
vmoens Jun 28, 2023
f43faea
Update torchrl/modules/tensordict_module/common.py
vmoens Jun 28, 2023
69b0588
Update torchrl/modules/tensordict_module/actors.py
vmoens Jun 28, 2023
b6fecbb
Add RolloutFromModel class
tcbegley Jun 26, 2023
bd8fbb6
Add rollout tests
tcbegley Jun 26, 2023
6fbb603
Apply suggestions from code review
tcbegley Jun 26, 2023
3e80a55
Address comments
tcbegley Jun 26, 2023
385ac90
Docstring lint
tcbegley Jun 26, 2023
8d0a152
Apply suggestions from code review
tcbegley Jun 27, 2023
fcddc97
Address comments
tcbegley Jun 27, 2023
5c7c72e
Fix tests
tcbegley Jun 28, 2023
92d5757
Handle missing transformers import
tcbegley Jun 28, 2023
eec0eaf
Import transformers locally
tcbegley Jun 28, 2023
87501ea
lint
vmoens Jun 28, 2023
043fcf6
Merge branch 'rlhf-rollout' into rlhf-example
tcbegley Jun 29, 2023
3f53046
Merge branch 'rlhf-networks' into rlhf-example
tcbegley Jun 29, 2023
8b69e41
lint
tcbegley Jun 29, 2023
24eaa3a
Example bugfixes
tcbegley Jun 29, 2023
fba43a1
Move KL controller logic
tcbegley Jun 29, 2023
20fa920
Merge branch 'main' into rlhf-example
vmoens Jul 4, 2023
c07ac93
amend
vmoens Jul 4, 2023
f463e0e
addressing comments about klcontroller
apbard Jul 4, 2023
eac5374
Merge remote-tracking branch 'origin/main' into rlhf-example
vmoens Sep 5, 2023
8d2dde7
Merge remote-tracking branch 'origin/main' into rlhf-example
vmoens Oct 1, 2023
a2ba045
Merge branch 'main' into rlhf-example
vmoens Oct 2, 2023
a9b94f0
amend
vmoens Oct 2, 2023
d983ebd
init
vmoens Oct 3, 2023
097c443
readme
vmoens Oct 3, 2023
0efd93a
amend
vmoens Oct 3, 2023
fba9f03
amend
vmoens Oct 3, 2023
cc535e5
amend
vmoens Oct 4, 2023
28c116f
amend
vmoens Oct 4, 2023
0f128a6
amend
vmoens Oct 4, 2023
e0ad043
amend
vmoens Oct 4, 2023
e8cad9b
Merge remote-tracking branch 'origin/main' into rlhf-example-refactor
vmoens Oct 4, 2023
c93c134
amend
vmoens Oct 4, 2023
56f7597
init
vmoens Oct 5, 2023
3fa6ea5
Merge branch 'refactor_ddpg_loss' into rlhf-example-refactor
vmoens Oct 5, 2023
c1c41dc
amend
vmoens Oct 5, 2023
880e5b4
amend
vmoens Oct 5, 2023
d36ce77
Update run_test.sh
vmoens Oct 5, 2023
942b311
amend
vmoens Oct 5, 2023
fca9f7b
amend
vmoens Oct 5, 2023
6362715
lint
vmoens Oct 5, 2023
e3b2d4f
amend
vmoens Oct 5, 2023
7918f86
amend
vmoens Oct 5, 2023
9658a44
Merge remote-tracking branch 'origin/main' into rlhf-example-refactor
vmoens Oct 5, 2023
eb041a4
lint
vmoens Oct 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Oct 4, 2023
commit cc535e5ec04b6526a198a0676b7a4d8ccd3232d2
2 changes: 1 addition & 1 deletion examples/rlhf/config/train_rlhf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ sys:
device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
ref_device: cuda:1 # device of reference model
dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
compile: True # use PyTorch 2.0 to compile the model to be faster
compile: False # use PyTorch 2.0 to compile the model to be faster
19 changes: 12 additions & 7 deletions examples/rlhf/models/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings

import torch
from tensordict.nn import TensorDictModule
Expand All @@ -12,14 +13,18 @@
def init_reward_model(
transformer_path=None, reward_model_path=None, device=None, compile_model=False
):
if not ((transformer_path is None) ^ (reward_model_path is None)):
raise ValueError(
"Exactly one of transformer_path or reward_model_path should be specified"
)
if transformer_path is not None:
model = GPT2RewardModel(transformer_path)
if transformer_path is None and reward_model_path is None:
warnings.warn("You did not provide a path to the reward model, a naive reward model will be used instead.")
model = GPT2RewardModel()
else:
model = GPT2RewardModel.from_pretrained(reward_model_path)
if not ((transformer_path is None) ^ (reward_model_path is None)):
raise ValueError(
"Exactly one of transformer_path or reward_model_path should be specified."
)
if transformer_path is not None:
model = GPT2RewardModel(transformer_path)
else:
model = GPT2RewardModel.from_pretrained(reward_model_path)

model.to(device)
if compile_model:
Expand Down
6 changes: 5 additions & 1 deletion examples/rlhf/train_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,17 @@ def main(cfg):
# ================= Training loop ================= #
#####################################################

stats_logger = TrainLogger(
collection_iters,
log_interval=cfg.io.log_interval,
logger=logger
)
pbar = tqdm(total=max_epochs * collection_iters)
for _ in range(max_epochs):
# ----------------- 1. Collect data, fill replay buffer ----------------- #
# it's possible we didn't fill the replay buffer in the last iteration if
# generation stopped early, so we empty first before repopulating
rb.empty()
stats_logger = TrainLogger(collection_iters)
for i in range(collection_iters):
batch = next(train_prompt_loader)
td = rollout_from_model.rollout_from_data(batch)
Expand Down
32 changes: 15 additions & 17 deletions examples/rlhf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from torchrl.objectives.value import GAE
from transformers import GenerationConfig, GPT2Tokenizer

from torchrl.record.loggers import Logger


class TestPromptLogger:
def __init__(self, batch, reward_model, logger, episode_length):
Expand Down Expand Up @@ -86,39 +88,33 @@ def log(self, model):


class TrainLogger:
def __init__(self, size):
def __init__(self, size: int, log_interval: int, logger: Logger):
self.data = TensorDict({}, [size])
self.counter = 0
self.log_interval = log_interval
self.logger = logger
self.it = -1

def __call__(self, data):
done = data.get(("next", "done"))
td_done = data[done.view(data.shape)]
next_reward = td_done.get(("next", "reward_raw"))
next_kl = td_done.get(("next", "reward_kl"))
data[self.counter]["next_reward"] = next_reward.mean().cpu()
data[self.counter]["next_kl"] = next_kl.mean().cpu()
self.data[self.counter]["next_reward"] = next_reward.mean().cpu()
self.data[self.counter]["next_kl"] = next_kl.mean().cpu()
self.counter += 1

def aggregate(self):
result = {}
for key, item in self.data.items():
result[key] = item.mean()
self.data = TensorDict(result, [])
self.aggregated_data = TensorDict(result, [])

def log(self):
if it % log_interval == 0:
val_reward_logger.info(
f"TRAIN: {it=}: {rollout_reward=:.4f} {rollout_kl_reward=:.4f} {rollout_kl=:.4f}"
)
wandb.log(
{
"rollout_reward": rollout_reward,
"rollout_kl_reward": rollout_kl_reward,
"rollout_kl": rollout_kl,
},
step=it,
)
pbar.set_description(f"TRAIN: {it=}: {rollout_reward=:.4f}")
self.it += 1
if self.it % self.log_interval == 0:
for key, item in self.aggregated_data.items():
self.logger.log_scalar(key, item)


class Evaluator:
Expand Down Expand Up @@ -224,6 +220,8 @@ def __call__(self, model, dataloader):

def resolve_name_or_path(name_or_path):
"""Hydra changes the working directory, so we need to absolutify paths."""
if not name_or_path:
return None
if name_or_path.startswith("./") or name_or_path.startswith("/"):
return to_absolute_path(name_or_path)
return name_or_path
Expand Down
23 changes: 21 additions & 2 deletions torchrl/data/rlhf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,30 @@ def _get_rollout_generated(self, generated, batch):
def _get_done_status(self, generated, batch):
# done is True when we either first sample an EOS token or reach the maximum number
# of generated tokens
terminated = generated == self.EOS_TOKEN_ID
done_idx = torch.minimum(
(generated != self.EOS_TOKEN_ID).sum(dim=-1) - batch.prompt_rindex,
torch.tensor(self.max_new_tokens) - 1,
)
truncated_idx = torch.tensor(self.max_new_tokens).expand_as(done_idx) - 1
zeros = torch.zeros(
done_idx.numel(),
self.max_new_tokens,
dtype=torch.bool,
device=generated.device,
)
truncated = zeros.scatter(-1, truncated_idx.unsqueeze(-1), 1).unsqueeze(-1)
done = zeros.scatter(-1, done_idx.unsqueeze(-1), 1).unsqueeze(-1)
terminated = done & ~truncated # we assume that if it's not truncated, it was terminated
return truncated | terminated, terminated

print('batch.prompt_rindex', batch.prompt_rindex)
print('generated', generated.shape)
terminated = (generated == self.EOS_TOKEN_ID)[..., -batch.prompt_rindex:]
terminated = terminated.int().cumsum(-1).bool()
done = terminated.clone()
done[..., self.max_new_tokens - 1] = 1
return done, terminated
print('self.max_new_tokens', self.max_new_tokens)
return done.unsqueeze(-1), terminated.unsqueeze(-1)

def _get_action(self, generated, batch):
# the sequence of actions for each trajectory is just the generated token ids
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/models/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, model_path=None):
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

super().__init__()
if model_path:
if model_path is not None:
model = GPT2LMHeadModel.from_pretrained(model_path, return_dict=False)
else:
model = GPT2LMHeadModel(GPT2LMHeadModel.config_class())
Expand Down
Loading