Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix DS for peft ref_model in ppo trainer #309

Merged
merged 1 commit into from
Apr 25, 2023
Merged

fix DS for peft ref_model in ppo trainer #309

merged 1 commit into from
Apr 25, 2023

Conversation

halfrot
Copy link
Contributor

@halfrot halfrot commented Apr 18, 2023

self.ref_model can be None. peft ref_model is got by calling disable_adapter method, e.g. ,

with self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter():
    ref_logprobs, _, _, _ = self.batched_forward_pass(self.model, queries, responses, model_inputs)

peft ref_model is got by calling `disable_adapter` method, e.g. ,
```
with self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter():
    ref_logprobs, _, _, _ = self.batched_forward_pass(self.model, queries, responses, model_inputs)
```
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 18, 2023

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada
Copy link
Contributor

Hi @halfrot
Thanks a lot for your PR,
I think the PR looks really good, could you just attach a reproducible script on the PR? Thanks!

@halfrot
Copy link
Contributor Author

halfrot commented Apr 25, 2023

Hi @halfrot Thanks a lot for your PR, I think the PR looks really good, could you just attach a reproducible script on the PR? Thanks!

sure and I just reproduced this error by running accelerate launch --config_file config.yaml example.py
here example.py almost copied from gpt2-sentiment_peft.py

# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional

import torch
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoTokenizer, HfArgumentParser, pipeline

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
from trl.core import LengthSampler
from deepspeed.ops.adam import DeepSpeedCPUAdam
from pprint import pprint
import wandb

wandb.login()
tqdm.pandas()


########################################################################
# This is a fully working simple example to use trl with accelerate.
#
# This example fine-tunes a GPT2 model on the IMDB dataset using PPO
# (proximal policy optimization).
# in any of the following settings (with the same script):
#   - single CPU or single GPU
#   - multi GPUS (using PyTorch distributed mode)
#   - multi GPUS (using DeepSpeed ZeRO-Offload stages 1 & 2)
#   - fp16 (mixed-precision) or fp32 (normal precision)
#
# To run it in each of these various modes, first initialize the accelerate
# configuration with `accelerate config`
#
########################################################################

########################################################################
# NOTE for to train with a 8-bit model a more recent version of
# transformers is required, full dependecies for this example:
# pip install  bitsandbytes datasets accelerate loralib
# pip install  git+https://github.com/huggingface/transformers.git@main
# pip install git+https://github.com/huggingface/peft.git
########################################################################

# We first define the configuration of the experiment, defining the model, the dataset,
# the training parameters, and the PPO parameters.
# Check the default arguments in the `PPOConfig` class for more details.
# If you want to log with tensorboard, add the kwarg
# `accelerator_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig.


# Define and parse arguments.
@dataclass
class ScriptArguments:
    """
    The name of the Casual LM model we wish to fine with PPO
    """

    # NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
    # models like gpt-neo* models are more suitable.
    model_name: Optional[str] = field(default="edbeeching/gpt-neo-125M-imdb", metadata={
        "help": "the model name"
    })
    log_with: Optional[str] = field(default="wandb", metadata={
        "help": "use 'wandb' to log with wandb"
    })
    learning_rate: Optional[float] = field(default=1.41e-5, metadata={
        "help": "the learning rate"
    })
    mini_batch_size: Optional[int] = field(default=16, metadata={
        "help": "the PPO minibatch size"
    })
    batch_size: Optional[int] = field(default=256, metadata={
        "help": "the batch size"
    })
    gradient_accumulation_steps: Optional[int] = field(
        default=1, metadata={
            "help": "the number of gradient accumulation steps"
        }
    )


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

run = wandb.init(
    project="PPO_training",
    # name=script_args.run_name,
    config={
        "model_name": script_args.model_name,
        "lr": script_args.learning_rate,
        "batch_size": script_args.batch_size
    }
)
config = PPOConfig(
    model_name=script_args.model_name,
    learning_rate=script_args.learning_rate,
    log_with=script_args.log_with,
    mini_batch_size=script_args.mini_batch_size,
    batch_size=script_args.batch_size,
    gradient_accumulation_steps=script_args.gradient_accumulation_steps,
)

# We then define the arguments to pass to the sentiment analysis pipeline.
# We set `return_all_scores` to True to get the sentiment score for each token.
sent_kwargs = {
    "return_all_scores": True,
    "function_to_apply": "none",
    "batch_size": config.mini_batch_size
}


# Below is an example function to build the dataset. In our case, we use the IMDB dataset
# from the `datasets` library. One should customize this function to train the model on
# its own dataset.
def build_dataset(config, dataset_name="imdb", input_min_text_length=2, input_max_text_length=8):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({
        "text": "review"
    })
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds


# We retrieve the dataloader by calling the `build_dataset` function.
dataset = build_dataset(config)


def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])


# set seed before initializing value head for deterministic eval
set_seed(config.seed)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = AutoModelForCausalLMWithValueHead.from_pretrained(
    config.model_name,
    load_in_8bit=True,
    # torch_dtype=torch.bfloat16,
    peft_config=lora_config,
    layer_norm_names=[],
)

tokenizer = AutoTokenizer.from_pretrained(config.model_name)


# Apply LoRA
# Here comes the magic with `peft`! Let's load a `PeftModel` and specify that we are going to use low-rank adapters (
# LoRA) using `get_peft_model` utility function from `peft`.
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: "
        f"{100 * trainable_params / all_param}"
    )


print_trainable_parameters(model)

# GPT-2 tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token.
# only for this model.
tokenizer.pad_token = tokenizer.eos_token
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(config, model, ref_model=None, tokenizer=tokenizer, dataset=dataset, data_collator=collator,
                         optimizer=DeepSpeedCPUAdam(filter(lambda p: p.requires_grad, model.parameters()),
                                                    lr=config.learning_rate))

# We then build the sentiment analysis pipeline, passing the model name and the
# sentiment analysis pipeline arguments. Let's also make sure to set the device
# to the same device as the PPOTrainer.
device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
    device = model.current_device if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)

# We then define the arguments to pass to the `generate` function. These arguments
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
# the `generate` function of the trained model.
generation_kwargs = {
    "min_length": -1,
    # "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "eos_token_id": -1,
}
output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)

# print(len(ppo_trainer.dataloader))
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]

    # cache and gradient checkpointing are not compatible, so we switch them on and off here
    model.gradient_checkpointing_disable()
    model.pretrained_model.config.use_cache = True
    # Get response from Causal LM
    response_tensors = ppo_trainer.generate(
        query_tensors, return_prompt=False, length_sampler=output_length_sampler, **generation_kwargs
    )
    batch["response"] = tokenizer.batch_decode(response_tensors)

    # Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

    # Run PPO step
    model.gradient_checkpointing_enable()
    model.pretrained_model.config.use_cache = False

    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

# model.push_to_hub(f"{script_args.model_name}-ppo-sentiment")

here config.yaml

compute_environment: LOCAL_MACHINE
deepspeed_config:
  gradient_accumulation_steps: 4
  gradient_clipping: 1.0
  offload_optimizer_device: cpu
  offload_param_device: cpu
  zero3_init_flag: true
  zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
dynamo_config: {}
fsdp_config: {}
machine_rank: 0
main_training_function: main
megatron_lm_config: {}
mixed_precision: fp8
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

and also error logs here

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/PPOTrain/test.py:223 in <module>                                                           │
│                                                                                                  │
│   220 # only for this model.                                                                     │
│   221 tokenizer.pad_token = tokenizer.eos_token                                                  │
│   222 # We then build the PPOTrainer, passing the model, the reference model, the tokenizer      │
│ ❱ 223 ppo_trainer = PPOTrainer(config, model, ref_model=None, tokenizer=tokenizer, dataset=dat   │
│   224 │   │   │   │   │   │    optimizer=DeepSpeedCPUAdam(filter(lambda p: p.requires_grad, mo   │
│   225 │   │   │   │   │   │   │   │   │   │   │   │   │   lr=config.learning_rate))              │
│   226                                                                                            │
│                                                                                                  │
│ /home/trl/trl/trl/trainer/ppo_trainer.py:284 in __init__                                         │
│                                                                                                  │
│    281 │   │   )                                                                                 │
│    282 │   │   if is_deepspeed_used:                                                             │
│    283 │   │   │   # 8 bit models are already set on the correct device                          │
│ ❱  284 │   │   │   if not getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False):  │
│    285 │   │   │   │   # DS integration only allows for single model and as `ref_model` is only  │
│    286 │   │   │   │   # `KL devergence loss`,i.e, in eval model, just have it be on the respec  │
│    287 │   │   │   │   # there is no need to pass it to the `accelerator.prepare` call           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: 'NoneType' object has no attribute 'pretrained_model'

@younesbelkada
Copy link
Contributor

Awesome, can you confirm this fixes your issue? i.e. after that you can train correctly with DS + peft + int8?

@halfrot
Copy link
Contributor Author

halfrot commented Apr 25, 2023

I've got quite a normal curve running the same script after fixing it
image
image
image

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for fixing the PEFT + DeepSpeed + int8 issue! Awesome contribution!

@younesbelkada younesbelkada merged commit 23a06c9 into huggingface:main Apr 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants