Skip to content

Commit

Permalink
add simpler superres
Browse files Browse the repository at this point in the history
  • Loading branch information
justinpinkney committed Oct 13, 2022
1 parent 85bdb31 commit 9e23562
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 7 deletions.
5 changes: 3 additions & 2 deletions ldm/data/laion.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,10 @@ def test_dataloader(self):
import cv2

class AddLR(object):
def __init__(self, factor, output_size):
def __init__(self, factor, output_size, image_key="jpg"):
self.factor = factor
self.output_size = output_size
self.image_key = image_key

def pt2np(self, x):
x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy()
Expand All @@ -236,7 +237,7 @@ def np2pt(self, x):

def __call__(self, sample):
# sample['jpg'] is tensor hwc in [-1, 1] at this point
x = self.pt2np(sample['jpg'])
x = self.pt2np(sample[self.image_key])
x = degradation_fn_bsr_light(x, sf=self.factor)['image']
x = cv2.resize(x, (self.output_size, self.output_size), interpolation=2)
x = self.np2pt(x)
Expand Down
50 changes: 46 additions & 4 deletions ldm/data/simple.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Dict
import numpy as np
from omegaconf import DictConfig, ListConfig
import torch
from torch.utils.data import Dataset
from pathlib import Path
Expand All @@ -9,18 +11,53 @@
from ldm.util import instantiate_from_config
from datasets import load_dataset

def make_multi_folder_data(paths, **kwargs):
"""Make a concat dataset from multiple folders
Don't suport captions yet
If paths is a list, that's ok, if it's a Dict interpret it as:
k=folder v=n_times to repeat that
"""
list_of_paths = []
if isinstance(paths, (Dict, DictConfig)):
for folder_path, repeats in paths.items():
list_of_paths.extend([folder_path]*repeats)
paths = list_of_paths

datasets = [FolderData(p, **kwargs) for p in paths]
return torch.utils.data.ConcatDataset(datasets)

class FolderData(Dataset):
def __init__(self, root_dir, caption_file=None, image_transforms=[], ext="jpg") -> None:
def __init__(self,
root_dir,
caption_file=None,
image_transforms=[],
ext="jpg",
default_caption="",
postprocess=None,
) -> None:
"""Create a dataset from a folder of images.
If you pass in a root directory it will be searched for images
ending in ext (ext can be a list)
"""
self.root_dir = Path(root_dir)
self.default_caption = ""
self.default_caption = default_caption
if isinstance(postprocess, DictConfig):
postprocess = instantiate_from_config(postprocess)
self.postprocess = postprocess
if caption_file is not None:
with open(caption_file, "rt") as f:
captions = json.load(f)
self.captions = captions
else:
self.captions = None

self.paths = list(self.root_dir.rglob(f"*.{ext}"))
if not isinstance(ext, (tuple, list, ListConfig)):
ext = [ext]

self.paths = []
for e in ext:
self.paths.extend(list(self.root_dir.rglob(f"*.{e}")))
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
image_transforms.extend([transforms.ToTensor(),
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
Expand All @@ -38,7 +75,7 @@ def __len__(self):
def __getitem__(self, index):
if self.captions is not None:
chosen = list(self.captions.keys())[index]
caption = self.captions[chosen]
caption = self.captions.get(chosen, None)
if caption is None:
caption = self.default_caption
im = Image.open(self.root_dir/chosen)
Expand All @@ -49,6 +86,11 @@ def __getitem__(self, index):
data = {"image": im}
if self.captions is not None:
data["txt"] = caption
else:
data["txt"] = self.default_caption

if self.postprocess is not None:
data = self.postprocess(data)
return data

def process_im(self, im):
Expand Down
95 changes: 95 additions & 0 deletions ldm/models/diffusion/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,3 +1802,98 @@ def log_images(self, batch, N=8, *args, **kwargs):
cond_img = torch.stack(bbox_imgs, dim=0)
logs['bbox_image'] = cond_img
return logs


class SimpleUpscaleDiffusion(LatentDiffusion):
def __init__(self, *args, low_scale_key="LR", **kwargs):
super().__init__(*args, **kwargs)
# assumes that neither the cond_stage nor the low_scale_model contain trainable params
assert not self.cond_stage_trainable
self.low_scale_key = low_scale_key

@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
if not log_mode:
z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
else:
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
force_c_encode=True, return_original_cond=True, bs=bs)
x_low = batch[self.low_scale_key][:bs]
x_low = rearrange(x_low, 'b h w c -> b c h w')
x_low = x_low.to(memory_format=torch.contiguous_format).float()

encoder_posterior = self.encode_first_stage(x_low)
zx = self.get_first_stage_encoding(encoder_posterior).detach()
all_conds = {"c_concat": [zx], "c_crossattn": [c]}

if log_mode:
# TODO: maybe disable if too expensive
interpretability = False
if interpretability:
zx = zx[:, :, ::2, ::2]
return z, all_conds, x, xrec, xc, x_low
return z, all_conds

@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None

log = dict()
z, c, x, xrec, xc, x_low = self.get_input(batch, self.first_stage_key, bs=N, log_mode=True)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
log["x_lr"] = x_low

if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)
log["conditioning"] = xc
elif self.cond_stage_key == 'class_label':
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)

if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples

if unconditional_guidance_scale > 1.0:
uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
uc = dict()
for k in c:
if k == "c_crossattn":
assert isinstance(c[k], list) and len(c[k]) == 1
uc[k] = [uc_tmp]
elif isinstance(c[k], list):
uc[k] = [c[k][i] for i in range(len(c[k]))]
else:
uc[k] = c[k]

with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg


return log
4 changes: 3 additions & 1 deletion ldm/models/diffusion/plms.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def sample(self,
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
Expand Down

0 comments on commit 9e23562

Please sign in to comment.