Skip to content

Commit

Permalink
new file: train_multimodal_from_scratch/README.md
Browse files Browse the repository at this point in the history
	new file:   train_multimodal_from_scratch/gradio_vlm.py
	new file:   train_multimodal_from_scratch/sft_train.py
	new file:   train_multimodal_from_scratch/test.ipynb
	new file:   train_multimodal_from_scratch/test.py
	new file:   train_multimodal_from_scratch/train.py
	new file:   train_multimodal_from_scratch/trainer.ipynb
  • Loading branch information
wyf3 committed Nov 30, 2024
1 parent 9c052de commit dc59907
Show file tree
Hide file tree
Showing 7 changed files with 1,857 additions and 0 deletions.
44 changes: 44 additions & 0 deletions train_multimodal_from_scratch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# 使用方法

## 下载模型及数据
### 下载qwen2.5-0.5b和siglip
qwen2.5-0.5b:\
https://hf-mirror.com/Qwen/Qwen2.5-0.5B-Instruct\
siglip:\
此处使用的是如下版本的siglip(模型小,但是效果可能没那么好,训练更快,显存要求更低):\
https://hf-mirror.com/google/siglip-base-patch16-384

也可以使用效果更好的版本,但是模型会更大(注意,使用这个版本可能需要修改image_pad_num这个参数,这个版本的模型输出的图片特征为(b,729,dim),在图片压缩的时候是reshape成(b,729/9,dim*9)):\
https://hf-mirror.com/google/siglip-so400m-patch14-384

### 下载数据集
1、预训练数据:\
图片数据:\
https://hf-mirror.com/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K\
中文文本数据:\
https://hf-mirror.com/datasets/LinkSoul/Chinese-LLaVA-Vision-Instructions
2、SFT数据:\
图片数据:\
https://hf-mirror.com/datasets/jingyaogong/minimind-v_dataset\
中文文本数据:\
https://hf-mirror.com/datasets/LinkSoul/Chinese-LLaVA-Vision-Instructions

## 开始训练
### 直接运行
预训练:\
python train.py\
SFT:\
python sft_train.py
### torchrun
预训练:\
torchrun --nproc_per_node=2 train.py
SFT:\
torchrun --nproc_per_node=2 sft_train.py
### deepspeed
预训练:\
deepspeed --include 'localhost:0,1' train.py\
SFT:\
deepspeed --include 'localhost:0,1' sft_train.py

## 测试
python test.py
77 changes: 77 additions & 0 deletions train_multimodal_from_scratch/gradio_vlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import gradio as gr
import json
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor, AutoConfig
from PIL import Image
from train import VLMConfig, VLM
import torch
from torch.nn import functional as F
device = "cuda:1"
processor = AutoProcessor.from_pretrained("/home/user/wyf/siglip-base-patch16-224")
tokenizer = AutoTokenizer.from_pretrained('/home/user/Downloads/Qwen2.5-0.5B-Instruct')
AutoConfig.register("vlm_model", VLMConfig)
AutoModelForCausalLM.register(VLMConfig, VLM)

pretrain_model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_multimodal_from_scratch/save/pretrain')
pretrain_model.to(device)

sft_model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_multimodal_from_scratch/save/sft')
sft_model.to(device)

pretrain_model.eval()
sft_model.eval()
def generate(mode, image_input, text_input, max_new_tokens = 100, temperature = 0.0, top_k = None):
q_text = tokenizer.apply_chat_template([{"role":"system", "content":'You are a helpful assistant.'}, {"role":"user", "content":f'{text_input}\n<image>'}], \
tokenize=False, \
add_generation_prompt=True).replace('<image>', '<|image_pad|>'*49)
input_ids = tokenizer(q_text, return_tensors='pt')['input_ids']
input_ids = input_ids.to(device)
# image = Image.open(image_input).convert("RGB")
pixel_values = processor(text=None, images=image_input).pixel_values
pixel_values = pixel_values.to(device)
eos = tokenizer.eos_token_id
s = input_ids.shape[1]
while input_ids.shape[1] < s + max_new_tokens - 1:
if mode == 'pretrain':
model = pretrain_model
else:
model = sft_model
inference_res = model(input_ids, None, pixel_values)
logits = inference_res.logits
logits = logits[:, -1, :]

for token in set(input_ids.tolist()[0]):
logits[:, token] /= 1.0

if temperature == 0.0:
_, idx_next = torch.topk(logits, k=1, dim=-1)
else:
logits = logits / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')

probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1, generator=None)

if idx_next == eos:
break

input_ids = torch.cat((input_ids, idx_next), dim=1)
return tokenizer.decode(input_ids[:, s:][0])

with gr.Blocks() as demo:
with gr.Row():
# 上传图片
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="选择图片")
with gr.Column(scale=1):
mode = gr.Radio(["pretrain", "sft"], label="选择模型")
text_input = gr.Textbox(label="输入文本")
text_output = gr.Textbox(label="输出文本")
generate_button = gr.Button("生成")
generate_button.click(generate, inputs=[mode, image_input, text_input], outputs=text_output)


if __name__ == "__main__":
demo.launch(share=False, server_name="0.0.0.0", server_port=7891)

161 changes: 161 additions & 0 deletions train_multimodal_from_scratch/sft_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer, AutoModelForCausalLM
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_outputs import CausalLMOutputWithPast
import zipfile
from PIL import Image
import io
import os
import json
from torch.utils.data import Dataset
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding
from typing import List, Dict, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor, AutoConfig
from PIL import Image
from train import VLMConfig, VLM


def find_assistant_tokens(tokenizer, target):
result = []
start_index =0
end_index = 0
while start_index <= len(target)-1:
if target[start_index]!=tokenizer('assistant')['input_ids'][0]:
start_index+=1
end_index+=1
else:
end_index+=1
if target[end_index]==tokenizer('<|im_end|>')['input_ids'][0]:
result.append((start_index+1,end_index+1))
start_index=end_index+1
return result

class SFTDataset(Dataset):
def __init__(self, images_path, data_path, tokenizer, processor, config):
super().__init__()
self.data_path = data_path
self.images_path = images_path
self.tokenizer = tokenizer
self.processor = processor
self.config = config
with open(self.data_path, 'r', encoding='utf-8') as f:
self.datas = json.load(f)


def __len__(self):
return len(self.datas)

def __getitem__(self, index):
sample = self.datas[index]
try:
image_name = 'COCO_train2014_' + str(sample['image'])
conversations = sample['conversations']
messages = [{"role":"system", "content":'You are a helpful assistant.'}]
for conversation in conversations:
if conversation['from'] == 'human':
messages.append({"role":"user", "content":conversation['value']})
else:
messages.append({"role":"assistant", "content":conversation['value']})
text = tokenizer.apply_chat_template(messages, \
tokenize=False, \
).replace('<image>', '<|image_pad|>'*self.config.image_pad_num)
# print(text)
input_ids = tokenizer(text)['input_ids']
indexs = find_assistant_tokens(tokenizer, input_ids)
labels = len(input_ids) * [tokenizer.pad_token_id]
for index in indexs:
labels[index[0]:index[1]] = input_ids[index[0]:index[1]]
input_ids = input_ids[:-1]
labels = labels[1:]


image = Image.open(os.path.join(self.images_path, image_name)).convert('RGB')

pixel_values = self.processor(text=None, images=image)['pixel_values']
except:

default_image = Image.new('RGB', (224, 224), color='white')
pixel_values = self.processor(text=None, images=default_image)['pixel_values']
q_text = self.tokenizer.apply_chat_template([{"role":"system", "content":'You are a helpful assistant.'}, {"role":"user", "content":"图片内容是什么\n<image>"}], \
tokenize=False, \
add_generation_prompt=True).replace('<image>', '<|image_pad|>'*self.config.image_pad_num)
a_text = '图片内容为空' + self.tokenizer.eos_token
q_input_ids = self.tokenizer(q_text)['input_ids']
a_input_ids = self.tokenizer(a_text)['input_ids']
input_ids = q_input_ids + a_input_ids
labels = [tokenizer.pad_token_id] * len(q_input_ids) + a_input_ids
input_ids = input_ids[:-1]
labels = labels[1:]

return {
'input_ids': input_ids,
'labels': labels,
'pixel_values': pixel_values
}

class MyDataCollator:
def __init__(self, tokenizer):
self.tokenizer = tokenizer

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
max_len = max(len(feature['input_ids']) for feature in features)
input_ids = []
labels = []
pixel_values = []
for feature in features:
input_ids.append(feature['input_ids'] + [self.tokenizer.pad_token_id] * (max_len - len(feature['input_ids'])))
labels.append(feature['labels'] + [self.tokenizer.pad_token_id] * (max_len - len(feature['labels'])))
pixel_values.append(feature['pixel_values'])

return {'input_ids': torch.tensor(input_ids, dtype=torch.long),
'labels': torch.tensor(labels, dtype=torch.long),
'pixel_values': torch.cat(pixel_values, dim=0)}


if __name__ == '__main__':
config = VLMConfig()
processor = AutoProcessor.from_pretrained("/home/user/wyf/siglip-base-patch16-224")
tokenizer = AutoTokenizer.from_pretrained('/home/user/Downloads/Qwen2.5-0.5B-Instruct')
AutoConfig.register("vlm_model", VLMConfig)
AutoModelForCausalLM.register(VLMConfig, VLM)
model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_multimodal_from_scratch/save/pretrain')

for name, param in model.named_parameters():
if 'linear' in name or 'vision_model':
param.requires_grad = False
if 'llm_model' in name:
param.requires_grad = True
print(f'模型参数量为:{sum(p.numel() for p in model.parameters())}')
print(f'模型可训练参数量为:{sum(p.numel() for p in model.parameters() if p.requires_grad)}')
images_path = './sft_images'
data_path = './dataset/llava_instruct_230k.json'
output_dir = 'save/sft'
args = TrainingArguments(
output_dir=output_dir,
do_train=True,
per_device_train_batch_size=2,
learning_rate=1e-4,
num_train_epochs=5,
save_steps=500,
save_total_limit=2,
fp16=True,
gradient_accumulation_steps=8,
logging_steps=100,
report_to='tensorboard',
dataloader_pin_memory=True,
dataloader_num_workers=1
)
trainer = Trainer(
model=model,
args=args,
train_dataset=SFTDataset(images_path, data_path, tokenizer, processor, config),
data_collator=MyDataCollator(tokenizer)
)

trainer.train(resume_from_checkpoint=True)
trainer.save_model('save/sft')
trainer.save_state()
Loading

0 comments on commit dc59907

Please sign in to comment.