Skip to content

A re-implementation of ControlNeXt trained with shape guidance.

License

Notifications You must be signed in to change notification settings

AlonzoLeeeooo/ControlNeXt-svd-shape

Repository files navigation

Re-implementation of ControlNeXt with Shape Masks on SVD

An unofficial re-implementation of ControlNeXt with shape masks based on the SVD foundation model. Please refer to this link for the official implementation of ControlNeXt.

[Paper] [Official Implementation] [Hugging Face]

Table of Contents

Overview

This is a re-implementation of ControlNet trained with shape masks. If you have any suggestions about this repo, please feel free to start a new issue or propose a PR.

<🎯Back to Table of Contents>

To-Do List

  • Update basic documents
  • Update training and inference code
  • Update pre-trained model weights
  • Regular Maintainence

<🎯Back to Table of Contents>

Code Structure

ControlNeXt-svd-shape
├── LICENSE
├── README.md
├── dataset_loaders                  <----- Code of dataset functions
│   └── youtube_vos.py
├── inference_svd.py                 <----- Script to inference ControlNeXt model
├── models                           <----- Code of U-net and ControlNeXt models
├── pipeline                         <----- Code of pipeline functions
├── requirements.txt                 <----- Dependency list
├── runners                          <----- Code of runner functions
│   ├── __init__.py
│   ├── controlnext_inference_runner.py
│   └── controlnext_train_runner.py
├── train_svd.py                     <----- Script to train ControlNeXt model
└── utils                            <----- Code of toolkit functions

<🎯Back to Table of Contents>

Implementation Details

This re-implementation of ControlNeXt is trained on YouTube-VOS dataset. The official segmentation annotation of YouTube-VOS is used as the input condition of ControlNeXt. The overall pipeline is trained with 20,000 iterations, a batch size of 4, and bfloat16 precision to achieve its best performance. For optimization, the trainable parameters contain all to_k and to_v linear layers in the U-net and the model parameters from ControlNeXt, where this is different from the official implementation that unlocks the entire U-net.

<🎯Back to Table of Contents>

Prerequisites

  1. To install all the dependencies, you can run the one-click installation command line:
pip install -r requirements.txt
  1. Download YouTube-VOS from this link to prepare the training data.
  2. To prepare the pre-trained model weights of Stable Video Diffusion from this link. For our pre-trained ControlNeXt and U-net weights, you can refer to our HuggingFace repo.

<🎯Back to Table of Contents>

Training

Once the data and pre-trained model weights are ready, you can train the ControlNeXt model with the following command:

python train_svd.py --pretrained_model_name_or_path SVD_CHECKPOINTS_PATH --train_batch_size TRAIN_BATCH_SIZE --video_path YOUTUBE_VOS_FRAMES_PATH --shape_path YOUTUBE_VOS_ANNOTATION_PATH --output_dir OUTPUT_PATH --finetune_unet

You can refer to the following example command line:

python train_svd.py --pretrained_model_name_or_path checkpoints/svd_xt_1.1 --train_batch_size 4 --video_path youtube_vos/JPEGImages --shape_path youtube_vos/Annotations --annotation_path youtube_vos/meta.json --output_dir OUTPUT_PATH --finetune_unet

<🎯Back to Table of Contents>

Sampling

Once the ControlNeXt model is trained, you can inference it with the following command line:

python inference_svd.py --pretrained_model_name_or_path SVD_CHECKPOINTS_PATH --validation_control_images_folder INPUT_CONDITIONS_PATH --output_dir OUTPUT_PATH --checkpoint_dir CONTROLNEXT_PATH --ref_image_path REFERENCE_IMAGE_PATH

Note that the code differs from the official implementation that you do not need to merge the DeepSpeed checkpoint by running an additional script. All you need is to configure --checkpoint_dir. Normally, a checkpoint saved with the DeepSpeed engine should have similar structures as follows:

checkpoints
├── latest
├── pytorch_model
│   ├── bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
│   └── mp_rank_00_model_states.pt
├── random_states_0.pkl
├── scheduler.bin
└── zero_to_fp32.py

You need to configure --checkpoint_dir checkpoints/pytorch_model/mp_rank_00_model_states.pt, and allow the script to automatically convert the checkpoint to the format of pytorch_model.bin. You can refer to the following example command line:

python inference_svd.py --pretrained_model_name_or_path checkpoints/svd_xt_1.1 --validation_control_images_folder examples/frames/car --output_dir outputs/inference --checkpoint_dir checkpoints/pytorch_model/mp_rank_00_model_states.pt --ref_image_path examples/frames/car/00000.png

<🎯Back to Table of Contents>

Star History

Star History Chart

<🎯Back to Table of Contents>

About

A re-implementation of ControlNeXt trained with shape guidance.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages