Skip to content

Commit

Permalink
use paddle API instead of numpy API to avoid breakgraph in SOT
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo committed Feb 22, 2024
1 parent ebe2a37 commit f390521
Showing 1 changed file with 48 additions and 58 deletions.
106 changes: 48 additions & 58 deletions test/dygraph_to_static/seq2seq_dygraph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from seq2seq_utils import Seq2SeqModelHyperParams as args

import paddle
from paddle import base
from paddle.base import ParamAttr
from paddle.nn import Embedding, Layer

INF = 1.0 * 1e5
Expand Down Expand Up @@ -111,14 +108,16 @@ def __init__(
self.mode = mode
self.kinf = 1e9

param_attr = ParamAttr(initializer=uniform_initializer(self.init_scale))
bias_attr = ParamAttr(initializer=zero_constant)
param_attr = paddle.ParamAttr(
initializer=uniform_initializer(self.init_scale)
)
bias_attr = paddle.ParamAttr(initializer=zero_constant)
forget_bias = 1.0

self.src_embeder = Embedding(
self.src_vocab_size,
self.hidden_size,
weight_attr=base.ParamAttr(
weight_attr=paddle.ParamAttr(
initializer=uniform_initializer(init_scale)
),
)
Expand All @@ -127,7 +126,7 @@ def __init__(
self.tar_vocab_size,
self.hidden_size,
sparse=False,
weight_attr=base.ParamAttr(
weight_attr=paddle.ParamAttr(
initializer=uniform_initializer(init_scale)
),
)
Expand All @@ -136,7 +135,7 @@ def __init__(
for i in range(num_layers):
self.enc_units.append(
self.add_sublayer(
"enc_units_%d" % i,
f"enc_units_{i}",
BasicLSTMUnit(
hidden_size=self.hidden_size,
input_size=self.hidden_size,
Expand All @@ -151,7 +150,7 @@ def __init__(
for i in range(num_layers):
self.dec_units.append(
self.add_sublayer(
"dec_units_%d" % i,
f"dec_units_{i}",
BasicLSTMUnit(
hidden_size=self.hidden_size,
input_size=self.hidden_size,
Expand Down Expand Up @@ -210,11 +209,11 @@ def forward(self, inputs):

# NOTE: modify model code about `enc_hidden` and `enc_cell` to transforme dygraph code successfully.
# Because nested list can't be transformed now.
enc_hidden_0 = paddle.to_tensor(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_hidden_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
enc_cell_0 = paddle.to_tensor(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_cell_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
zero = paddle.zeros(shape=[1], dtype="int64")
enc_hidden = paddle.tensor.create_array(dtype="float32")
Expand Down Expand Up @@ -291,8 +290,8 @@ def forward(self, inputs):

dec_output = paddle.stack(dec_output)
dec_output = self.fc(self._transpose_batch_time(dec_output))
loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=dec_output, label=label, soft_label=False
loss = paddle.nn.functional.cross_entropy(
input=dec_output, label=label, soft_label=False, reduction="none"
)
loss = paddle.squeeze(loss, axis=[2])
max_tar_seq_len = paddle.shape(tar)[1]
Expand All @@ -311,11 +310,11 @@ def beam_search(self, inputs):
self.batch_size = src.shape[0]

src_emb = self.src_embeder(self._transpose_batch_time(src))
enc_hidden_0 = paddle.to_tensor(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_hidden_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
enc_cell_0 = paddle.to_tensor(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_cell_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
zero = paddle.zeros(shape=[1], dtype="int64")
enc_hidden = paddle.tensor.create_array(dtype="float32")
Expand Down Expand Up @@ -366,23 +365,17 @@ def beam_search(self, inputs):

# beam search
batch_beam_shape = (self.batch_size, self.beam_size)
vocab_size_tensor = paddle.to_tensor(
np.full((1), self.tar_vocab_size)
).astype("int64")
start_token_tensor = paddle.to_tensor(
np.full(batch_beam_shape, self.beam_start_token, dtype='int64')
vocab_size_tensor = paddle.full([1], self.tar_vocab_size, dtype="int64")
start_token_tensor = paddle.full(
batch_beam_shape, self.beam_start_token, dtype="int64"
)
end_token_tensor = paddle.to_tensor(
np.full(batch_beam_shape, self.beam_end_token, dtype='int64')
end_token_tensor = paddle.full(
batch_beam_shape, self.beam_end_token, dtype="int64"
)
step_input = self.tar_embeder(start_token_tensor)
beam_finished = paddle.to_tensor(
np.full(batch_beam_shape, 0, dtype='float32')
)
beam_finished = paddle.full(batch_beam_shape, 0, dtype="float32")
beam_state_log_probs = paddle.to_tensor(
np.array(
[[0.0] + [-self.kinf] * (self.beam_size - 1)], dtype="float32"
)
[[0.0] + [-self.kinf] * (self.beam_size - 1)], dtype="float32"
)
beam_state_log_probs = paddle.expand(
beam_state_log_probs,
Expand All @@ -394,10 +387,7 @@ def beam_search(self, inputs):

batch_pos = paddle.expand(
paddle.unsqueeze(
paddle.to_tensor(
np.arange(0, self.batch_size, 1, dtype="int64")
),
[1],
paddle.arange(0, self.batch_size, 1, dtype="int64"), [1]
),
[-1, self.beam_size],
)
Expand Down Expand Up @@ -438,9 +428,7 @@ def beam_search(self, inputs):
)
noend_array = [-self.kinf] * self.tar_vocab_size
noend_array[self.beam_end_token] = 0
noend_mask_tensor = paddle.to_tensor(
np.array(noend_array, dtype='float32')
)
noend_mask_tensor = paddle.to_tensor(noend_array, dtype="float32")

step_log_probs = paddle.multiply(
paddle.expand(
Expand Down Expand Up @@ -538,14 +526,16 @@ def __init__(
self.mode = mode
self.kinf = 1e9

param_attr = ParamAttr(initializer=uniform_initializer(self.init_scale))
bias_attr = ParamAttr(initializer=zero_constant)
param_attr = paddle.ParamAttr(
initializer=uniform_initializer(self.init_scale)
)
bias_attr = paddle.ParamAttr(initializer=zero_constant)
forget_bias = 1.0

self.src_embeder = Embedding(
self.src_vocab_size,
self.hidden_size,
weight_attr=base.ParamAttr(
weight_attr=paddle.ParamAttr(
name='source_embedding',
initializer=uniform_initializer(init_scale),
),
Expand All @@ -555,7 +545,7 @@ def __init__(
self.tar_vocab_size,
self.hidden_size,
sparse=False,
weight_attr=base.ParamAttr(
weight_attr=paddle.ParamAttr(
name='target_embedding',
initializer=uniform_initializer(init_scale),
),
Expand All @@ -565,7 +555,7 @@ def __init__(
for i in range(num_layers):
self.enc_units.append(
self.add_sublayer(
"enc_units_%d" % i,
f"enc_units_{i}",
BasicLSTMUnit(
hidden_size=self.hidden_size,
input_size=self.hidden_size,
Expand All @@ -581,12 +571,12 @@ def __init__(
if i == 0:
self.dec_units.append(
self.add_sublayer(
"dec_units_%d" % i,
f"dec_units_{i}",
BasicLSTMUnit(
hidden_size=self.hidden_size,
input_size=self.hidden_size * 2,
param_attr=ParamAttr(
name="dec_units_%d" % i,
param_attr=paddle.ParamAttr(
name=f"dec_units_{i}",
initializer=uniform_initializer(
self.init_scale
),
Expand All @@ -599,12 +589,12 @@ def __init__(
else:
self.dec_units.append(
self.add_sublayer(
"dec_units_%d" % i,
f"dec_units_{i}",
BasicLSTMUnit(
hidden_size=self.hidden_size,
input_size=self.hidden_size,
param_attr=ParamAttr(
name="dec_units_%d" % i,
param_attr=paddle.ParamAttr(
name=f"dec_units_{i}",
initializer=uniform_initializer(
self.init_scale
),
Expand Down Expand Up @@ -727,12 +717,12 @@ def forward(self, inputs):

# NOTE: modify model code about `enc_hidden` and `enc_cell` to transform dygraph code successfully.
# Because nested list can't be transformed now.
enc_hidden_0 = paddle.to_tensor(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_hidden_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
enc_hidden_0.stop_gradient = True
enc_cell_0 = paddle.to_tensor(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_cell_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
enc_hidden_0.stop_gradient = True
zero = paddle.zeros(shape=[1], dtype="int64")
Expand Down Expand Up @@ -790,8 +780,8 @@ def forward(self, inputs):
enc_outputs = self._transpose_batch_time(enc_outputs)

# train
input_feed = paddle.to_tensor(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
input_feed = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
# NOTE: set stop_gradient here, otherwise grad var is null
input_feed.stop_gradient = True
Expand Down Expand Up @@ -829,8 +819,8 @@ def forward(self, inputs):

dec_output = paddle.stack(dec_output)
dec_output = self.fc(self._transpose_batch_time(dec_output))
loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=dec_output, label=label, soft_label=False
loss = paddle.nn.functional.cross_entropy(
input=dec_output, label=label, soft_label=False, reduction="none"
)
loss = paddle.squeeze(loss, axis=[2])
max_tar_seq_len = paddle.shape(tar)[1]
Expand Down

0 comments on commit f390521

Please sign in to comment.