Skip to content

Commit

Permalink
try multi embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
justinpinkney committed Dec 21, 2022
1 parent a8ac454 commit b5b1882
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions ldm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,63 @@ def forward(self, x):
def encode(self, im):
return self(im).unsqueeze(1)

from torchvision import transforms
import random

class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
"""
Uses the CLIP image encoder.
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
"""
def __init__(
self,
model='ViT-L/14',
jit=False,
device='cpu',
antialias=True,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
# We don't use the text part so delete it
del self.model.transformer
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.max_crops = 4

def preprocess(self, x):

# Expects inputs in the range -1, 1
randcrop = transforms.RandomCrop(224)
resize = transforms.Resize(224)
n = random.randint(0, self.max_crops)
crops = [randcrop(x) for _ in range(n)]
patches = [resize(x)]
patches.extend(crops)
x = torch.cat(patches, dim=0)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x

def forward(self, x):
# x is assumed to be in range [-1,1]
if isinstance(x, list):
# [""] denotes condition dropout for ucg
device = self.model.visual.conv1.weight.device
return torch.zeros(1, self.max_crops + 1, 768, device=device)
batch_tokens = []
for im in x:
patches = self.preprocess(im.unsqueeze(0))
tokens = self.model.encode_image(patches).float()
pad_amount = self.max_crops + 1 - tokens.shape[0]
batch_tokens.append(torch.nn.functional.pad(tokens, [0,0,0,pad_amount]).unsqueeze(0))

return torch.cat(batch_tokens, dim=0)

def encode(self, im):
return self(im)

class SpatialRescaler(nn.Module):
def __init__(self,
n_stages=1,
Expand Down

0 comments on commit b5b1882

Please sign in to comment.