Skip to content

Commit

Permalink
Added docs
Browse files Browse the repository at this point in the history
  • Loading branch information
kpandey008 committed Aug 8, 2022
1 parent a252673 commit 290a87b
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 55 deletions.
44 changes: 22 additions & 22 deletions main/configs/dataset/cifar10/test.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# DDPM config used for DDPM training
ddpm:
data:
data: # Most of these have same meaning as in `cifar10/train.yaml`
root: ???
name: "cifar10"
image_size: 32
hflip: True
n_channels: 3
norm: True
ddpm_latent_path: ""
ddpm_latent_path: "" # If sharing DDPM latents between diffusevae samples, path to .pt tensor containing latent codes

model:
model: # Most of these have same meaning as in `cifar10/train.yaml`
dim : 128
attn_resolutions: "16,"
n_residual: 2
Expand All @@ -21,27 +21,27 @@ ddpm:
n_timesteps: 1000

evaluation:
chkpt_path: ???
save_path: ???
z_cond: False
z_dim: 512
guidance_weight: 0.0
type: 'form1'
resample_strategy: "spaced"
skip_strategy: "uniform"
sample_method: "ddpm"
sample_from: "target"
seed: 0
device: "gpu:0"
chkpt_path: ??? # DiffuseVAE checkpoint path
save_path: ??? # Path to write samples to (automatically creates directories if needed)
z_cond: False # Whether to condition UNet on vae latent
z_dim: 512 # Dimensionality of the vae latent
guidance_weight: 0.0 # Guidance weight during sampling if using Classifier free guidance
type: 'form1' # DiffuseVAE type. One of ['form1', 'form2', 'uncond']. `uncond` is baseline DDPM
resample_strategy: "spaced" # Whether to use spaced or truncated sampling. Use 'truncated' if sampling for the entire 1000 steps
skip_strategy: "uniform" # Skipping strategy to use if `resample_strategy=spaced`. Can be ['uniform', 'quad'] as in DDIM
sample_method: "ddpm" # Sampling backend. Can be ['ddim', 'ddpm']
sample_from: "target" # Whether to sampling from the (non)-EMA model. Can be ['source', 'target']
seed: 0 # Random seed during sampling
device: "gpu:0" # Device. Uses TPU/CPU if set to `tpu` or `cpu`. For GPU, use gpu:<comma separated id list>. Ex: gpu:0,1 would run only on gpus 0 and 1
n_samples: 50000
n_steps: 1000
n_steps: 1000 # Number of reverse process steps to use during sampling. Typically [0-100] for DDIM and T=1000 for DDPM
workers: 2
batch_size: 8
save_vae: False
variance: "fixedlarge"
sample_prefix: ""
temp: 1.0
save_mode: image
batch_size: 8 # Batch size during sampling per gpu
save_vae: False # Whether to save VAE samples along with final samples. Useful to visualize the generator-refiner framework in action!
variance: "fixedlarge" # DDPM variance to use when using DDPM. Can be ['fixedsmall', 'fixedlarge']
sample_prefix: "" # Prefix used in naming when saving samples to disk
temp: 1.0 # Temperature sampling factor in DDPM latents
save_mode: image # Whether to save samples as .png or .npy. One of ['image', 'numpy']

interpolation:
n_steps: 10
Expand Down
66 changes: 33 additions & 33 deletions main/configs/dataset/cifar10/train.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# DDPM config used for DDPM training
ddpm:
data:
root: ???
name: "cifar10"
image_size: 32
hflip: True
n_channels: 3
norm: True
root: ??? # Dataset root
name: "cifar10" # Dataset name (check main/util.py `get_dataset` for a registry)
image_size: 32 # Image resolution
hflip: True # Whether to use horizontal flip
n_channels: 3 # Num input channels
norm: True # Whether to scale data between [-1, 1]

model:
model: # UNet specific params. Check the DDPM implementation for details on these
dim : 128
attn_resolutions: "16,"
n_residual: 2
Expand All @@ -20,29 +20,29 @@ ddpm:
n_timesteps: 1000

training:
seed: 0
fp16: False
use_ema: True
z_cond: False
z_dim: 512
type: 'form1'
ema_decay: 0.9999
batch_size: 32
epochs: 5000
log_step: 1
device: "gpu:0"
chkpt_interval: 1
optimizer: "Adam"
lr: 2e-4
restore_path: ""
vae_chkpt_path: ???
results_dir: ???
workers: 2
grad_clip: 1.0
n_anneal_steps: 5000
loss: "l2"
chkpt_prefix: ""
cfd_rate: 0.0
seed: 0 # Random seed
fp16: False # Whether to use fp16
use_ema: True # Whether to use EMA (Improves sample quality)
z_cond: False # Whether to condition UNet on vae latent
z_dim: 512 # Dimensionality of the vae latent
type: 'form1' # DiffuseVAE type. One of ['form1', 'form2', 'uncond']. `uncond` is baseline DDPM
ema_decay: 0.9999 # EMA decay rate
batch_size: 32 # Training batch size (per GPU, per TPU core if using distributed training)
epochs: 5000 # Max number of epochs
log_step: 1 # log interval
device: "gpu:0" # Device. Uses TPU/CPU if set to `tpu` or `cpu`. For GPU, use gpu:<comma separated id list>. Ex: gpu:0,1 would run only on gpus 0 and 1
chkpt_interval: 1 # Number of epochs between two checkpoints
optimizer: "Adam" # Optimizer
lr: 2e-4 # Learning rate
restore_path: "" # Checkpoint restore path
vae_chkpt_path: ??? # VAE checkpoint path. Useful when using form1 or form2
results_dir: ??? # Directory to store the checkpoint in
workers: 2 # Num workers
grad_clip: 1.0 # gradient clipping threshold
n_anneal_steps: 5000 # number of warmup steps
loss: "l2" # Diffusion loss type. Among ['l2', 'l1']
chkpt_prefix: "" # prefix appended to the checkpoint name
cfd_rate: 0.0 # Conditioning signal dropout rate as in Classifier-free guidance

# VAE config used for VAE training
vae:
Expand All @@ -53,13 +53,13 @@ vae:
n_channels: 3
hflip: False

model:
model: # VAE specific params. Check the `main/models/vae.py`
enc_block_config : "32x7,32d2,32t16,16x4,16d2,16t8,8x4,8d2,8t4,4x3,4d4,4t1,1x3"
enc_channel_config: "32:64,16:128,8:256,4:256,1:512"
dec_block_config: "1x1,1u4,1t4,4x2,4u2,4t8,8x3,8u2,8t16,16x7,16u2,16t32,32x15"
dec_channel_config: "32:64,16:128,8:256,4:256,1:512"

training:
training: # Most of these are same as explained above but for VAE training
seed: 0
fp16: False
batch_size: 128
Expand All @@ -73,4 +73,4 @@ vae:
results_dir: ???
workers: 2
chkpt_prefix: ""
alpha: 1.0
alpha: 1.0 # The beta value in beta-vae. I know the param name might be misleading :(

0 comments on commit 290a87b

Please sign in to comment.