Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
Joyce94 committed Aug 13, 2023
1 parent a26748e commit 13b3873
Show file tree
Hide file tree
Showing 10 changed files with 465 additions and 311 deletions.
397 changes: 218 additions & 179 deletions script/ppo/ppo_trainer_with_peft.py

Large diffs are not rendered by default.

26 changes: 15 additions & 11 deletions script/ppo/run_ppo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,46 @@ accelerate launch --config_file default_config.yaml run_ppo_with_peft.py \
--sft_model_path ${sft_model_path} \
--reward_lora_path ${reward_lora_path} \
--dataset_dir ${dataset_dir} \
--pretrain_dataset_dir ${pretrain_dataset_dir} \
--extra_dataset_dir ${extra_dataset_dir} \
--per_device_train_batch_size 2 \
--per_device_mini_train_batch_size 2 \
--gradient_accumulation_steps 2 \
--gradient_accumulation_steps 8 \
--do_train \
--num_train_epochs 1 \
--seed 512 \
--lr_scheduler_type cosine \
--actor_lr 1e-4 \
--critic_lr 1e-4 \
--logging_steps 10 \
--save_steps 10 \
--logging_steps 100 \
--save_steps 100 \
--dataloader_num_workers 16 \
--block_size 256 \
--max_prompt_length 256 \
--max_response_length 512 \
--output_dir ${actor_output_dir} \
--critic_output_dir ${critic_output_dir} \
--actor_lora_rank 8 \
--actor_lora_rank 64 \
--actor_lora_alpha 32 \
--actor_lora_target ${actor_lora_trainable} \
--actor_lora_dropout 0.05 \
--critic_lora_rank 8 \
--critic_lora_rank 64 \
--critic_lora_alpha 32 \
--critic_lora_target ${critic_lora_trainable} \
--critic_lora_dropout 0.05 \
--max_prompt_length 256 \
--max_response_length 256 \
--ppo_epochs 1 \
--gamma 1 \
--lam 0.95 \
--kl_penalty_beta 0.02 \
--use_last_reward \
--reward_score_clip 10 \
--value_clip 0.2 \
--ratio_clip 0.2 \
--actor_loss_weight 1 \
--critic_loss_weight 2 \
--pretrain_loss_weight 0.1 \
--pretrain_warmup_steps 500 \
--critic_loss_weight 1 \
--extra_loss_weight 0.2 \
--extra_warmup_steps_ratio 0.2 \
--entropy_beta 0.0 \
--kl_loss_alpha 0.0 \
--report_to "wandb" \
--torch_dtype float16 \
--fp16
122 changes: 90 additions & 32 deletions script/ppo/run_ppo_with_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@
from pathlib import Path
from datasets import load_dataset,concatenate_datasets
from itertools import chain
from utils.data_collator import PPODataCollatorWithPadding
from utils.data_collator import PPODataCollatorWithPadding,DataCollatorForSupervisedDataset
from utils.models import PPOEngine
from ppo.ppo_trainer_with_peft import PPOPeftTrainer

from torch.utils.data import DataLoader, RandomSampler
from accelerate import Accelerator
from torch.optim import AdamW

import os

logger = logging.getLogger(__name__)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

logger = logging.getLogger(__name__)
IGNORE_INDEX = -100

def main():

Expand All @@ -49,13 +52,17 @@ def process_tokenize(examples):

input_ids = source_ids + [tokenizer.bos_token_id]
labels = target_ids + [tokenizer.bos_token_id]

if len(input_ids) > training_args.max_prompt_length:
input_ids = input_ids[:training_args.max_prompt_length]
labels = labels[:training_args.max_prompt_length]

model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
return model_inputs

logger.info("process rlhf datasets")
with training_args.main_process_first(desc="process rlhf datasets"):
logger.info("process prompt datasets")
with training_args.main_process_first(desc="process prompt datasets"):
if data_args.dataset_dir is not None:
all_datasets = []
path = Path(data_args.dataset_dir)
Expand All @@ -68,7 +75,7 @@ def process_tokenize(examples):
cache_dir=data_args.data_cache_dir
)

tokenized_data = raw_dataset.shuffle().map(
tokenized_data = raw_dataset.map(
process_tokenize,
batched=True,
num_proc=training_args.dataloader_num_workers,
Expand All @@ -94,39 +101,90 @@ def process_tokenize_for_pt(examples):
return {"input_ids": result, "labels": result.copy()}


if data_args.pretrain_dataset_dir is not None:
logger.info("process pretrain data")
with training_args.main_process_first(desc="process pretrain data"):
pt_datasets = []
path = Path(data_args.pretrain_dataset_dir)
files = [file.name for file in path.glob("*.txt")]
for file in files:
data_path = os.path.join(path, file)
raw_dataset = load_dataset(
"text",
data_files=data_path
)
def process_tokenize_for_sft(examples):
PROMPT_TEMPLATE = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response: "
)
model_inputs = {"input_ids": [], "labels": []}
for instruction, input, output in zip(examples['instruction'], examples['input'], examples['output']):
if input is not None and input != "":
instruction = instruction + '\n' + input
source = PROMPT_TEMPLATE.format_map({'instruction':instruction})
source_ids = tokenizer.encode(text=source, add_special_tokens=False)
target_ids = tokenizer.encode(text=output, add_special_tokens=False)

tokenized_data = raw_dataset.shuffle().map(
process_tokenize_for_pt,
batched=True,
num_proc=training_args.dataloader_num_workers,
remove_columns="text"
)
pt_datasets.append(tokenized_data['train'])
if len(pt_datasets) == 1:
pt_datasets = pt_datasets[0]
input_ids = source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
labels = [IGNORE_INDEX] * len(source_ids) + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]

if len(input_ids) > training_args.max_length:
input_ids = input_ids[:training_args.max_length]
labels = labels[:training_args.max_length]

model_inputs["input_ids"].append(torch.LongTensor(input_ids))
model_inputs["labels"].append(torch.LongTensor(labels))

return model_inputs


if data_args.extra_dataset_dir is not None:
logger.info("process extra data")
with training_args.main_process_first(desc="process extra data"):
extra_datasets = []
path = Path(data_args.extra_dataset_dir)
if training_args.extra_dataset_type == 'sft':
files = [file.name for file in path.glob("*.json")]
for file in files:
data_path = os.path.join(path, file)
raw_dataset = load_dataset(
"json",
data_files=data_path,
)
tokenized_data = raw_dataset.map(
process_tokenize_for_sft,
batched=True,
num_proc=training_args.dataloader_num_workers,
remove_columns=["instruction","input","output"],
)
extra_datasets.append(tokenized_data['train'])

else:
files = [file.name for file in path.glob("*.txt")]
for file in files:
data_path = os.path.join(path, file)
raw_dataset = load_dataset(
"text",
data_files=data_path
)

tokenized_data = raw_dataset.map(
process_tokenize_for_pt,
batched=True,
num_proc=training_args.dataloader_num_workers,
remove_columns="text"
)
extra_datasets.append(tokenized_data['train'])

if len(extra_datasets) == 1:
extra_datasets = extra_datasets[0]
else:
pt_datasets = concatenate_datasets(pt_datasets)
# pt_datasets = pt_datasets.train_test_split(test_size=data_args.split_ratio)
extra_datasets = concatenate_datasets(extra_datasets)


## load model
logger.info("load model")

data_collator = PPODataCollatorWithPadding(tokenizer)
ppo_engine = PPOEngine(model_args, training_args)


data_collator = PPODataCollatorWithPadding(tokenizer)
if data_args.extra_dataset_dir is not None:
if training_args.extra_dataset_type == 'sft':
extra_data_collator = DataCollatorForSupervisedDataset(tokenizer)
else:
extra_data_collator = default_data_collator


logger.info("training")

trainer = PPOPeftTrainer(
Expand All @@ -136,8 +194,8 @@ def process_tokenize_for_pt(examples):
train_dataset = all_datasets,
data_collator = data_collator,
tokenizer = tokenizer,
pretrain_train_dataset = pt_datasets if data_args.pretrain_dataset_dir is not None else None,
pretrain_data_collator = default_data_collator if data_args.pretrain_dataset_dir is not None else None,
extra_train_dataset = extra_datasets if data_args.extra_dataset_dir is not None else None,
extra_data_collator = extra_data_collator if data_args.extra_dataset_dir is not None else None,

)

Expand Down
25 changes: 11 additions & 14 deletions script/rm/run_rm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,32 @@ torchrun --nnodes 1 --nproc_per_node 1 run_rm_with_peft.py \
--model_name_or_path ${pretrained_model} \
--dataset_dir ${dataset_dir} \
--split_ratio 0.01 \
--data_cache_dir ${data_cache_dir} \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--dataloader_num_workers 16 \
--gradient_accumulation_steps 8 \
--do_train \
--do_eval \
--seed 512 \
--fp16 \
--num_train_epochs 1 \
--max_length 1024 \
--max_length 512 \
--clm_loss_weight 1.0 \
--use_last_reward \
--learning_rate 1e-5 \
--warmup_ratio 0.05 \
--weight_decay 0.01 \
--logging_strategy steps \
--evaluation_strategy steps \
--logging_steps 10 \
--logging_steps 100 \
--save_strategy steps \
--save_total_limit 1 \
--eval_steps 10 \
--save_steps 10 \
--block_size 512 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 3 \
--output_dir ${output_dir} \
--overwrite_output_dir \
--logging_first_step True \
--lora_rank 8 \
--lora_rank 128 \
--lora_alpha 32 \
--lora_target ${lora_trainable} \
--lora_dropout 0.05 \
--torch_dtype float16 \
--report_to "wandb"

--report_to "wandb"
Loading

0 comments on commit 13b3873

Please sign in to comment.