Skip to content

Commit

Permalink
Use multi process reader for dygraph (PaddlePaddle#2416)
Browse files Browse the repository at this point in the history
* add multi process reader

* use distributed_batch_reader
  • Loading branch information
chengduo authored Jun 20, 2019
1 parent 55138a4 commit dbc27b8
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 24 deletions.
12 changes: 4 additions & 8 deletions dygraph/mnist/mnist_dygraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,11 @@ def train_mnist(args):
if args.use_data_parallel:
mnist = fluid.dygraph.parallel.DataParallel(mnist, strategy)

train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=BATCH_SIZE, drop_last=True)
if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_sampler(
paddle.dataset.mnist.train(),
batch_size=BATCH_SIZE * trainer_count)
else:
train_reader = paddle.batch(
paddle.dataset.mnist.train(),
batch_size=BATCH_SIZE,
drop_last=True)
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)

test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE, drop_last=True)
Expand Down
11 changes: 4 additions & 7 deletions dygraph/resnet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,14 +282,11 @@ def train_resnet():
if args.use_data_parallel:
resnet = fluid.dygraph.parallel.DataParallel(resnet, strategy)

train_reader = paddle.batch(
paddle.dataset.flowers.train(use_xmap=False), batch_size=batch_size)
if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_sampler(
paddle.dataset.flowers.train(use_xmap=False),
batch_size=batch_size * trainer_count)
else:
train_reader = paddle.batch(
paddle.dataset.flowers.train(use_xmap=False),
batch_size=batch_size)
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)

test_reader = paddle.batch(
paddle.dataset.flowers.test(use_xmap=False), batch_size=batch_size)
Expand Down
14 changes: 5 additions & 9 deletions dygraph/transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,16 +1119,12 @@ def train():
transformer = fluid.dygraph.parallel.DataParallel(transformer,
strategy)

reader = paddle.batch(
wmt16.train(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=TrainTaskConfig.batch_size)
if args.use_data_parallel:
reader = fluid.contrib.reader.distributed_sampler(
wmt16.train(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=TrainTaskConfig.batch_size * trainer_count)
else:
reader = paddle.batch(
wmt16.train(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=TrainTaskConfig.batch_size)
reader = fluid.contrib.reader.distributed_batch_reader(reader)

for i in range(200):
dy_step = 0
Expand Down

0 comments on commit dbc27b8

Please sign in to comment.