Skip to content

Commit

Permalink
Updated readme
Browse files Browse the repository at this point in the history
  • Loading branch information
kpandey008 committed Aug 8, 2022
1 parent a4c3898 commit fdd83bf
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 118 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ Please refer to the scripts provided in the table corresponding to some inferenc
| Generate reconstructions from DiffuseVAE | `scripts/test_recons_ddpm.sh` |
| Interpolate in the VAE/DDPM latent space using DiffuseVAE | `scripts/interpolate.sh` |

For computing the evaluation metrics (FID, IS etc.), we use the [torch-fidelity](https://github.com/toshas/torch-fidelity) package. See `scripts/fid.sh` for some sample usage examples.


## Pretrained checkpoints
All pretrained checkpoints have been organized by dataset and can be accessed [here](https://drive.google.com/drive/folders/1GzIh75NnpgPa4A1hSb_viPowuaSHnL7R?usp=sharing).
Expand Down
2 changes: 1 addition & 1 deletion scripts/expde.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
# python main/expde.py fit-gmm ~/afhq256_dog_latents/latents_afhq.npy --save-path '/data1/kushagrap20/afhq256_dog_latents/gmm_z/' --n-components 100

# Fit GMM CelebA-HQ-256
python main/expde.py fit-gmm ~/celebahq_latents/latents_celebahq.npy --save-path '/data1/kushagrap20/celebahq_latents/gmm_z/' --n-components 150
# python main/expde.py fit-gmm ~/celebahq_latents/latents_celebahq.npy --save-path '/data1/kushagrap20/celebahq_latents/gmm_z/' --n-components 150
46 changes: 23 additions & 23 deletions scripts/interpolate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,26 @@
# dataset.vae.evaluation.expde_model_path=\'/data1/kushagrap20/cmhq128_latents/gmm_z/gmm_100.joblib\' \
# dataset.ddpm.data.ddpm_latent_path=\'/data1/kushagrap20/ddpm_latents.pt\' \

python main/eval/ddpm/interpolate_ddpm.py +dataset=celebamaskhq128/test \
dataset.ddpm.data.norm=True \
dataset.ddpm.model.attn_resolutions=\'16,\' \
dataset.ddpm.model.dropout=0.1 \
dataset.ddpm.model.n_residual=2 \
dataset.ddpm.model.dim_mults=\'1,2,2,3,4\' \
dataset.ddpm.model.n_heads=8 \
dataset.ddpm.evaluation.guidance_weight=0.0 \
dataset.ddpm.evaluation.seed=1 \
dataset.ddpm.evaluation.chkpt_path=\'/data1/kushagrap20/diffusevae_rework/cmhq/ddpmv2-cmhq128_rework_form1_18thJune_sota_nheads=8_dropout=0.1-epoch=999-loss=0.0066.ckpt\' \
dataset.ddpm.evaluation.type='form1' \
dataset.ddpm.evaluation.resample_strategy='truncated' \
dataset.ddpm.evaluation.skip_strategy='quad' \
dataset.ddpm.evaluation.sample_method='ddpm' \
dataset.ddpm.evaluation.sample_from='target' \
dataset.ddpm.evaluation.temp=1.0 \
dataset.ddpm.evaluation.save_path=\'/data1/kushagrap20/diffusevae_rework_experiments/linear_interpolate_ddpm/form1_fixed/\' \
dataset.ddpm.evaluation.z_cond=False \
dataset.ddpm.evaluation.n_steps=1000 \
dataset.ddpm.evaluation.save_vae=True \
dataset.ddpm.evaluation.workers=1 \
dataset.vae.evaluation.chkpt_path=\'/data1/kushagrap20/diffusevae_rework/cmhq/vae-cmhq128_alpha=1.0-epoch=499-train_loss=0.0000.ckpt\' \
dataset.vae.evaluation.expde_model_path=\'/data1/kushagrap20/cmhq128_latents/gmm_z/gmm_100.joblib\' \
# python main/eval/ddpm/interpolate_ddpm.py +dataset=celebamaskhq128/test \
# dataset.ddpm.data.norm=True \
# dataset.ddpm.model.attn_resolutions=\'16,\' \
# dataset.ddpm.model.dropout=0.1 \
# dataset.ddpm.model.n_residual=2 \
# dataset.ddpm.model.dim_mults=\'1,2,2,3,4\' \
# dataset.ddpm.model.n_heads=8 \
# dataset.ddpm.evaluation.guidance_weight=0.0 \
# dataset.ddpm.evaluation.seed=1 \
# dataset.ddpm.evaluation.chkpt_path=\'/data1/kushagrap20/diffusevae_rework/cmhq/ddpmv2-cmhq128_rework_form1_18thJune_sota_nheads=8_dropout=0.1-epoch=999-loss=0.0066.ckpt\' \
# dataset.ddpm.evaluation.type='form1' \
# dataset.ddpm.evaluation.resample_strategy='truncated' \
# dataset.ddpm.evaluation.skip_strategy='quad' \
# dataset.ddpm.evaluation.sample_method='ddpm' \
# dataset.ddpm.evaluation.sample_from='target' \
# dataset.ddpm.evaluation.temp=1.0 \
# dataset.ddpm.evaluation.save_path=\'/data1/kushagrap20/diffusevae_rework_experiments/linear_interpolate_ddpm/form1_fixed/\' \
# dataset.ddpm.evaluation.z_cond=False \
# dataset.ddpm.evaluation.n_steps=1000 \
# dataset.ddpm.evaluation.save_vae=True \
# dataset.ddpm.evaluation.workers=1 \
# dataset.vae.evaluation.chkpt_path=\'/data1/kushagrap20/diffusevae_rework/cmhq/vae-cmhq128_alpha=1.0-epoch=499-train_loss=0.0000.ckpt\' \
# dataset.vae.evaluation.expde_model_path=\'/data1/kushagrap20/cmhq128_latents/gmm_z/gmm_100.joblib\' \
32 changes: 16 additions & 16 deletions scripts/test_ae.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
# ~/vae_celeba64_alpha\=1.0/checkpoints/vae-celeba64_alpha\=1.0-epoch\=245-train_loss\=0.0000.ckpt \
# ~/datasets/img_align_celeba/

python main/test.py reconstruct --device gpu:0 \
--dataset ffhq \
--image-size 128 \
--num-samples 64 \
--save-path ~/vae_samples_ffhq128_deletem_recons/ \
--write-mode image \
/data1/kushagrap20/vae_ffhq128_11thJune_alpha\=1.0/checkpoints/vae-ffhq128_11thJune_alpha\=1.0-epoch\=496-train_loss\=0.0000.ckpt \
~/datasets/ffhq/
# python main/test.py reconstruct --device gpu:0 \
# --dataset ffhq \
# --image-size 128 \
# --num-samples 64 \
# --save-path ~/vae_samples_ffhq128_deletem_recons/ \
# --write-mode image \
# /data1/kushagrap20/vae_ffhq128_11thJune_alpha\=1.0/checkpoints/vae-ffhq128_11thJune_alpha\=1.0-epoch\=496-train_loss\=0.0000.ckpt \
# ~/datasets/ffhq/

# python main/test.py sample --device gpu:0 \
# --image-size 32 \
Expand All @@ -24,14 +24,14 @@ python main/test.py reconstruct --device gpu:0 \
# 512 \
# /data1/kushagrap20/checkpoints/cifar10/vae-cifar10-epoch=500-train_loss=0.00.ckpt \

python main/test.py sample --device gpu:0 \
--image-size 128 \
--seed 0 \
--num-samples 64 \
--save-path ~/vae_samples_ffhq128_deletem/ \
--write-mode image \
1024 \
/data1/kushagrap20/vae_ffhq128_11thJune_alpha\=1.0/checkpoints/vae-ffhq128_11thJune_alpha\=1.0-epoch\=496-train_loss\=0.0000.ckpt \
# python main/test.py sample --device gpu:0 \
# --image-size 128 \
# --seed 0 \
# --num-samples 64 \
# --save-path ~/vae_samples_ffhq128_deletem/ \
# --write-mode image \
# 1024 \
# /data1/kushagrap20/vae_ffhq128_11thJune_alpha\=1.0/checkpoints/vae-ffhq128_11thJune_alpha\=1.0-epoch\=496-train_loss\=0.0000.ckpt \


# python main/test.py reconstruct --device gpu:0 \
Expand Down
60 changes: 30 additions & 30 deletions scripts/test_recons_ddpm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,33 @@
# dataset.ddpm.evaluation.workers=1 \
# dataset.vae.evaluation.chkpt_path=\'/data1/kushagrap20/diffusevae_rework/celebahq128/vae-cmhq128_alpha=1.0-epoch=499-train_loss=0.0000.ckpt\'

python main/eval/ddpm/generate_recons.py +dataset=cifar10/test \
dataset.ddpm.data.root='/data1/kushagrap20/datasets/' \
dataset.ddpm.data.name='cifar10' \
dataset.ddpm.data.norm=True \
dataset.ddpm.data.hflip=False \
dataset.ddpm.model.attn_resolutions=\'16,\' \
dataset.ddpm.model.dropout=0.3 \
dataset.ddpm.model.n_residual=2 \
dataset.ddpm.model.dim_mults=\'1,2,2,2\' \
dataset.ddpm.model.n_heads=8 \
dataset.ddpm.evaluation.guidance_weight=0.0 \
dataset.ddpm.evaluation.seed=0 \
dataset.ddpm.evaluation.sample_prefix='gpu_0' \
dataset.ddpm.evaluation.device=\'gpu:0\' \
dataset.ddpm.evaluation.save_mode='image' \
dataset.ddpm.evaluation.chkpt_path=\'/data1/kushagrap20/diffusevae_rework/cifar10/diffusevae_cifar10_rework_form1__7thJune_sota_nheads=8_dropout=0.3/checkpoints/ddpmv2-cifar10_rework_form1__7thJune_sota_nheads=8_dropout=0.3-epoch=2850-loss=0.0299.ckpt\' \
dataset.ddpm.evaluation.type='form1' \
dataset.ddpm.evaluation.resample_strategy='truncated' \
dataset.ddpm.evaluation.skip_strategy='quad' \
dataset.ddpm.evaluation.sample_method='ddpm' \
dataset.ddpm.evaluation.sample_from='target' \
dataset.ddpm.evaluation.temp=1.0 \
dataset.ddpm.evaluation.batch_size=16 \
dataset.ddpm.evaluation.save_path=\'/data1/kushagrap20/ddpm_cifar10_recons_superres/\' \
dataset.ddpm.evaluation.z_cond=False \
dataset.ddpm.evaluation.n_samples=1000 \
dataset.ddpm.evaluation.n_steps=1000 \
dataset.ddpm.evaluation.save_vae=True \
dataset.ddpm.evaluation.workers=1 \
dataset.vae.evaluation.chkpt_path=\'/data1/kushagrap20/checkpoints/cifar10/vae-cifar10-epoch=500-train_loss=0.00.ckpt\'
# python main/eval/ddpm/generate_recons.py +dataset=cifar10/test \
# dataset.ddpm.data.root='/data1/kushagrap20/datasets/' \
# dataset.ddpm.data.name='cifar10' \
# dataset.ddpm.data.norm=True \
# dataset.ddpm.data.hflip=False \
# dataset.ddpm.model.attn_resolutions=\'16,\' \
# dataset.ddpm.model.dropout=0.3 \
# dataset.ddpm.model.n_residual=2 \
# dataset.ddpm.model.dim_mults=\'1,2,2,2\' \
# dataset.ddpm.model.n_heads=8 \
# dataset.ddpm.evaluation.guidance_weight=0.0 \
# dataset.ddpm.evaluation.seed=0 \
# dataset.ddpm.evaluation.sample_prefix='gpu_0' \
# dataset.ddpm.evaluation.device=\'gpu:0\' \
# dataset.ddpm.evaluation.save_mode='image' \
# dataset.ddpm.evaluation.chkpt_path=\'/data1/kushagrap20/diffusevae_rework/cifar10/diffusevae_cifar10_rework_form1__7thJune_sota_nheads=8_dropout=0.3/checkpoints/ddpmv2-cifar10_rework_form1__7thJune_sota_nheads=8_dropout=0.3-epoch=2850-loss=0.0299.ckpt\' \
# dataset.ddpm.evaluation.type='form1' \
# dataset.ddpm.evaluation.resample_strategy='truncated' \
# dataset.ddpm.evaluation.skip_strategy='quad' \
# dataset.ddpm.evaluation.sample_method='ddpm' \
# dataset.ddpm.evaluation.sample_from='target' \
# dataset.ddpm.evaluation.temp=1.0 \
# dataset.ddpm.evaluation.batch_size=16 \
# dataset.ddpm.evaluation.save_path=\'/data1/kushagrap20/ddpm_cifar10_recons_superres/\' \
# dataset.ddpm.evaluation.z_cond=False \
# dataset.ddpm.evaluation.n_samples=1000 \
# dataset.ddpm.evaluation.n_steps=1000 \
# dataset.ddpm.evaluation.save_vae=True \
# dataset.ddpm.evaluation.workers=1 \
# dataset.vae.evaluation.chkpt_path=\'/data1/kushagrap20/checkpoints/cifar10/vae-cifar10-epoch=500-train_loss=0.00.ckpt\'
96 changes: 48 additions & 48 deletions scripts/train_ae.sh
Original file line number Diff line number Diff line change
@@ -1,52 +1,52 @@
# CelebAMaskHQ training
python main/train_ae.py +dataset=celebamaskhq128/train \
dataset.vae.data.root='/data1/kushagrap20/datasets/CelebAMask-HQ/' \
dataset.vae.data.name='celebamaskhq' \
dataset.vae.data.hflip=True \
dataset.vae.training.batch_size=42 \
dataset.vae.training.log_step=50 \
dataset.vae.training.epochs=500 \
dataset.vae.training.device=\'gpu:0,1,3\' \
dataset.vae.training.results_dir=\'/data1/kushagrap20/vae_cmhq128_alpha=1.0/\' \
dataset.vae.training.workers=2 \
dataset.vae.training.chkpt_prefix=\'cmhq128_alpha=1.0\' \
dataset.vae.training.alpha=1.0
# # CelebAMaskHQ training
# python main/train_ae.py +dataset=celebamaskhq128/train \
# dataset.vae.data.root='/data1/kushagrap20/datasets/CelebAMask-HQ/' \
# dataset.vae.data.name='celebamaskhq' \
# dataset.vae.data.hflip=True \
# dataset.vae.training.batch_size=42 \
# dataset.vae.training.log_step=50 \
# dataset.vae.training.epochs=500 \
# dataset.vae.training.device=\'gpu:0,1,3\' \
# dataset.vae.training.results_dir=\'/data1/kushagrap20/vae_cmhq128_alpha=1.0/\' \
# dataset.vae.training.workers=2 \
# dataset.vae.training.chkpt_prefix=\'cmhq128_alpha=1.0\' \
# dataset.vae.training.alpha=1.0

# FFHQ 128 training
python main/train_ae.py +dataset=ffhq/train \
dataset.vae.data.root='/data1/kushagrap20/datasets/ffhq/' \
dataset.vae.data.name='ffhq' \
dataset.vae.data.hflip=True \
dataset.vae.training.batch_size=32 \
dataset.vae.training.log_step=50 \
dataset.vae.training.epochs=1500 \
dataset.vae.training.device=\'gpu:0,1,2,3\' \
dataset.vae.training.results_dir=\'/data1/kushagrap20/vae_ffhq128_11thJune_alpha=1.0/\' \
dataset.vae.training.workers=2 \
dataset.vae.training.chkpt_prefix=\'ffhq128_11thJune_alpha=1.0\' \
dataset.vae.training.alpha=1.0
# # FFHQ 128 training
# python main/train_ae.py +dataset=ffhq/train \
# dataset.vae.data.root='/data1/kushagrap20/datasets/ffhq/' \
# dataset.vae.data.name='ffhq' \
# dataset.vae.data.hflip=True \
# dataset.vae.training.batch_size=32 \
# dataset.vae.training.log_step=50 \
# dataset.vae.training.epochs=1500 \
# dataset.vae.training.device=\'gpu:0,1,2,3\' \
# dataset.vae.training.results_dir=\'/data1/kushagrap20/vae_ffhq128_11thJune_alpha=1.0/\' \
# dataset.vae.training.workers=2 \
# dataset.vae.training.chkpt_prefix=\'ffhq128_11thJune_alpha=1.0\' \
# dataset.vae.training.alpha=1.0

# AFHQv2 training
python main/train_ae.py +dataset=afhq256/train \
dataset.vae.data.root='/data1/kushagrap20/datasets/afhq_v2/' \
dataset.vae.data.name='afhq' \
dataset.vae.training.batch_size=8 \
dataset.vae.training.epochs=500 \
dataset.vae.training.device=\'gpu:0,1,2,3\' \
dataset.vae.training.results_dir=\'/data1/kushagrap20/vae_afhq256_10thJuly_alpha=1.0/\' \
dataset.vae.training.workers=2 \
dataset.vae.training.chkpt_prefix=\'afhq256_10thJuly_alpha=1.0\' \
dataset.vae.training.alpha=1.0
# # AFHQv2 training
# python main/train_ae.py +dataset=afhq256/train \
# dataset.vae.data.root='/data1/kushagrap20/datasets/afhq_v2/' \
# dataset.vae.data.name='afhq' \
# dataset.vae.training.batch_size=8 \
# dataset.vae.training.epochs=500 \
# dataset.vae.training.device=\'gpu:0,1,2,3\' \
# dataset.vae.training.results_dir=\'/data1/kushagrap20/vae_afhq256_10thJuly_alpha=1.0/\' \
# dataset.vae.training.workers=2 \
# dataset.vae.training.chkpt_prefix=\'afhq256_10thJuly_alpha=1.0\' \
# dataset.vae.training.alpha=1.0


# CelebA training
python main/train_ae.py +dataset=celeba64/train \
dataset.vae.data.root='/data1/kushagrap20/datasets/img_align_celeba/' \
dataset.vae.data.name='celeba' \
dataset.vae.training.batch_size=32 \
dataset.vae.training.epochs=1500 \
dataset.vae.training.device=\'gpu:0,1,2,3\' \
dataset.vae.training.results_dir=\'/data1/kushagrap20/vae_celeba64_alpha=1.0/\' \
dataset.vae.training.workers=4 \
dataset.vae.training.chkpt_prefix=\'celeba64_alpha=1.0\' \
dataset.vae.training.alpha=1.0
# # CelebA training
# python main/train_ae.py +dataset=celeba64/train \
# dataset.vae.data.root='/data1/kushagrap20/datasets/img_align_celeba/' \
# dataset.vae.data.name='celeba' \
# dataset.vae.training.batch_size=32 \
# dataset.vae.training.epochs=1500 \
# dataset.vae.training.device=\'gpu:0,1,2,3\' \
# dataset.vae.training.results_dir=\'/data1/kushagrap20/vae_celeba64_alpha=1.0/\' \
# dataset.vae.training.workers=4 \
# dataset.vae.training.chkpt_prefix=\'celeba64_alpha=1.0\' \
# dataset.vae.training.alpha=1.0

0 comments on commit fdd83bf

Please sign in to comment.