-
Notifications
You must be signed in to change notification settings - Fork 2
/
search_dis_arch.py
140 lines (128 loc) · 5.67 KB
/
search_dis_arch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from __future__ import absolute_import, division, print_function
import search_cfg
import archs
import datasets
# from trainer.trainer_generator import GenTrainer
from trainer.trainer_discriminator import DisTrainer, LinearLrDecay
from utils.utils import set_log_dir, save_checkpoint, create_logger, count_parameters_in_MB
from utils.inception_score import _init_inception
from utils.fid_score import create_inception_graph, check_or_download_inception
from utils.flop_benchmark import print_FLOPs
from archs.super_network import Generator, Discriminator
from algorithms.search_algs import GanAlgorithm
import torch
import os
import numpy as np
import torch.nn as nn
from tensorboardX import SummaryWriter
from tqdm import tqdm
from copy import deepcopy
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
def main():
args = search_cfg.parse_args()
torch.cuda.manual_seed(args.random_seed)
# set visible GPU ids
if len(args.gpu_ids) > 0:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids
# the first GPU in visible GPUs is dedicated for evaluation (running Inception model)
str_ids = args.gpu_ids.split(',')
args.gpu_ids = []
for id in range(len(str_ids)):
if id >= 0:
args.gpu_ids.append(id)
if len(args.gpu_ids) > 1:
args.gpu_ids = args.gpu_ids[1:]
else:
args.gpu_ids = args.gpu_ids
gan_alg = GanAlgorithm(args)
# import network from genotype
basemodel_gen = Generator(args)
gen_net = torch.nn.DataParallel(
basemodel_gen, device_ids=args.gpu_ids).cuda(args.gpu_ids[0])
basemodel_dis = Discriminator(args)
dis_net = torch.nn.DataParallel(
basemodel_dis, device_ids=args.gpu_ids).cuda(args.gpu_ids[0])
# weight init
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
if args.init_type == 'normal':
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif args.init_type == 'orth':
nn.init.orthogonal_(m.weight.data)
elif args.init_type == 'xavier_uniform':
nn.init.xavier_uniform(m.weight.data, 1.)
else:
raise NotImplementedError(
'{} unknown inital type'.format(args.init_type))
elif classname.find('BatchNorm2d') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0.0)
gen_net.apply(weights_init)
dis_net.apply(weights_init)
# set up data_loader
dataset = datasets.ImageDataset(args)
train_loader = dataset.train
# epoch number for dis_net
args.max_epoch_D = args.max_epoch_G * args.n_critic
if args.max_iter_G:
args.max_epoch_D = np.ceil(
args.max_iter_G * args.n_critic / len(train_loader))
max_iter_D = args.max_epoch_D * len(train_loader)
# set TensorFlow environment for evaluation (calculate IS and FID)
_init_inception()
inception_path = check_or_download_inception('./tmp/imagenet/')
create_inception_graph(inception_path)
# fid stat
if args.dataset.lower() == 'cifar10':
fid_stat = './fid_stat/fid_stats_cifar10_train.npz'
elif args.dataset.lower() == 'stl10':
fid_stat = './fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
else:
raise NotImplementedError(f'no fid stat for {args.dataset.lower()}')
assert os.path.exists(fid_stat)
gen_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, gen_net.parameters()),
args.g_lr, (args.beta1, args.beta2))
dis_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, dis_net.parameters()),
args.d_lr, (args.beta1, args.beta2))
gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, max_iter_D)
dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, max_iter_D)
# initial
start_epoch = 0
best_fid = 1e4
# set writer
args.path_helper = set_log_dir('exps', args.exp_name)
logger = create_logger(args.path_helper['log_path'])
logger.info(args)
writer_dict = {
'writer': SummaryWriter(args.path_helper['log_path']),
'train_global_steps': start_epoch * len(train_loader),
'valid_global_steps': start_epoch // args.val_freq,
}
# model size
logger.info('Param size of G = %fMB', count_parameters_in_MB(gen_net))
logger.info('Param size of D = %fMB', count_parameters_in_MB(dis_net))
best_genotypes = None
gen_genotype = np.load(os.path.join('exps', 'best_G.npy'))
dis_trainer = DisTrainer(args, gen_net, dis_net, gen_optimizer,
dis_optimizer, train_loader, gan_alg, gen_genotype)
# search discriminator
for epoch in tqdm(range(int(start_epoch), int(args.epoch_discriminator)), desc='search discriminator'):
lr_schedulers = (
gen_scheduler, dis_scheduler) if args.lr_decay else None
dis_trainer.train(epoch, writer_dict, lr_schedulers)
if epoch > args.warmup and (epoch - args.warmup) % args.ga_interval == 0:
best_genotypes = dis_trainer.search_evol_arch(epoch, fid_stat)
#best_genotypes = trainer.select_best(epoch)
for index, best_genotype in enumerate(best_genotypes):
file_path = os.path.join(
args.path_helper['ckpt_path'], "best_dis_{}_{}.npy".format(str(epoch), str(index)))
np.save(file_path, best_genotype)
# save checkpoint
checkpoint_file = os.path.join(
args.path_helper['ckpt_path'], 'dis_checkpoint')
ckpt = {'epoch': epoch, 'weight': dis_net.state_dict()}
torch.save(ckpt, checkpoint_file)
if __name__ == '__main__':
main()