Skip to content

Commit

Permalink
remove slip
Browse files Browse the repository at this point in the history
  • Loading branch information
MSFTserver committed Apr 4, 2022
1 parent a1f25c7 commit c509aa1
Showing 1 changed file with 2 additions and 19 deletions.
21 changes: 2 additions & 19 deletions disco.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@
Remove Super Resolution
Remove SLIP Models
'''
)

Expand Down Expand Up @@ -439,7 +441,6 @@ def createPath(filepath):

if is_colab:
gitclone("https://github.com/openai/CLIP")
#gitclone("https://github.com/facebookresearch/SLIP.git")
gitclone("https://github.com/crowsonkb/guided-diffusion")
gitclone("https://github.com/assafshocher/ResizeRight.git")
gitclone("https://github.com/MSFTserver/pytorch3d-lite.git")
Expand Down Expand Up @@ -468,7 +469,6 @@ def createPath(filepath):
import sys
import torch

# sys.path.append('./SLIP')
sys.path.append('./pytorch3d-lite')
sys.path.append('./ResizeRight')
sys.path.append('./MiDaS')
Expand Down Expand Up @@ -496,7 +496,6 @@ def createPath(filepath):
sys.path.append('./guided-diffusion')
import clip
from resize_right import resize
# from models import SLIP_VITB16, SLIP, SLIP_VITL16
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
from datetime import datetime
import numpy as np
Expand Down Expand Up @@ -1636,8 +1635,6 @@ def forward(self, input, t):
RN50x4 = False #@param{type:"boolean"}
RN50x16 = False #@param{type:"boolean"}
RN50x64 = False #@param{type:"boolean"}
SLIPB16 = False #@param{type:"boolean"}
SLIPL16 = False #@param{type:"boolean"}

#@markdown If you're having issues with model downloads, check this to compare SHA's:
check_model_SHA = False #@param{type:"boolean"}
Expand Down Expand Up @@ -1771,20 +1768,6 @@ def forward(self, input, t):
if RN50x64 is True: clip_models.append(clip.load('RN50x64', jit=False)[0].eval().requires_grad_(False).to(device))
if RN101 is True: clip_models.append(clip.load('RN101', jit=False)[0].eval().requires_grad_(False).to(device))

if SLIPB16:
SLIPB16model = SLIP_VITB16(ssl_mlp_dim=4096, ssl_emb_dim=256)
if not os.path.exists(f'{model_path}/slip_base_100ep.pt'):
wget("https://dl.fbaipublicfiles.com/slip/slip_base_100ep.pt", model_path)
sd = torch.load(f'{model_path}/slip_base_100ep.pt')
real_sd = {}
for k, v in sd['state_dict'].items():
real_sd['.'.join(k.split('.')[1:])] = v
del sd
SLIPB16model.load_state_dict(real_sd)
SLIPB16model.requires_grad_(False).eval().to(device)

clip_models.append(SLIPL16model)

normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
lpips_model = lpips.LPIPS(net='vgg').to(device)

Expand Down

0 comments on commit c509aa1

Please sign in to comment.