-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_w_distill.py
220 lines (179 loc) · 10.8 KB
/
train_w_distill.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import time, os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import scipy.io as sio
import numpy as np
from random import shuffle
import tensorflow as tf
from tensorflow import ConfigProto
from nets import nets_factory
from dataloader import Dataloader
import op_util
home_path = os.path.dirname(os.path.abspath(__file__))
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.app.flags.DEFINE_string('train_dir', 'IEP/CIFAR100/WResNet/IEP/',
'Directory where checkpoints and event logs are written to.')
tf.app.flags.DEFINE_string('Distillation', 'IEP',
'Distillation method : IEP')
tf.app.flags.DEFINE_string('dataset', 'cifar100',
'Distillation method : cifar100, TinyImageNet, CUB200')
tf.app.flags.DEFINE_string('model_name', 'WResNet',
'Distillation method : ResNet, WResNet')
tf.app.flags.DEFINE_string('main_scope', 'Student',
'networ`s scope')
FLAGS = tf.app.flags.FLAGS
def main(_):
### define path and hyper-parameter
Learning_rate =1e-1
batch_size = 128
val_batch_size = 200
train_epoch = 100
init_epoch = 20 if FLAGS.Distillation in {'IEP'} else 0
total_epoch = init_epoch + train_epoch
weight_decay = 5e-4
should_log = 200
save_summaries_secs = 20
tf.logging.set_verbosity(tf.logging.INFO)
gpu_num = '0'
if FLAGS.Distillation == 'None':
FLAGS.Distillation = None
train_images, train_labels, val_images, val_labels, pre_processing, teacher = Dataloader(FLAGS.dataset, home_path, FLAGS.model_name)
num_label = int(np.max(train_labels)+1)
dataset_len, *image_size = train_images.shape
with tf.Graph().as_default() as graph:
# make placeholder for inputs
image_ph = tf.placeholder(tf.uint8, [None]+image_size)
label_ph = tf.placeholder(tf.int32, [None])
is_training_ph = tf.placeholder(tf.bool,[])
# pre-processing
image = pre_processing(image_ph, is_training_ph)
label = tf.contrib.layers.one_hot_encoding(label_ph, num_label, on_value=1.0)
# make global step
global_step = tf.train.create_global_step()
epoch = tf.floor_div(tf.cast(global_step, tf.float32)*batch_size, dataset_len)
max_number_of_steps = int(dataset_len*total_epoch)//batch_size+1
# make learning rate scheduler
LR = learning_rate_scheduler(Learning_rate, [epoch, init_epoch, train_epoch], [0.3, 0.6, 0.8], 0.2)
## load Net
class_loss, accuracy = MODEL(FLAGS.model_name, FLAGS.main_scope, image, label, [is_training_ph, epoch < init_epoch], Distillation = FLAGS.Distillation)
#make training operator
if FLAGS.Distillation in {'IEP'}:
train_op, train_op2 = op_util.Optimizer_w_IEP(class_loss, LR, weight_decay, epoch, init_epoch, global_step)
else:
train_op = op_util.Optimizer(class_loss, LR, weight_decay, epoch, init_epoch, global_step)
## collect summary ops for plotting in tensorboard
summary_op = tf.summary.merge(tf.get_collection(tf.GraphKeys.SUMMARIES)[:1], name='summary_op')
## make placeholder and summary op for training and validation results
train_acc_place = tf.placeholder(dtype=tf.float32)
val_acc_place = tf.placeholder(dtype=tf.float32)
val_summary = [tf.summary.scalar('accuracy/training_accuracy', train_acc_place),
tf.summary.scalar('accuracy/validation_accuracy', val_acc_place)]
val_summary_op = tf.summary.merge(list(val_summary), name='val_summary_op')
## start training
train_writer = tf.summary.FileWriter('%s'%FLAGS.train_dir,graph,flush_secs=save_summaries_secs)
config = ConfigProto()
config.gpu_options.visible_device_list = gpu_num
config.gpu_options.allow_growth=True
val_itr = len(val_labels)//val_batch_size
logs = {'training_acc' : [], 'validation_acc' : []}
with tf.Session(config=config) as sess:
if FLAGS.Distillation is not None:
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
n = 0
for v in global_variables:
if teacher.get(v.name[:-2]) is not None:
v._initial_value = tf.constant(teacher[v.name[:-2]].reshape(*v.get_shape().as_list()))
v._initializer_op = tf.assign(v._variable,v._initial_value,name= v.name[:-2]+'/Assign').op
n += 1
print ('%d Teacher params assigned'%n)
sess.run(tf.global_variables_initializer())
sum_train_accuracy = []; time_elapsed = []; total_loss = []
idx = list(range(train_labels.shape[0]))
shuffle(idx)
epoch_ = 0
for step in range(max_number_of_steps):
start_time = time.time()
## feed data
if (step*batch_size)//dataset_len < init_epoch:
tl, log, train_acc = sess.run([train_op2, summary_op, accuracy],
feed_dict = {image_ph : train_images[idx[:batch_size]],
label_ph : np.squeeze(train_labels[idx[:batch_size]]),
is_training_ph : True})
else:
tl, log, train_acc = sess.run([train_op, summary_op, accuracy],
feed_dict = {image_ph : train_images[idx[:batch_size]],
label_ph : np.squeeze(train_labels[idx[:batch_size]]),
is_training_ph : True})
time_elapsed.append( time.time() - start_time )
total_loss.append(tl)
sum_train_accuracy.append(train_acc)
idx[:batch_size] = []
if len(idx) < batch_size:
idx_ = list(range(train_labels.shape[0]))
shuffle(idx_)
idx += idx_
step += 1
if (step*batch_size)//dataset_len>=init_epoch+epoch_:
## do validation
sum_val_accuracy = []
for i in range(val_itr):
acc = sess.run(accuracy, feed_dict = {image_ph : val_images[i*val_batch_size:(i+1)*val_batch_size],
label_ph : np.squeeze(val_labels[i*val_batch_size:(i+1)*val_batch_size]),
is_training_ph : False})
sum_val_accuracy.append(acc)
sum_train_accuracy = np.mean(sum_train_accuracy)*100 if (step*batch_size)//dataset_len>init_epoch else 1.
sum_val_accuracy= np.mean(sum_val_accuracy)*100
tf.logging.info('Epoch %s Step %s - train_Accuracy : %.2f%% val_Accuracy : %.2f%%'
%(str(epoch_).rjust(3, '0'), str(step).rjust(6, '0'),
sum_train_accuracy, sum_val_accuracy))
result_log = sess.run(val_summary_op, feed_dict={train_acc_place : sum_train_accuracy,
val_acc_place : sum_val_accuracy })
logs['training_acc'].append(sum_train_accuracy)
logs['validation_acc'].append(sum_val_accuracy)
if (step*batch_size)//dataset_len == init_epoch and FLAGS.Distillation in {'FitNet', 'FSP', 'AB'}:
#re-initialize Momentum for fair comparison w/ initialization and multi-task learning methods
for v in global_variables:
if v.name[:-len('Momentum:0')]=='Momentum:0':
sess.run(v.assign(np.zeros(*v.get_shape().as_list()) ))
if step == max_number_of_steps:
train_writer.add_summary(result_log, train_epoch)
else:
train_writer.add_summary(result_log, epoch_)
sum_train_accuracy = []
epoch_ += 1
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)+tf.get_collection('BN_collection')
if step % should_log == 0:
tf.logging.info('global step %s: loss = %.4f (%.3f sec/step)',str(step).rjust(6, '0'), np.mean(total_loss), np.mean(time_elapsed))
train_writer.add_summary(log, step)
time_elapsed = []
total_loss = []
elif (step*batch_size) % dataset_len == 0:
train_writer.add_summary(log, step)
## save variables to use for something
var = {}
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)+tf.get_collection('BN_collection')
# variables = tf.get_collection('MHA')
for v in variables:
if v.name.split('/')[0] == FLAGS.main_scope:
var[v.name[:-2]] = sess.run(v)
sio.savemat(FLAGS.train_dir + '/train_params.mat',var)
sio.savemat(FLAGS.train_dir + '/log.mat',logs)
## close all
tf.logging.info('Finished training! Saving model to disk.')
train_writer.add_session_log(tf.SessionLog(status=tf.SessionLog.STOP))
train_writer.close()
def MODEL(model_name, scope, image, label, is_training, Distillation):
network_fn = nets_factory.get_network_fn(model_name)
end_points = network_fn(image, label, scope, is_training=is_training, Distill=Distillation)
loss = tf.losses.softmax_cross_entropy(label,end_points['Logits'])
accuracy = tf.contrib.metrics.accuracy(tf.cast(tf.argmax(end_points['Logits'], 1), tf.int32), tf.cast(tf.argmax(label, 1),tf.int32))
return loss, accuracy
def learning_rate_scheduler(Learning_rate, epochs, decay_point, decay_rate):
with tf.variable_scope('learning_rate_scheduler'):
e, ie, te = epochs
for i, dp in enumerate(decay_point):
Learning_rate = tf.cond(tf.greater_equal(e, ie + int(te*dp)), lambda : Learning_rate*decay_rate,
lambda : Learning_rate)
tf.summary.scalar('learning_rate', Learning_rate)
return Learning_rate
if __name__ == '__main__':
tf.app.run()