Skip to content

Commit

Permalink
new file: train_moe_from_scratch/1.ipynb
Browse files Browse the repository at this point in the history
	new file:   train_moe_from_scratch/README.md
	new file:   train_moe_from_scratch/dataset.py
	new file:   train_moe_from_scratch/moe_sft_train.py
	new file:   train_moe_from_scratch/moe_test.py
	new file:   train_moe_from_scratch/moe_train.py
	new file:   train_moe_from_scratch/screenshot-20241207-093824.png
	new file:   train_moe_from_scratch/sft.jsonl
	new file:   train_moe_from_scratch/tokenizer/merges.txt
	new file:   train_moe_from_scratch/tokenizer/tokenizer.json
	new file:   train_moe_from_scratch/tokenizer/tokenizer_config.json
	new file:   train_moe_from_scratch/tokenizer/vocab.json
	new file:   train_moe_from_scratch/train.jsonl
  • Loading branch information
wyf3 committed Jan 4, 2025
1 parent d286d09 commit 832835d
Show file tree
Hide file tree
Showing 13 changed files with 39,843 additions and 0 deletions.
1,876 changes: 1,876 additions & 0 deletions train_moe_from_scratch/1.ipynb

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions train_moe_from_scratch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# 使用方法

## 下载数据

https://github.com/jingyaogong/minimind
![image](.\screenshot-20241207-093824.png)

## 开始训练
### 直接运行
预训练:\
python moe_train.py\
SFT:\
python moe_sft_train.py
### torchrun
预训练:\
torchrun --nproc_per_node=2 moe_train.py
SFT:\
torchrun --nproc_per_node=2 moe_sft_train.py
### deepspeed
预训练:\
deepspeed --include 'localhost:0,1' moe_train.py\
SFT:\
deepspeed --include 'localhost:0,1' moe_sft_train.py

## 测试
python moe_test.py
197 changes: 197 additions & 0 deletions train_moe_from_scratch/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
import os
import pandas as pd

from torch.utils.data import IterableDataset, Dataset
import json
import numpy as np
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import PretrainedConfig
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator, DataCollatorForTokenClassification, AutoConfig


class LLMDataset(Dataset):
def __init__(self, data_path, tokenizer, max_seq_len):
super().__init__()
self.data_path = data_path
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
with open(self.data_path, 'r', encoding='utf-8') as f:
self.data = f.readlines()

def __len__(self):
return len(self.data)

def __getitem__(self, index: int):

line = self.data[index]
line = json.loads(line)
text = '<s>' + line['text'] + '</s>'
input_ids = self.tokenizer.encode(text)
text_len = len(input_ids)
if text_len > self.max_seq_len:
input_ids = input_ids[:self.max_seq_len]
else:
input_ids = input_ids + [0] * (self.max_seq_len - text_len)
input_ids = np.array(input_ids)
X = np.array(input_ids[:-1]).astype(np.int64)
Y = np.array(input_ids[1:]).astype(np.int64)
return {
'input_ids': torch.from_numpy(X),
'labels': torch.from_numpy(Y),
}

class SFTDataset(Dataset):
def __init__(self, data_path, tokenizer, max_seq_len):
super().__init__()
self.data_path = data_path
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len

with open(self.data_path, 'r', encoding='utf-8') as f:
self.data = f.readlines()

def __len__(self):
return len(self.data)

def __getitem__(self, index):
line = self.data[index]
line = json.loads(line)
instruction_text = line['instruction']
input_text = line['input']
output_text = line['output']
history = line['history']
query = instruction_text + input_text
answer = output_text + self.tokenizer.eos_token
messages = []
if history:
for i in history:
messages.append({'role': 'user', 'content': i[0]})
messages.append({'role': 'assistant', 'content': i[1]})

messages.append({'role': 'user', 'content': query})
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
prompt_input_ids = self.tokenizer.encode(prompt)
answer_input_ids = self.tokenizer.encode(answer)
input_ids = prompt_input_ids + answer_input_ids
labels = [0] * len(prompt_input_ids) + answer_input_ids
text_len = len(input_ids)
if text_len > self.max_seq_len:
input_ids = input_ids[:self.max_seq_len]
labels = labels[:self.max_seq_len]
else:
input_ids = input_ids + [0] * (self.max_seq_len - text_len)
labels = labels + [0] * (self.max_seq_len - text_len)

input_ids = input_ids[:-1]
labels = labels[1:]
return {'input_ids': torch.tensor(input_ids), 'labels': torch.tensor(labels)}


# 内存不够,可使用如下方法加载数据
# class LLMDataset(IterableDataset):
# def __init__(self, data_path, tokenizer, max_seq_len):
# super().__init__()
# self.data_path = data_path
# self.tokenizer = tokenizer
# self.max_seq_len = max_seq_len

# def __iter__(self):
# return self.data_process()

# def data_process(self):
# with open(self.data_path, 'r', encoding='utf-8') as f:
# for line in f:
# line = json.loads(line)
# text = '<s>' + line['text'] + '</s>'
# input_ids = self.tokenizer.encode(text)
# text_len = len(input_ids)
# if text_len > self.max_seq_len:
# input_ids = input_ids[:self.max_seq_len]
# else:
# input_ids = input_ids + [0] * (self.max_seq_len - text_len)
# input_ids = np.array(input_ids)
# X = np.array(input_ids[:-1]).astype(np.int64)
# Y = np.array(input_ids[1:]).astype(np.int64)
# yield {
# 'input_ids': torch.from_numpy(X),
# 'labels': torch.from_numpy(Y),
# }

class DPODataset(Dataset):
def __init__(self, data_path, tokenizer):
super().__init__()
self.data_path = data_path
self.tokenizer = tokenizer

with open(self.data_path, 'r', encoding='utf-8') as f:
self.datas = json.load(f)

def __getitem__(self, index):
sample = self.datas[index]
prompt = sample['prompt']
chosen = sample['chosen']
rejected = sample['rejected']
messages = [
{"role": "user", "content": prompt}
]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
prompt_inputs = self.tokenizer(text=text)['input_ids']
rejected_inputs = self.tokenizer(text=rejected)['input_ids'] + [self.tokenizer.eos_token_id]
chosen_inputs = self.tokenizer(text=chosen)['input_ids'] + [self.tokenizer.eos_token_id]
return [prompt_inputs, chosen_inputs, rejected_inputs]

def __len__(self):
return len(self.datas)


class DPODataCollator:
def __init__(self, tokenizer, max_seq_len):
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
def __call__(self, features):
inputs_ids = []
labels = []

for feature in features:
inputs_ids.append(feature[0] + feature[1])
labels.append([0]*len(feature[0]) + feature[1])
for feature in features:
inputs_ids.append(feature[0] + feature[2])
labels.append([0]*len(feature[0]) + feature[2])

def process(inputs_ids, labels):
inputs_ids = [input_ids[:self.max_seq_len] for input_ids in inputs_ids]
labels = [label[:self.max_seq_len] for label in labels]
max_len = max([len(input_ids) for input_ids in inputs_ids])
batch_input_ids = []
batch_labels = []

for input_ids, label in zip(inputs_ids, labels):
if len(input_ids) <= max_len:
input_ids = input_ids+[0]*(max_len-len(input_ids))
label = label+[0]*(max_len-len(label))
batch_input_ids.append(input_ids[:-1])
batch_labels.append(label[1:])
return batch_input_ids, batch_labels

inputs_ids, labels = process(inputs_ids, labels)

return {
"input_ids": torch.tensor(inputs_ids),
"labels": torch.tensor(labels)
}




50 changes: 50 additions & 0 deletions train_moe_from_scratch/moe_sft_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
import os
import pandas as pd

from torch.utils.data import IterableDataset, Dataset
import json
import numpy as np
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import PretrainedConfig
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator, DataCollatorForTokenClassification, AutoConfig
from dataset import SFTDataset, LLMDataset
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import torch
from moe_train import LLM, Config

if __name__ == '__main__':
AutoConfig.register("moe_model", Config)
AutoModelForCausalLM.register(Config, LLM)
model = AutoModelForCausalLM.from_pretrained('./saves/moe')
print(f'模型参数量为:{sum(p.numel() for p in model.parameters() if p.requires_grad)}')

data_collator = DefaultDataCollator()
tokenizer = AutoTokenizer.from_pretrained("./tokenizer", use_fast=True)
args = TrainingArguments(output_dir='./sft',
num_train_epochs=5,
do_train=True,
per_device_train_batch_size=2,
gradient_accumulation_steps=1,
# max_steps=15000,
logging_steps=1,
report_to='tensorboard',
save_total_limit=5,
bf16=True,
learning_rate=2e-4,
lr_scheduler_type='cosine',
dataloader_num_workers=1,
dataloader_pin_memory=True,
save_safetensors=False)
dataset = SFTDataset('./sft.jsonl', tokenizer=tokenizer, max_seq_len=1024)
trainer = Trainer(model=model, args=args, train_dataset=dataset, tokenizer=tokenizer, data_collator=data_collator)
# 如果是初次训练resume_from_checkpoint为false,接着checkpoint继续训练,为True
trainer.train(resume_from_checkpoint=False)
trainer.save_model('./saves/sft')
trainer.save_state()
13 changes: 13 additions & 0 deletions train_moe_from_scratch/moe_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import torch
from moe_train import LLM, Config
t = AutoTokenizer.from_pretrained('./saves/moe')
AutoConfig.register("moe_model", Config)
AutoModelForCausalLM.register(Config, LLM)
model = AutoModelForCausalLM.from_pretrained('./saves/moe')

input_data = [t.bos_token_id] + t.encode('1+1等于')
print(input_data)

for token in model.generate({"input_ids":torch.tensor(input_data).unsqueeze(0), "labels":None}, t.eos_token_id, 20, stream=False,temperature=0.0, top_k=1):
print(t.decode(token[0]))
Loading

0 comments on commit 832835d

Please sign in to comment.