Skip to content

Commit

Permalink
stacked MNIST, OMNIGLOT, LSUN Churches and FID calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
arash-vahdat committed Oct 12, 2021
1 parent 38eb997 commit 9fc1a28
Show file tree
Hide file tree
Showing 12 changed files with 1,206 additions and 28 deletions.
34 changes: 33 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ python create_ffhq_lmdb.py --ffhq_img_path=$DATA_DIR/ffhq/images1024x1024/ --ffh
```
</details>

<details><summary>LSUN</summary>

We use LSUN datasets in our follow-up works. Visit [LSUN](https://www.yf.io/p/lsun) for
instructions on how to download this dataset. Since the LSUN scene datasets come in the
LMDB format, they are ready to be loaded using torchvision data loaders.

</details>


## Running the main NVAE training and evaluation scripts
We use the following commands on each dataset for training NVAEs on each dataset for
Expand Down Expand Up @@ -275,7 +283,7 @@ Above, `$CHECKPOINT_DIR` and `$EXPR_ID` are the same variables used for running

## Post-training sampling, evaluation, and checkpoints

<details><summary>Evaluation</summary>
<details><summary>Evaluating Log-Likelihood</summary>

You can use the following command to load a trained model and evaluate it on the test datasets:

Expand Down Expand Up @@ -303,6 +311,30 @@ as described in the paper. If you remove `--readjust_bn`, the sampling will proc

</details>

<details><summary>Computing FID</summary>

You can compute the FID score using 50K samples. To do so, you will need to create
a mean and covariance statistics file on the training data using a command like:

```shell script
cd $CODE_DIR
python scripts/precompute_fid_statistics.py --data $DATA_DIR/cifar10 --dataset cifar10 --fid_dir /tmp/fid-stats/
```
The command above computes the references statistics on the CIFAR-10 dataset and stores them in the `--fid_dir` durectory.
Given the reference statistics file, we can run the following command to compute the FID score:

```shell script
cd $CODE_DIR
python evaluate.py --checkpoint $CHECKPOINT_DIR/eval-$EXPR_ID/checkpoint.pt --data $DATA_DIR/cifar10 --eval_mode=evaluate_fid --fid_dir /tmp/fid-stats/ --temp=0.6 --readjust_bn
```
where `--temp` sets the temperature used for sampling and `--readjust_bn` enables readjustment of the BN statistics
as described in the paper. If you remove `--readjust_bn`, the sampling will proceed with BN layer in the eval mode
(i.e., BN layers will use running mean and variances extracted during training).
Above, `$CHECKPOINT_DIR` and `$EXPR_ID` are the same variables used for running the main training script.
Set `--data` to the same argument that was used when training NVAE (our example is for MNIST).

</details>

<details><summary>Checkpoints</summary>

We provide checkpoints on MNIST, CIFAR-10, CelebA 64, CelebA HQ 256, FFHQ in
Expand Down
114 changes: 113 additions & 1 deletion datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,57 @@

"""Code for getting the data loaders."""

import numpy as np
from PIL import Image
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from scipy.io import loadmat
import os
import utils
import urllib
from lmdb_datasets import LMDBDataset
from thirdparty.lsun import LSUN


class StackedMNIST(dset.MNIST):
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
super(StackedMNIST, self).__init__(root=root, train=train, transform=transform,
target_transform=target_transform, download=download)

index1 = np.hstack([np.random.permutation(len(self.data)), np.random.permutation(len(self.data))])
index2 = np.hstack([np.random.permutation(len(self.data)), np.random.permutation(len(self.data))])
index3 = np.hstack([np.random.permutation(len(self.data)), np.random.permutation(len(self.data))])
self.num_images = 2 * len(self.data)

self.index = []
for i in range(self.num_images):
self.index.append((index1[i], index2[i], index3[i]))

def __len__(self):
return self.num_images

def __getitem__(self, index):
img = np.zeros((28, 28, 3), dtype=np.uint8)
target = 0
for i in range(3):
img_, target_ = self.data[self.index[index][i]], int(self.targets[self.index[index][i]])
img[:, :, i] = img_
target += target_ * 10 ** (2 - i)

img = Image.fromarray(img, mode="RGB")

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target



class Binarize(object):
""" This class introduces a binarization transformation
"""
Expand All @@ -42,6 +84,47 @@ def get_loaders(args):
"""Get data loaders for required dataset."""
return get_loaders_eval(args.dataset, args)

def download_omniglot(data_dir):
filename = 'chardata.mat'
if not os.path.exists(data_dir):
os.mkdir(data_dir)
url = 'https://raw.github.com/yburda/iwae/master/datasets/OMNIGLOT/chardata.mat'

filepath = os.path.join(data_dir, filename)
if not os.path.exists(filepath):
filepath, _ = urllib.request.urlretrieve(url, filepath)
print('Downloaded', filename)

return


def load_omniglot(data_dir):
download_omniglot(data_dir)

data_path = os.path.join(data_dir, 'chardata.mat')

omni = loadmat(data_path)
train_data = 255 * omni['data'].astype('float32').reshape((28, 28, -1)).transpose((2, 1, 0))
test_data = 255 * omni['testdata'].astype('float32').reshape((28, 28, -1)).transpose((2, 1, 0))

train_data = train_data.astype('uint8')
test_data = test_data.astype('uint8')

return train_data, test_data


class OMNIGLOT(Dataset):
def __init__(self, data, transform):
self.data = data
self.transform = transform

def __getitem__(self, index):
d = self.data[index]
img = Image.fromarray(d)
return self.transform(img), 0 # return zero as label.

def __len__(self):
return len(self.data)

def get_loaders_eval(dataset, args):
"""Get train and valid loaders for cifar10/tiny imagenet."""
Expand All @@ -60,6 +143,20 @@ def get_loaders_eval(dataset, args):
root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.MNIST(
root=args.data, train=False, download=True, transform=valid_transform)
elif dataset == 'stacked_mnist':
num_classes = 1000
train_transform, valid_transform = _data_transforms_stacked_mnist(args)
train_data = StackedMNIST(
root=args.data, train=True, download=True, transform=train_transform)
valid_data = StackedMNIST(
root=args.data, train=False, download=True, transform=valid_transform)
elif dataset == 'omniglot':
num_classes = 0
download_omniglot(args.data)
train_transform, valid_transform = _data_transforms_mnist(args)
train_data, valid_data = load_omniglot(args.data)
train_data = OMNIGLOT(train_data, train_transform)
valid_data = OMNIGLOT(valid_data, valid_transform)
elif dataset.startswith('celeba'):
if dataset == 'celeba_64':
resize = 64
Expand Down Expand Up @@ -162,6 +259,21 @@ def _data_transforms_mnist(args):
return train_transform, valid_transform


def _data_transforms_stacked_mnist(args):
"""Get data transforms for cifar10."""
train_transform = transforms.Compose([
transforms.Pad(padding=2),
transforms.ToTensor()
])

valid_transform = transforms.Compose([
transforms.Pad(padding=2),
transforms.ToTensor()
])

return train_transform, valid_transform


def _data_transforms_generic(size):
train_transform = transforms.Compose([
transforms.Resize(size),
Expand Down
46 changes: 45 additions & 1 deletion distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@torch.jit.script
def soft_clamp5(x: torch.Tensor):
return x.div_(5.).tanh_().mul(5.) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]
return x.div(5.).tanh_().mul(5.) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]


@torch.jit.script
Expand Down Expand Up @@ -50,6 +50,28 @@ def kl(self, normal_dist):
return 0.5 * (term1 * term1 + term2 * term2) - 0.5 - torch.log(term2)


class NormalDecoder:
def __init__(self, param, num_bits=8):
B, C, H, W = param.size()
self.num_c = C // 2
mu = param[:, :self.num_c, :, :] # B, 3, H, W
log_sigma = param[:, self.num_c:, :, :] # B, 3, H, W
self.dist = Normal(mu, log_sigma)

def log_prob(self, samples):
assert torch.max(samples) <= 1.0 and torch.min(samples) >= 0.0
# convert samples to be in [-1, 1]
samples = 2 * samples - 1.0

return self.dist.log_p(samples)

def sample(self, t=1.):
x, _ = self.dist.sample()
x = torch.clamp(x, -1, 1.)
x = x / 2. + 0.5
return x


class DiscLogistic:
def __init__(self, param):
B, C, H, W = param.size()
Expand Down Expand Up @@ -178,3 +200,25 @@ def sample(self, t=1.):
x = x / 2. + 0.5
return x

def mean(self):
sel = torch.softmax(self.logit_probs, dim=1) # B, M, H, W
sel = sel.unsqueeze(1) # B, 1, M, H, W

# select logistic parameters
means = torch.sum(self.means * sel, dim=2) # B, 3, H, W
coeffs = torch.sum(self.coeffs * sel, dim=2) # B, 3, H, W

# we don't sample from logistic components, because of the linear dependencies, we use mean
x = means # B, 3, H, W
x0 = torch.clamp(x[:, 0, :, :], -1, 1.) # B, H, W
x1 = torch.clamp(x[:, 1, :, :] + coeffs[:, 0, :, :] * x0, -1, 1) # B, H, W
x2 = torch.clamp(x[:, 2, :, :] + coeffs[:, 1, :, :] * x0 + coeffs[:, 2, :, :] * x1, -1, 1) # B, H, W

x0 = x0.unsqueeze(1)
x1 = x1.unsqueeze(1)
x2 = x2.unsqueeze(1)

x = torch.cat([x0, x1, x2], 1)
x = x / 2. + 0.5
return x

47 changes: 35 additions & 12 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import argparse
import torch
import numpy as np
import os
import matplotlib.pyplot as plt
from time import time

Expand All @@ -17,7 +18,7 @@
from model import AutoEncoder
import utils
import datasets
from train import test, init_processes
from train import test, init_processes, test_vae_fid


def set_bn(model, bn_eval_mode, num_samples=1, t=1.0, iter=100):
Expand Down Expand Up @@ -55,6 +56,9 @@ def main(eval_args):
logging.info('old model, no num_mixture_dec was found.')
args.num_mixture_dec = 10

if eval_args.batch_size > 0:
args.batch_size = eval_args.batch_size

logging.info('loaded the model at epoch %d', checkpoint['epoch'])
arch_instance = utils.get_arch_cells(args.arch_instance)
model = AutoEncoder(args, None, arch_instance)
Expand Down Expand Up @@ -86,14 +90,23 @@ def main(eval_args):
logging.info('final valid neg log p %f', valid_neg_log_p)
logging.info('final valid nelbo in bpd %f', valid_nelbo * bpd_coeff)
logging.info('final valid neg log p in bpd %f', valid_neg_log_p * bpd_coeff)

elif eval_args.eval_mode == 'evaluate_fid':
bn_eval_mode = not eval_args.readjust_bn
set_bn(model, bn_eval_mode, num_samples=2, t=eval_args.temp, iter=500)
args.fid_dir = eval_args.fid_dir
args.num_process_per_node, args.num_proc_node = eval_args.world_size, 1 # evaluate only one 1 node
fid = test_vae_fid(model, args, total_fid_samples=50000)
logging.info('fid is %f' % fid)
else:
bn_eval_mode = not eval_args.readjust_bn
num_samples = 16
total_samples = 50000 // eval_args.world_size # num images per gpu
num_samples = 100 # sampling batch size
num_iter = int(np.ceil(total_samples / num_samples)) # num iterations per gpu

with torch.no_grad():
n = int(np.floor(np.sqrt(num_samples)))
set_bn(model, bn_eval_mode, num_samples=36, t=eval_args.temp, iter=500)
for ind in range(10): # sampling is repeated.
set_bn(model, bn_eval_mode, num_samples=16, t=eval_args.temp, iter=500)
for ind in range(num_iter): # sampling is repeated.
torch.cuda.synchronize()
start = time()
with autocast():
Expand All @@ -103,14 +116,20 @@ def main(eval_args):
else output.sample()
torch.cuda.synchronize()
end = time()

output_tiled = utils.tile_image(output_img, n).cpu().numpy().transpose(1, 2, 0)
logging.info('sampling time per batch: %0.3f sec', (end - start))
output_tiled = np.asarray(output_tiled * 255, dtype=np.uint8)
output_tiled = np.squeeze(output_tiled)

plt.imshow(output_tiled)
plt.show()
visualize = False
if visualize:
output_tiled = utils.tile_image(output_img, n).cpu().numpy().transpose(1, 2, 0)
output_tiled = np.asarray(output_tiled * 255, dtype=np.uint8)
output_tiled = np.squeeze(output_tiled)

plt.imshow(output_tiled)
plt.show()
else:
file_path = os.path.join(eval_args.save, 'gpu_%d_samples_%d.npz' % (eval_args.local_rank, ind))
np.savez_compressed(file_path, samples=output_img.cpu().numpy())
logging.info('Saved at: {}'.format(file_path))


if __name__ == '__main__':
Expand All @@ -120,7 +139,7 @@ def main(eval_args):
help='location of the checkpoint')
parser.add_argument('--save', type=str, default='/tmp/expr',
help='location of the checkpoint')
parser.add_argument('--eval_mode', type=str, default='sample', choices=['sample', 'evaluate'],
parser.add_argument('--eval_mode', type=str, default='sample', choices=['sample', 'evaluate', 'evaluate_fid'],
help='evaluation mode. you can choose between sample or evaluate.')
parser.add_argument('--eval_on_train', action='store_true', default=False,
help='Settings this to true will evaluate the model on training data.')
Expand All @@ -132,6 +151,10 @@ def main(eval_args):
help='The temperature used for sampling.')
parser.add_argument('--num_iw_samples', type=int, default=1000,
help='The number of IW samples used in test_ll mode.')
parser.add_argument('--fid_dir', type=str, default='/tmp/fid-stats',
help='path to directory where fid related files are stored')
parser.add_argument('--batch_size', type=int, default=0,
help='Batch size used during evaluation. If set to zero, training batch size is used.')
# DDP.
parser.add_argument('--local_rank', type=int, default=0,
help='rank of process')
Expand Down
Loading

0 comments on commit 9fc1a28

Please sign in to comment.