Skip to content

Commit

Permalink
Add CE to dygraph Se-Resnext model (PaddlePaddle#2699)
Browse files Browse the repository at this point in the history
Update mnist_dygraph.py
fix bug
* add ce to se_resnext
* delete useless comments and fix unique_name bugs
  • Loading branch information
DDDivano authored and junjun315 committed Jul 4, 2019
1 parent 9d18809 commit 9d690da
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 9 deletions.
9 changes: 9 additions & 0 deletions dygraph/se_resnext/.run_ce.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/bash

# This file is only used for continuous evaluation.
# dygraph single card
export FLAGS_cudnn_deterministic=True
export CUDA_VISIBLE_DEVICES=5
python -u train.py --ce --epoch 1 | python _ce.py
#python train.py --ce --epoch 1 | python _ce.py

64 changes: 64 additions & 0 deletions dygraph/se_resnext/_ce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
####this file is only used for continuous evaluation test!
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
sys.path.append(os.environ['ceroot'])
from kpi import CostKpi, DurationKpi, AccKpi

#### NOTE kpi.py should shared in models in some way!!!!

train_acc1 = AccKpi('train_acc1', 0.01, 0, actived=True, desc="train acc1")
train_acc5 = AccKpi('train_acc5', 0.01, 0, actived=True, desc="train acc5")
train_loss = CostKpi('train_loss', 0.01, 0, actived=True, desc="train loss")
test_acc1 = AccKpi('test_acc1', 0.01, 0, actived=True, desc='test acc1')
test_acc5 = AccKpi('test_acc5', 0.01, 0, actived=True, desc='test acc5')
test_loss = CostKpi('test_loss', 0.01, 0, actived=True, desc='test loss')

tracking_kpis = [train_acc1, train_acc5, train_loss,
test_acc1, test_acc5, test_loss]

def parse_log(log):
'''
This method should be implemented by model developers.
The suggestion:
each line in the log should be key, value, for example:
"
train_cost\t1.0
test_cost\t1.0
train_cost\t1.0
train_cost\t1.0
train_acc\t1.2
"
'''
for line in log.split('\n'):
fs = line.strip().split('\t')
print(fs)
if len(fs) == 3 and fs[0] == 'kpis':
print("-----%s" % fs)
kpi_name = fs[1]
kpi_value = float(fs[2])
yield kpi_name, kpi_value


def log_to_ce(log):
kpi_tracker = {}
for kpi in tracking_kpis:
kpi_tracker[kpi.name] = kpi

for (kpi_name, kpi_value) in parse_log(log):
print(kpi_name, kpi_value)
kpi_tracker[kpi_name].add_record(kpi_value)
kpi_tracker[kpi_name].persist()


if __name__ == '__main__':
log = sys.stdin.read()
print("*****")
print(log)
print("****")
log_to_ce(log)
37 changes: 28 additions & 9 deletions dygraph/se_resnext/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from paddle.fluid.dygraph.base import to_variable
import sys
import math
import argparse

parser = argparse.ArgumentParser("Training for Se-ResNeXt.")
parser.add_argument("-e", "--epoch", default=200, type=int, help="set epoch")
parser.add_argument("--ce", action="store_true", help="run ce")
args = parser.parse_args()
batch_size = 64
train_parameters = {
"input_size": [3, 224, 224],
Expand Down Expand Up @@ -324,12 +329,12 @@ def eval(model, data):
label = to_variable(y_data)
label._stop_gradient = True
out = model(img)
cost,pred = fluid.layers.softmax_with_cross_entropy(out,label,return_softmax=True)
avg_loss = fluid.layers.mean(x=cost)

acc_top1 = fluid.layers.accuracy(input=pred, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=pred, label=label, k=5)

softmax_out = fluid.layers.softmax(out,use_cudnn=False)
loss = fluid.layers.cross_entropy(input=softmax_out, label=label)
avg_loss = fluid.layers.mean(x=loss)
acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=softmax_out, label=label, k=5)
dy_out = avg_loss.numpy()

total_loss += dy_out
Expand All @@ -341,19 +346,28 @@ def eval(model, data):
( batch_id, total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample))

if args.ce:
print("kpis\ttest_acc1\t%0.3f" % (total_acc1 / total_sample))
print("kpis\ttest_acc5\t%0.3f" % (total_acc5 / total_sample))
print("kpis\ttest_loss\t%0.3f" % (total_loss / total_sample))
print("final eval loss %0.3f acc1 %0.3f acc5 %0.3f" % \
(total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample))

def train():
seed = 90

epoch_num = train_parameters["num_epochs"]

if args.ce:
epoch_num = args.epoch
batch_size = train_parameters["batch_size"]

with fluid.dygraph.guard():
fluid.default_startup_program().random_seed = 90
fluid.default_main_program().random_seed = 90
if args.ce:
print("ce mode")
seed = 90
np.random.seed(seed)
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed

se_resnext = SeResNeXt("se_resnext")
optimizer = optimizer_setting(train_parameters)
Expand Down Expand Up @@ -404,10 +418,15 @@ def train():
total_acc5 += acc_top5.numpy()
total_sample += 1
if batch_id % 10 == 0:
print(fluid.dygraph.base._print_debug_msg())
print( "epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f lr %0.5f" % \
( epoch_id, batch_id, total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample, lr))

if args.ce:
print("kpis\ttrain_acc1\t%0.3f" % (total_acc1 / total_sample))
print("kpis\ttrain_acc5\t%0.3f" % (total_acc5 / total_sample))
print("kpis\ttrain_loss\t%0.3f" % (total_loss / total_sample))
print("epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f" % \
(epoch_id, batch_id, total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample))
Expand Down

0 comments on commit 9d690da

Please sign in to comment.