Skip to content

Commit

Permalink
Fix dygraph model save (PaddlePaddle#3369)
Browse files Browse the repository at this point in the history
* fix model save

* fix doc
  • Loading branch information
chengduoZH authored Sep 23, 2019
1 parent cc8e0d0 commit 32ae4f2
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 47 deletions.
12 changes: 8 additions & 4 deletions dygraph/mnist/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def parse_args():
"--use_data_parallel",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to shuffle instances in each pass.")
help="The flag indicating whether to use data parallel mode to train the model."
)
parser.add_argument("-e", "--epoch", default=5, type=int, help="set epoch")
parser.add_argument("--ce", action="store_true", help="run ce")
args = parser.parse_args()
Expand Down Expand Up @@ -175,7 +176,6 @@ def train_mnist(args):
epoch_num = args.epoch
BATCH_SIZE = 64

trainer_count = fluid.dygraph.parallel.Env().nranks
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
Expand Down Expand Up @@ -241,8 +241,12 @@ def train_mnist(args):
print("Loss at epoch {} , Test avg_loss is: {}, acc is: {}".format(
epoch, test_cost, test_acc))

fluid.dygraph.save_persistables(mnist.state_dict(), "save_dir")
print("checkpoint saved")
save_parameters = (not args.use_data_parallel) or (
args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
if save_parameters:
fluid.dygraph.save_persistables(mnist.state_dict(), "save_dir")
print("checkpoint saved")

inference_mnist()

Expand Down
20 changes: 14 additions & 6 deletions dygraph/resnet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@ def parse_args():
"--use_data_parallel",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to shuffle instances in each pass.")
parser.add_argument("-e", "--epoch", default=120, type=int, help="set epoch")
parser.add_argument("-b", "--batch_size", default=32, type=int, help="set epoch")
help="The flag indicating whether to use data parallel mode to train the model."
)
parser.add_argument(
"-e", "--epoch", default=120, type=int, help="set epoch")
parser.add_argument(
"-b", "--batch_size", default=32, type=int, help="set epoch")
parser.add_argument("--ce", action="store_true", help="run ce")
args = parser.parse_args()
return args
Expand All @@ -49,6 +52,7 @@ def parse_args():
args = parse_args()
batch_size = args.batch_size


def optimizer_setting():

total_images = IMAGENET1000
Expand Down Expand Up @@ -275,7 +279,6 @@ def eval(model, data):

def train_resnet():
epoch = args.epoch
trainer_count = fluid.dygraph.parallel.Env().nranks
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
Expand Down Expand Up @@ -353,7 +356,6 @@ def train_resnet():
optimizer.minimize(avg_loss)
resnet.clear_gradients()


total_loss += dy_out
total_acc1 += acc_top1.numpy()
total_acc5 += acc_top5.numpy()
Expand All @@ -373,7 +375,13 @@ def train_resnet():
total_acc1 / total_sample, total_acc5 / total_sample))
resnet.eval()
eval(resnet, test_reader)
fluid.dygraph.save_persistables(resnet.state_dict(), 'resnet_params')

save_parameters = (not args.use_data_parallel) or (
args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
if save_parameters:
fluid.dygraph.save_persistables(resnet.state_dict(),
'resnet_params')


if __name__ == '__main__':
Expand Down
78 changes: 42 additions & 36 deletions dygraph/se_resnext/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@

parser = argparse.ArgumentParser("Training for Se-ResNeXt.")
parser.add_argument("-e", "--epoch", default=200, type=int, help="set epoch")
parser.add_argument("--ce", action="store_true", help="run ce")
parser.add_argument("--ce", action="store_true", help="run ce")
parser.add_argument(
"--use_data_parallel",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to shuffle instances in each pass.")
"--use_data_parallel",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use data parallel mode to train the model."
)
args = parser.parse_args()
batch_size = 64
train_parameters = {
Expand All @@ -51,27 +52,28 @@
"batch_size": batch_size,
"lr": 0.0125,
"total_images": 6149,
"num_epochs":200
"num_epochs": 200
}

momentum_rate = 0.9
l2_decay = 1.2e-4


def optimizer_setting(params):
ls = params["learning_strategy"]
if "total_images" not in params:
total_images = 6149
else:
total_images = params["total_images"]

batch_size = ls["batch_size"]
step = int(math.ceil(float(total_images) / batch_size))
bd = [step * e for e in ls["epochs"]]
lr = params["lr"]
num_epochs = params["num_epochs"]
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.cosine_decay(
learning_rate=lr,step_each_epoch=step,epochs=num_epochs),
learning_rate=lr, step_each_epoch=step, epochs=num_epochs),
momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay))

Expand All @@ -97,7 +99,7 @@ def __init__(self,
groups=groups,
act=None,
bias_attr=False,
param_attr=fluid.ParamAttr(name="weights"))
param_attr=fluid.ParamAttr(name="weights"))

self._batch_norm = BatchNorm(self.full_name(), num_filters, act=act)

Expand All @@ -114,20 +116,21 @@ def __init__(self, name_scope, num_channels, reduction_ratio):
super(SqueezeExcitation, self).__init__(name_scope)
self._pool = Pool2D(
self.full_name(), pool_size=0, pool_type='avg', global_pooling=True)
stdv = 1.0/math.sqrt(num_channels*1.0)
stdv = 1.0 / math.sqrt(num_channels * 1.0)
self._squeeze = FC(
self.full_name(),
size=num_channels // reduction_ratio,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv,stdv)),
initializer=fluid.initializer.Uniform(-stdv, stdv)),
act='relu')
stdv = 1.0/math.sqrt(num_channels/16.0*1.0)
stdv = 1.0 / math.sqrt(num_channels / 16.0 * 1.0)
self._excitation = FC(
self.full_name(),
size=num_channels,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv,stdv)),
initializer=fluid.initializer.Uniform(-stdv, stdv)),
act='sigmoid')

def forward(self, input):
y = self._pool(input)
y = self._squeeze(y)
Expand Down Expand Up @@ -310,15 +313,15 @@ def forward(self, inputs):
for bottleneck_block in self.bottleneck_block_list:
y = bottleneck_block(y)
y = self.pool2d_avg(y)
y = fluid.layers.dropout(y, dropout_prob=0.5,seed=100)
y = fluid.layers.dropout(y, dropout_prob=0.5, seed=100)
y = self.out(y)
return y


def eval(model, data):

model.eval()
batch_size=32
batch_size = 32
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
Expand All @@ -336,7 +339,7 @@ def eval(model, data):
label._stop_gradient = True
out = model(img)

softmax_out = fluid.layers.softmax(out,use_cudnn=False)
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
loss = fluid.layers.cross_entropy(input=softmax_out, label=label)
avg_loss = fluid.layers.mean(x=loss)
acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1)
Expand All @@ -351,7 +354,7 @@ def eval(model, data):
print("test | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f" % \
( batch_id, total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample))

if args.ce:
print("kpis\ttest_acc1\t%0.3f" % (total_acc1 / total_sample))
print("kpis\ttest_acc5\t%0.3f" % (total_acc5 / total_sample))
Expand All @@ -360,8 +363,9 @@ def eval(model, data):
(total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample))


def train():

epoch_num = train_parameters["num_epochs"]
if args.ce:
epoch_num = args.epoch
Expand All @@ -378,47 +382,48 @@ def train():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
strategy = fluid.dygraph.parallel.prepare_context()
se_resnext = SeResNeXt("se_resnext")
optimizer = optimizer_setting(train_parameters)
if args.use_data_parallel:
se_resnext = fluid.dygraph.parallel.DataParallel(se_resnext, strategy)
se_resnext = fluid.dygraph.parallel.DataParallel(se_resnext,
strategy)
train_reader = paddle.batch(
paddle.dataset.flowers.train(use_xmap=False),
batch_size=batch_size,
drop_last=True
)
drop_last=True)
if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)
test_reader = paddle.batch(
paddle.dataset.flowers.test(use_xmap=False), batch_size=32)
paddle.dataset.flowers.test(use_xmap=False), batch_size=32)

for epoch_id in range(epoch_num):
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0
for batch_id, data in enumerate(train_reader()):

dy_x_data = np.array(
[x[0].reshape(3, 224, 224)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(
batch_size, 1)

dy_x_data = np.array([x[0].reshape(3, 224, 224)
for x in data]).astype('float32')
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
batch_size, 1)

img = to_variable(dy_x_data)
label = to_variable(y_data)
label.stop_gradient = True

out = se_resnext(img)
softmax_out = fluid.layers.softmax(out,use_cudnn=False)
loss = fluid.layers.cross_entropy(input=softmax_out, label=label)
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
loss = fluid.layers.cross_entropy(
input=softmax_out, label=label)
avg_loss = fluid.layers.mean(x=loss)

acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=softmax_out, label=label, k=5)

acc_top1 = fluid.layers.accuracy(
input=softmax_out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(
input=softmax_out, label=label, k=5)

dy_out = avg_loss.numpy()
if args.use_data_parallel:
Expand All @@ -430,7 +435,7 @@ def train():

optimizer.minimize(avg_loss)
se_resnext.clear_gradients()

lr = optimizer._global_learning_rate().numpy()
total_loss += dy_out
total_acc1 += acc_top1.numpy()
Expand All @@ -452,5 +457,6 @@ def train():
eval(se_resnext, test_reader)
se_resnext.train()


if __name__ == '__main__':
train()
3 changes: 2 additions & 1 deletion dygraph/transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def parse_args():
"--use_data_parallel",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to shuffle instances in each pass.")
help="The flag indicating whether to use data parallel mode to train the model."
)
args = parser.parse_args()
return args

Expand Down

0 comments on commit 32ae4f2

Please sign in to comment.