Skip to content

Commit

Permalink
Finetuning (justinpinkney#11)
Browse files Browse the repository at this point in the history
* simple datasets

* add conversion script

* finish fine tune example

* update readme

* update readme
  • Loading branch information
justinpinkney authored Sep 16, 2022
1 parent 704f564 commit f1293f9
Show file tree
Hide file tree
Showing 11 changed files with 942 additions and 15 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
logs/
dump/
im-examples/
outputs/
flagged/
*.egg-info
Expand Down
28 changes: 24 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,33 @@
# Experiments with Stable Diffusion

## Image variations
This repository extends and adds to the [original training repo](https://github.com/pesser/stable-diffusion) for Stable Diffusion.

Currently it adds:

- [Fine tuning](#fine-tuning)
- [Image variations](#image-variations)
- [Conversion to Huggingface Diffusers](scripts/convert_sd_to_diffusers.py)

## Fine tuning

Makes it easy to fine tune Stable Diffusion on your own dataset. For example generating new Pokemon from text:

[![](assets/img-vars.jpg)](https://twitter.com/Buntworthy/status/1561703483316781057)
![](assets/pokemontage.jpg)

> Girl with a pearl earring, Cute Obama creature, Donald Trump, Boris Johnson, Totoro, Hello Kitty

For a step by step guide see the [Lambda Labs examples repo](https://github.com/LambdaLabsML/examples).

## Image variations

Try it out in colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1JqNbI_kDq_Gth2MIYdsphgNgyGIJxBgB?usp=sharing)
![](assets/im-vars-thin.jpg)

[![Open Demo](https://img.shields.io/badge/%CE%BB-Open%20Demo-blueviolet)](https://47725.gradio.app/)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1JqNbI_kDq_Gth2MIYdsphgNgyGIJxBgB?usp=sharing)
[![Open in Spaces](https://img.shields.io/badge/%F0%9F%A4%97-Open%20in%20Spaces-orange)]()

_TODO describe in more detail_
For more details on the Image Variation model see the [model card](https://huggingface.co/lambdalabs/stable-diffusion-image-conditioned).

- Get access to a Linux machine with a decent NVIDIA GPU (e.g. on [Lambda GPU Cloud](https://lambdalabs.com/service/gpu-cloud))
- Clone this repo
Expand Down
Binary file added assets/pokemontage.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
133 changes: 133 additions & 0 deletions configs/stable-diffusion/pokemon.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "image"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
scale_factor: 0.18215

scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]

unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False

first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder


data:
target: main.DataModuleFromConfig
params:
batch_size: 4
num_workers: 4
num_val_workers: 0 # Avoid a weird val dataloader issue
train:
target: ldm.data.simple.hf_dataset
params:
name: lambdalabs/pokemon-blip-captions
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.RandomCrop
params:
size: 512
- target: torchvision.transforms.RandomHorizontalFlip
validation:
target: ldm.data.simple.TextOnly
params:
captions:
- "A pokemon with green eyes, large wings, and a hat"
- "A cute bunny rabbit"
- "Yoda"
- "An epic landscape photo of a mountain"
output_size: 512
n_gpus: 2 # small hack to sure we see all our samples


lightning:
find_unused_parameters: False

modelcheckpoint:
params:
every_n_train_steps: 2000
save_top_k: -1
monitor: null

callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 2000
max_images: 4
increase_log_steps: False
log_first_step: True
log_all_val: True
log_images_kwargs:
use_ema_scope: True
inpaint: False
plot_progressive_rows: False
plot_diffusion_rows: False
N: 4
unconditional_guidance_scale: 3.0
unconditional_guidance_label: [""]

trainer:
benchmark: True
num_sanity_val_steps: 0
accumulate_grad_batches: 1
101 changes: 101 additions & 0 deletions ldm/data/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import numpy as np
import torch
from torch.utils.data import Dataset
from pathlib import Path
import json
from PIL import Image
from torchvision import transforms
from einops import rearrange
from ldm.util import instantiate_from_config
from datasets import load_dataset

class FolderData(Dataset):
def __init__(self, root_dir, caption_file, image_transforms, ext="jpg") -> None:
self.root_dir = Path(root_dir)
with open(caption_file, "rt") as f:
captions = json.load(f)
self.captions = captions

self.paths = list(self.root_dir.rglob(f"*.{ext}"))
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
image_transforms.extend([transforms.ToTensor(),
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
image_transforms = transforms.Compose(image_transforms)
self.tform = image_transforms

# assert all(['full/' + str(x.name) in self.captions for x in self.paths])

def __len__(self):
return len(self.captions.keys())

def __getitem__(self, index):
chosen = list(self.captions.keys())[index]
im = Image.open(self.root_dir/chosen)
im = self.process_im(im)
caption = self.captions[chosen]
if caption is None:
caption = "old book illustration"
return {"jpg": im, "txt": caption}

def process_im(self, im):
im = im.convert("RGB")
return self.tform(im)

def hf_dataset(
name,
image_transforms=[],
image_column="image",
text_column="text",
split='train',
image_key='image',
caption_key='txt',
):
"""Make huggingface dataset with appropriate list of transforms applied
"""
ds = load_dataset(name, split=split)
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
image_transforms.extend([transforms.ToTensor(),
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
tform = transforms.Compose(image_transforms)

assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}"

def pre_process(examples):
processed = {}
processed[image_key] = [tform(im) for im in examples[image_column]]
processed[caption_key] = examples[text_column]
return processed

ds.set_transform(pre_process)
return ds

class TextOnly(Dataset):
def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1):
"""Returns only captions with dummy images"""
self.output_size = output_size
self.image_key = image_key
self.caption_key = caption_key
if isinstance(captions, Path):
self.captions = self._load_caption_file(captions)
else:
self.captions = captions

if n_gpus > 1:
# hack to make sure that all the captions appear on each gpu
repeated = [n_gpus*[x] for x in self.captions]
self.captions = []
[self.captions.extend(x) for x in repeated]

def __len__(self):
return len(self.captions)

def __getitem__(self, index):
dummy_im = torch.zeros(3, self.output_size, self.output_size)
dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c')
return {self.image_key: dummy_im, self.caption_key: self.captions[index]}

def _load_caption_file(self, filename):
with open(filename, 'rt') as f:
captions = f.readlines()
return [x.strip('\n') for x in captions]
3 changes: 2 additions & 1 deletion ldm/models/diffusion/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def ddim_sampling(self, cond, shape,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
if callback: callback(i)
if callback:
img = callback(i, img, pred_x0)
if img_callback: img_callback(pred_x0, i)

if index % log_every_t == 0 or index == total_steps - 1:
Expand Down
12 changes: 9 additions & 3 deletions ldm/models/diffusion/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,9 +1343,8 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=
log["samples_x0_quantized"] = x_samples

if unconditional_guidance_scale > 1.0:
# uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
# FIXME
uc = torch.zeros_like(c)
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
# uc = torch.zeros_like(c)
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
Expand Down Expand Up @@ -1396,6 +1395,13 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
# FIXME JP
# params = []
# from ldm.modules.attention import CrossAttention
# for n, m in self.model.named_modules():
# if isinstance(m, CrossAttention) and n.endswith('attn2'):
# params.extend(m.parameters())
# END FIXME JP
if self.cond_stage_trainable:
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
params = params + list(self.cond_stage_model.parameters())
Expand Down
21 changes: 21 additions & 0 deletions ldm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,19 @@ def forward(self, text):
def encode(self, text):
return self(text)

class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
super().__init__()
self.embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)
self.projection = torch.nn.Linear(768, 768)

def forward(self, text):
z = self.embedder(text)
return self.projection(z)

def encode(self, text):
return self(text)

class FrozenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the CLIP image encoder.
Expand All @@ -192,6 +205,14 @@ def __init__(
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)

# I didn't call this originally, but seems like it was frozen anyway
self.freeze()

def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False

def preprocess(self, x):
# Expects inputs in the range -1, 1
x = kornia.geometry.resize(x, (224, 224),
Expand Down
Loading

0 comments on commit f1293f9

Please sign in to comment.