Skip to content

Commit

Permalink
AFHQv2 updates
Browse files Browse the repository at this point in the history
  • Loading branch information
kpandey008 committed Aug 2, 2022
1 parent bbb0942 commit cb08607
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 18 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ ddpm:
data:
root: ???
name: "afhq"
image_size: 128
image_size: 256
hflip: True
n_channels: 3
norm: True
Expand All @@ -13,8 +13,8 @@ ddpm:
attn_resolutions: "16,"
n_residual: 2
dim_mults: "1,1,2,2,4,4"
dropout: 0.0
n_heads: 1
dropout: 0.1
n_heads: 8
beta1: 0.0001
beta2: 0.02
n_timesteps: 1000
Expand All @@ -23,6 +23,8 @@ ddpm:
seed: 0
fp16: False
use_ema: True
z_cond: False
z_dim: 1024
type: 'form1'
ema_decay: 0.9999
batch_size: 8
Expand All @@ -39,21 +41,22 @@ ddpm:
n_anneal_steps: 5000
loss: "l2"
chkpt_prefix: ""
cfd_rate: 0.0

# VAE config used for VAE training
vae:
data:
root: ???
name: "afhq"
image_size: 128
image_size: 256
n_channels: 3
hflip: False
hflip: True

model:
enc_block_config : "128x1,128d2,128t64,64x3,64d2,64t32,32x3,32d2,32t16,16x7,16d2,16t8,8x3,8d2,8t4,4x3,4d4,4t1,1x2"
enc_channel_config: "128:64,64:64,32:128,16:128,8:256,4:512,1:1024"
dec_block_config: "1x1,1u4,1t4,4x2,4u2,4t8,8x2,8u2,8t16,16x6,16u2,16t32,32x2,32u2,32t64,64x2,64u2,64t128,128x1"
dec_channel_config: "128:64,64:64,32:128,16:128,8:256,4:512,1:1024"
enc_block_config : "256x3,256d2,256t128,128x3,128d2,128t64,64x5,64d2,64t32,32x7,32d2,32t16,16x9,16d2,16t8,8x7,8d2,8t4,4x5,4d4,4t1,1x2"
enc_channel_config: "256:64,128:64,64:64,32:128,16:128,8:256,4:512,1:1024"
dec_block_config: "1x2,1u4,1t4,4x4,4u2,4t8,8x6,8u2,8t16,16x8,16u2,16t32,32x5,32u2,32t64,64x4,64u2,64t128,128x2,128u2,128t256,256x2"
dec_channel_config: "256:64,128:64,64:64,32:128,16:128,8:256,4:512,1:1024"

training:
seed: 0
Expand Down
2 changes: 1 addition & 1 deletion main/configs/dataset/celebamaskhq128/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ vae:
name: "celebamaskhq"
image_size: 128
n_channels: 3
hflip: False
hflip: True

model:
enc_block_config : "128x1,128d2,128t64,64x3,64d2,64t32,32x3,32d2,32t16,16x7,16d2,16t8,8x3,8d2,8t4,4x3,4d4,4t1,1x2"
Expand Down
2 changes: 1 addition & 1 deletion main/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from .celeba_mask import CelebAMaskHQDataset
from .recons import ReconstructionDataset, ReconstructionDatasetv2
from .cifar10 import CIFAR10Dataset
from .afhq import AFHQDataset
from .afhq import AFHQv2Dataset
from .ffhq import FFHQLmdbDataset, FFHQDataset
from .celebahq import CelebAHQDataset
10 changes: 5 additions & 5 deletions main/datasets/afhq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from tqdm import tqdm


class AFHQDataset(Dataset):
class AFHQv2Dataset(Dataset):
def __init__(self, root, norm=True, subsample_size=None, transform=None, **kwargs):
# We only train on the AFHQ train set (around 14630 images)
# We only train on the AFHQ train set
if not os.path.isdir(root):
raise ValueError(f"The specified root: {root} does not exist")
self.root = root
Expand All @@ -18,7 +18,7 @@ def __init__(self, root, norm=True, subsample_size=None, transform=None, **kwarg

self.images = []

subfolder_list = ['dog', 'cat', 'wild']
subfolder_list = ["dog", "cat", "wild"]
base_path = os.path.join(self.root, "train")
for subfolder in subfolder_list:
sub_path = os.path.join(base_path, subfolder)
Expand Down Expand Up @@ -52,6 +52,6 @@ def __len__(self):


if __name__ == "__main__":
root = "/data1/kushagrap20/datasets/afhq"
dataset = AFHQDataset(root, subsample_size=None)
root = "/data1/kushagrap20/datasets/afhq_v2/"
dataset = AFHQv2Dataset(root)
print(len(dataset))
4 changes: 2 additions & 2 deletions main/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from PIL import Image

from datasets import (
AFHQDataset,
AFHQv2Dataset,
CelebADataset,
CelebAMaskHQDataset,
CIFAR10Dataset,
Expand Down Expand Up @@ -74,7 +74,7 @@ def get_dataset(name, root, image_size, norm=True, flip=False, **kwargs):
elif name == "celebahq":
dataset = CelebAHQDataset(root, norm=norm, transform=transform, **kwargs)
elif name == "afhq":
dataset = AFHQDataset(root, norm=norm, transform=transform, **kwargs)
dataset = AFHQv2Dataset(root, norm=norm, transform=transform, **kwargs)
elif name == "ffhq":
dataset = FFHQDataset(root, norm=norm, transform=transform, **kwargs)
elif name == "recons":
Expand Down

0 comments on commit cb08607

Please sign in to comment.