-
Notifications
You must be signed in to change notification settings - Fork 11
/
train_net.py
executable file
·146 lines (129 loc) · 7.19 KB
/
train_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim as optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
from util import distributed as du
import time
from collections import OrderedDict
from data import create_dataset
from data import shuffle_dataset
from models import create_model
from util.visualizer import Visualizer
from util.evaluation import evaluation
from util import html,util
from util.visualizer import save_images
def train(cfg):
#init
du.init_distributed_training(cfg)
# Set random seed from configs.
np.random.seed(cfg.RNG_SEED)
torch.manual_seed(cfg.RNG_SEED)
#init dataset
dataset = create_dataset(cfg) # create a dataset given cfg.dataset_mode and other options
dataset_size = len(dataset) # get the number of images in the dataset.
print('The number of training images = %d' % dataset_size)
postion_embedding = util.PositionEmbeddingSine(cfg)
patch_pos = util.PatchPositionEmbeddingSine(cfg)
model = create_model(cfg) # create a model given cfg.model and other options
model.set_position(postion_embedding,patch_pos=patch_pos)
# model.setup(cfg) # regular setup: load and print networks; create schedulers
visualizer = Visualizer(cfg) # create a visualizer that display/save images and plots
total_iters = 0 # the total number of training iterations
# cur_device = torch.cuda.current_device()
is_master = du.is_master_proc(cfg.NUM_GPUS)
for epoch in range(cfg.epoch_count, cfg.niter + cfg.niter_decay + 1): # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
if is_master:
epoch_start_time = time.time() # timer for entire epoch
iter_data_time = time.time() # timer for data loading per iteration
epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
shuffle_dataset(dataset, epoch)
for i, data in enumerate(dataset): # inner loop within one epoch
if is_master:
iter_start_time = time.time() # timer for computation per iteration
if total_iters % cfg.print_freq == 0:
t_data = iter_start_time - iter_data_time
iter_data_time = time.time()
visualizer.reset()
total_iters += cfg.batch_size
epoch_iter += cfg.batch_size
if epoch == cfg.epoch_count and i == 0:
model.data_dependent_initialize(data)
model.setup(cfg) # regular setup: load and print networks; create schedulers
model.set_input(data) # unpack data from dataset and apply preprocessing
model.optimize_parameters() # calculate loss functions, get gradients, update network weights
if total_iters % cfg.display_freq == 0 and is_master: # display images on visdom and save images to a HTML file
save_result = total_iters % cfg.update_html_freq == 0
model.compute_visuals()
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
losses = model.get_current_losses()
if cfg.NUM_GPUS > 1:
losses = du.all_reduce(losses)
if total_iters % cfg.print_freq == 0 and is_master: # print training losses and save logging information to the disk
t_comp = (time.time() - iter_start_time) / cfg.batch_size
visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
if cfg.display_id > 0:
visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
if total_iters % cfg.save_latest_freq == 0 and is_master: # cache our latest model every <save_latest_freq> iterations
print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
save_suffix = 'iter_%d' % total_iters if cfg.save_by_iter else 'latest'
model.save_networks(save_suffix)
if epoch % cfg.save_epoch_freq == 0 and is_master: # cache our model every <save_epoch_freq> epochs
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
model.save_networks('latest')
if cfg.save_iter_model and epoch>=55:
model.save_networks(epoch)
if is_master:
print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, cfg.niter + cfg.niter_decay, time.time() - epoch_start_time))
model.update_learning_rate() # update learning rates at the end of every epoch.
def test(cfg):
dataset = create_dataset(cfg) # create a dataset given cfg.dataset_mode and other options
postion_embedding = util.PositionEmbeddingSine(cfg)
patch_pos = util.PatchPositionEmbeddingSine(cfg)
model = create_model(cfg) # create a model given cfg.model and other options
model.set_position(postion_embedding,patch_pos=patch_pos)
model.setup(cfg) # regular setup: load and print networks; create schedulers
# create a website
web_dir = os.path.join(cfg.results_dir, cfg.name, '%s_%s' % (cfg.phase, cfg.epoch)) # define the website directory
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (cfg.name, cfg.phase, cfg.epoch))
# test with eval mode. This only affects layers like batchnorm and dropout.
# For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.
# For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.
if cfg.eval:
model.eval()
ismaster = du.is_master_proc(cfg.NUM_GPUS)
fmse_score_list = []
mse_scores = 0
fmse_scores = 0
num_image = 0
# print (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
for i, data in enumerate(dataset):
# if i >= 100:
# print (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
# break
model.set_input(data) # unpack data from data loader
model.test() # run inference
visuals = model.get_current_visuals() # get image results
img_path = model.get_image_paths() # get image paths # Added by Mia
if i % 5 == 0 and ismaster: # save images to an HTML file
print('processing (%04d)-th image... %s' % (i, img_path))
visuals_ones = OrderedDict()
harmonized = None
real = None
for j in range(len(img_path)):
img_path_one = []
for label, im_data in visuals.items():
visuals_ones[label] = im_data[j:j+1, :, :, :]
img_path_one.append(img_path[j])
save_images(webpage, visuals_ones, img_path_one, aspect_ratio=cfg.aspect_ratio, width=cfg.display_winsize)
num_image += 1
visuals_ones.clear()
webpage.save() # save the HTML