Skip to content

Commit

Permalink
update pypi image-reward
Browse files Browse the repository at this point in the history
  • Loading branch information
xujz18 committed Apr 16, 2023
1 parent de478e5 commit 0f5fa3e
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 9 deletions.
11 changes: 8 additions & 3 deletions ImageReward/ImageReward.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def __init__(self, med_config, device='cpu'):

def score(self, prompt, image_path):

if (type(image_path).__name__=='list'):
_, rewards = self.inference_rank(prompt, image_path)
return rewards

# text encode
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)

Expand All @@ -103,7 +107,7 @@ def score(self, prompt, image_path):
rewards = self.mlp(txt_features)
rewards = (rewards - self.mean) / self.std

return rewards.cpu().numpy().item()
return rewards.detach().cpu().numpy().item()


def inference_rank(self, prompt, generations_list):
Expand Down Expand Up @@ -131,8 +135,9 @@ def inference_rank(self, prompt, generations_list):
txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
rewards = self.mlp(txt_features) # [image_num, 1]
rewards = (rewards - self.mean) / self.std
rewards = torch.squeeze(rewards)
_, rank = torch.sort(rewards, dim=0, descending=True)
_, indices = torch.sort(rank, dim=0)
indices = torch.squeeze(indices) + 1
indices = indices + 1

return indices.cpu().numpy().tolist(), rewards.cpu().numpy().tolist()
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
5 changes: 3 additions & 2 deletions ImageReward/models/AestheticScore.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ def inference_rank(self, prompt, generations_list):

img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim]
rewards = self.mlp(img_features)
rewards = torch.squeeze(rewards)
_, rank = torch.sort(rewards, dim=0, descending=True)
_, indices = torch.sort(rank, dim=0)
indices = torch.squeeze(indices) + 1
indices = indices + 1

return indices.cpu().numpy().tolist(), rewards.cpu().numpy().tolist()
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
5 changes: 3 additions & 2 deletions ImageReward/models/BLIPScore.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def inference_rank(self, prompt, generations_list):
txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim]
rewards = torch.sum(torch.mul(txt_features, img_features), dim=1, keepdim=True)
rewards = torch.squeeze(rewards)
_, rank = torch.sort(rewards, dim=0, descending=True)
_, indices = torch.sort(rank, dim=0)
indices = torch.squeeze(indices) + 1
indices = indices + 1

return indices.cpu().numpy().tolist(), rewards.cpu().numpy().tolist()
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
5 changes: 3 additions & 2 deletions ImageReward/models/CLIPScore.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def inference_rank(self, prompt, generations_list):
txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim]
rewards = torch.sum(torch.mul(txt_features, img_features), dim=1, keepdim=True)
rewards = torch.squeeze(rewards)
_, rank = torch.sort(rewards, dim=0, descending=True)
_, indices = torch.sort(rank, dim=0)
indices = torch.squeeze(indices) + 1
indices = indices + 1

return indices.cpu().numpy().tolist(), rewards.cpu().numpy().tolist()
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()

0 comments on commit 0f5fa3e

Please sign in to comment.