-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
accelerate
integration
#58
Changes from 1 commit
9c977d0
1971cea
45cad09
a0ebdaa
dec21f3
4254292
19f4d92
35330a9
34773de
b810d8a
e4c57b2
7516b37
40f81e0
b1638e5
e2e7a90
96b4115
5eb46ad
609f718
4d57b47
fac85b5
c1b166b
9495f2a
c813857
157eca6
2efb961
0a1c9a2
b6004f0
f47b907
2918a8e
747d5f0
7615994
65be5bd
edd5ea3
76c2afd
5d41170
7843a34
6cd89d5
4e802e8
6012a9b
d2c363f
d048bbe
66f23b1
e318307
31d12d6
48c1070
e9cec71
9a987d4
244f001
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- added comments on example - fixes CI test - rewards should be a list of tensors - clearer error messages - remove build model method - refactor log stats method Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,17 @@ | ||
# coding=utf-8 | ||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import torch | ||
import time | ||
from tqdm import tqdm | ||
|
@@ -7,12 +21,29 @@ | |
from transformers import pipeline, AutoTokenizer | ||
from datasets import load_dataset | ||
|
||
from trl import PPOTrainer | ||
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead | ||
from trl.trainer import LengthSampler | ||
|
||
######################################################################## | ||
# This is a fully working simple example to use trl with accelerate. | ||
# | ||
# This example fine-tunes a GPT2 model on the IMDB dataset using PPO | ||
# (proximal policy optimization). | ||
# in any of the following settings (with the same script): | ||
# - single CPU or single GPU | ||
# - multi GPUS (using PyTorch distributed mode) | ||
# - multi GPUS (using DeepSpeed ZeRO-Offload stages 1 & 2) | ||
# - fp16 (mixed-precision) or fp32 (normal precision) | ||
# | ||
# To run it in each of these various modes, first initialize the accelerate | ||
# configuration with `accelerate config` | ||
# | ||
######################################################################## | ||
|
||
# We first define the configuration of the experiment, defining the model, the dataset, | ||
# the training parameters, and the PPO parameters. | ||
config = { | ||
"model_name": "lvwerra/gpt2-imdb", | ||
# "model_name": "facebook/opt-350m", | ||
"dataset_name": "imdb", | ||
"cls_model_name": "lvwerra/distilbert-imdb", | ||
"steps": 20000, | ||
|
@@ -34,12 +65,17 @@ | |
"vf_coef":.1, | ||
} | ||
|
||
# We then define the arguments to pass to the sentiment analysis pipeline. | ||
# We set `return_all_scores` to True to get the sentiment score for each token. | ||
sent_kwargs = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here re data classes |
||
"return_all_scores": True, | ||
"function_to_apply": "none", | ||
"batch_size": config["forward_batch_size"] | ||
} | ||
|
||
# Below is an example function to build the dataset. In our case, we use the IMDB dataset | ||
# from the `datasets` library. One should customize this function to train the model on | ||
# its own dataset. | ||
def build_dataset(config): | ||
""" | ||
Build dataset for training. This builds the dataset from `load_dataset`, one should | ||
|
@@ -61,7 +97,6 @@ def build_dataset(config): | |
ds = ds.rename_columns({'text': 'review', 'label': 'sentiment'}) | ||
ds = ds.filter(lambda x: len(x["review"])>200, batched=False) | ||
|
||
|
||
input_size = LengthSampler(config["txt_in_min_len"], config["txt_in_max_len"]) | ||
|
||
def tokenize(sample): | ||
|
@@ -77,20 +112,36 @@ def collater(data): | |
dataloader = torch.utils.data.DataLoader(ds, batch_size=config['batch_size'], collate_fn=collater) | ||
return dataloader | ||
|
||
# We retrieve the dataloader by calling the `build_dataset` function. | ||
dataloader = build_dataset(config) | ||
ppo_trainer = PPOTrainer(dataloader, **config) | ||
|
||
dataloader = ppo_trainer.dataloader | ||
# Now let's build the model, the reference model, and the tokenizer. | ||
model = AutoModelForCausalLMWithValueHead.from_pretrained(config["model_name"]) | ||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config["model_name"]) | ||
tokenizer = AutoTokenizer.from_pretrained(config["model_name"]) | ||
|
||
tokenizer.pad_token = tokenizer.eos_token | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
tokenizer = ppo_trainer.tokenizer | ||
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer | ||
ppo_trainer = PPOTrainer(model, ref_model, tokenizer, dataloader, **config) | ||
|
||
# the PPOTrainer has a dataloader attribute, which we can use to get the dataloader - | ||
# this step is important in a distributed setting, as the dataloader needs to be | ||
# converted to a distributed dataloader. | ||
dataloader = ppo_trainer.dataloader | ||
|
||
# We then build the sentiment analysis pipeline, passing the model name and the | ||
# sentiment analysis pipeline arguments. Let's also make sure to set the device | ||
# to the same device as the PPOTrainer. | ||
device = ppo_trainer.accelerator.device | ||
if device.index is None: | ||
# single GPU - maybe introduce this hack inside PPOTrainer? | ||
device = 0 | ||
sentiment_pipe = pipeline("sentiment-analysis","lvwerra/distilbert-imdb", device=device) | ||
|
||
|
||
# We then define the arguments to pass to the `generate` function. These arguments | ||
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around | ||
# the `generate` function of the trained model. | ||
gen_kwargs = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be cleaner to use the new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, let's not commit to very new features, otherwise we need very hard |
||
"min_length":-1, | ||
"top_k": 0.0, | ||
|
@@ -116,13 +167,13 @@ def collater(data): | |
t = time.time() | ||
texts = [q + r for q,r in zip(batch['query'], batch['response'])] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. with the remove columns method inside the trainer the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The query are kept here, https://github.com/younesbelkada/trl/blob/d2c363fe4018c74df829ed6c067fad50ecaaf479/trl/trainer/ppo_trainer.py#L152 but maybe we can change that, wdyt? |
||
pipe_outputs = sentiment_pipe(texts, **sent_kwargs) | ||
rewards = torch.tensor([output[1]["score"] for output in pipe_outputs]).to(device) | ||
rewards = [torch.tensor(output[1]["score"]).to(device) for output in pipe_outputs] | ||
timing['time/get_sentiment_preds'] = time.time()-t | ||
|
||
#### Run PPO step | ||
t = time.time() | ||
stats = ppo_trainer.step(query_tensors, response_tensors, rewards) | ||
ppo_trainer.log_stats(stats, timing, batch, rewards, t0, logs) | ||
ppo_trainer.log_stats(stats, batch, rewards, logs, timing, t0) | ||
# Log the timing of the whole optimization step. | ||
timing['time/optimization'] = time.time()-t | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: use data classes instead of dicts for the config (easier to refactor in future) like we do in
transformers
: https://github.com/huggingface/transformers/blob/bbcd961897aa6cc439ef4cca5cef6db4283c5b76/examples/pytorch/text-classification/run_glue.py#L70There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a simple dataclass for now: 747d5f0
maybe we can refactor as it is done in
transformers
as a follow up PR!