-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
new file: train_moe_from_scratch/1.ipynb
new file: train_moe_from_scratch/README.md new file: train_moe_from_scratch/dataset.py new file: train_moe_from_scratch/moe_sft_train.py new file: train_moe_from_scratch/moe_test.py new file: train_moe_from_scratch/moe_train.py new file: train_moe_from_scratch/screenshot-20241207-093824.png new file: train_moe_from_scratch/sft.jsonl new file: train_moe_from_scratch/tokenizer/merges.txt new file: train_moe_from_scratch/tokenizer/tokenizer.json new file: train_moe_from_scratch/tokenizer/tokenizer_config.json new file: train_moe_from_scratch/tokenizer/vocab.json new file: train_moe_from_scratch/train.jsonl
- Loading branch information
Showing
13 changed files
with
39,843 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# 使用方法 | ||
|
||
## 下载数据 | ||
|
||
https://github.com/jingyaogong/minimind | ||
![image](.\screenshot-20241207-093824.png) | ||
|
||
## 开始训练 | ||
### 直接运行 | ||
预训练:\ | ||
python moe_train.py\ | ||
SFT:\ | ||
python moe_sft_train.py | ||
### torchrun | ||
预训练:\ | ||
torchrun --nproc_per_node=2 moe_train.py | ||
SFT:\ | ||
torchrun --nproc_per_node=2 moe_sft_train.py | ||
### deepspeed | ||
预训练:\ | ||
deepspeed --include 'localhost:0,1' moe_train.py\ | ||
SFT:\ | ||
deepspeed --include 'localhost:0,1' moe_sft_train.py | ||
|
||
## 测试 | ||
python moe_test.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
import math | ||
from typing import List, Optional, Tuple, Union | ||
import torch | ||
import torch.nn.functional as F | ||
import torch.utils.checkpoint | ||
from torch import nn | ||
import os | ||
import pandas as pd | ||
|
||
from torch.utils.data import IterableDataset, Dataset | ||
import json | ||
import numpy as np | ||
from transformers import PreTrainedModel | ||
from transformers.modeling_outputs import CausalLMOutputWithPast | ||
from transformers import PretrainedConfig | ||
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator, DataCollatorForTokenClassification, AutoConfig | ||
|
||
|
||
class LLMDataset(Dataset): | ||
def __init__(self, data_path, tokenizer, max_seq_len): | ||
super().__init__() | ||
self.data_path = data_path | ||
self.tokenizer = tokenizer | ||
self.max_seq_len = max_seq_len | ||
with open(self.data_path, 'r', encoding='utf-8') as f: | ||
self.data = f.readlines() | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, index: int): | ||
|
||
line = self.data[index] | ||
line = json.loads(line) | ||
text = '<s>' + line['text'] + '</s>' | ||
input_ids = self.tokenizer.encode(text) | ||
text_len = len(input_ids) | ||
if text_len > self.max_seq_len: | ||
input_ids = input_ids[:self.max_seq_len] | ||
else: | ||
input_ids = input_ids + [0] * (self.max_seq_len - text_len) | ||
input_ids = np.array(input_ids) | ||
X = np.array(input_ids[:-1]).astype(np.int64) | ||
Y = np.array(input_ids[1:]).astype(np.int64) | ||
return { | ||
'input_ids': torch.from_numpy(X), | ||
'labels': torch.from_numpy(Y), | ||
} | ||
|
||
class SFTDataset(Dataset): | ||
def __init__(self, data_path, tokenizer, max_seq_len): | ||
super().__init__() | ||
self.data_path = data_path | ||
self.tokenizer = tokenizer | ||
self.max_seq_len = max_seq_len | ||
|
||
with open(self.data_path, 'r', encoding='utf-8') as f: | ||
self.data = f.readlines() | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, index): | ||
line = self.data[index] | ||
line = json.loads(line) | ||
instruction_text = line['instruction'] | ||
input_text = line['input'] | ||
output_text = line['output'] | ||
history = line['history'] | ||
query = instruction_text + input_text | ||
answer = output_text + self.tokenizer.eos_token | ||
messages = [] | ||
if history: | ||
for i in history: | ||
messages.append({'role': 'user', 'content': i[0]}) | ||
messages.append({'role': 'assistant', 'content': i[1]}) | ||
|
||
messages.append({'role': 'user', 'content': query}) | ||
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) | ||
prompt_input_ids = self.tokenizer.encode(prompt) | ||
answer_input_ids = self.tokenizer.encode(answer) | ||
input_ids = prompt_input_ids + answer_input_ids | ||
labels = [0] * len(prompt_input_ids) + answer_input_ids | ||
text_len = len(input_ids) | ||
if text_len > self.max_seq_len: | ||
input_ids = input_ids[:self.max_seq_len] | ||
labels = labels[:self.max_seq_len] | ||
else: | ||
input_ids = input_ids + [0] * (self.max_seq_len - text_len) | ||
labels = labels + [0] * (self.max_seq_len - text_len) | ||
|
||
input_ids = input_ids[:-1] | ||
labels = labels[1:] | ||
return {'input_ids': torch.tensor(input_ids), 'labels': torch.tensor(labels)} | ||
|
||
|
||
# 内存不够,可使用如下方法加载数据 | ||
# class LLMDataset(IterableDataset): | ||
# def __init__(self, data_path, tokenizer, max_seq_len): | ||
# super().__init__() | ||
# self.data_path = data_path | ||
# self.tokenizer = tokenizer | ||
# self.max_seq_len = max_seq_len | ||
|
||
# def __iter__(self): | ||
# return self.data_process() | ||
|
||
# def data_process(self): | ||
# with open(self.data_path, 'r', encoding='utf-8') as f: | ||
# for line in f: | ||
# line = json.loads(line) | ||
# text = '<s>' + line['text'] + '</s>' | ||
# input_ids = self.tokenizer.encode(text) | ||
# text_len = len(input_ids) | ||
# if text_len > self.max_seq_len: | ||
# input_ids = input_ids[:self.max_seq_len] | ||
# else: | ||
# input_ids = input_ids + [0] * (self.max_seq_len - text_len) | ||
# input_ids = np.array(input_ids) | ||
# X = np.array(input_ids[:-1]).astype(np.int64) | ||
# Y = np.array(input_ids[1:]).astype(np.int64) | ||
# yield { | ||
# 'input_ids': torch.from_numpy(X), | ||
# 'labels': torch.from_numpy(Y), | ||
# } | ||
|
||
class DPODataset(Dataset): | ||
def __init__(self, data_path, tokenizer): | ||
super().__init__() | ||
self.data_path = data_path | ||
self.tokenizer = tokenizer | ||
|
||
with open(self.data_path, 'r', encoding='utf-8') as f: | ||
self.datas = json.load(f) | ||
|
||
def __getitem__(self, index): | ||
sample = self.datas[index] | ||
prompt = sample['prompt'] | ||
chosen = sample['chosen'] | ||
rejected = sample['rejected'] | ||
messages = [ | ||
{"role": "user", "content": prompt} | ||
] | ||
text = self.tokenizer.apply_chat_template( | ||
messages, | ||
tokenize=False, | ||
add_generation_prompt=True | ||
) | ||
prompt_inputs = self.tokenizer(text=text)['input_ids'] | ||
rejected_inputs = self.tokenizer(text=rejected)['input_ids'] + [self.tokenizer.eos_token_id] | ||
chosen_inputs = self.tokenizer(text=chosen)['input_ids'] + [self.tokenizer.eos_token_id] | ||
return [prompt_inputs, chosen_inputs, rejected_inputs] | ||
|
||
def __len__(self): | ||
return len(self.datas) | ||
|
||
|
||
class DPODataCollator: | ||
def __init__(self, tokenizer, max_seq_len): | ||
self.tokenizer = tokenizer | ||
self.max_seq_len = max_seq_len | ||
def __call__(self, features): | ||
inputs_ids = [] | ||
labels = [] | ||
|
||
for feature in features: | ||
inputs_ids.append(feature[0] + feature[1]) | ||
labels.append([0]*len(feature[0]) + feature[1]) | ||
for feature in features: | ||
inputs_ids.append(feature[0] + feature[2]) | ||
labels.append([0]*len(feature[0]) + feature[2]) | ||
|
||
def process(inputs_ids, labels): | ||
inputs_ids = [input_ids[:self.max_seq_len] for input_ids in inputs_ids] | ||
labels = [label[:self.max_seq_len] for label in labels] | ||
max_len = max([len(input_ids) for input_ids in inputs_ids]) | ||
batch_input_ids = [] | ||
batch_labels = [] | ||
|
||
for input_ids, label in zip(inputs_ids, labels): | ||
if len(input_ids) <= max_len: | ||
input_ids = input_ids+[0]*(max_len-len(input_ids)) | ||
label = label+[0]*(max_len-len(label)) | ||
batch_input_ids.append(input_ids[:-1]) | ||
batch_labels.append(label[1:]) | ||
return batch_input_ids, batch_labels | ||
|
||
inputs_ids, labels = process(inputs_ids, labels) | ||
|
||
return { | ||
"input_ids": torch.tensor(inputs_ids), | ||
"labels": torch.tensor(labels) | ||
} | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import math | ||
from typing import List, Optional, Tuple, Union | ||
import torch | ||
import torch.nn.functional as F | ||
import torch.utils.checkpoint | ||
from torch import nn | ||
import os | ||
import pandas as pd | ||
|
||
from torch.utils.data import IterableDataset, Dataset | ||
import json | ||
import numpy as np | ||
from transformers import PreTrainedModel | ||
from transformers.modeling_outputs import CausalLMOutputWithPast | ||
from transformers import PretrainedConfig | ||
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator, DataCollatorForTokenClassification, AutoConfig | ||
from dataset import SFTDataset, LLMDataset | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | ||
import torch | ||
from moe_train import LLM, Config | ||
|
||
if __name__ == '__main__': | ||
AutoConfig.register("moe_model", Config) | ||
AutoModelForCausalLM.register(Config, LLM) | ||
model = AutoModelForCausalLM.from_pretrained('./saves/moe') | ||
print(f'模型参数量为:{sum(p.numel() for p in model.parameters() if p.requires_grad)}') | ||
|
||
data_collator = DefaultDataCollator() | ||
tokenizer = AutoTokenizer.from_pretrained("./tokenizer", use_fast=True) | ||
args = TrainingArguments(output_dir='./sft', | ||
num_train_epochs=5, | ||
do_train=True, | ||
per_device_train_batch_size=2, | ||
gradient_accumulation_steps=1, | ||
# max_steps=15000, | ||
logging_steps=1, | ||
report_to='tensorboard', | ||
save_total_limit=5, | ||
bf16=True, | ||
learning_rate=2e-4, | ||
lr_scheduler_type='cosine', | ||
dataloader_num_workers=1, | ||
dataloader_pin_memory=True, | ||
save_safetensors=False) | ||
dataset = SFTDataset('./sft.jsonl', tokenizer=tokenizer, max_seq_len=1024) | ||
trainer = Trainer(model=model, args=args, train_dataset=dataset, tokenizer=tokenizer, data_collator=data_collator) | ||
# 如果是初次训练resume_from_checkpoint为false,接着checkpoint继续训练,为True | ||
trainer.train(resume_from_checkpoint=False) | ||
trainer.save_model('./saves/sft') | ||
trainer.save_state() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | ||
import torch | ||
from moe_train import LLM, Config | ||
t = AutoTokenizer.from_pretrained('./saves/moe') | ||
AutoConfig.register("moe_model", Config) | ||
AutoModelForCausalLM.register(Config, LLM) | ||
model = AutoModelForCausalLM.from_pretrained('./saves/moe') | ||
|
||
input_data = [t.bos_token_id] + t.encode('1+1等于') | ||
print(input_data) | ||
|
||
for token in model.generate({"input_ids":torch.tensor(input_data).unsqueeze(0), "labels":None}, t.eos_token_id, 20, stream=False,temperature=0.0, top_k=1): | ||
print(t.decode(token[0])) |
Oops, something went wrong.