Skip to content

Commit

Permalink
new file: train_siglip_from_scratch/clip.png
Browse files Browse the repository at this point in the history
	new file:   train_siglip_from_scratch/data_process.ipynb
	new file:   train_siglip_from_scratch/dataset.py
	new file:   train_siglip_from_scratch/model.py
	new file:   train_siglip_from_scratch/siglip.png
	new file:   train_siglip_from_scratch/test.ipynb
	new file:   train_siglip_from_scratch/train.py
	new file:   "train_siglip_from_scratch/\344\274\252\344\273\243\347\240\201.png"
  • Loading branch information
wyf3 committed Dec 20, 2024
1 parent 06c1d79 commit 8731f00
Show file tree
Hide file tree
Showing 8 changed files with 1,788 additions and 0 deletions.
Binary file added train_siglip_from_scratch/clip.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,488 changes: 1,488 additions & 0 deletions train_siglip_from_scratch/data_process.ipynb

Large diffs are not rendered by default.

88 changes: 88 additions & 0 deletions train_siglip_from_scratch/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from torch.utils.data import Dataset
import json
from PIL import Image
import os
import torch
import pandas as pd
from io import BytesIO
import base64
from transformers import AutoTokenizer, AutoProcessor
import random

class SiglipDataset(Dataset):
def __init__(self, text_data_path,
image_data_path,
tokenizer,
processor,
max_seq_length=64,
):
super().__init__()
self.text_data_path = text_data_path
self.image_data_path = image_data_path
self.tokenizer = tokenizer
self.processor = processor
self.max_seq_length = max_seq_length
with open(self.text_data_path, 'r', encoding='utf-8') as f:
self.datas = []
lines = f.readlines()
for line in lines:
line = json.loads(line)
for image_id in line['image_ids']:
self.datas.append({'image_id': image_id, 'text': line['text']})

random.shuffle(self.datas)

self.images = pd.read_csv(self.image_data_path, sep='\t', header=None)
def __getitem__(self, index):

sample = self.datas[index]

image_id = sample['image_id']
text = sample['text']
tok = self.tokenizer(text, max_length=self.max_seq_length, padding='max_length', truncation=True)
input_ids = tok['input_ids']
attention_mask = tok['attention_mask']
image_base64 = self.images[self.images[0]==image_id][1].values[0]
image_bytes = base64.b64decode(image_base64)


image = Image.open(BytesIO(image_bytes)).convert("RGB")
pixel_values = self.processor(images=image, return_tensors='pt')['pixel_values']

return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'pixel_values': pixel_values
}

def __len__(self):
return len(self.datas)

class MyDataCollator:
def __init__(self, tokenizer):
self.tokenizer = tokenizer

def __call__(self, features):
input_ids = [f['input_ids'] for f in features]
attention_mask = [f['attention_mask'] for f in features]
pixel_values = [f['pixel_values'] for f in features]
return {
'input_ids': torch.tensor(input_ids),
'attention_mask': torch.tensor(attention_mask),
'pixel_values': torch.cat(pixel_values, dim=0)
}


if __name__ == '__main__':

tokenizer = AutoTokenizer.from_pretrained('/home/user/wyf/chinese-roberta-wwm-ext')
processor = AutoProcessor.from_pretrained('/home/user/wyf/train_siglip_from_scratch/vit-base-patch16-224')

dataset = SiglipDataset(text_data_path='/home/user/wyf/train_siglip_from_scratch/MUGE/all_texts.jsonl',
image_data_path='/home/user/wyf/train_siglip_from_scratch/MUGE/all_imgs.tsv',
tokenizer=tokenizer,
processor=processor,
max_seq_length=64)

print(len(dataset))
print(dataset[2])
74 changes: 74 additions & 0 deletions train_siglip_from_scratch/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoTokenizer, AutoProcessor
from transformers import ViTImageProcessor, ViTForImageClassification

import torch.nn as nn
from transformers.utils import ModelOutput
import torch
import torch.nn.functional as F
from dataclasses import dataclass

@dataclass
class SiglipOutput(ModelOutput):
loss: torch.FloatTensor = None
logits_per_text: torch.FloatTensor = None
logits_per_image: torch.FloatTensor = None
text_embeds: torch.FloatTensor = None
image_embeds: torch.FloatTensor = None




class SiglipConfig(PretrainedConfig):
model_type = "siglip"
def __init__(
self,
vision_model_name_or_path: str = "vit-base-patch16-224",
text_model_name_or_path: str = "bert-base-chinese",
**kwargs):
super().__init__(**kwargs)
self.vision_model_name_or_path = vision_model_name_or_path
self.text_model_name_or_path = text_model_name_or_path



class SiglipModel(PreTrainedModel):
config_class = SiglipConfig
def __init__(self, config: SiglipConfig):
super().__init__(config)
self.vision_model = AutoModel.from_pretrained(config.vision_model_name_or_path)
self.process = AutoProcessor.from_pretrained(config.vision_model_name_or_path)
self.text_model = AutoModel.from_pretrained(config.text_model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_name_or_path)
self.t = nn.Parameter(torch.randn(1))
self.b = nn.Parameter(torch.randn(1))


def forward(self, input_ids, attention_mask, pixel_values):


text_outputs = self.text_model(input_ids, attention_mask)
vision_outputs = self.vision_model(pixel_values)


vision_features = vision_outputs[1] # pooler_output
text_features = text_outputs[1] # pooler_output

vision_features = vision_features / vision_features.norm(p=2, dim=-1, keepdim=True) # l2标准化
text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True) # l2标准化

logits_per_text = torch.matmul(text_features, vision_features.t()) * self.t.exp() + self.b
logits_per_image = logits_per_text.t()

b = logits_per_text.shape[0]
eye = torch.eye(b, device=logits_per_text.device) # 生成单位矩阵
labels = 2*eye - torch.ones_like(logits_per_text, device=logits_per_text.device) # 对角线全为1,非对角线为-1,即成对的图文标签为1,非成对的为-1
loglik = F.logsigmoid(labels * logits_per_text)
nll = -torch.sum(loglik, dim=-1)
loss = nll.mean()


return SiglipOutput(loss=loss, logits_per_text=logits_per_text, logits_per_image=logits_per_image, text_embeds=text_features, image_embeds=vision_features)




Binary file added train_siglip_from_scratch/siglip.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
90 changes: 90 additions & 0 deletions train_siglip_from_scratch/test.ipynb

Large diffs are not rendered by default.

48 changes: 48 additions & 0 deletions train_siglip_from_scratch/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from transformers import TrainingArguments, Trainer, default_data_collator
from model import SiglipModel, SiglipConfig
from dataset import SiglipDataset, MyDataCollator
from transformers import AutoTokenizer, AutoProcessor
from transformers import ViTImageProcessor, ViTForImageClassification

def train():

config = SiglipConfig(vision_model_name_or_path='/home/user/wyf/train_siglip_from_scratch/vit-base-patch16-224',
text_model_name_or_path='/home/user/wyf/chinese-roberta-wwm-ext')

model = SiglipModel(config)
tokenizer = AutoTokenizer.from_pretrained(config.text_model_name_or_path)
processor = AutoProcessor.from_pretrained(config.vision_model_name_or_path)

args = TrainingArguments(
output_dir='./outputs',
do_train=True,
per_device_train_batch_size=32,
learning_rate=1e-4,
num_train_epochs=40,
save_steps=2000,
save_total_limit=5,
fp16=True,
gradient_accumulation_steps=8,
logging_steps=100,
report_to='none',
dataloader_pin_memory=True,
dataloader_num_workers=1,
)
dataset = SiglipDataset(text_data_path='/home/user/wyf/train_siglip_from_scratch/MUGE/all_texts.jsonl',
image_data_path='/home/user/wyf/train_siglip_from_scratch/MUGE/all_imgs.tsv',
tokenizer=tokenizer,
processor=processor,
max_seq_length=64)

trainer = Trainer(
model=model,
args=args,
train_dataset=dataset,
data_collator=MyDataCollator(tokenizer)
)
trainer.train(resume_from_checkpoint=True)
trainer.save_model()
trainer.save_state()

if __name__ == '__main__':
train()
Binary file added train_siglip_from_scratch/伪代码.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 8731f00

Please sign in to comment.