An unofficial re-implementation of ControlNeXt with shape masks based on the SVD foundation model. Please refer to this link for the official implementation of ControlNeXt.
- 1. Overview
- 2. To-Do List
- 3. Code Structure
- 4. Implementation Details
- 5. Prerequisites
- 6. Training
- 7. Sampling
- 8. Results
- 9. Star History
This is a re-implementation of ControlNet trained with shape masks. If you have any suggestions about this repo, please feel free to start a new issue or propose a PR.
- Update basic documents
- Update training and inference code
- Update pre-trained model weights
- Regular Maintainence
ControlNeXt-svd-shape
├── LICENSE
├── README.md
├── dataset_loaders <----- Code of dataset functions
│ └── youtube_vos.py
├── inference_svd.py <----- Script to inference ControlNeXt model
├── models <----- Code of U-net and ControlNeXt models
├── pipeline <----- Code of pipeline functions
├── requirements.txt <----- Dependency list
├── runners <----- Code of runner functions
│ ├── __init__.py
│ ├── controlnext_inference_runner.py
│ └── controlnext_train_runner.py
├── train_svd.py <----- Script to train ControlNeXt model
└── utils <----- Code of toolkit functions
This re-implementation of ControlNeXt is trained on YouTube-VOS dataset.
The official segmentation annotation of YouTube-VOS is used as the input condition of ControlNeXt.
The overall pipeline is trained with 20,000
iterations, a batch size of 4
, and bfloat16
precision to achieve its best performance.
For optimization, the trainable parameters contain all to_k
and to_v
linear layers in the U-net and the model parameters from ControlNeXt, where this is different from the official implementation that unlocks the entire U-net.
- To install all the dependencies, you can run the one-click installation command line:
pip install -r requirements.txt
- Download YouTube-VOS from this link to prepare the training data.
- To prepare the pre-trained model weights of Stable Video Diffusion from this link. For our pre-trained ControlNeXt and U-net weights, you can refer to our HuggingFace repo.
Once the data and pre-trained model weights are ready, you can train the ControlNeXt model with the following command:
python train_svd.py --pretrained_model_name_or_path SVD_CHECKPOINTS_PATH --train_batch_size TRAIN_BATCH_SIZE --video_path YOUTUBE_VOS_FRAMES_PATH --shape_path YOUTUBE_VOS_ANNOTATION_PATH --output_dir OUTPUT_PATH --finetune_unet
You can refer to the following example command line:
python train_svd.py --pretrained_model_name_or_path checkpoints/svd_xt_1.1 --train_batch_size 4 --video_path youtube_vos/JPEGImages --shape_path youtube_vos/Annotations --annotation_path youtube_vos/meta.json --output_dir OUTPUT_PATH --finetune_unet
Once the ControlNeXt model is trained, you can inference it with the following command line:
python inference_svd.py --pretrained_model_name_or_path SVD_CHECKPOINTS_PATH --validation_control_images_folder INPUT_CONDITIONS_PATH --output_dir OUTPUT_PATH --checkpoint_dir CONTROLNEXT_PATH --ref_image_path REFERENCE_IMAGE_PATH
Note that the code differs from the official implementation that you do not need to merge the DeepSpeed checkpoint by running an additional script. All you need is to configure --checkpoint_dir
.
Normally, a checkpoint saved with the DeepSpeed
engine should have similar structures as follows:
checkpoints
├── latest
├── pytorch_model
│ ├── bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
│ └── mp_rank_00_model_states.pt
├── random_states_0.pkl
├── scheduler.bin
└── zero_to_fp32.py
You need to configure --checkpoint_dir checkpoints/pytorch_model/mp_rank_00_model_states.pt
, and allow the script to automatically convert the checkpoint to the format of pytorch_model.bin
.
You can refer to the following example command line:
python inference_svd.py --pretrained_model_name_or_path checkpoints/svd_xt_1.1 --validation_control_images_folder examples/frames/car --output_dir outputs/inference --checkpoint_dir checkpoints/pytorch_model/mp_rank_00_model_states.pt --ref_image_path examples/frames/car/00000.png