-
Notifications
You must be signed in to change notification settings - Fork 327
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Algorithm] RLHF end-to-end, clean (#1597)
Co-authored-by: Alessandro Pietro Bardelli <apbard@users.noreply.github.com> Co-authored-by: Tom Begley <tomcbegley@gmail.com>
- Loading branch information
1 parent
f09b0c8
commit fe19cf5
Showing
26 changed files
with
1,402 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
*.png | ||
*.bin | ||
*.pt | ||
*.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# RLHF example | ||
|
||
This example uses RLHF (Reinforcement Learning with Human Feedback) to train a | ||
language model to summarize Reddit posts. | ||
|
||
## Getting started | ||
|
||
Make sure you have PyTorch>=2.0 installed. You can find installation instructions | ||
[here](https://pytorch.org/get-started/locally/). | ||
|
||
From this directory, you can install extra requirements for running these | ||
examples with | ||
|
||
```sh | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Training the models | ||
### Training the transformer | ||
|
||
Once the data has been prepared, you can train the GPT model. | ||
|
||
```sh | ||
python train.py | ||
``` | ||
|
||
Default configuration can be found in `config/train.yaml`, and any option can | ||
be overridden with command-line arguments, for example to run the training | ||
script with a different batch size: | ||
|
||
```sh | ||
python train.py --batch_size=128 | ||
``` | ||
> **_NOTE:_** Apple Silicon Macbooks users make sure to use `--device=mps` | ||
> and prepend all commands with `PYTORCH_ENABLE_MPS_FALLBACK=1` to enable CPU fallback | ||
### Training the reward model | ||
|
||
Once you have completed supervised fine-tuning, copy the desired model | ||
checkpoint to `./out` or update the config to point `model.name_or_path` at | ||
the relevant checkpoint in the timestamped working directory created by Hydra. | ||
You can then train the reward model with: | ||
|
||
```sh | ||
python train_reward.py | ||
``` | ||
|
||
### Training the final model with RLHF | ||
|
||
Once again, make sure you have either updated the configuration to point | ||
`reward_model.name_or_path` at the relevant timestamped working directory, or | ||
copy the checkpoint to `./out_reward`. | ||
You can then train the final model by running | ||
|
||
```sh | ||
python train_rlhf.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
io: | ||
eval_interval: 200 | ||
log_interval: 50 | ||
eval_iters: 100 | ||
data: | ||
batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size | ||
block_size: 550 | ||
model: | ||
name_or_path: gpt2 # gpt2 for pre-trained, local path for checkpoint | ||
out_dir: ./out | ||
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ | ||
train: | ||
grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0 | ||
max_iters: 5000 # total number of training iterations | ||
gradient_accumulation_steps: 2 # used to simulate larger batch sizes | ||
always_save_checkpoint: False # if True, always save a checkpoint after each evaluation in out_dir | ||
decay_lr: True # whether to decay the learning rate | ||
optimizer: | ||
# keyword arguments for torch.optim.AdamW | ||
lr: 1.0e-5 | ||
weight_decay: 1.0e-1 | ||
betas: [0.9, 0.95] | ||
scheduler: | ||
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR | ||
T_max: 5000 # maximum number of iterations | ||
eta_min: 1.0e-6 # minimum learning rate | ||
sys: | ||
device: cuda # examples: cpu, cuda, cuda:0, cuda:1 etc., or try mps on macbooks | ||
dtype: bfloat16 # float32, bfloat16, or float16, the latter will auto implement a GradScaler | ||
compile: True # use PyTorch 2.0 to compile the model to be faster |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
io: | ||
eval_interval: 200 | ||
log_interval: 50 | ||
eval_iters: 100 | ||
data: | ||
batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size | ||
block_size: 550 | ||
model: | ||
name_or_path: ./out | ||
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ | ||
reward_model: | ||
out_dir: ./out_reward | ||
init_from: scratch # 'scratch' or 'resume' - if "resume" model will be loaded from out_dir_reward | ||
train: | ||
grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0 | ||
max_iters: 20000 # total number of training iterations | ||
gradient_accumulation_steps: 2 # used to simulate larger batch sizes | ||
always_save_checkpoint: False # if True, always save a checkpoint after each eval | ||
decay_lr: False # whether to decay the learning rate | ||
optimizer: | ||
# keyword arguments for torch.optim.AdamW | ||
lr: 1.0e-5 | ||
weight_decay: 1.0e-1 | ||
betas: [0.9, 0.95] | ||
scheduler: | ||
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR | ||
T_max: 20000 | ||
eta_min: 1.0e-6 | ||
sys: | ||
device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks | ||
dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler | ||
compile: True # use PyTorch 2.0 to compile the model to be faster |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
io: | ||
eval_interval: 6 | ||
log_interval: 1 | ||
eval_iters: 10 | ||
logger: wandb | ||
data: | ||
batch_size: 4 # if gradient_accumulation_steps > 1, this is the micro-batch size | ||
block_size: 550 | ||
num_workers: 1 | ||
model: | ||
name_or_path: ./out | ||
out_dir: ./out_rlhf | ||
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ | ||
reward_model: | ||
name_or_path: ./out_reward | ||
train: | ||
grad_clip: 1.0 | ||
max_epochs: 1000 # total number of training iterations | ||
always_save_checkpoint: True # if True, always save a checkpoint after each eval | ||
decay_lr: True | ||
optimizer: | ||
# keyword arguments for torch.optim.AdamW | ||
lr: 5.0e-5 | ||
weight_decay: 0.0 # 01 | ||
betas: [0.9, 0.999] | ||
scheduler: | ||
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR | ||
T_max: 3000 # max_epochs * num_rollouts / ppo_batch_size | ||
eta_min: 5.0e-6 | ||
ppo: | ||
episode_length: 50 | ||
ppo_batch_size: 16 | ||
ppo_num_epochs: 3 | ||
num_rollouts_per_epoch: 32 | ||
sys: | ||
device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks | ||
ref_device: cuda:1 # device of reference model | ||
dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler | ||
compile: False # use PyTorch 2.0 to compile the model to be faster |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from torchrl.data.rlhf.prompt import get_prompt_dataloader_tldr | ||
|
||
__all__ = ["get_prompt_dataloader_tldr"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from torchrl.modules.tensordict_module.actors import LMHeadActorValueOperator | ||
from torchrl.modules.tensordict_module.common import VmapModule | ||
|
||
from .transformer import init_transformer | ||
|
||
__all__ = ["init_actor_critic"] | ||
|
||
|
||
def init_actor_critic(model_cfg, sys_cfg): | ||
|
||
transformer_name_or_path = model_cfg.name_or_path | ||
dropout = model_cfg.dropout | ||
|
||
device = sys_cfg.device | ||
compile_model = sys_cfg.compile | ||
base_model = init_transformer( | ||
transformer_name_or_path, | ||
dropout, | ||
device, | ||
as_tensordictmodule=False, | ||
compile_model=compile_model, | ||
inference=True, | ||
) | ||
model = LMHeadActorValueOperator(base_model) | ||
model.to(device) | ||
model.eval() | ||
actor = model.get_policy_operator() | ||
critic = model.get_value_operator() | ||
critic_head = model.get_value_head() | ||
|
||
return actor, VmapModule(critic), critic_head, base_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import warnings | ||
|
||
import torch | ||
from tensordict.nn import TensorDictModule | ||
|
||
from torchrl.modules.models.rlhf import GPT2RewardModel | ||
|
||
|
||
def init_reward_model( | ||
transformer_path=None, reward_model_path=None, device=None, compile_model=False | ||
): | ||
if transformer_path is None and reward_model_path is None: | ||
warnings.warn( | ||
"You did not provide a path to the reward model, a naive reward model will be used instead." | ||
) | ||
model = GPT2RewardModel() | ||
else: | ||
if not ((transformer_path is None) ^ (reward_model_path is None)): | ||
raise ValueError( | ||
"Exactly one of transformer_path or reward_model_path should be specified." | ||
) | ||
if transformer_path is not None: | ||
model = GPT2RewardModel(transformer_path) | ||
else: | ||
model = GPT2RewardModel.from_pretrained(reward_model_path) | ||
|
||
model.to(device) | ||
if compile_model: | ||
print("Compiling the reward model...") | ||
model = torch.compile(model) | ||
|
||
model = TensorDictModule( | ||
model, | ||
in_keys=["input_ids", "attention_mask"], | ||
out_keys=["rewards", "end_scores"], | ||
) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import torch | ||
from tensordict.nn import TensorDictModule | ||
from transformers import GPT2LMHeadModel | ||
|
||
|
||
def init_transformer( | ||
name_or_path, | ||
dropout, | ||
device, | ||
compile_model, | ||
as_tensordictmodule=True, | ||
inference=False, | ||
): | ||
model_kwargs = { | ||
"resid_pdrop": dropout, | ||
"embd_pdrop": dropout, | ||
"attn_pdrop": dropout, | ||
"summary_first_dropout": dropout, | ||
} | ||
model = GPT2LMHeadModel.from_pretrained( | ||
name_or_path, return_dict=False, **model_kwargs | ||
) | ||
model.to(device) | ||
|
||
if compile_model: | ||
# TODO: logging instead of printing? | ||
print("Compiling transformer model...") | ||
model = torch.compile(model) | ||
|
||
if as_tensordictmodule: | ||
model = TensorDictModule( | ||
model, | ||
in_keys={ | ||
"input_ids": "input_ids", | ||
"attention_mask": "attention_mask", | ||
"labels": "labels", | ||
}, | ||
out_keys=["logits"] if inference else ["loss", "logits"], | ||
) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
datasets | ||
hydra-core | ||
matplotlib | ||
numpy | ||
PyYAML | ||
requests | ||
tiktoken | ||
tqdm | ||
transformers | ||
git+https://github.com/pytorch/rl | ||
git+https://github.com/pytorch-labs/tensordict |
Oops, something went wrong.