Skip to content

Commit

Permalink
update build_once for mnist (PaddlePaddle#4103)
Browse files Browse the repository at this point in the history
test=develop
  • Loading branch information
songyouwei authored and phlrain committed Dec 24, 2019
1 parent 66f6039 commit 894429c
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions dygraph/mnist/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import paddle
import paddle.fluid as fluid
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.dygraph.base import to_variable


Expand All @@ -41,7 +41,6 @@ def parse_args():

class SimpleImgConvPool(fluid.dygraph.Layer):
def __init__(self,
name_scope,
num_channels,
num_filters,
filter_size,
Expand All @@ -58,10 +57,10 @@ def __init__(self,
use_cudnn=False,
param_attr=None,
bias_attr=None):
super(SimpleImgConvPool, self).__init__(name_scope)
super(SimpleImgConvPool, self).__init__()

self._conv2d = Conv2D(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,
Expand All @@ -74,7 +73,6 @@ def __init__(self,
use_cudnn=use_cudnn)

self._pool2d = Pool2D(
self.full_name(),
pool_size=pool_size,
pool_type=pool_type,
pool_stride=pool_stride,
Expand All @@ -89,20 +87,19 @@ def forward(self, inputs):


class MNIST(fluid.dygraph.Layer):
def __init__(self, name_scope):
super(MNIST, self).__init__(name_scope)
def __init__(self):
super(MNIST, self).__init__()

self._simple_img_conv_pool_1 = SimpleImgConvPool(
self.full_name(), 1, 20, 5, 2, 2, act="relu")
1, 20, 5, 2, 2, act="relu")

self._simple_img_conv_pool_2 = SimpleImgConvPool(
self.full_name(), 20, 50, 5, 2, 2, act="relu")
20, 50, 5, 2, 2, act="relu")

pool_2_shape = 50 * 4 * 4
self.pool_2_shape = 50 * 4 * 4
SIZE = 10
scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
self._fc = FC(self.full_name(),
10,
scale = (2.0 / (self.pool_2_shape**2 * SIZE))**0.5
self._fc = Linear(self.pool_2_shape, 10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
Expand All @@ -111,6 +108,7 @@ def __init__(self, name_scope):
def forward(self, inputs, label=None):
x = self._simple_img_conv_pool_1(inputs)
x = self._simple_img_conv_pool_2(x)
x = fluid.layers.reshape(x, shape=[-1, self.pool_2_shape])
x = self._fc(x)
if label is not None:
acc = fluid.layers.accuracy(input=x, label=label)
Expand Down Expand Up @@ -148,7 +146,7 @@ def inference_mnist():
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
mnist_infer = MNIST("mnist")
mnist_infer = MNIST()
# load checkpoint
model_dict, _ = fluid.load_dygraph("save_temp")
mnist_infer.set_dict(model_dict)
Expand Down Expand Up @@ -188,7 +186,7 @@ def train_mnist(args):

if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
mnist = MNIST("mnist")
mnist = MNIST()
adam = AdamOptimizer(learning_rate=0.001)
if args.use_data_parallel:
mnist = fluid.dygraph.parallel.DataParallel(mnist, strategy)
Expand Down

0 comments on commit 894429c

Please sign in to comment.