Skip to content

Commit

Permalink
modified: train_llm_from_scratch/dataset.py
Browse files Browse the repository at this point in the history
	new file:   train_llm_from_scratch/dpo.png
	new file:   train_llm_from_scratch/dpo_train.py
  • Loading branch information
wyf3 committed Dec 13, 2024
1 parent 9069f8a commit c453cb4
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 1 deletion.
74 changes: 73 additions & 1 deletion train_llm_from_scratch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,76 @@ def __getitem__(self, index):
# yield {
# 'input_ids': torch.from_numpy(X),
# 'labels': torch.from_numpy(Y),
# }
# }

class DPODataset(Dataset):
def __init__(self, data_path, tokenizer):
super().__init__()
self.data_path = data_path
self.tokenizer = tokenizer

with open(self.data_path, 'r', encoding='utf-8') as f:
self.datas = json.load(f)

def __getitem__(self, index):
sample = self.datas[index]
prompt = sample['prompt']
chosen = sample['chosen']
rejected = sample['rejected']
messages = [
{"role": "user", "content": prompt}
]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
prompt_inputs = self.tokenizer(text=text)['input_ids']
rejected_inputs = self.tokenizer(text=rejected)['input_ids'] + [self.tokenizer.eos_token_id]
chosen_inputs = self.tokenizer(text=chosen)['input_ids'] + [self.tokenizer.eos_token_id]
return [prompt_inputs, chosen_inputs, rejected_inputs]

def __len__(self):
return len(self.datas)


class DPODataCollator:
def __init__(self, tokenizer, max_seq_len):
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
def __call__(self, features):
inputs_ids = []
labels = []

for feature in features:
inputs_ids.append(feature[0] + feature[1])
labels.append([0]*len(feature[0]) + feature[1])
for feature in features:
inputs_ids.append(feature[0] + feature[2])
labels.append([0]*len(feature[0]) + feature[2])

def process(inputs_ids, labels):
inputs_ids = [input_ids[:self.max_seq_len] for input_ids in inputs_ids]
labels = [label[:self.max_seq_len] for label in labels]
max_len = max([len(input_ids) for input_ids in inputs_ids])
batch_input_ids = []
batch_labels = []

for input_ids, label in zip(inputs_ids, labels):
if len(input_ids) <= max_len:
input_ids = input_ids+[0]*(max_len-len(input_ids))
label = label+[0]*(max_len-len(label))
batch_input_ids.append(input_ids[:-1])
batch_labels.append(label[1:])
return batch_input_ids, batch_labels

inputs_ids, labels = process(inputs_ids, labels)

return {
"input_ids": torch.tensor(inputs_ids),
"labels": torch.tensor(labels)
}




Binary file added train_llm_from_scratch/dpo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
134 changes: 134 additions & 0 deletions train_llm_from_scratch/dpo_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from transformers import TrainingArguments, Trainer, AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
import torch.nn.functional as F
from dataset import DPODataset, DPODataCollator
from train import LLM, Config


def logits_to_probs(logits, labels):
# logits shape: (batch_size, seq_len, vocab_size)
# labels shape: (batch_size, seq_len)
# probs shape: (batch_size, seq_len)
log_probs = F.log_softmax(logits, dim=2)
probs = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)
return probs

def mask_logits(logits, labels):
# logits shape: (batch_size, seq_len, vocab_size)
# labels_masks shape: (batch_size, seq_len)
new_logits = []
for logit, label in zip(logits, labels):
new_logits.append(logit[label != 0].sum().unsqueeze(0))

return new_logits


def dpo_loss(ref_probs, probs, beta):
def split_probs(probs):
len_chosen = int(len(probs) // 2)
chosen_data = probs[:len_chosen]
reject_data = probs[len_chosen:]
return torch.cat(chosen_data), torch.cat(reject_data)

ref_chosen_probs, ref_reject_probs = split_probs(ref_probs)
chosen_probs, reject_probs = split_probs(probs)
pi_logratios = chosen_probs - reject_probs
ref_logratios = ref_chosen_probs - ref_reject_probs
logits = pi_logratios - ref_logratios
loss = -F.logsigmoid(beta*logits)
return loss.mean()



class DPOTrainer(Trainer):

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
input_ids = inputs['input_ids']
labels = inputs['labels']
with torch.no_grad():
ref_logits = ref_model(input_ids=input_ids, labels = labels).logits
ref_probs = logits_to_probs(ref_logits, labels)
ref_probs = mask_logits(ref_probs, labels)
logits = model(input_ids=input_ids, labels = labels).logits
probs = logits_to_probs(logits, labels)
probs = mask_logits(probs, labels)
loss = dpo_loss(ref_probs, probs, 0.1)
return loss

# def training_step(
# self, model, inputs, num_items_in_batch=None
# ) -> torch.Tensor:
# input_ids = inputs['input_ids']
# labels = inputs['labels']
# with torch.no_grad():
# ref_logits = ref_model(input_ids=input_ids, labels = labels).logits
# ref_probs = logits_to_probs(ref_logits, labels)
# ref_probs = mask_logits(ref_probs, labels)
# # 因为参考模型的累计概率不发生变化,为了尽量减少多次计算,计算一次参考模型的累积概率,多训练几次需要优化的模型
# for _ in range(1):

# model.train()
# logits = model(input_ids=input_ids, labels = labels).logits
# probs = logits_to_probs(logits, labels)
# probs = mask_logits(probs, labels)

# if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
# self.optimizer.train()

# with self.compute_loss_context_manager():
# # loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
# loss = dpo_loss(ref_probs, probs, 0.2)

# # del inputs
# if (
# self.args.torch_empty_cache_steps is not None
# and self.state.global_step % self.args.torch_empty_cache_steps == 0
# ):

# torch.cuda.empty_cache()

# kwargs = {}

# if self.args.n_gpu > 1:
# loss = loss.mean() # mean() to average on multi-gpu parallel training

# self.accelerator.backward(loss, retain_graph=True, **kwargs)
# # Finally we need to normalize the loss for reporting
# if num_items_in_batch is None:
# return loss.detach() / self.args.gradient_accumulation_steps
# return loss.detach()


if __name__ == "__main__":
AutoConfig.register("small_model", Config)
AutoModelForCausalLM.register(Config, LLM)
model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_model_from_scratch/saves/sft')

print(f'模型可训练参数量为:{sum(p.numel() for p in model.parameters() if p.requires_grad)}')
ref_model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_model_from_scratch/saves/sft').eval().to('cuda')

tokenizer = AutoTokenizer.from_pretrained("/home/user/wyf/train_model_from_scratch/tokenizer", use_fast=True)
data_collator = DPODataCollator(tokenizer, max_seq_len=512) # 加载的大模型旋转位置编码最大长度为1024,这里不能超过这个值
args = TrainingArguments(output_dir='./dpo-1-epoch',
num_train_epochs=1, # 训练太多轮,模型似乎会输出很多重复内容
do_train=True,
per_device_train_batch_size=16,
gradient_accumulation_steps=4,
# max_steps=15000,
logging_steps=50,
report_to='tensorboard',
save_total_limit=3,
bf16=True,
learning_rate=0.00001, # 学习率很重要,太大会把模型训飞
lr_scheduler_type='cosine',
dataloader_num_workers=1,
dataloader_pin_memory=True,
save_safetensors=False,
save_steps=100)
dataset = DPODataset('/home/user/wyf/train_model_from_scratch/dataset/dpo_data_512.json', tokenizer=tokenizer)
trainer = DPOTrainer(model=model, args=args, train_dataset=dataset, tokenizer=tokenizer, data_collator=data_collator)

# 如果是初次训练resume_from_checkpoint为false,接着checkpoint继续训练,为True
trainer.train(resume_from_checkpoint=False)
trainer.save_model('./saves/dpo-1-epoch')
trainer.save_state()

0 comments on commit c453cb4

Please sign in to comment.