forked from AUTOMATIC1111/stable-diffusion-webui
-
Notifications
You must be signed in to change notification settings - Fork 0
/
interrogate.py
167 lines (118 loc) · 6.1 KB
/
interrogate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import contextlib
import os
import sys
import traceback
from collections import namedtuple
import re
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared
from modules import devices, paths, lowvram
blip_image_eval_size = 384
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
clip_model_name = 'ViT-L/14'
Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.")
class InterrogateModels:
blip_model = None
clip_model = None
clip_preprocess = None
categories = None
dtype = None
def __init__(self, content_dir):
self.categories = []
if os.path.exists(content_dir):
for filename in os.listdir(content_dir):
m = re_topn.search(filename)
topn = 1 if m is None else int(m.group(1))
with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
lines = [x.strip() for x in file.readlines()]
self.categories.append(Category(name=filename, topn=topn, items=lines))
def load_blip_model(self):
import models.blip
blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
blip_model.eval()
return blip_model
def load_clip_model(self):
import clip
model, preprocess = clip.load(clip_model_name)
model.eval()
model = model.to(shared.device)
return model, preprocess
def load(self):
if self.blip_model is None:
self.blip_model = self.load_blip_model()
if not shared.cmd_opts.no_half:
self.blip_model = self.blip_model.half()
self.blip_model = self.blip_model.to(shared.device)
if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model()
if not shared.cmd_opts.no_half:
self.clip_model = self.clip_model.half()
self.clip_model = self.clip_model.to(shared.device)
self.dtype = next(self.clip_model.parameters()).dtype
def send_clip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory:
if self.clip_model is not None:
self.clip_model = self.clip_model.to(devices.cpu)
def send_blip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory:
if self.blip_model is not None:
self.blip_model = self.blip_model.to(devices.cpu)
def unload(self):
self.send_clip_to_ram()
self.send_blip_to_ram()
devices.torch_gc()
def rank(self, image_features, text_array, top_count=1):
import clip
if shared.opts.interrogate_clip_dict_limit != 0:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = torch.zeros((1, len(text_array))).to(shared.device)
for i in range(image_features.shape[0]):
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
similarity /= image_features.shape[0]
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
def generate_caption(self, pil_image):
gpu_image = transforms.Compose([
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
with torch.no_grad():
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
return caption[0]
def interrogate(self, pil_image):
res = None
try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
devices.torch_gc()
self.load()
caption = self.generate_caption(pil_image)
self.send_blip_to_ram()
devices.torch_gc()
res = caption
cilp_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
with torch.no_grad(), precision_scope("cuda"):
image_features = self.clip_model.encode_image(cilp_image).type(self.dtype)
image_features /= image_features.norm(dim=-1, keepdim=True)
if shared.opts.interrogate_use_builtin_artists:
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
res += ", " + artist[0]
for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
res += ", " + match
except Exception:
print(f"Error interrogating", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
res += "<error>"
self.unload()
return res