Skip to content

Commit

Permalink
changes for freezing and upscaling
Browse files Browse the repository at this point in the history
  • Loading branch information
justinpinkney committed Sep 30, 2022
1 parent 0eafc35 commit 31ab82d
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 10 deletions.
181 changes: 181 additions & 0 deletions configs/stable-diffusion/upscaling_256.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 32
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: "hybrid-adm"
monitor: val/loss_simple_ema
scale_factor: 0.18215
low_scale_key: "lr"

low_scale_config:
target: ldm.modules.encoders.modules.LowScaleEncoder
params:
scale_factor: 0.18215
linear_start: 0.00085
linear_end: 0.0120
timesteps: 1000
max_noise_level: 100
output_size: null
model_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity

scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]

unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
num_classes: 1000
image_size: 16 # unused
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False

first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder


data:
target: ldm.data.laion.WebDataModuleFromConfig
params:
tar_base: "/mnt/data_rome/laion/improved_aesthetics_6plus/ims"
batch_size: 50
num_workers: 8
multinode: True
train:
shards: '{00000..01209}.tar'
shuffle: 10000
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.RandomCrop
params:
size: 256
postprocess:
target: ldm.data.laion.AddLR
params:
factor: 4
output_size: 256

# NOTE use enough shards to avoid empty validation loops in workers
validation:
shards: '{00000..00012}.tar'
shuffle: 0
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.CenterCrop
params:
size: 256
postprocess:
target: ldm.data.laion.AddLR
params:
factor: 1
output_size: 256


lightning:
find_unused_parameters: False

modelcheckpoint:
params:
every_n_train_steps: 5000

callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 20
max_images: 4
increase_log_steps: False
log_first_step: False
log_images_kwargs:
use_ema_scope: False
inpaint: False
plot_progressive_rows: False
plot_diffusion_rows: False
N: 4
unconditional_guidance_scale: 3.0
unconditional_guidance_label: [""]

trainer:
benchmark: True
val_check_interval: 5000000 # really sorry
num_sanity_val_steps: 0
accumulate_grad_batches: 1
5 changes: 4 additions & 1 deletion ldm/data/laion.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,12 @@ def test_dataloader(self):


from ldm.modules.image_degradation import degradation_fn_bsr_light
import cv2

class AddLR(object):
def __init__(self, factor):
def __init__(self, factor, output_size):
self.factor = factor
self.output_size = output_size

def pt2np(self, x):
x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy()
Expand All @@ -236,6 +238,7 @@ def __call__(self, sample):
# sample['jpg'] is tensor hwc in [-1, 1] at this point
x = self.pt2np(sample['jpg'])
x = degradation_fn_bsr_light(x, sf=self.factor)['image']
x = cv2.resize(x, (self.output_size, self.output_size), interpolation=2)
x = self.np2pt(x)
sample['lr'] = x
return sample
Expand Down
20 changes: 11 additions & 9 deletions ldm/models/diffusion/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.modules.attention import CrossAttention


__conditioning_keys__ = {'concat': 'c_concat',
Expand Down Expand Up @@ -1396,17 +1397,18 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=

def configure_optimizers(self):
lr = self.learning_rate
if self.unet_trainable:
params = []
if self.unet_trainable == "attn":
print("Training only unet attention layers")
for n, m in self.model.named_modules():
if isinstance(m, CrossAttention) and n.endswith('attn2'):
params.extend(m.parameters())
elif self.unet_trainable is True or self.unet_trainable == "all":
print("Training the full unet")
params = list(self.model.parameters())
else:
params = []

# TODO allow certain parts trainables
# from ldm.modules.attention import CrossAttention
# for n, m in self.model.named_modules():
# if isinstance(m, CrossAttention) and n.endswith('attn2'):
# params.extend(m.parameters())
# END JP
raise ValueError(f"Unrecognised setting for unet_trainable: {self.unet_trainable}")

if self.cond_stage_trainable:
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
params = params + list(self.cond_stage_model.parameters())
Expand Down

0 comments on commit 31ab82d

Please sign in to comment.