Skip to content

Commit

Permalink
ImageReward
Browse files Browse the repository at this point in the history
  • Loading branch information
xujz18 committed Apr 13, 2023
1 parent d39aea5 commit 6d863e1
Show file tree
Hide file tree
Showing 25 changed files with 11,532 additions and 1 deletion.
138 changes: 138 additions & 0 deletions ImageReward/ImageReward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
'''
@File : ImageReward.py
@Time : 2023/01/28 19:53:00
@Auther : Jiazheng Xu
@Contact : xjz22@mails.tsinghua.edu.cn
@Description: ImageReward Reward model.
* Based on CLIP code base and improved-aesthetic-predictor code base
* https://github.com/openai/CLIP
* https://github.com/christophschuhmann/improved-aesthetic-predictor
'''

import torch
import torch.nn as nn
from PIL import Image
from .models.BLIP.blip_pretrain import BLIP_Pretrain
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC


def _convert_image_to_rgb(image):
return image.convert("RGB")


def _transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])


class MLP(nn.Module):
def __init__(self, input_size):
super().__init__()
self.input_size = input_size

self.layers = nn.Sequential(
nn.Linear(self.input_size, 1024),
#nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(1024, 128),
#nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 64),
#nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(64, 16),
#nn.ReLU(),
nn.Linear(16, 1)
)

# initial MLP param
for name, param in self.layers.named_parameters():
if 'weight' in name:
nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
if 'bias' in name:
nn.init.constant_(param, val=0)

def forward(self, input):
return self.layers(input)


class ImageReward(nn.Module):
def __init__(self, med_config, device='cpu'):
super().__init__()
self.device = device

self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config)
self.preprocess = _transform(224)
self.mlp = MLP(768)

self.mean = 0.16717362830052426
self.std = 1.0333394966054072


def score(self, prompt, image_path):

# text encode
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)

# image encode
pil_image = Image.open(image_path)
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
image_embeds = self.blip.visual_encoder(image)

# text encode cross attention with image
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device)
text_output = self.blip.text_encoder(text_input.input_ids,
attention_mask = text_input.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
)

txt_features = text_output.last_hidden_state[:,0,:].float() # (feature_dim)
rewards = self.mlp(txt_features)
rewards = (rewards - self.mean) / self.std

return rewards.cpu().numpy().item()


def inference_rank(self, prompt, generations_list):

text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)

txt_set = []
for generations in generations_list:
# image encode
img_path = generations
pil_image = Image.open(img_path)
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
image_embeds = self.blip.visual_encoder(image)

# text encode cross attention with image
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device)
text_output = self.blip.text_encoder(text_input.input_ids,
attention_mask = text_input.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
)
txt_set.append(text_output.last_hidden_state[:,0,:])

txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
rewards = self.mlp(txt_features) # [image_num, 1]
rewards = (rewards - self.mean) / self.std
_, rank = torch.sort(rewards, dim=0, descending=True)
_, indices = torch.sort(rank, dim=0)
indices = torch.squeeze(indices) + 1

return indices.cpu().numpy().tolist(), rewards.cpu().numpy().tolist()
1 change: 1 addition & 0 deletions ImageReward/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .utils import *
79 changes: 79 additions & 0 deletions ImageReward/models/AestheticScore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
'''
@File : AestheticScore.py
@Time : 2023/02/12 14:54:00
@Auther : Jiazheng Xu
@Contact : xjz22@mails.tsinghua.edu.cn
@Description: AestheticScore.
* Based on improved-aesthetic-predictor code base
* https://github.com/christophschuhmann/improved-aesthetic-predictor
'''

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import clip


# if you changed the MLP architecture during training, change it also here:
class MLP(nn.Module):
def __init__(self, input_size):
super().__init__()
self.input_size = input_size
self.layers = nn.Sequential(
nn.Linear(self.input_size, 1024),
#nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(1024, 128),
#nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 64),
#nn.ReLU(),
nn.Dropout(0.1),

nn.Linear(64, 16),
#nn.ReLU(),

nn.Linear(16, 1)
)

def forward(self, x):
return self.layers(x)


class AestheticScore(nn.Module):
def __init__(self, download_root, device='cpu'):
super().__init__()
self.device = device
self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device, jit=False,
download_root=download_root)
self.mlp = MLP(768)

if device == "cpu":
self.clip_model.float()
else:
clip.model.convert_weights(self.clip_model) # Actually this line is unnecessary since clip by default already on float16

# have clip.logit_scale require no grad.
self.clip_model.logit_scale.requires_grad_(False)


def inference_rank(self, prompt, generations_list):

img_set = []
for generations in generations_list:
# image encode
img_path = generations
pil_image = Image.open(img_path)
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
image_features = F.normalize(self.clip_model.encode_image(image))
img_set.append(image_features)

img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim]
rewards = self.mlp(img_features)
_, rank = torch.sort(rewards, dim=0, descending=True)
_, indices = torch.sort(rank, dim=0)
indices = torch.squeeze(indices) + 1

return indices.cpu().numpy().tolist(), rewards.cpu().numpy().tolist()
Empty file.
70 changes: 70 additions & 0 deletions ImageReward/models/BLIP/blip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
'''
* Adapted from BLIP (https://github.com/salesforce/BLIP)
'''

import warnings
warnings.filterwarnings("ignore")

import torch
import os
from urllib.parse import urlparse
from timm.models.hub import download_cached_file
from transformers import BertTokenizer
from .vit import VisionTransformer, interpolate_pos_embed


def init_tokenizer():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
return tokenizer


def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):

assert vit in ['base', 'large'], "vit parameter must be base or large"
if vit=='base':
vision_width = 768
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
drop_path_rate=0 or drop_path_rate
)
elif vit=='large':
vision_width = 1024
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
drop_path_rate=0.1 or drop_path_rate
)
return visual_encoder, vision_width


def is_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")

def load_checkpoint(model,url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
checkpoint = torch.load(cached_file, map_location='cpu')
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location='cpu')
else:
raise RuntimeError('checkpoint url or path is invalid')

state_dict = checkpoint['model']

state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
model.visual_encoder_m)
for key in model.state_dict().keys():
if key in state_dict.keys():
if state_dict[key].shape!=model.state_dict()[key].shape:
print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
del state_dict[key]

msg = model.load_state_dict(state_dict,strict=False)
print('load checkpoint from %s'%url_or_filename)
return model,msg

43 changes: 43 additions & 0 deletions ImageReward/models/BLIP/blip_pretrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
'''
* Adapted from BLIP (https://github.com/salesforce/BLIP)
'''

import transformers
transformers.logging.set_verbosity_error()

from torch import nn
import os
from .med import BertConfig, BertModel
from .blip import create_vit, init_tokenizer

class BLIP_Pretrain(nn.Module):
def __init__(self,
med_config = "med_config.json",
image_size = 224,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
embed_dim = 256,
queue_size = 57600,
momentum = 0.995,
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()

self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)

self.tokenizer = init_tokenizer()
encoder_config = BertConfig.from_json_file(med_config)
encoder_config.encoder_width = vision_width
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)

text_width = self.text_encoder.config.hidden_size

self.vision_proj = nn.Linear(vision_width, embed_dim)
self.text_proj = nn.Linear(text_width, embed_dim)

Loading

0 comments on commit 6d863e1

Please sign in to comment.