-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
new file: train_llm_from_scratch/README.md
new file: train_llm_from_scratch/dataset.py new file: train_llm_from_scratch/screenshot-20241207-093824.png new file: train_llm_from_scratch/sft_train.py new file: train_llm_from_scratch/test_llm.ipynb new file: train_llm_from_scratch/tokenizer/merges.txt new file: train_llm_from_scratch/tokenizer/tokenizer.json new file: train_llm_from_scratch/tokenizer/tokenizer_config.json new file: train_llm_from_scratch/tokenizer/vocab.json new file: train_llm_from_scratch/train.ipynb new file: train_llm_from_scratch/train.py new file: train_llm_from_scratch/train_tokenizer.ipynb new file: train_llm_from_scratch/trainer_state_pretrain.json new file: train_llm_from_scratch/trainer_state_sft.json
- Loading branch information
Showing
14 changed files
with
44,310 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# 使用方法 | ||
|
||
## 下载数据 | ||
|
||
https://github.com/jingyaogong/minimind | ||
![image](.\screenshot-20241207-093824.png) | ||
|
||
## 开始训练 | ||
### 直接运行 | ||
预训练:\ | ||
python train.py\ | ||
SFT:\ | ||
python sft_train.py | ||
### torchrun | ||
预训练:\ | ||
torchrun --nproc_per_node=2 train.py | ||
SFT:\ | ||
torchrun --nproc_per_node=2 sft_train.py | ||
### deepspeed | ||
预训练:\ | ||
deepspeed --include 'localhost:0,1' train.py\ | ||
SFT:\ | ||
deepspeed --include 'localhost:0,1' sft_train.py | ||
|
||
## 测试 | ||
test_llm.ipynb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import math | ||
from typing import List, Optional, Tuple, Union | ||
import torch | ||
import torch.nn.functional as F | ||
import torch.utils.checkpoint | ||
from torch import nn | ||
import os | ||
import pandas as pd | ||
|
||
from torch.utils.data import IterableDataset, Dataset | ||
import json | ||
import numpy as np | ||
from transformers import PreTrainedModel | ||
from transformers.modeling_outputs import CausalLMOutputWithPast | ||
from transformers import PretrainedConfig | ||
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator, DataCollatorForTokenClassification, AutoConfig | ||
|
||
|
||
class LLMDataset(Dataset): | ||
def __init__(self, data_path, tokenizer, max_seq_len): | ||
super().__init__() | ||
self.data_path = data_path | ||
self.tokenizer = tokenizer | ||
self.max_seq_len = max_seq_len | ||
with open(self.data_path, 'r', encoding='utf-8') as f: | ||
self.data = f.readlines() | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, index: int): | ||
|
||
line = self.data[index] | ||
line = json.loads(line) | ||
text = '<s>' + line['text'] + '</s>' | ||
input_ids = self.tokenizer.encode(text) | ||
text_len = len(input_ids) | ||
if text_len > self.max_seq_len: | ||
input_ids = input_ids[:self.max_seq_len] | ||
else: | ||
input_ids = input_ids + [0] * (self.max_seq_len - text_len) | ||
input_ids = np.array(input_ids) | ||
X = np.array(input_ids[:-1]).astype(np.int64) | ||
Y = np.array(input_ids[1:]).astype(np.int64) | ||
return { | ||
'input_ids': torch.from_numpy(X), | ||
'labels': torch.from_numpy(Y), | ||
} | ||
|
||
class SFTDataset(Dataset): | ||
def __init__(self, data_path, tokenizer, max_seq_len): | ||
super().__init__() | ||
self.data_path = data_path | ||
self.tokenizer = tokenizer | ||
self.max_seq_len = max_seq_len | ||
|
||
with open(self.data_path, 'r', encoding='utf-8') as f: | ||
self.data = f.readlines() | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, index): | ||
line = self.data[index] | ||
line = json.loads(line) | ||
instruction_text = line['instruction'] | ||
input_text = line['input'] | ||
output_text = line['output'] | ||
history = line['history'] | ||
query = instruction_text + input_text | ||
answer = output_text + self.tokenizer.eos_token | ||
messages = [] | ||
if history: | ||
for i in history: | ||
messages.append({'role': 'user', 'content': i[0]}) | ||
messages.append({'role': 'assistant', 'content': i[1]}) | ||
|
||
messages.append({'role': 'user', 'content': query}) | ||
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) | ||
prompt_input_ids = self.tokenizer.encode(prompt) | ||
answer_input_ids = self.tokenizer.encode(answer) | ||
input_ids = prompt_input_ids + answer_input_ids | ||
labels = [0] * len(prompt_input_ids) + answer_input_ids | ||
text_len = len(input_ids) | ||
if text_len > self.max_seq_len: | ||
input_ids = input_ids[:self.max_seq_len] | ||
labels = labels[:self.max_seq_len] | ||
else: | ||
input_ids = input_ids + [0] * (self.max_seq_len - text_len) | ||
labels = labels + [0] * (self.max_seq_len - text_len) | ||
|
||
input_ids = input_ids[:-1] | ||
labels = labels[1:] | ||
return {'input_ids': torch.tensor(input_ids), 'labels': torch.tensor(labels)} | ||
|
||
|
||
# 内存不够,可使用如下方法加载数据 | ||
# class LLMDataset(IterableDataset): | ||
# def __init__(self, data_path, tokenizer, max_seq_len): | ||
# super().__init__() | ||
# self.data_path = data_path | ||
# self.tokenizer = tokenizer | ||
# self.max_seq_len = max_seq_len | ||
|
||
# def __iter__(self): | ||
# return self.data_process() | ||
|
||
# def data_process(self): | ||
# with open(self.data_path, 'r', encoding='utf-8') as f: | ||
# for line in f: | ||
# line = json.loads(line) | ||
# text = '<s>' + line['text'] + '</s>' | ||
# input_ids = self.tokenizer.encode(text) | ||
# text_len = len(input_ids) | ||
# if text_len > self.max_seq_len: | ||
# input_ids = input_ids[:self.max_seq_len] | ||
# else: | ||
# input_ids = input_ids + [0] * (self.max_seq_len - text_len) | ||
# input_ids = np.array(input_ids) | ||
# X = np.array(input_ids[:-1]).astype(np.int64) | ||
# Y = np.array(input_ids[1:]).astype(np.int64) | ||
# yield { | ||
# 'input_ids': torch.from_numpy(X), | ||
# 'labels': torch.from_numpy(Y), | ||
# } |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import math | ||
from typing import List, Optional, Tuple, Union | ||
import torch | ||
import torch.nn.functional as F | ||
import torch.utils.checkpoint | ||
from torch import nn | ||
import os | ||
import pandas as pd | ||
|
||
from torch.utils.data import IterableDataset, Dataset | ||
import json | ||
import numpy as np | ||
from transformers import PreTrainedModel | ||
from transformers.modeling_outputs import CausalLMOutputWithPast | ||
from transformers import PretrainedConfig | ||
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator, DataCollatorForTokenClassification, AutoConfig | ||
from dataset import SFTDataset, LLMDataset | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | ||
import torch | ||
from train import LLM, Config | ||
|
||
if __name__ == '__main__': | ||
AutoConfig.register("small_model", Config) | ||
AutoModelForCausalLM.register(Config, LLM) | ||
model = AutoModelForCausalLM.from_pretrained('./saves/model') | ||
print(f'模型参数量为:{sum(p.numel() for p in model.parameters() if p.requires_grad)}') | ||
|
||
data_collator = DefaultDataCollator() | ||
tokenizer = AutoTokenizer.from_pretrained("./tokenizer", use_fast=True) | ||
args = TrainingArguments(output_dir='./sft', | ||
num_train_epochs=5, | ||
do_train=True, | ||
per_device_train_batch_size=64, | ||
gradient_accumulation_steps=8, | ||
# max_steps=15000, | ||
logging_steps=100, | ||
report_to='tensorboard', | ||
save_total_limit=5, | ||
bf16=True, | ||
learning_rate=2e-4, | ||
lr_scheduler_type='cosine', | ||
dataloader_num_workers=1, | ||
dataloader_pin_memory=True, | ||
save_safetensors=False) | ||
dataset = SFTDataset('./sft_data_zh.jsonl', tokenizer=tokenizer, max_seq_len=1024) | ||
trainer = Trainer(model=model, args=args, train_dataset=dataset, tokenizer=tokenizer, data_collator=data_collator) | ||
# 如果是初次训练resume_from_checkpoint为false,接着checkpoint继续训练,为True | ||
trainer.train(resume_from_checkpoint=False) | ||
trainer.save_model('./saves/sft') | ||
trainer.save_state() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig\n", | ||
"import torch\n", | ||
"from train import LLM, Config" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"t = AutoTokenizer.from_pretrained('/home/user/wyf/train_model_from_scratch/saves/pretrain')\n", | ||
"AutoConfig.register(\"small_model\", Config)\n", | ||
"AutoModelForCausalLM.register(Config, LLM)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_model_from_scratch/saves/pretrain')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 187, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[2, 20, 14, 20, 6239]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"input_data = [t.bos_token_id] + t.encode('1+1等于')\n", | ||
"print(input_data)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 188, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"1.5,即1.5,即1.5,\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for token in model.generate({\"input_ids\":torch.tensor(input_data).unsqueeze(0), \"labels\":None}, t.eos_token_id, 20, stream=False,temperature=0.0, top_k=8):\n", | ||
" print(t.decode(token[0]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 189, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_model_from_scratch/saves/sft')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 190, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[2, 321, 276, 202, 20, 14, 20, 6239, 3, 202, 2, 1079, 539, 502, 202]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"input_data = t.apply_chat_template([{'role':'user', 'content':'1+1等于'}])\n", | ||
"print(input_data)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 191, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"1+1等于2。\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for token in model.generate({\"input_ids\":torch.tensor(input_data).unsqueeze(0), \"labels\":None}, t.eos_token_id, 200, stream=False,temperature=0.0, top_k=8):\n", | ||
" print(t.decode(token[0]))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "wyf", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.0" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.