Skip to content

Commit

Permalink
fix input shape for new Embedding (PaddlePaddle#4048)
Browse files Browse the repository at this point in the history
test=develop
  • Loading branch information
songyouwei authored and phlrain committed Dec 10, 2019
1 parent d766869 commit 09123cd
Show file tree
Hide file tree
Showing 9 changed files with 18 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,12 @@ def _create_mask(self, input_mask, append_head=False, auto_regressive=False):
Create attention mask.
@param : input_mask
@type : Variable(shape: [batch_size, max_seq_len, 1])
@type : Variable(shape: [batch_size, max_seq_len])
@param : auto_regressive
@type : bool
"""
input_mask = fluid.layers.unsqueeze(input=input_mask, axes=[2])
seq_len = input_mask.shape[1]

input_mask = layers.cast(input_mask, self._dtype)
Expand Down
8 changes: 4 additions & 4 deletions PaddleNLP/Research/Dialogue-PLATO/plato/modules/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def main():
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
model = Embedder("Embedder", 10, 20, 20, 20, 20)
token_inp = fluid.dygraph.to_variable(np.random.randint(0, 19, [10, 10, 1]).astype("int64"))
pos_inp = fluid.dygraph.to_variable(np.random.randint(0, 19, [10, 10, 1]).astype("int64"))
type_inp = fluid.dygraph.to_variable(np.random.randint(0, 19, [10, 10, 1]).astype("int64"))
turn_inp = fluid.dygraph.to_variable(np.random.randint(0, 19, [10, 10, 1]).astype("int64"))
token_inp = fluid.dygraph.to_variable(np.random.randint(0, 19, [10, 10]).astype("int64"))
pos_inp = fluid.dygraph.to_variable(np.random.randint(0, 19, [10, 10]).astype("int64"))
type_inp = fluid.dygraph.to_variable(np.random.randint(0, 19, [10, 10]).astype("int64"))
turn_inp = fluid.dygraph.to_variable(np.random.randint(0, 19, [10, 10]).astype("int64"))
out = model(token_inp, pos_inp, type_inp, turn_inp)
print(out)

Expand Down
1 change: 0 additions & 1 deletion PaddleNLP/Research/Dialogue-PLATO/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def main():
test_loader = DataLoader(test_dataset, hparams.Trainer, collate_fn=collate_fn, is_test=hparams.do_infer)

def to_tensor(array):
array = np.expand_dims(array, -1)
return fluid.dygraph.to_variable(array)

if hparams.use_data_distributed:
Expand Down
3 changes: 1 addition & 2 deletions PaddleSpeech/DeepVoice3/deepvoice3_paddle/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ def create_batch(batch):

x_batch = np.array(
[_pad(x[0], max_input_len) for x in batch], dtype=np.int64)
x_batch = np.expand_dims(x_batch, axis=-1)

mel_batch = np.array(
[_pad_2d(
Expand Down Expand Up @@ -318,7 +317,7 @@ def create_batch(batch):
done = np.expand_dims(np.expand_dims(done, axis=1), axis=1)

if multi_speaker:
speaker_ids = np.expand_dims(np.array([x[3] for x in batch]), axis=-1)
speaker_ids = np.array([x[3] for x in batch])
return (x_batch, input_lengths, mel_batch, y_batch, text_positions,
frame_positions, done, target_lengths, speaker_ids)
else:
Expand Down
4 changes: 2 additions & 2 deletions PaddleSpeech/DeepVoice3/deepvoice3_paddle/deepvoice3.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def forward(self, x, speaker_embed=None):
Encode text sequence.
Args:
x (Variable): Shape(B, T_enc, 1), dtype: int64. Ihe input text
x (Variable): Shape(B, T_enc), dtype: int64. Ihe input text
indices. T_enc means the timesteps of decoder input x.
speaker_embed (Variable, optional): Shape(Batch_size, speaker_dim),
dtype: float32. Speaker embeddings. This arg is not None only
Expand Down Expand Up @@ -1228,7 +1228,7 @@ def forward(self,
valid lengths for each example in text_sequences.
mel_inputs (Variable): Shape(B, C_mel, 1, T_mel), ground truth
mel-spectrogram, which is used as decoder inputs when training.
speaker_indices (Variable, optional): Shape(Batch_size, 1),
speaker_indices (Variable, optional): Shape(Batch_size),
dtype: int64. Speaker index for each example. This arg is not
None only when the model is a multispeaker model.
text_positions (Variable): Shape(B, T_enc, 1), dtype: int64.
Expand Down
2 changes: 1 addition & 1 deletion dygraph/ocr_recognition/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def forward(self, inputs, label_in):

decoder_boot = self.fc(backward_first)

label_in = fluid.layers.reshape(label_in, [-1, 1], inplace=False)
label_in = fluid.layers.reshape(label_in, [-1], inplace=False)
trg_embedding = self.embedding(label_in)

trg_embedding = fluid.layers.reshape(
Expand Down
4 changes: 2 additions & 2 deletions dygraph/ptb_lm/ptb_dy.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def eval(model, data):
train_data_iter = reader.get_data_iter(data, batch_size, num_steps)
for batch_id, batch in enumerate(train_data_iter):
x_data, y_data = batch
x_data = x_data.reshape((-1, num_steps, 1))
x_data = x_data.reshape((-1, num_steps))
y_data = y_data.reshape((-1, 1))
x = to_variable(x_data)
y = to_variable(y_data)
Expand Down Expand Up @@ -399,7 +399,7 @@ def eval(model, data):
start_time = time.time()
for batch_id, batch in enumerate(train_data_iter):
x_data, y_data = batch
x_data = x_data.reshape((-1, num_steps, 1))
x_data = x_data.reshape((-1, num_steps))
y_data = y_data.reshape((-1, 1))
x = to_variable(x_data)
y = to_variable(y_data)
Expand Down
4 changes: 2 additions & 2 deletions dygraph/sentiment/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def train():
'constant',
constant_values=(args.vocab_size))
for x in data
]).astype('int64').reshape(-1, 1))
]).astype('int64').reshape(-1))
label = to_variable(
np.array([x[1] for x in data]).astype('int64').reshape(
args.batch_size, 1))
Expand Down Expand Up @@ -206,7 +206,7 @@ def train():
eval_label = to_variable(
np.array([x[1] for x in eval_data]).astype(
'int64').reshape(args.batch_size, 1))
eval_doc = to_variable(eval_np_doc.reshape(-1, 1))
eval_doc = to_variable(eval_np_doc.reshape(-1))
eval_avg_cost, eval_prediction, eval_acc = model(
eval_doc, eval_label)
eval_np_mask = (
Expand Down
8 changes: 4 additions & 4 deletions dygraph/sentiment/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, name_scope, dict_dim, batch_size, seq_len):

def forward(self, inputs, label=None):
emb = self.embedding(inputs)
o_np_mask = (inputs.numpy() != self.dict_dim).astype('float32')
o_np_mask = (np.expand_dims(inputs.numpy(), -1) != self.dict_dim).astype('float32')
mask_emb = fluid.layers.expand(
to_variable(o_np_mask), [1, self.hid_dim])
emb = emb * mask_emb
Expand Down Expand Up @@ -155,7 +155,7 @@ def __init__(self, name_scope, dict_dim, batch_size, seq_len):

def forward(self, inputs, label=None):
emb = self.embedding(inputs)
o_np_mask = (inputs.numpy() != self.dict_dim).astype('float32')
o_np_mask = (np.expand_dims(inputs.numpy(), -1) != self.dict_dim).astype('float32')
mask_emb = fluid.layers.expand(
to_variable(o_np_mask), [1, self.hid_dim])
emb = emb * mask_emb
Expand Down Expand Up @@ -205,7 +205,7 @@ def __init__(self, name_scope, dict_dim, batch_size, seq_len):
def forward(self, inputs, label=None):
emb = self.embedding(inputs)
o_np_mask = to_variable(
inputs.numpy() != self.dict_dim).astype('float32')
np.expand_dims(inputs.numpy(), -1) != self.dict_dim).astype('float32')
mask_emb = fluid.layers.expand(
to_variable(o_np_mask), [1, self.hid_dim])
emb = emb * mask_emb
Expand Down Expand Up @@ -258,7 +258,7 @@ def __init__(self, name_scope, dict_dim, batch_size, seq_len):
def forward(self, inputs, label=None):
emb = self.embedding(inputs)
o_np_mask = to_variable(
inputs.numpy() != self.dict_dim).astype('float32')
np.expand_dims(inputs.numpy(), -1) != self.dict_dim).astype('float32')
mask_emb = fluid.layers.expand(
to_variable(o_np_mask), [1, self.hid_dim])
emb = emb * mask_emb
Expand Down

0 comments on commit 09123cd

Please sign in to comment.