forked from huggingface/trl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix style, typos, license (huggingface#103)
* fix style, typos, license * quality * minor updates, remove Mapping error for 3.10+ * fix style * update per feedback --------- Co-authored-by: leandro <leandro.vonwerra@spoud.io>
- Loading branch information
Nathan Lambert
and
leandro
authored
Jan 27, 2023
1 parent
99c6ff2
commit ef5aaa7
Showing
13 changed files
with
304 additions
and
101 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# coding=utf-8 | ||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import torch | ||
from tqdm import tqdm | ||
|
||
tqdm.pandas() | ||
|
||
from transformers import pipeline, AutoTokenizer | ||
from datasets import load_dataset | ||
|
||
import bitsandbytes as bnb | ||
|
||
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead | ||
from trl.core import LengthSampler | ||
|
||
######################################################################## | ||
# This is a fully working simple example to use trl with accelerate. | ||
# | ||
# This example fine-tunes a GPT2 model on the IMDB dataset using PPO | ||
# (proximal policy optimization). | ||
# in any of the following settings (with the same script): | ||
# - single CPU or single GPU | ||
# - multi GPUS (using PyTorch distributed mode) | ||
# - multi GPUS (using DeepSpeed ZeRO-Offload stages 1 & 2) | ||
# - fp16 (mixed-precision) or fp32 (normal precision) | ||
# | ||
# To run it in each of these various modes, first initialize the accelerate | ||
# configuration with `accelerate config` | ||
# | ||
######################################################################## | ||
|
||
# We first define the configuration of the experiment, defining the model, the dataset, | ||
# the training parameters, and the PPO parameters. | ||
# Check the default arguments in the `PPOConfig` class for more details. | ||
config = PPOConfig( | ||
model_name="lvwerra/gpt2-imdb", | ||
learning_rate=1.41e-6, | ||
) | ||
|
||
# We then define the arguments to pass to the sentiment analysis pipeline. | ||
# We set `return_all_scores` to True to get the sentiment score for each token. | ||
sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": config.forward_batch_size} | ||
|
||
# Below is an example function to build the dataset. In our case, we use the IMDB dataset | ||
# from the `datasets` library. One should customize this function to train the model on | ||
# its own dataset. | ||
def build_dataset(config, dataset_name="imdb", input_min_text_length=2, input_max_text_length=8): | ||
""" | ||
Build dataset for training. This builds the dataset from `load_dataset`, one should | ||
customize this function to train the model on its own dataset. | ||
Args: | ||
dataset_name (`str`): | ||
The name of the dataset to be loaded. | ||
Returns: | ||
dataloader (`torch.utils.data.DataLoader`): | ||
The dataloader for the dataset. | ||
""" | ||
tokenizer = AutoTokenizer.from_pretrained(config.model_name) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
# load imdb with datasets | ||
ds = load_dataset(dataset_name, split="train") | ||
ds = ds.rename_columns({"text": "review"}) | ||
ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False) | ||
|
||
input_size = LengthSampler(input_min_text_length, input_max_text_length) | ||
|
||
def tokenize(sample): | ||
sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()] | ||
sample["query"] = tokenizer.decode(sample["input_ids"]) | ||
return sample | ||
|
||
ds = ds.map(tokenize, batched=False) | ||
ds.set_format(type="torch") | ||
return ds | ||
|
||
|
||
# We retrieve the dataloader by calling the `build_dataset` function. | ||
dataset = build_dataset(config) | ||
|
||
|
||
def collator(data): | ||
return dict((key, [d[key] for d in data]) for key in data[0]) | ||
|
||
|
||
# Now let's build the model, the reference model, and the tokenizer. | ||
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name) | ||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name) | ||
tokenizer = AutoTokenizer.from_pretrained(config.model_name) | ||
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=config.learning_rate) | ||
|
||
# GPT-2 tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token. | ||
# only for this model. | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer | ||
ppo_trainer = PPOTrainer( | ||
config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator, optimizer=optimizer | ||
) | ||
|
||
# We then build the sentiment analysis pipeline, passing the model name and the | ||
# sentiment analysis pipeline arguments. Let's also make sure to set the device | ||
# to the same device as the PPOTrainer. | ||
device = ppo_trainer.accelerator.device | ||
if ppo_trainer.accelerator.num_processes == 1: | ||
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug | ||
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device) | ||
|
||
# We then define the arguments to pass to the `generate` function. These arguments | ||
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around | ||
# the `generate` function of the trained model. | ||
generation_kwargs = { | ||
"min_length": -1, | ||
"top_k": 0.0, | ||
"top_p": 1.0, | ||
"do_sample": True, | ||
"pad_token_id": tokenizer.eos_token_id, | ||
} | ||
output_min_length = 4 | ||
output_max_length = 16 | ||
output_length_sampler = LengthSampler(output_min_length, output_max_length) | ||
|
||
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): | ||
query_tensors = batch["input_ids"] | ||
|
||
#### Get response from gpt2 | ||
response_tensors = [] | ||
for query in query_tensors: | ||
gen_len = output_length_sampler() | ||
generation_kwargs["max_new_tokens"] = gen_len | ||
response = ppo_trainer.generate(query, **generation_kwargs) | ||
response_tensors.append(response.squeeze()[-gen_len:]) | ||
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors] | ||
|
||
#### Compute sentiment score | ||
texts = [q + r for q, r in zip(batch["query"], batch["response"])] | ||
pipe_outputs = sentiment_pipe(texts, **sent_kwargs) | ||
rewards = [torch.tensor(output[1]["score"]).to(device) for output in pipe_outputs] | ||
|
||
#### Run PPO step | ||
stats = ppo_trainer.step(query_tensors, response_tensors, rewards) | ||
ppo_trainer.log_stats(stats, batch, rewards) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.