Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

accelerate integration #58

Merged
merged 48 commits into from
Dec 30, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
9c977d0
working v1
younesbelkada Dec 27, 2022
1971cea
add `accelerate` on requirements
younesbelkada Dec 27, 2022
45cad09
add `accelerate` on `setup.py`
younesbelkada Dec 27, 2022
a0ebdaa
add `datasets` on `setup.py`
younesbelkada Dec 27, 2022
dec21f3
small updates
younesbelkada Dec 27, 2022
4254292
rm unneeded file
younesbelkada Dec 27, 2022
19f4d92
replace with `generate`
younesbelkada Dec 27, 2022
35330a9
Update trl/trainer/accelerate_ppo.py
younesbelkada Dec 27, 2022
34773de
correct return
younesbelkada Dec 27, 2022
b810d8a
add dataloader support
younesbelkada Dec 27, 2022
e4c57b2
add `wandb` to `setup.py`
younesbelkada Dec 27, 2022
7516b37
refactor
younesbelkada Dec 27, 2022
40f81e0
test
younesbelkada Dec 27, 2022
b1638e5
fix test
younesbelkada Dec 27, 2022
e2e7a90
rename file
younesbelkada Dec 27, 2022
96b4115
refactor
younesbelkada Dec 27, 2022
5eb46ad
remove unneeded device assignment
younesbelkada Dec 27, 2022
609f718
fix correct device assignment
younesbelkada Dec 27, 2022
4d57b47
standardize docstrings
younesbelkada Dec 27, 2022
fac85b5
add `wandb` on `dev`
younesbelkada Dec 27, 2022
c1b166b
fix slow convergence
younesbelkada Dec 28, 2022
9495f2a
oops
younesbelkada Dec 28, 2022
c813857
revert fix
younesbelkada Dec 28, 2022
157eca6
revert patch
younesbelkada Dec 28, 2022
2efb961
Merge remote-tracking branch 'origin/master' into accelerate-ppo
younesbelkada Dec 28, 2022
0a1c9a2
remove unneeded reshape
younesbelkada Dec 28, 2022
b6004f0
add input safety checker
younesbelkada Dec 28, 2022
f47b907
refactor
younesbelkada Dec 28, 2022
2918a8e
Apply suggestions from code review
younesbelkada Dec 29, 2022
747d5f0
refactor
younesbelkada Dec 29, 2022
7615994
some refactor
younesbelkada Dec 29, 2022
65be5bd
remove unneeded hack
younesbelkada Dec 29, 2022
edd5ea3
adapt dataset
younesbelkada Dec 29, 2022
76c2afd
fix test
younesbelkada Dec 29, 2022
5d41170
remove rollout
younesbelkada Dec 29, 2022
7843a34
remove timing
younesbelkada Dec 29, 2022
6cd89d5
remove `shuffle=True`
younesbelkada Dec 29, 2022
4e802e8
remove `LengthSampler` from trainer
younesbelkada Dec 29, 2022
6012a9b
refactor
younesbelkada Dec 29, 2022
d2c363f
remove text length sampler args from config
younesbelkada Dec 29, 2022
d048bbe
change collate_fn
younesbelkada Dec 29, 2022
66f23b1
fix silent bug
younesbelkada Dec 29, 2022
e318307
rename
younesbelkada Dec 29, 2022
31d12d6
move file
younesbelkada Dec 29, 2022
48c1070
refactor base trainer
younesbelkada Dec 29, 2022
e9cec71
fix collate
younesbelkada Dec 29, 2022
9a987d4
Merge remote-tracking branch 'origin/master' into accelerate-ppo
younesbelkada Dec 29, 2022
244f001
final bug
younesbelkada Dec 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor
- added comments on example
- fixes CI test
- rewards should be a list of tensors
- clearer error messages
- remove build model method
- refactor log stats method

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
  • Loading branch information
younesbelkada and lvwerra committed Dec 28, 2022
commit f47b907f6a828433d7aeee516b53cd1ad09d66a1
69 changes: 60 additions & 9 deletions examples/scripts/04-ppo-sentiment-accelerate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import time
from tqdm import tqdm
Expand All @@ -7,12 +21,29 @@
from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import PPOTrainer
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
from trl.trainer import LengthSampler

########################################################################
# This is a fully working simple example to use trl with accelerate.
#
# This example fine-tunes a GPT2 model on the IMDB dataset using PPO
# (proximal policy optimization).
# in any of the following settings (with the same script):
# - single CPU or single GPU
# - multi GPUS (using PyTorch distributed mode)
# - multi GPUS (using DeepSpeed ZeRO-Offload stages 1 & 2)
# - fp16 (mixed-precision) or fp32 (normal precision)
#
# To run it in each of these various modes, first initialize the accelerate
# configuration with `accelerate config`
#
########################################################################

# We first define the configuration of the experiment, defining the model, the dataset,
# the training parameters, and the PPO parameters.
config = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: use data classes instead of dicts for the config (easier to refactor in future) like we do in transformers: https://github.com/huggingface/transformers/blob/bbcd961897aa6cc439ef4cca5cef6db4283c5b76/examples/pytorch/text-classification/run_glue.py#L70

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a simple dataclass for now: 747d5f0
maybe we can refactor as it is done in transformers as a follow up PR!

"model_name": "lvwerra/gpt2-imdb",
# "model_name": "facebook/opt-350m",
"dataset_name": "imdb",
"cls_model_name": "lvwerra/distilbert-imdb",
"steps": 20000,
Expand All @@ -34,12 +65,17 @@
"vf_coef":.1,
}

# We then define the arguments to pass to the sentiment analysis pipeline.
# We set `return_all_scores` to True to get the sentiment score for each token.
sent_kwargs = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here re data classes

"return_all_scores": True,
"function_to_apply": "none",
"batch_size": config["forward_batch_size"]
}

# Below is an example function to build the dataset. In our case, we use the IMDB dataset
# from the `datasets` library. One should customize this function to train the model on
# its own dataset.
def build_dataset(config):
"""
Build dataset for training. This builds the dataset from `load_dataset`, one should
Expand All @@ -61,7 +97,6 @@ def build_dataset(config):
ds = ds.rename_columns({'text': 'review', 'label': 'sentiment'})
ds = ds.filter(lambda x: len(x["review"])>200, batched=False)


input_size = LengthSampler(config["txt_in_min_len"], config["txt_in_max_len"])

def tokenize(sample):
Expand All @@ -77,20 +112,36 @@ def collater(data):
dataloader = torch.utils.data.DataLoader(ds, batch_size=config['batch_size'], collate_fn=collater)
return dataloader

# We retrieve the dataloader by calling the `build_dataset` function.
dataloader = build_dataset(config)
ppo_trainer = PPOTrainer(dataloader, **config)

dataloader = ppo_trainer.dataloader
# Now let's build the model, the reference model, and the tokenizer.
model = AutoModelForCausalLMWithValueHead.from_pretrained(config["model_name"])
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config["model_name"])
tokenizer = AutoTokenizer.from_pretrained(config["model_name"])

tokenizer.pad_token = tokenizer.eos_token
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

tokenizer = ppo_trainer.tokenizer
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(model, ref_model, tokenizer, dataloader, **config)

# the PPOTrainer has a dataloader attribute, which we can use to get the dataloader -
# this step is important in a distributed setting, as the dataloader needs to be
# converted to a distributed dataloader.
dataloader = ppo_trainer.dataloader

# We then build the sentiment analysis pipeline, passing the model name and the
# sentiment analysis pipeline arguments. Let's also make sure to set the device
# to the same device as the PPOTrainer.
device = ppo_trainer.accelerator.device
if device.index is None:
# single GPU - maybe introduce this hack inside PPOTrainer?
device = 0
sentiment_pipe = pipeline("sentiment-analysis","lvwerra/distilbert-imdb", device=device)


# We then define the arguments to pass to the `generate` function. These arguments
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
# the `generate` function of the trained model.
gen_kwargs = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@younesbelkada younesbelkada Dec 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed!
However I tried to run generate with GenerationConfig and didn't managed to make it work since it seems that the feature is currently only available on the main branch. This PR: huggingface/transformers#20388 has been merged 2 weeks ago
So maybe let's address this in a follow up PR !

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's not commit to very new features, otherwise we need very hard transformers dependency

"min_length":-1,
"top_k": 0.0,
Expand All @@ -116,13 +167,13 @@ def collater(data):
t = time.time()
texts = [q + r for q,r in zip(batch['query'], batch['response'])]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the remove columns method inside the trainer the query shouldn't be there anymore? since we don't pass the data through the model internally, we don't need to remove the columns?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
rewards = torch.tensor([output[1]["score"] for output in pipe_outputs]).to(device)
rewards = [torch.tensor(output[1]["score"]).to(device) for output in pipe_outputs]
timing['time/get_sentiment_preds'] = time.time()-t

#### Run PPO step
t = time.time()
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, timing, batch, rewards, t0, logs)
ppo_trainer.log_stats(stats, batch, rewards, logs, timing, t0)
# Log the timing of the whole optimization step.
timing['time/optimization'] = time.time()-t

2 changes: 1 addition & 1 deletion tests/test_gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_gpt2_model():
for query_tensor, response_tensor in dummy_dataloader:
# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = torch.Tensor([1.0]* 2)
reward = [torch.tensor(1.0), torch.tensor(0.0)]
# train model
train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
break
Expand Down
7 changes: 6 additions & 1 deletion trl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .modeling_vhead import AutoModelForCausalLMWithValueHead
from .modeling_vhead import AutoModelForCausalLMWithValueHead
from .modeling_base import PreTrainedModelWrapper

SUPPORTED_ARCHITECTURES = (
AutoModelForCausalLMWithValueHead,
)
Loading