-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
new file: train_multimodal_from_scratch/README.md
new file: train_multimodal_from_scratch/gradio_vlm.py new file: train_multimodal_from_scratch/sft_train.py new file: train_multimodal_from_scratch/test.ipynb new file: train_multimodal_from_scratch/test.py new file: train_multimodal_from_scratch/train.py new file: train_multimodal_from_scratch/trainer.ipynb
- Loading branch information
Showing
7 changed files
with
1,857 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# 使用方法 | ||
|
||
## 下载模型及数据 | ||
### 下载qwen2.5-0.5b和siglip | ||
qwen2.5-0.5b:\ | ||
https://hf-mirror.com/Qwen/Qwen2.5-0.5B-Instruct\ | ||
siglip:\ | ||
此处使用的是如下版本的siglip(模型小,但是效果可能没那么好,训练更快,显存要求更低):\ | ||
https://hf-mirror.com/google/siglip-base-patch16-384 | ||
|
||
也可以使用效果更好的版本,但是模型会更大(注意,使用这个版本可能需要修改image_pad_num这个参数,这个版本的模型输出的图片特征为(b,729,dim),在图片压缩的时候是reshape成(b,729/9,dim*9)):\ | ||
https://hf-mirror.com/google/siglip-so400m-patch14-384 | ||
|
||
### 下载数据集 | ||
1、预训练数据:\ | ||
图片数据:\ | ||
https://hf-mirror.com/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K\ | ||
中文文本数据:\ | ||
https://hf-mirror.com/datasets/LinkSoul/Chinese-LLaVA-Vision-Instructions | ||
2、SFT数据:\ | ||
图片数据:\ | ||
https://hf-mirror.com/datasets/jingyaogong/minimind-v_dataset\ | ||
中文文本数据:\ | ||
https://hf-mirror.com/datasets/LinkSoul/Chinese-LLaVA-Vision-Instructions | ||
|
||
## 开始训练 | ||
### 直接运行 | ||
预训练:\ | ||
python train.py\ | ||
SFT:\ | ||
python sft_train.py | ||
### torchrun | ||
预训练:\ | ||
torchrun --nproc_per_node=2 train.py | ||
SFT:\ | ||
torchrun --nproc_per_node=2 sft_train.py | ||
### deepspeed | ||
预训练:\ | ||
deepspeed --include 'localhost:0,1' train.py\ | ||
SFT:\ | ||
deepspeed --include 'localhost:0,1' sft_train.py | ||
|
||
## 测试 | ||
python test.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import gradio as gr | ||
import json | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor, AutoConfig | ||
from PIL import Image | ||
from train import VLMConfig, VLM | ||
import torch | ||
from torch.nn import functional as F | ||
device = "cuda:1" | ||
processor = AutoProcessor.from_pretrained("/home/user/wyf/siglip-base-patch16-224") | ||
tokenizer = AutoTokenizer.from_pretrained('/home/user/Downloads/Qwen2.5-0.5B-Instruct') | ||
AutoConfig.register("vlm_model", VLMConfig) | ||
AutoModelForCausalLM.register(VLMConfig, VLM) | ||
|
||
pretrain_model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_multimodal_from_scratch/save/pretrain') | ||
pretrain_model.to(device) | ||
|
||
sft_model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_multimodal_from_scratch/save/sft') | ||
sft_model.to(device) | ||
|
||
pretrain_model.eval() | ||
sft_model.eval() | ||
def generate(mode, image_input, text_input, max_new_tokens = 100, temperature = 0.0, top_k = None): | ||
q_text = tokenizer.apply_chat_template([{"role":"system", "content":'You are a helpful assistant.'}, {"role":"user", "content":f'{text_input}\n<image>'}], \ | ||
tokenize=False, \ | ||
add_generation_prompt=True).replace('<image>', '<|image_pad|>'*49) | ||
input_ids = tokenizer(q_text, return_tensors='pt')['input_ids'] | ||
input_ids = input_ids.to(device) | ||
# image = Image.open(image_input).convert("RGB") | ||
pixel_values = processor(text=None, images=image_input).pixel_values | ||
pixel_values = pixel_values.to(device) | ||
eos = tokenizer.eos_token_id | ||
s = input_ids.shape[1] | ||
while input_ids.shape[1] < s + max_new_tokens - 1: | ||
if mode == 'pretrain': | ||
model = pretrain_model | ||
else: | ||
model = sft_model | ||
inference_res = model(input_ids, None, pixel_values) | ||
logits = inference_res.logits | ||
logits = logits[:, -1, :] | ||
|
||
for token in set(input_ids.tolist()[0]): | ||
logits[:, token] /= 1.0 | ||
|
||
if temperature == 0.0: | ||
_, idx_next = torch.topk(logits, k=1, dim=-1) | ||
else: | ||
logits = logits / temperature | ||
if top_k is not None: | ||
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | ||
logits[logits < v[:, [-1]]] = -float('Inf') | ||
|
||
probs = F.softmax(logits, dim=-1) | ||
idx_next = torch.multinomial(probs, num_samples=1, generator=None) | ||
|
||
if idx_next == eos: | ||
break | ||
|
||
input_ids = torch.cat((input_ids, idx_next), dim=1) | ||
return tokenizer.decode(input_ids[:, s:][0]) | ||
|
||
with gr.Blocks() as demo: | ||
with gr.Row(): | ||
# 上传图片 | ||
with gr.Column(scale=1): | ||
image_input = gr.Image(type="pil", label="选择图片") | ||
with gr.Column(scale=1): | ||
mode = gr.Radio(["pretrain", "sft"], label="选择模型") | ||
text_input = gr.Textbox(label="输入文本") | ||
text_output = gr.Textbox(label="输出文本") | ||
generate_button = gr.Button("生成") | ||
generate_button.click(generate, inputs=[mode, image_input, text_input], outputs=text_output) | ||
|
||
|
||
if __name__ == "__main__": | ||
demo.launch(share=False, server_name="0.0.0.0", server_port=7891) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer, AutoModelForCausalLM | ||
from PIL import Image | ||
import requests | ||
from transformers import AutoProcessor, AutoModel | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from transformers.modeling_outputs import CausalLMOutputWithPast | ||
import zipfile | ||
from PIL import Image | ||
import io | ||
import os | ||
import json | ||
from torch.utils.data import Dataset | ||
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding | ||
from typing import List, Dict, Any | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor, AutoConfig | ||
from PIL import Image | ||
from train import VLMConfig, VLM | ||
|
||
|
||
def find_assistant_tokens(tokenizer, target): | ||
result = [] | ||
start_index =0 | ||
end_index = 0 | ||
while start_index <= len(target)-1: | ||
if target[start_index]!=tokenizer('assistant')['input_ids'][0]: | ||
start_index+=1 | ||
end_index+=1 | ||
else: | ||
end_index+=1 | ||
if target[end_index]==tokenizer('<|im_end|>')['input_ids'][0]: | ||
result.append((start_index+1,end_index+1)) | ||
start_index=end_index+1 | ||
return result | ||
|
||
class SFTDataset(Dataset): | ||
def __init__(self, images_path, data_path, tokenizer, processor, config): | ||
super().__init__() | ||
self.data_path = data_path | ||
self.images_path = images_path | ||
self.tokenizer = tokenizer | ||
self.processor = processor | ||
self.config = config | ||
with open(self.data_path, 'r', encoding='utf-8') as f: | ||
self.datas = json.load(f) | ||
|
||
|
||
def __len__(self): | ||
return len(self.datas) | ||
|
||
def __getitem__(self, index): | ||
sample = self.datas[index] | ||
try: | ||
image_name = 'COCO_train2014_' + str(sample['image']) | ||
conversations = sample['conversations'] | ||
messages = [{"role":"system", "content":'You are a helpful assistant.'}] | ||
for conversation in conversations: | ||
if conversation['from'] == 'human': | ||
messages.append({"role":"user", "content":conversation['value']}) | ||
else: | ||
messages.append({"role":"assistant", "content":conversation['value']}) | ||
text = tokenizer.apply_chat_template(messages, \ | ||
tokenize=False, \ | ||
).replace('<image>', '<|image_pad|>'*self.config.image_pad_num) | ||
# print(text) | ||
input_ids = tokenizer(text)['input_ids'] | ||
indexs = find_assistant_tokens(tokenizer, input_ids) | ||
labels = len(input_ids) * [tokenizer.pad_token_id] | ||
for index in indexs: | ||
labels[index[0]:index[1]] = input_ids[index[0]:index[1]] | ||
input_ids = input_ids[:-1] | ||
labels = labels[1:] | ||
|
||
|
||
image = Image.open(os.path.join(self.images_path, image_name)).convert('RGB') | ||
|
||
pixel_values = self.processor(text=None, images=image)['pixel_values'] | ||
except: | ||
|
||
default_image = Image.new('RGB', (224, 224), color='white') | ||
pixel_values = self.processor(text=None, images=default_image)['pixel_values'] | ||
q_text = self.tokenizer.apply_chat_template([{"role":"system", "content":'You are a helpful assistant.'}, {"role":"user", "content":"图片内容是什么\n<image>"}], \ | ||
tokenize=False, \ | ||
add_generation_prompt=True).replace('<image>', '<|image_pad|>'*self.config.image_pad_num) | ||
a_text = '图片内容为空' + self.tokenizer.eos_token | ||
q_input_ids = self.tokenizer(q_text)['input_ids'] | ||
a_input_ids = self.tokenizer(a_text)['input_ids'] | ||
input_ids = q_input_ids + a_input_ids | ||
labels = [tokenizer.pad_token_id] * len(q_input_ids) + a_input_ids | ||
input_ids = input_ids[:-1] | ||
labels = labels[1:] | ||
|
||
return { | ||
'input_ids': input_ids, | ||
'labels': labels, | ||
'pixel_values': pixel_values | ||
} | ||
|
||
class MyDataCollator: | ||
def __init__(self, tokenizer): | ||
self.tokenizer = tokenizer | ||
|
||
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: | ||
max_len = max(len(feature['input_ids']) for feature in features) | ||
input_ids = [] | ||
labels = [] | ||
pixel_values = [] | ||
for feature in features: | ||
input_ids.append(feature['input_ids'] + [self.tokenizer.pad_token_id] * (max_len - len(feature['input_ids']))) | ||
labels.append(feature['labels'] + [self.tokenizer.pad_token_id] * (max_len - len(feature['labels']))) | ||
pixel_values.append(feature['pixel_values']) | ||
|
||
return {'input_ids': torch.tensor(input_ids, dtype=torch.long), | ||
'labels': torch.tensor(labels, dtype=torch.long), | ||
'pixel_values': torch.cat(pixel_values, dim=0)} | ||
|
||
|
||
if __name__ == '__main__': | ||
config = VLMConfig() | ||
processor = AutoProcessor.from_pretrained("/home/user/wyf/siglip-base-patch16-224") | ||
tokenizer = AutoTokenizer.from_pretrained('/home/user/Downloads/Qwen2.5-0.5B-Instruct') | ||
AutoConfig.register("vlm_model", VLMConfig) | ||
AutoModelForCausalLM.register(VLMConfig, VLM) | ||
model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_multimodal_from_scratch/save/pretrain') | ||
|
||
for name, param in model.named_parameters(): | ||
if 'linear' in name or 'vision_model': | ||
param.requires_grad = False | ||
if 'llm_model' in name: | ||
param.requires_grad = True | ||
print(f'模型参数量为:{sum(p.numel() for p in model.parameters())}') | ||
print(f'模型可训练参数量为:{sum(p.numel() for p in model.parameters() if p.requires_grad)}') | ||
images_path = './sft_images' | ||
data_path = './dataset/llava_instruct_230k.json' | ||
output_dir = 'save/sft' | ||
args = TrainingArguments( | ||
output_dir=output_dir, | ||
do_train=True, | ||
per_device_train_batch_size=2, | ||
learning_rate=1e-4, | ||
num_train_epochs=5, | ||
save_steps=500, | ||
save_total_limit=2, | ||
fp16=True, | ||
gradient_accumulation_steps=8, | ||
logging_steps=100, | ||
report_to='tensorboard', | ||
dataloader_pin_memory=True, | ||
dataloader_num_workers=1 | ||
) | ||
trainer = Trainer( | ||
model=model, | ||
args=args, | ||
train_dataset=SFTDataset(images_path, data_path, tokenizer, processor, config), | ||
data_collator=MyDataCollator(tokenizer) | ||
) | ||
|
||
trainer.train(resume_from_checkpoint=True) | ||
trainer.save_model('save/sft') | ||
trainer.save_state() |
Oops, something went wrong.