This codebase support:
- Training and testing of Stage1 and Stage2 models
- Customized dataset, model, loss
- Simple config-based customization
- Scalable TPU training
- Built-in gan loss, perceptual loss, online FID evaluation
- Wandb logging w/ visualization
We strongly suggest to use torch_xla >= 2.5.0
to avoid possible function bug when XLA_DISABLE_FUNCTIONALIZATION=1
is used for acceleration.
pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html
and
pip install -r requirements.txt
If you want to use the PSNR/SSIM metric under eval/
, you need to additionally do:
sudo apt-get install ffmpeg libsm6 libxext6 # required for opencv (cv2)
and pip install opencv-python
Also need to login to wandb:
export PATH=/home/{USER_NAME}/.local/bin:$PATH && wandb login {KEY}
The wandb login only needs to be done once on the pod/vm. But multiple times will not hurt.
A main illustration of the structure of the repo is:
.
├── configs
│ ├── xxx/xxx.yaml # config file for training/testing
|── ckpt_gcs # (optional) for storing checkpoints on GCS using GCSFuse
|── data # often soft-linked to PD
|── imagenet # (optional) for storing imagenet data
|── xxx # (optional) for storing other datasets
|── rqvae # main codebase
|── img_datasets # dataset classes
|── models # implementation of stage1/2 models
|── interfaces.py # abstract classes for models
|── eval # evaluation scripts
|── pytorch_fid/ # FID evaluation code for torch_XLA
|── psnr_ssims.py # PSNR/SSIM evaluation code
|── gpu_eval # evaluation scripts for GPU
|── fid/ # standard FID evaluation code (ADM suite)
The codebase is designed to be modular and extensible. The main pipeline is:
Dataset
All customized dataset instances should inherit from torch.utils.data.Dataset
and re-implement the __getitem__
to return a LabeledImageData
instance (defined in rqvae/img_datasets/interfaces.py
). Optionally, you can also implement the collate_fn
to customize the batch collation behavior. If not defined, it will used a default collation function rqvae.img_datasets.LabeledImageDatasetWrapper.default_collate_fn
.
To properly use the dataset, you need to modify rqvae/img_datasets/__init__.py
to include your dataset class (sorry for a stack of if/else statements, I will fix it to config-based soon).
Likewise, you also need to modify rqvae/img_datasets/transforms.py
to include your customized transforms. (again sorry for the stack of if/else statements)
LabeledImageData
is a @dataclass
with the following fields:
img
:torch.Tensor
(optional)condition
:Any
(optional)img_path
:str
(optional)additional_attr
:Optional[str]
(optional)
Stage1 Model
The main idea for Stage1 Model is that it encodes the input into a latent representation and decodes it back to a reconstruction.
Stage1 Model's behavior is defined (in rqvae/models/interfaces.py
) as:
forward
: it accepts aLabeledImageData
as input and returns aStage1ModelOutput
as output in forward pass (encoding and decoding)encode
it accepts aLabeledImageData
and returns aStage1Encodings
decode
it accepts aStage1Encodings
and returns aStage1ModelOutput
compute_loss
it accepts aStage1ModelOutput
(reconstruction) and aLabeledImageData
(input) and returns a dict for loss.
Stage1ModelOutput
is a @dataclass
with the following fields:
xs_recon
:torch.Tensor
(reconstruction)additional_attr
:dict
(optional) for storing possible additional attributes
Stage1Encodings
is a @dataclass
with the following fields:
zs
:torch.Tensor
(latent representation)additional_attr
:dict
(optional) for storing possible additional attributes likemu
,logvar
in VAE
For loss, it's a dict (sorry for no abstraction here, it should be one), with three keys:
{
'loss_total': torch.Tensor, # total loss
'loss_recon': torch.Tensor, # reconstruction loss like L1
'loss_latent': torch.Tensor, # latent loss like KL divergence
}
you should always set loss_total
to be the sum of loss_recon
and loss_latent
in compute_loss
. The actual total loss in training will be loss_total
.
Connector
The connector is a simple class that connects the Stage1 and Stage2 models. For example, you may want to do a reshape to reshape a 1D stage1 latent to 2D for stage2, and reshape it back for decoding. The connector is defined (in rqvae.models.interfaces.base_connector
) as:
forward
: it accepts aUnion[Stage1Encodings, Stage2ModelOutput]
and returns aStage1Encodings
. This function is called after Stage1 Model'sencode
and before Stage2 Model'sforward
.reverse
: it accepts aUnion[Stage1Encodings, Stage2ModelOutput]
and returns aStage1Encodings
. This function is called after Stage2 Model'sforward
and before Stage1 Model'sdecode
.
By default we have a base_connector
that does nothing but simply return the input. Note that this will return a Stage2ModelOutput
for forward
and reverse
if the input is a Stage2ModelOutput
, which is a bit invalid. So if you're using a base_connector
, you should make sure your Stage1 Model can accept a Stage2ModelOutput
as input.
Stage2 Model
The main idea for Stage2 Model is that it accepts the latent from stage1 and the input data, learn something from them, then output a latent for decoding in generation.
Stage2 Model's behavior is defined (in rqvae/models/interfaces.py
) as:
forward
: it accepts aStage1Encodings
and aLabeledImageData
as input and returns aStage2ModelOutput
as output in forward passcompute_loss
: it accepts aStage1Encodings
, aStage2ModelOutput
and aLabeledImageData
and returns a dict for loss. Onlyloss_total
is required here.infer
(generation): it accepts aLabeledImageData
(usually not needed as we're doing noise-to-image generation) and returns aStage2ModelOutput
for decoding.
Stage2ModelOutput
is a @dataclass
with the following fields:
zs_pred
:torch.Tensor
(generated latent)zs_degraded
:torch.Tensor
(degraded latent, for example, noised), often not neededadditional_attr
:dict
(optional) for storing possible additional attributes
Stage1 Training
Basic idea for stage1 training is an AE like training recipe. It feeds data into the Stage1 Model, get a loss from the loss function of the model, and additionally do a LPIPS and GAN loss. For one iteration, we update the optimizer for model once and the optimizer for the discriminator once. The training is defined in rqvae/trainers/stage1_trainer.py
.
Pipeline:
- Call dataloader to get a batch of data:
LabeledImageData
- Call Stage1 Model's
forward
to getStage1ModelOutput
- Call Stage1 Model's
compute_loss
withStage1ModelOutput
andLabeledImageData
to get loss (loss_total
,loss_recon
,loss_latent
) - Extract
xs_recon
fromStage1ModelOutput
andimg
fromLabeledImageData
for LPIPS loss and GAN loss - use
loss_total + loss_lpips * lpips_weight + loss_gan * gan_weight
as the total loss for Stage1 Model - Update the optimizer and scheduler for Stage1 Model
- Do another forward pass to get
Stage1ModelOutput
(for GAN) - Calculate the GAN loss for the discriminator
- Update the optimizer for the discriminator
note that we assume the xs_recon
, img
are in the range of [0, 1] to properly calculate GAN and LPIPS, you should do the normalization and de-normalization in the stage1 model.
Stage2 Training
We basically follow a diffusion training recipe for Stage2 Model. In every epoch we do a encoding to get the latent, then feed the Stage2 Model with the latent and the input data to get a prediction. We then calculate the loss and update the optimizer. The training is defined in rqvae/trainers/stage2_trainer.py
.
Pipeline:
- Call dataloader to get a batch of data:
LabeledImageData
- Call Stage1 Model's
encode
to getStage1Encodings
- Call Connector's
forward
to getStage1Encodings
for Stage2 Model - Call Stage2 Model's
forward
withStage1Encodings
andLabeledImageData
to getStage2ModelOutput
- Call Stage2 Model's
compute_loss
withStage1Encodings
,Stage2ModelOutput
andLabeledImageData
to get loss (loss_total
) - Update the optimizer and scheduler for Stage2 Model
(- Call inference
every 0.5 epoch for visualization in wandb)
Online FID Eval
We break down the FID eval into:
- InceptionV3 feature extraction, which is done in parallel on pod
- FID calculation, which is done on the master node after a feature all_reduce
To use the online FID eval, you need to first provide a npz file (by --fid_gt_act_path
in the training script) containing the InceptionV3 features for the Ground Truth data (how to ? TBD). Then set --do_online_eval
to True
in the training script. The FID will be logged to wandb for every eval step and the feature extracted will be saved.
Stage1 Inference
We use the test set of the dataset as the input for Stage1 Inference, and call forward
to get the reconstruction. We'll call get_recon_imgs
for image and reconstructions to get the final visualization (usually you just set it to return the input or simply do a clipping).
Stage2 Inference TBD
Configurations
The config file is defined as a yaml
file in configs/
(or any other directory you want). The config file should contain the following fields:
TBD
Currently, please see configs/imagenet256/stage1/klvae/f_16.yaml
for an example for Stage1 and configs/imagenet256/stage2/pixel/SID.yaml
for an example for Stage2.
For training we'll need all the fields, but for inference we only need the definition of dataset and model. Additional fields are simply ignored.
Define your Stage1, Stage2 models and put them in rqvae/models/
. Define your connectors and put them in rqvae/models/connectors.py
. Then define a config for your training recipe. You're then all set!
NOTE:
- You should change the
XLACACHE_PATH
in script to somewhere local you like (on pod/vm). This is temporary storage for XLA cache files and should not be uploaded to GCS or PD. Using the cache will significantly speed up the warm-up time for XLA after first run. export XLA_DISABLE_FUNCTIONALIZATION=1
is set by default for a 30-70% speedup in training (due to some optimizer bugs). If you want to disable it, simply comment out the line in the script. You MUST usetorch_xla >= 2.5.0
to avoid possible function bug whenXLA_DISABLE_FUNCTIONALIZATION=1
is used.
To train your Stage1 Model:
./train_stage1.sh [wand_project_name] [save_directory] [config_path] [world_size] [wandb_run_name] [resume_ckpt_path(optional)] [resume_wandb_id(optional)]
This will save the checkpoints to save_directory/wand_project_name/wandb_run_name_{TIME}/
and log to wandb with the name wandb_run_name_{TIME}
in the project wand_project_name
. By default, --do_online_eval
is set to True
for Stage1 training, so run ./train_stage1woeval.sh
with same params if you don't want to do online FID eval.
To do an inference for Stage1 Model (the input will be the test set):
./recon_stage1.sh [save_directory] [ckpt_path] [config] [world_size]
And we'll save all the reconstructions directly under save_directory/
To train your Stage2 Model:
TBD
To check if the Stage1 model is properly defined, you can run:
python sanity_check/stage1_connect.py [config_path] [image_shape]
image_shape
is a tuple of the shape for the input. If not set, it will default to (3, 256, 256)
. If the input is 2D (H,W)
, it will be reshaped to 3D with (3, H, W)
.
This will call the Stage1 Model with a predefined image (resized to image_shape
) and check every function in the model. If it doesn't raise any error, then your model is properly defined.
Likewise, call:
sanity_check/stage2_forward.py [config_path] [image_shape]
to check the Stage2 Model.