Skip to content

Commit

Permalink
Add beam search for dygraph Transformer. (PaddlePaddle#3555)
Browse files Browse the repository at this point in the history
* Add beam search for dygraph Transformer.
Re-organize dygraph Transformer.

* Add custumed data support for dygraph Transformer.

* Add validation in dygraph Transformer.

* Update notes for multi-gpu for dygraph Transformer.
  • Loading branch information
guoshengCS authored Oct 15, 2019
1 parent 627e60c commit 2bbe37a
Show file tree
Hide file tree
Showing 6 changed files with 1,598 additions and 1,078 deletions.
121 changes: 113 additions & 8 deletions dygraph/transformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,39 @@

### 数据集说明

我们使用公开的 [WMT'16 EN-DE 数据集](http://www.statmt.org/wmt16/translation-task.html)训练
我们使用[WMT-16](http://www.statmt.org/wmt16/)新增的[multimodal task](http://www.statmt.org/wmt16/multimodal-task.html)中的[translation task](http://www.statmt.org/wmt16/multimodal-task.html#task1)的数据集作为示例。该数据集为英德翻译数据,包含29001条训练数据,1000条测试数据。


可以将下载好的wmt16数据集放在`~/.cache/paddle/dataset/wmt16/`目录下
该数据集内置在了Paddle中,可以通过 `paddle.dataset.wmt16` 使用,执行本项目中的训练代码数据集将自动下载到 `~/.cache/paddle/dataset/wmt16/` 目录下。

### 安装说明

1. paddle安装

本项目依赖于 Paddlepaddle Fluid 1.4.1,请参考安装指南进行安装。
本项目依赖于 PaddlePaddle Fluid 1.6.0 及以上版本(1.6.0 待近期正式发版,可先使用 develop),请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装

2. 环境依赖

2. 安装代码
3. 环境依赖
多卡运行需要 NCCL 2.4.7 版本。

### 执行训练:
如果是使用GPU单卡训练,启动训练的方式:
```
env CUDA_VISIBLE_DEVICES=0 python train.py
```

这里`CUDA_VISIBLE_DEVICES=0`表示是执行在0号设备卡上,请根据自身情况修改这个参数。
这里`CUDA_VISIBLE_DEVICES=0`表示是执行在0号设备卡上,请根据自身情况修改这个参数。如需调整其他模型及训练参数,可在 `config.py` 中修改或使用如下方式传入:

```sh
python train.py \
n_head 16 \
d_model 1024 \
d_inner_hid 4096 \
prepostprocess_dropout 0.3
```

Paddle动态图支持多进程多卡进行模型训练,启动训练的方式:
```
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog train.py --use_data_parallel 1
python -m paddle.distributed.launch --started_port 9999 --selected_gpus=0,1,2,3 --log_dir ./mylog train.py --use_data_parallel 1
```
此时,程序会将每个进程的输出log导入到`./mylog`路径下:
```
Expand Down Expand Up @@ -64,8 +72,105 @@ python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog tr
pass num : 0, batch_id: 110, dy_graph avg loss: [7.4179897]
pass num : 0, batch_id: 120, dy_graph avg loss: [7.318419]


### 执行预测

训练完成后,使用如下命令进行预测:

```
env CUDA_VISIBLE_DEVICES=0 python predict.py
```

预测结果将输出到 `predict.txt` 文件中(可在运行时通过 `--output_file` 更改),其他模型与预测参数也可在 `config.py` 中修改或使用如下方式传入:

```sh
python predict.py \
n_head 16 \
d_model 1024 \
d_inner_hid 4096 \
prepostprocess_dropout 0.3
```

完成预测后,可以借助第三方工具进行 BLEU 指标的评估,可按照如下方式进行:

```sh
# 提取 reference 数据
tar -zxvf ~/.cache/paddle/dataset/wmt16/wmt16.tar.gz
awk 'BEGIN {FS="\t"}; {print $2}' wmt16/test > ref.de

# clone mosesdecoder代码
git clone https://github.com/moses-smt/mosesdecoder.git

# 进行评估
perl mosesdecoder/scripts/generic/multi-bleu.perl ref.de < predict.txt
```

使用默认配置单卡训练20个 epoch 训练的模型约有如下评估结果:
```
BLEU = 32.38, 64.3/39.1/25.9/16.9 (BP=1.000, ratio=1.001, hyp_len=12122, ref_len=12104)
```


## 进阶使用

### 自定义数据

- 训练:

修改 `train.py` 中的如下代码段

```python
reader = paddle.batch(wmt16.train(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=TrainTaskConfig.batch_size)
```


将其中的 `wmt16.train` 替换为类似如下的 python generator :

```python
def reader(file_name, src_dict, trg_dict):
start_id = src_dict[START_MARK] # BOS
end_id = src_dict[END_MARK] # EOS
unk_id = src_dict[UNK_MARK] # UNK

src_col, trg_col = 0, 1

for line in open(file_name, "r"):
line = line.strip()
line_split = line.strip().split("\t")
if len(line_split) != 2:
continue
src_words = line_split[src_col].split()
src_ids = [start_id] + [
src_dict.get(w, unk_id) for w in src_words
] + [end_id]

trg_words = line_split[trg_col].split()
trg_ids = [trg_dict.get(w, unk_id) for w in trg_words]

trg_ids_next = trg_ids + [end_id]
trg_ids = [start_id] + trg_ids

yield src_ids, trg_ids, trg_ids_next
```

该 generator 产生的数据为单个样本,是包含源句(src_ids),目标句(trg_ids)和标签(trg_ids_next)三个 integer list 的 tuple;其中 src_ids 包含 BOS 和 EOS 的 id,trg_ids 包含 BOS 的 id,trg_ids_next 包含 EOS 的 id。

- 预测:
修改 `predict.py` 中的如下代码段

```python
reader = paddle.batch(wmt16.test(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=InferTaskConfig.batch_size)
id2word = wmt16.get_dict("de",
ModelHyperParams.trg_vocab_size,
reverse=True)
```

将其中的 `wmt16.test` 替换为和训练部分类似的 python generator ;另外还需要提供将 id 映射到 word 的 python dict 作为 `id2word` .

### 模型原理介绍

Transformer 是论文 [Attention Is All You Need](https://arxiv.org/abs/1706.03762) 中提出的用以完成机器翻译(machine translation, MT)等序列到序列(sequence to sequence, Seq2Seq)学习任务的一种全新网络结构。其同样使用了 Seq2Seq 任务中典型的编码器-解码器(Encoder-Decoder)的框架结构,但相较于此前广泛使用的循环神经网络(Recurrent Neural Network, RNN),其完全使用注意力(Attention)机制来实现序列到序列的建模,整体网络结构如图1所示。
Expand Down
206 changes: 206 additions & 0 deletions dygraph/transformer/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class TrainTaskConfig(object):
"""
TrainTaskConfig
"""
# the epoch number to train.
pass_num = 20
# the number of sequences contained in a mini-batch.
# deprecated, set batch_size in args.
batch_size = 32
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate = 2.0
beta1 = 0.9
beta2 = 0.997
eps = 1e-9
# the parameters for learning rate scheduling.
warmup_steps = 8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps = 0.1


class InferTaskConfig(object):
# the number of examples in one run for sequence generation.
batch_size = 4
# the parameters for beam search.
beam_size = 4
alpha=0.6
# max decoded length, should be less than ModelHyperParams.max_length
max_out_len = 30



class ModelHyperParams(object):
"""
ModelHyperParams
"""
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size = 10000
# size of target word dictionay
trg_vocab_size = 10000
# index for <bos> token
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2

# max length of sequences deciding the size of position encoding table.
max_length = 50
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 2048
# the dimension that keys are projected to for dot-product attention.
d_key = 64
# the dimension that values are projected to for dot-product attention.
d_value = 64
# number of head used in multi-head attention.
n_head = 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
# dropout rates of different modules.
prepostprocess_dropout = 0.1
attention_dropout = 0.1
relu_dropout = 0.1
# to process before each sub-layer
preprocess_cmd = "n" # layer normalization
# to process after each sub-layer
postprocess_cmd = "da" # dropout + residual connection
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing = False


# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = -1
# The placeholder for squence length in compile time.
seq_len = ModelHyperParams.max_length
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch, 1]
"src_word": [(batch_size, seq_len, 1), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size, max_src_len_in_batch, 1]
"src_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias":
[(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"],
# The actual data shape of trg_word is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_word": [(batch_size, seq_len, 1), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias":
[(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias":
[(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(batch_size * seq_len, 1), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(batch_size * seq_len, 1), "float32"],
# This input is used in beam-search decoder.
"init_score": [(batch_size, 1), "float32", 2],
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx": [(batch_size, ), "int32"],
}

# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = (
"src_word_emb_table",
"trg_word_emb_table",
)
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table",
)
# separated inputs for different usages.
encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias",
)
decoder_data_input_fields = (
"trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output",
)
label_data_input_fields = (
"lbl_word",
"lbl_weight",
)
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields = (
"trg_word",
# "init_score",
# "init_idx",
"trg_src_attn_bias",
)


def merge_cfg_from_list(cfg_list, g_cfgs):
"""
Set the above global configurations using the cfg_list.
"""
assert len(cfg_list) % 2 == 0
for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
for g_cfg in g_cfgs:
if hasattr(g_cfg, key):
try:
value = eval(value)
except Exception: # for file path
pass
setattr(g_cfg, key, value)
break
Loading

0 comments on commit 2bbe37a

Please sign in to comment.