An implementation of Encoder-based Domain Tuning for Fast Personalization of Text-to-Image Models by using d🧨ffusers.
My summary tweet is found here.
- Release the current-best pre-trained model, trained on CelebA-HQ+FFHQ. Please see Model Zoo for more information.
$ git clone https://github.com/mkshing/e4t-diffusion.git
$ cd e4t-diffusion
$ pip install -r requirements.txt
- e4t-diffusion-ffhq-celebahq-v1: a pre-trained model for face trained on FFHQ+CelebA-HQ. To get better results, I used Stable unCLIP as data augmentation.
logs at the pre-training phase "a photo of *s in the beach" after domain-tuning on a Yann LeCun's photo
You need a domain-specific E4T pre-trained model corresponding to your target image. If your target image is your face, you need to pre-train on a large face image dataset. Or, if you have an artistic image, you might want to train on WikiArt like so.
accelerate launch pretrain_e4t.py \
--pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \
--clip_model_name_or_path="ViT-H-14::laion2b_s32b_b79k" \
--domain_class_token="art" \
--placeholder_token="*s" \
--prompt_template="art" \
--save_sample_prompt="a photo of the *s,a photo of the *s in monet style" \
--reg_lambda=0.01 \
--domain_embed_scale=0.1 \
--output_dir="pretrained-wikiart" \
--train_image_dataset="Artificio/WikiArt" \
--iterable_dataset \
--resolution=512 \
--train_batch_size=16 \
--learning_rate=1e-6 --scale_lr \
--checkpointing_steps=10000 \
--log_steps=1000 \
--max_train_steps=100000 \
--unfreeze_clip_vision \
--mixed_precision="fp16" \
--enable_xformers_memory_efficient_attention
When you get a pre-trained model, you are ready for domain tuning! In this step, all parameters in addition to UNet itself (optionally text encoder) are trained. Unlike Dreambooth, E4T needs only <15 training steps according to the paper.
accelerate launch tuning_e4t.py \
--pretrained_model_name_or_path="e4t pre-trained model path" \
--prompt_template="a photo of {placeholder_token}" \
--reg_lambda=0.1 \
--output_dir="path-to-save-model" \
--train_image_path="image path or url" \
--resolution=512 \
--train_batch_size=16 \
--learning_rate=1e-6 --scale_lr \
--max_train_steps=30 \
--mixed_precision="fp16" \
--enable_xformers_memory_efficient_attention
Once your domain-tuning is done, you can do inference by including your placeholder token in the prompt.
python inference.py \
--pretrained_model_name_or_path "e4t pre-trained model path" \
--prompt "Times square in the style of *s" \
--num_images_per_prompt 3 \
--scheduler_type "ddim" \
--image_path_or_url "same image path or url as domain tuning" \
--num_inference_steps 50 \
--guidance_scale 7.5
I would like to thank Stability AI for providing the computer resources to test this code and train pre-trained models.
@misc{https://doi.org/10.48550/arXiv.2302.12228,
url = {https://arxiv.org/abs/2302.12228},
author = {Rinon Gal, Moab Arar, Yuval Atzmon, Amit H. Bermano, Gal Chechik, Daniel Cohen-Or},
title = {Encoder-based Domain Tuning for Fast Personalization of Text-to-Image Models},
publisher = {arXiv},
year = {2023},
copyright = {arXiv.org perpetual, non-exclusive license}
}
- Pre-training
- Domain-tuning
- Inference
- Data augmentation by stable unclip
- Use an off-the-shelf face segmentation network for human face domain.
Finally, we find that for the human face domain, it is helpful to use an off-the-shelf face segmentation network [Deng et al. 2019] to mask the diffusion loss at this stage.
- Support ToMe for more efficient training