-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
new file: train_siglip_from_scratch/clip.png
new file: train_siglip_from_scratch/data_process.ipynb new file: train_siglip_from_scratch/dataset.py new file: train_siglip_from_scratch/model.py new file: train_siglip_from_scratch/siglip.png new file: train_siglip_from_scratch/test.ipynb new file: train_siglip_from_scratch/train.py new file: "train_siglip_from_scratch/\344\274\252\344\273\243\347\240\201.png"
- Loading branch information
Showing
8 changed files
with
1,788 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from torch.utils.data import Dataset | ||
import json | ||
from PIL import Image | ||
import os | ||
import torch | ||
import pandas as pd | ||
from io import BytesIO | ||
import base64 | ||
from transformers import AutoTokenizer, AutoProcessor | ||
import random | ||
|
||
class SiglipDataset(Dataset): | ||
def __init__(self, text_data_path, | ||
image_data_path, | ||
tokenizer, | ||
processor, | ||
max_seq_length=64, | ||
): | ||
super().__init__() | ||
self.text_data_path = text_data_path | ||
self.image_data_path = image_data_path | ||
self.tokenizer = tokenizer | ||
self.processor = processor | ||
self.max_seq_length = max_seq_length | ||
with open(self.text_data_path, 'r', encoding='utf-8') as f: | ||
self.datas = [] | ||
lines = f.readlines() | ||
for line in lines: | ||
line = json.loads(line) | ||
for image_id in line['image_ids']: | ||
self.datas.append({'image_id': image_id, 'text': line['text']}) | ||
|
||
random.shuffle(self.datas) | ||
|
||
self.images = pd.read_csv(self.image_data_path, sep='\t', header=None) | ||
def __getitem__(self, index): | ||
|
||
sample = self.datas[index] | ||
|
||
image_id = sample['image_id'] | ||
text = sample['text'] | ||
tok = self.tokenizer(text, max_length=self.max_seq_length, padding='max_length', truncation=True) | ||
input_ids = tok['input_ids'] | ||
attention_mask = tok['attention_mask'] | ||
image_base64 = self.images[self.images[0]==image_id][1].values[0] | ||
image_bytes = base64.b64decode(image_base64) | ||
|
||
|
||
image = Image.open(BytesIO(image_bytes)).convert("RGB") | ||
pixel_values = self.processor(images=image, return_tensors='pt')['pixel_values'] | ||
|
||
return { | ||
'input_ids': input_ids, | ||
'attention_mask': attention_mask, | ||
'pixel_values': pixel_values | ||
} | ||
|
||
def __len__(self): | ||
return len(self.datas) | ||
|
||
class MyDataCollator: | ||
def __init__(self, tokenizer): | ||
self.tokenizer = tokenizer | ||
|
||
def __call__(self, features): | ||
input_ids = [f['input_ids'] for f in features] | ||
attention_mask = [f['attention_mask'] for f in features] | ||
pixel_values = [f['pixel_values'] for f in features] | ||
return { | ||
'input_ids': torch.tensor(input_ids), | ||
'attention_mask': torch.tensor(attention_mask), | ||
'pixel_values': torch.cat(pixel_values, dim=0) | ||
} | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
tokenizer = AutoTokenizer.from_pretrained('/home/user/wyf/chinese-roberta-wwm-ext') | ||
processor = AutoProcessor.from_pretrained('/home/user/wyf/train_siglip_from_scratch/vit-base-patch16-224') | ||
|
||
dataset = SiglipDataset(text_data_path='/home/user/wyf/train_siglip_from_scratch/MUGE/all_texts.jsonl', | ||
image_data_path='/home/user/wyf/train_siglip_from_scratch/MUGE/all_imgs.tsv', | ||
tokenizer=tokenizer, | ||
processor=processor, | ||
max_seq_length=64) | ||
|
||
print(len(dataset)) | ||
print(dataset[2]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoTokenizer, AutoProcessor | ||
from transformers import ViTImageProcessor, ViTForImageClassification | ||
|
||
import torch.nn as nn | ||
from transformers.utils import ModelOutput | ||
import torch | ||
import torch.nn.functional as F | ||
from dataclasses import dataclass | ||
|
||
@dataclass | ||
class SiglipOutput(ModelOutput): | ||
loss: torch.FloatTensor = None | ||
logits_per_text: torch.FloatTensor = None | ||
logits_per_image: torch.FloatTensor = None | ||
text_embeds: torch.FloatTensor = None | ||
image_embeds: torch.FloatTensor = None | ||
|
||
|
||
|
||
|
||
class SiglipConfig(PretrainedConfig): | ||
model_type = "siglip" | ||
def __init__( | ||
self, | ||
vision_model_name_or_path: str = "vit-base-patch16-224", | ||
text_model_name_or_path: str = "bert-base-chinese", | ||
**kwargs): | ||
super().__init__(**kwargs) | ||
self.vision_model_name_or_path = vision_model_name_or_path | ||
self.text_model_name_or_path = text_model_name_or_path | ||
|
||
|
||
|
||
class SiglipModel(PreTrainedModel): | ||
config_class = SiglipConfig | ||
def __init__(self, config: SiglipConfig): | ||
super().__init__(config) | ||
self.vision_model = AutoModel.from_pretrained(config.vision_model_name_or_path) | ||
self.process = AutoProcessor.from_pretrained(config.vision_model_name_or_path) | ||
self.text_model = AutoModel.from_pretrained(config.text_model_name_or_path) | ||
self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_name_or_path) | ||
self.t = nn.Parameter(torch.randn(1)) | ||
self.b = nn.Parameter(torch.randn(1)) | ||
|
||
|
||
def forward(self, input_ids, attention_mask, pixel_values): | ||
|
||
|
||
text_outputs = self.text_model(input_ids, attention_mask) | ||
vision_outputs = self.vision_model(pixel_values) | ||
|
||
|
||
vision_features = vision_outputs[1] # pooler_output | ||
text_features = text_outputs[1] # pooler_output | ||
|
||
vision_features = vision_features / vision_features.norm(p=2, dim=-1, keepdim=True) # l2标准化 | ||
text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True) # l2标准化 | ||
|
||
logits_per_text = torch.matmul(text_features, vision_features.t()) * self.t.exp() + self.b | ||
logits_per_image = logits_per_text.t() | ||
|
||
b = logits_per_text.shape[0] | ||
eye = torch.eye(b, device=logits_per_text.device) # 生成单位矩阵 | ||
labels = 2*eye - torch.ones_like(logits_per_text, device=logits_per_text.device) # 对角线全为1,非对角线为-1,即成对的图文标签为1,非成对的为-1 | ||
loglik = F.logsigmoid(labels * logits_per_text) | ||
nll = -torch.sum(loglik, dim=-1) | ||
loss = nll.mean() | ||
|
||
|
||
return SiglipOutput(loss=loss, logits_per_text=logits_per_text, logits_per_image=logits_per_image, text_embeds=text_features, image_embeds=vision_features) | ||
|
||
|
||
|
||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from transformers import TrainingArguments, Trainer, default_data_collator | ||
from model import SiglipModel, SiglipConfig | ||
from dataset import SiglipDataset, MyDataCollator | ||
from transformers import AutoTokenizer, AutoProcessor | ||
from transformers import ViTImageProcessor, ViTForImageClassification | ||
|
||
def train(): | ||
|
||
config = SiglipConfig(vision_model_name_or_path='/home/user/wyf/train_siglip_from_scratch/vit-base-patch16-224', | ||
text_model_name_or_path='/home/user/wyf/chinese-roberta-wwm-ext') | ||
|
||
model = SiglipModel(config) | ||
tokenizer = AutoTokenizer.from_pretrained(config.text_model_name_or_path) | ||
processor = AutoProcessor.from_pretrained(config.vision_model_name_or_path) | ||
|
||
args = TrainingArguments( | ||
output_dir='./outputs', | ||
do_train=True, | ||
per_device_train_batch_size=32, | ||
learning_rate=1e-4, | ||
num_train_epochs=40, | ||
save_steps=2000, | ||
save_total_limit=5, | ||
fp16=True, | ||
gradient_accumulation_steps=8, | ||
logging_steps=100, | ||
report_to='none', | ||
dataloader_pin_memory=True, | ||
dataloader_num_workers=1, | ||
) | ||
dataset = SiglipDataset(text_data_path='/home/user/wyf/train_siglip_from_scratch/MUGE/all_texts.jsonl', | ||
image_data_path='/home/user/wyf/train_siglip_from_scratch/MUGE/all_imgs.tsv', | ||
tokenizer=tokenizer, | ||
processor=processor, | ||
max_seq_length=64) | ||
|
||
trainer = Trainer( | ||
model=model, | ||
args=args, | ||
train_dataset=dataset, | ||
data_collator=MyDataCollator(tokenizer) | ||
) | ||
trainer.train(resume_from_checkpoint=True) | ||
trainer.save_model() | ||
trainer.save_state() | ||
|
||
if __name__ == '__main__': | ||
train() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.