Skip to content

Commit

Permalink
[Cleanup][B-2] clean some paddle.base.dygraph.to_variable for test (#…
Browse files Browse the repository at this point in the history
…61904)


---------

Co-authored-by: SigureMo <sigure.qaq@gmail.com>
  • Loading branch information
PommesPeter and SigureMo authored Feb 23, 2024
1 parent d7c5cf5 commit d316845
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 122 deletions.
107 changes: 49 additions & 58 deletions test/dygraph_to_static/seq2seq_dygraph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from seq2seq_utils import Seq2SeqModelHyperParams as args

import paddle
from paddle import base
from paddle.base import ParamAttr
from paddle.base.dygraph.base import to_variable
from paddle.nn import Embedding, Layer

INF = 1.0 * 1e5
Expand Down Expand Up @@ -112,14 +108,16 @@ def __init__(
self.mode = mode
self.kinf = 1e9

param_attr = ParamAttr(initializer=uniform_initializer(self.init_scale))
bias_attr = ParamAttr(initializer=zero_constant)
param_attr = paddle.ParamAttr(
initializer=uniform_initializer(self.init_scale)
)
bias_attr = paddle.ParamAttr(initializer=zero_constant)
forget_bias = 1.0

self.src_embeder = Embedding(
self.src_vocab_size,
self.hidden_size,
weight_attr=base.ParamAttr(
weight_attr=paddle.ParamAttr(
initializer=uniform_initializer(init_scale)
),
)
Expand All @@ -128,7 +126,7 @@ def __init__(
self.tar_vocab_size,
self.hidden_size,
sparse=False,
weight_attr=base.ParamAttr(
weight_attr=paddle.ParamAttr(
initializer=uniform_initializer(init_scale)
),
)
Expand All @@ -137,7 +135,7 @@ def __init__(
for i in range(num_layers):
self.enc_units.append(
self.add_sublayer(
"enc_units_%d" % i,
f"enc_units_{i}",
BasicLSTMUnit(
hidden_size=self.hidden_size,
input_size=self.hidden_size,
Expand All @@ -152,7 +150,7 @@ def __init__(
for i in range(num_layers):
self.dec_units.append(
self.add_sublayer(
"dec_units_%d" % i,
f"dec_units_{i}",
BasicLSTMUnit(
hidden_size=self.hidden_size,
input_size=self.hidden_size,
Expand Down Expand Up @@ -211,11 +209,11 @@ def forward(self, inputs):

# NOTE: modify model code about `enc_hidden` and `enc_cell` to transforme dygraph code successfully.
# Because nested list can't be transformed now.
enc_hidden_0 = to_variable(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_hidden_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
enc_cell_0 = to_variable(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_cell_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
zero = paddle.zeros(shape=[1], dtype="int64")
enc_hidden = paddle.tensor.create_array(dtype="float32")
Expand Down Expand Up @@ -292,8 +290,8 @@ def forward(self, inputs):

dec_output = paddle.stack(dec_output)
dec_output = self.fc(self._transpose_batch_time(dec_output))
loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=dec_output, label=label, soft_label=False
loss = paddle.nn.functional.cross_entropy(
input=dec_output, label=label, soft_label=False, reduction="none"
)
loss = paddle.squeeze(loss, axis=[2])
max_tar_seq_len = paddle.shape(tar)[1]
Expand All @@ -312,11 +310,11 @@ def beam_search(self, inputs):
self.batch_size = src.shape[0]

src_emb = self.src_embeder(self._transpose_batch_time(src))
enc_hidden_0 = to_variable(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_hidden_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
enc_cell_0 = to_variable(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_cell_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
zero = paddle.zeros(shape=[1], dtype="int64")
enc_hidden = paddle.tensor.create_array(dtype="float32")
Expand Down Expand Up @@ -367,23 +365,17 @@ def beam_search(self, inputs):

# beam search
batch_beam_shape = (self.batch_size, self.beam_size)
vocab_size_tensor = to_variable(
np.full((1), self.tar_vocab_size)
).astype("int64")
start_token_tensor = to_variable(
np.full(batch_beam_shape, self.beam_start_token, dtype='int64')
vocab_size_tensor = paddle.full([1], self.tar_vocab_size, dtype="int64")
start_token_tensor = paddle.full(
batch_beam_shape, self.beam_start_token, dtype="int64"
)
end_token_tensor = to_variable(
np.full(batch_beam_shape, self.beam_end_token, dtype='int64')
end_token_tensor = paddle.full(
batch_beam_shape, self.beam_end_token, dtype="int64"
)
step_input = self.tar_embeder(start_token_tensor)
beam_finished = to_variable(
np.full(batch_beam_shape, 0, dtype='float32')
)
beam_state_log_probs = to_variable(
np.array(
[[0.0] + [-self.kinf] * (self.beam_size - 1)], dtype="float32"
)
beam_finished = paddle.full(batch_beam_shape, 0, dtype="float32")
beam_state_log_probs = paddle.to_tensor(
[[0.0] + [-self.kinf] * (self.beam_size - 1)], dtype="float32"
)
beam_state_log_probs = paddle.expand(
beam_state_log_probs,
Expand All @@ -395,8 +387,7 @@ def beam_search(self, inputs):

batch_pos = paddle.expand(
paddle.unsqueeze(
to_variable(np.arange(0, self.batch_size, 1, dtype="int64")),
[1],
paddle.arange(0, self.batch_size, 1, dtype="int64"), [1]
),
[-1, self.beam_size],
)
Expand Down Expand Up @@ -437,9 +428,7 @@ def beam_search(self, inputs):
)
noend_array = [-self.kinf] * self.tar_vocab_size
noend_array[self.beam_end_token] = 0
noend_mask_tensor = to_variable(
np.array(noend_array, dtype='float32')
)
noend_mask_tensor = paddle.to_tensor(noend_array, dtype="float32")

step_log_probs = paddle.multiply(
paddle.expand(
Expand Down Expand Up @@ -537,14 +526,16 @@ def __init__(
self.mode = mode
self.kinf = 1e9

param_attr = ParamAttr(initializer=uniform_initializer(self.init_scale))
bias_attr = ParamAttr(initializer=zero_constant)
param_attr = paddle.ParamAttr(
initializer=uniform_initializer(self.init_scale)
)
bias_attr = paddle.ParamAttr(initializer=zero_constant)
forget_bias = 1.0

self.src_embeder = Embedding(
self.src_vocab_size,
self.hidden_size,
weight_attr=base.ParamAttr(
weight_attr=paddle.ParamAttr(
name='source_embedding',
initializer=uniform_initializer(init_scale),
),
Expand All @@ -554,7 +545,7 @@ def __init__(
self.tar_vocab_size,
self.hidden_size,
sparse=False,
weight_attr=base.ParamAttr(
weight_attr=paddle.ParamAttr(
name='target_embedding',
initializer=uniform_initializer(init_scale),
),
Expand All @@ -564,7 +555,7 @@ def __init__(
for i in range(num_layers):
self.enc_units.append(
self.add_sublayer(
"enc_units_%d" % i,
f"enc_units_{i}",
BasicLSTMUnit(
hidden_size=self.hidden_size,
input_size=self.hidden_size,
Expand All @@ -580,12 +571,12 @@ def __init__(
if i == 0:
self.dec_units.append(
self.add_sublayer(
"dec_units_%d" % i,
f"dec_units_{i}",
BasicLSTMUnit(
hidden_size=self.hidden_size,
input_size=self.hidden_size * 2,
param_attr=ParamAttr(
name="dec_units_%d" % i,
param_attr=paddle.ParamAttr(
name=f"dec_units_{i}",
initializer=uniform_initializer(
self.init_scale
),
Expand All @@ -598,12 +589,12 @@ def __init__(
else:
self.dec_units.append(
self.add_sublayer(
"dec_units_%d" % i,
f"dec_units_{i}",
BasicLSTMUnit(
hidden_size=self.hidden_size,
input_size=self.hidden_size,
param_attr=ParamAttr(
name="dec_units_%d" % i,
param_attr=paddle.ParamAttr(
name=f"dec_units_{i}",
initializer=uniform_initializer(
self.init_scale
),
Expand Down Expand Up @@ -726,12 +717,12 @@ def forward(self, inputs):

# NOTE: modify model code about `enc_hidden` and `enc_cell` to transform dygraph code successfully.
# Because nested list can't be transformed now.
enc_hidden_0 = to_variable(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_hidden_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
enc_hidden_0.stop_gradient = True
enc_cell_0 = to_variable(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_cell_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
enc_hidden_0.stop_gradient = True
zero = paddle.zeros(shape=[1], dtype="int64")
Expand Down Expand Up @@ -789,8 +780,8 @@ def forward(self, inputs):
enc_outputs = self._transpose_batch_time(enc_outputs)

# train
input_feed = to_variable(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
input_feed = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
# NOTE: set stop_gradient here, otherwise grad var is null
input_feed.stop_gradient = True
Expand Down Expand Up @@ -828,8 +819,8 @@ def forward(self, inputs):

dec_output = paddle.stack(dec_output)
dec_output = self.fc(self._transpose_batch_time(dec_output))
loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=dec_output, label=label, soft_label=False
loss = paddle.nn.functional.cross_entropy(
input=dec_output, label=label, soft_label=False, reduction="none"
)
loss = paddle.squeeze(loss, axis=[2])
max_tar_seq_len = paddle.shape(tar)[1]
Expand Down
21 changes: 11 additions & 10 deletions test/dygraph_to_static/simnet_dygraph_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -298,15 +298,16 @@ class FC(paddle.nn.Layer):
Examples:
.. code-block:: python
from paddle.base.dygraph.base import to_variable
import paddle.base as base
from paddle.base.dygraph import FC
import numpy as np
data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
with base.dygraph.guard():
fc = FC("fc", 64, num_flatten_dims=2)
data = to_variable(data)
conv = fc(data)
import paddle
import paddle.base as base
from paddle.base.dygraph import FC
import numpy as np
data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
with base.dygraph.guard():
fc = FC("fc", 64, num_flatten_dims=2)
data = paddle.to_tensor(data)
conv = fc(data)
"""

def __init__(
Expand Down
27 changes: 1 addition & 26 deletions test/dygraph_to_static/test_basic_api_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,13 @@
)

import paddle
from paddle import base, to_tensor
from paddle.base import dygraph
from paddle.base.dygraph import to_variable
from paddle import to_tensor
from paddle.jit.api import to_static

SEED = 2020
np.random.seed(SEED)


def dyfunc_to_variable(x):
res = base.dygraph.to_variable(x, name=None, zero_copy=None)
return res


def dyfunc_to_variable_2(x):
res = dygraph.to_variable(value=np.zeros(shape=(1), dtype=np.int32))
return res


def dyfunc_to_variable_3(x):
res = to_variable(x, name=None, zero_copy=None)
return res


def dyfunc_to_tensor(x):
res1 = paddle.to_tensor(x, dtype=None, place=None, stop_gradient=True)
res2 = paddle.tensor.to_tensor(data=res1)
Expand Down Expand Up @@ -73,15 +56,7 @@ def setUp(self):
dyfunc_bool_to_tensor,
dyfunc_int_to_tensor,
dyfunc_float_to_tensor,
dyfunc_to_variable,
dyfunc_to_variable_2,
dyfunc_to_variable_3,
]
self.place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)

def get_dygraph_output(self):
res = self.dygraph_func(self.input).numpy()
Expand Down
15 changes: 6 additions & 9 deletions test/dygraph_to_static/test_declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_ast_only,
test_legacy_and_pt_and_pir,
)
from test_basic_api_transformation import dyfunc_to_variable

import paddle
from paddle.framework import use_pir_api
Expand All @@ -35,6 +33,11 @@
from paddle.static import InputSpec


def call_to_tensor(x):
res = paddle.to_tensor(x)
return res


def create_simple_net():
class SimpleNet(Layer):
def __init__(self):
Expand Down Expand Up @@ -370,7 +373,7 @@ class TestDeclarativeAPI(Dy2StTestBase):
@test_ast_only
@test_legacy_and_pt_and_pir
def test_error(self):
func = paddle.jit.to_static(dyfunc_to_variable)
func = paddle.jit.to_static(call_to_tensor)

paddle.enable_static()

Expand All @@ -379,12 +382,6 @@ def test_error(self):
with self.assertRaises(RuntimeError):
func(np.ones(5).astype("int32"))

with enable_to_static_guard(False):
with self.assertRaises(AssertionError):
# AssertionError: We Only support to_variable in imperative mode,
# please use base.dygraph.guard() as context to run it in imperative Mode
func(np.ones(5).astype("int32"))

paddle.disable_static()


Expand Down
Loading

0 comments on commit d316845

Please sign in to comment.