Skip to content

Commit

Permalink
First upload
Browse files Browse the repository at this point in the history
  • Loading branch information
yhl authored and yhl committed Mar 4, 2023
1 parent f89dccb commit a983b11
Show file tree
Hide file tree
Showing 112 changed files with 12,630 additions and 19 deletions.
1 change: 1 addition & 0 deletions DATA/README.MD
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TODO
74 changes: 74 additions & 0 deletions README.MD
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@

# GLIGEN: Open-Set Grounded Text-to-Image Generation (CVPR 2023)

[Yuheng Li](https://yuheng-li.github.io/), [Haotian Liu](https://hliu.cc), [Qingyang Wu](https://scholar.google.ca/citations?user=HDiw-TsAAAAJ&hl=en/), [Fangzhou Mu](https://pages.cs.wisc.edu/~fmu/), [Jianwei Yang](https://jwyang.github.io/), [Jianfeng Gao](https://www.microsoft.com/en-us/research/people/jfgao/), [Chunyuan Li*](https://chunyuan.li/), [Yong Jae Lee*](https://pages.cs.wisc.edu/~yongjaelee/) (*Co-senior authors)

[[Project Page](https://gligen.github.io/)] [[Paper](https://arxiv.org/abs/2301.07093)] [[Demo](https://huggingface.co/spaces/gligen/demo)] [[YouTube Video](https://youtu.be/-MCkU7IAGKs)]
![Teaser figure](figures/concept.gif)

[![IMAGE ALT TEXT HERE](https://github.com/gligen/GLIGEN/blob/master/figures/teaser_v4.png)](https://youtu.be/-MCkU7IAGKs)

- Go beyond text prompt with GLIGEN: enable new capabilities on frozen text-to-image generation models to ground on various prompts, including box, keypoints and images.
- GLIGEN’s zero-shot performance on COCO and LVIS outperforms that of existing supervised layout-to-image baselines by a large margin.

## Requirements
We provide [dockerfile](env_docker/Dockerfile) to setup environment.


## Download GLIGEN models

We provide five checkpoints for different use scenarios. All models here are based on SD-V-1.4.
- 1) [text grounding for generation](https://huggingface.co/gligen/sd-gligen-base/blob/main/diffusion_pytorch_model.bin)
- 2) [text and image grounding for generation](https://huggingface.co/gligen/sd-gligen-image/blob/main/diffusion_pytorch_model.bin)
- 3) [keypoint grounding for generation](https://huggingface.co/gligen/sd-gligen-keypoint/blob/main/diffusion_pytorch_model.bin)
- 4) [text grounding for inpainting](https://huggingface.co/gligen/sd-gligen-keypoint/blob/main/diffusion_pytorch_model.bin)
- 5) [text and image grounding for inpainting](https://huggingface.co/gligen/sd-gligen-inpaint-image/blob/main/diffusion_pytorch_model.bin)

## Inference: Generate images with GLIGEN

We provide one script to generate images using provided checkpoints. First download models and put them in `gligen_checkpoints`. Then run
```bash
python gligen_inference.py
```
Example samples for each checkpoint will be saved in `generation_samples`. One can check `gligen_inference.py` for more details about interface.


## Training

### Grounded generation training

One need to first prepare data for different grounding modality conditions. Refer [data](DATA/README.MD) for the data we used for different GLIGEN models. Once data is ready, the following command is used to train GLIGEN. (We support multi-GPUs training)

```bash
ptyhon main.py --name=your_experiment_name --yaml_file=path_to_your_yaml_config
```
The `--yaml_file` is the most important argument and below we will use one example to explain key components so that one can be familiar with our code and know how to customize training on their own grounding modalities. The other args are self-explanatory by their names. The experiment will be saved in `OUTPUT_ROOT/name`

One can refer `configs/flicker_text.yaml` as one example. One can see that there are 5 components defining this yaml: **diffusion**, **model**, **autoencoder**, **text_encoder**, **train_dataset_names** and **grounding_tokenizer_input**. Typecially, **diffusion**, **autoencoder** and **text_encoder** should not be changed as they are defined by Stable Diffusion. One should pay attention to following:

- Within **model** we add new argument **grounding_tokenizer** which defines a network producing grounding tokens. This network will be instantized in the model. One can refer to `ldm/modules/diffusionmodules/grounding_net_example.py` for more details about defining this network.
- **grounding_tokenizer_input** will define a network taking in batch data from dataloader and produce input for the grounding_tokenizer. In other words, it is an intermediante class between dataloader and grounding_tokenizer. One can refer `grounding_input/__init__.py` for details about defining this class.
- **train_dataset_names** should be listing a serial of names of datasets (all datasets will be concatenated internally, thus it is useful to combine datasets for training). Each dataset name should be first registered in `dataset/catalog.py`. We have listed all dataset we used; if one needs to train GLIGEN on their own modality dataset, please don't forget first list its name there.


### Grounded inpainting training

GLIGEN also supports inpainting training. The following command can be used:
```bash
ptyhon main.py --name=your_experiment_name --yaml_file=path_to_your_yaml_config --inpaint_mode=True --ckpt=path_to_an_adapted_model
```
Typecially, we first train GLIGEN on generation task (e.g., text grounded generation) and this model has 4 channels for input conv (latent space of Stable Diffusion), then we modify the saved checkpoint to 9 channels with addition 5 channels initilized with 0. This continue training can lead to faster convergence and better results. path_to_an_adapted_model refers to this modified checkpoint, `convert_ckpt.py` can be used for modifying checkpoint. **NOTE:** yaml file is the same for generation and inpainting training, one only need to change `--inpaint_mode`

## Citation
```
@article{li2023gligen,
title={GLIGEN: Open-Set Grounded Text-to-Image Generation},
author={Li, Yuheng and Liu, Haotian and Wu, Qingyang and Mu, Fangzhou and Yang, Jianwei and Gao, Jianfeng and Li, Chunyuan and Lee, Yong Jae},
journal={CVPR},
year={2023}
}
```

## Disclaimer

The original GLIGEN was partly implemented and trained during an internship at Microsoft. This repo re-implements GLIGEN in PyTorch with university GPUs after the internship. Despite the minor implementation differences, this repo aims to reproduce the results and observations in the paper for research purposes.
19 changes: 0 additions & 19 deletions README.md

This file was deleted.

69 changes: 69 additions & 0 deletions configs/coco2017K.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
diffusion:
target: ldm.models.diffusion.ldm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.012
timesteps: 1000


model:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 64 # unused in the unet, but will be used when create xT
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
transformer_depth: 1
context_dim: 768
fuser_type: gatedSA # gatedCA or gatedSA
use_checkpoint: True

grounding_tokenizer:
target: ldm.modules.diffusionmodules.keypoint_grounding_net.PositionNet
params:
max_persons_per_image: 8 # must same as the one in dataset
out_dim: 768 # Not constrained to this, as one linear project is appiled at each Gated layer to match visual dimension


autoencoder:
target: ldm.models.autoencoder.AutoencoderKL
params:
scale_factor: 0.18215
embed_dim: 4
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0


text_encoder:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder




train_dataset_names:
COCO2017Keypoint:
image_size: 512
prob_real_caption: 1
max_persons_per_image: 8 # This must be same as the one in Model
random_flip: True


grounding_tokenizer_input:
target: grounding_input.keypoint_grounding_tokinzer_input.GroundingNetInput
71 changes: 71 additions & 0 deletions configs/flickr_text.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
diffusion:
target: ldm.models.diffusion.ldm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.012
timesteps: 1000


model:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 64 # unused in the unet, but will be used when create xT
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
transformer_depth: 1
context_dim: 768
fuser_type: gatedSA # gatedCA or gatedSA. We have ablate this, self-attention is better than cross-attention, thus please set this as gatedSA usually
use_checkpoint: True

grounding_tokenizer:
target: ldm.modules.diffusionmodules.text_grounding_net.PositionNet
params:
in_dim: 768 # this is pre-processing feature dim from CLIP Text encoder; penultimate feature
out_dim: 768 # Not constrained to this, as one linear project is appiled at each Gated layer to match visual dimension


autoencoder:
target: ldm.models.autoencoder.AutoencoderKL
params:
scale_factor: 0.18215
embed_dim: 4
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0


text_encoder:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder




train_dataset_names:
FlickrGrounding:
which_layer_text: before
image_size: 512
max_boxes_per_data: 30
prob_use_caption: 0.5
random_crop: False
random_flip: True


grounding_tokenizer_input:
target: grounding_input.text_grounding_tokinzer_input.GroundingNetInput
74 changes: 74 additions & 0 deletions configs/flickr_text_image.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
diffusion:
target: ldm.models.diffusion.ldm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.012
timesteps: 1000


model:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 64 # unused in the unet, but will be used when create xT
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
transformer_depth: 1
context_dim: 768
fuser_type: gatedSA # gatedCA or gatedSA
use_checkpoint: True

grounding_tokenizer:
target: ldm.modules.diffusionmodules.text_image_grounding_net.PositionNet
params:
in_dim: 768 # this is pre-processing feature dim from CLIP Text encoder; penultimate feature
out_dim: 768 # Not constrained to this, as one linear project is appiled at each Gated layer to match visual dimension


autoencoder:
target: ldm.models.autoencoder.AutoencoderKL
params:
scale_factor: 0.18215
embed_dim: 4
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0


text_encoder:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder



# This is second stage training. (Resume base model from O365+GoldG)
# randomly drop caption for all dataset (made caption used for O365)
train_dataset_names:
FlickrGrounding:
which_layer_text: before
which_layer_image: after_reproject
image_size: 512
max_boxes_per_data: 30
prob_use_caption: 0.5
random_drop_embedding: both
random_crop: False
random_flip: True


grounding_tokenizer_input:
target: grounding_input.text_image_grounding_tokinzer_input.GroundingNetInput
25 changes: 25 additions & 0 deletions convert_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
import argparse




parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, default=None, help="")
parser.add_argument("--new_ckpt_path", type=str, default=None, help="")
args = parser.parse_args()


new_conv_weight = torch.zeros(320, 4+4+1, 3, 3 )

ckpt = torch.load(args.ckpt_path, map_location="cpu")

for key,value in ckpt["model"].items():
if key == "input_blocks.0.0.weight":
old_conv_weight = value
new_conv_weight[:,0:4,:,:] = old_conv_weight
ckpt["model"]["input_blocks.0.0.weight"] = new_conv_weight

save = {"model":ckpt["model"]}
torch.save(save, args.new_ckpt_path)

Empty file added dataset/__init__.py
Empty file.
Binary file added dataset/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file added dataset/__pycache__/base_dataset.cpython-38.pyc
Binary file not shown.
Binary file not shown.
Binary file added dataset/__pycache__/catalog.cpython-38.pyc
Binary file not shown.
Binary file added dataset/__pycache__/cd_dataset.cpython-38.pyc
Binary file not shown.
Binary file added dataset/__pycache__/concat_dataset.cpython-38.pyc
Binary file not shown.
Binary file added dataset/__pycache__/dataset_kp.cpython-38.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added dataset/__pycache__/tsv.cpython-38.pyc
Binary file not shown.
Binary file added dataset/__pycache__/tsv_dataset.cpython-38.pyc
Binary file not shown.
Loading

0 comments on commit a983b11

Please sign in to comment.