Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cleanup][B-2] clean some paddle.base.dygraph.to_variable for test #61904

Merged
merged 11 commits into from
Feb 23, 2024
107 changes: 49 additions & 58 deletions test/dygraph_to_static/seq2seq_dygraph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from seq2seq_utils import Seq2SeqModelHyperParams as args

import paddle
from paddle import base
from paddle.base import ParamAttr
from paddle.base.dygraph.base import to_variable
from paddle.nn import Embedding, Layer

INF = 1.0 * 1e5
Expand Down Expand Up @@ -112,14 +108,16 @@ def __init__(
self.mode = mode
self.kinf = 1e9

param_attr = ParamAttr(initializer=uniform_initializer(self.init_scale))
bias_attr = ParamAttr(initializer=zero_constant)
param_attr = paddle.ParamAttr(
initializer=uniform_initializer(self.init_scale)
)
bias_attr = paddle.ParamAttr(initializer=zero_constant)
forget_bias = 1.0

self.src_embeder = Embedding(
self.src_vocab_size,
self.hidden_size,
weight_attr=base.ParamAttr(
weight_attr=paddle.ParamAttr(
initializer=uniform_initializer(init_scale)
),
)
Expand All @@ -128,7 +126,7 @@ def __init__(
self.tar_vocab_size,
self.hidden_size,
sparse=False,
weight_attr=base.ParamAttr(
weight_attr=paddle.ParamAttr(
initializer=uniform_initializer(init_scale)
),
)
Expand All @@ -137,7 +135,7 @@ def __init__(
for i in range(num_layers):
self.enc_units.append(
self.add_sublayer(
"enc_units_%d" % i,
f"enc_units_{i}",
BasicLSTMUnit(
hidden_size=self.hidden_size,
input_size=self.hidden_size,
Expand All @@ -152,7 +150,7 @@ def __init__(
for i in range(num_layers):
self.dec_units.append(
self.add_sublayer(
"dec_units_%d" % i,
f"dec_units_{i}",
BasicLSTMUnit(
hidden_size=self.hidden_size,
input_size=self.hidden_size,
Expand Down Expand Up @@ -211,11 +209,11 @@ def forward(self, inputs):

# NOTE: modify model code about `enc_hidden` and `enc_cell` to transforme dygraph code successfully.
# Because nested list can't be transformed now.
enc_hidden_0 = to_variable(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_hidden_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
enc_cell_0 = to_variable(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_cell_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
zero = paddle.zeros(shape=[1], dtype="int64")
enc_hidden = paddle.tensor.create_array(dtype="float32")
Expand Down Expand Up @@ -292,8 +290,8 @@ def forward(self, inputs):

dec_output = paddle.stack(dec_output)
dec_output = self.fc(self._transpose_batch_time(dec_output))
loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=dec_output, label=label, soft_label=False
loss = paddle.nn.functional.cross_entropy(
input=dec_output, label=label, soft_label=False, reduction="none"
)
loss = paddle.squeeze(loss, axis=[2])
max_tar_seq_len = paddle.shape(tar)[1]
Expand All @@ -312,11 +310,11 @@ def beam_search(self, inputs):
self.batch_size = src.shape[0]

src_emb = self.src_embeder(self._transpose_batch_time(src))
enc_hidden_0 = to_variable(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_hidden_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
enc_cell_0 = to_variable(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_cell_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
zero = paddle.zeros(shape=[1], dtype="int64")
enc_hidden = paddle.tensor.create_array(dtype="float32")
Expand Down Expand Up @@ -367,23 +365,17 @@ def beam_search(self, inputs):

# beam search
batch_beam_shape = (self.batch_size, self.beam_size)
vocab_size_tensor = to_variable(
np.full((1), self.tar_vocab_size)
).astype("int64")
start_token_tensor = to_variable(
np.full(batch_beam_shape, self.beam_start_token, dtype='int64')
vocab_size_tensor = paddle.full([1], self.tar_vocab_size, dtype="int64")
start_token_tensor = paddle.full(
batch_beam_shape, self.beam_start_token, dtype="int64"
)
end_token_tensor = to_variable(
np.full(batch_beam_shape, self.beam_end_token, dtype='int64')
end_token_tensor = paddle.full(
batch_beam_shape, self.beam_end_token, dtype="int64"
)
step_input = self.tar_embeder(start_token_tensor)
beam_finished = to_variable(
np.full(batch_beam_shape, 0, dtype='float32')
)
beam_state_log_probs = to_variable(
np.array(
[[0.0] + [-self.kinf] * (self.beam_size - 1)], dtype="float32"
)
beam_finished = paddle.full(batch_beam_shape, 0, dtype="float32")
beam_state_log_probs = paddle.to_tensor(
[[0.0] + [-self.kinf] * (self.beam_size - 1)], dtype="float32"
)
beam_state_log_probs = paddle.expand(
beam_state_log_probs,
Expand All @@ -395,8 +387,7 @@ def beam_search(self, inputs):

batch_pos = paddle.expand(
paddle.unsqueeze(
to_variable(np.arange(0, self.batch_size, 1, dtype="int64")),
[1],
paddle.arange(0, self.batch_size, 1, dtype="int64"), [1]
),
[-1, self.beam_size],
)
Expand Down Expand Up @@ -437,9 +428,7 @@ def beam_search(self, inputs):
)
noend_array = [-self.kinf] * self.tar_vocab_size
noend_array[self.beam_end_token] = 0
noend_mask_tensor = to_variable(
np.array(noend_array, dtype='float32')
)
noend_mask_tensor = paddle.to_tensor(noend_array, dtype="float32")

step_log_probs = paddle.multiply(
paddle.expand(
Expand Down Expand Up @@ -537,14 +526,16 @@ def __init__(
self.mode = mode
self.kinf = 1e9

param_attr = ParamAttr(initializer=uniform_initializer(self.init_scale))
bias_attr = ParamAttr(initializer=zero_constant)
param_attr = paddle.ParamAttr(
initializer=uniform_initializer(self.init_scale)
)
bias_attr = paddle.ParamAttr(initializer=zero_constant)
forget_bias = 1.0

self.src_embeder = Embedding(
self.src_vocab_size,
self.hidden_size,
weight_attr=base.ParamAttr(
weight_attr=paddle.ParamAttr(
name='source_embedding',
initializer=uniform_initializer(init_scale),
),
Expand All @@ -554,7 +545,7 @@ def __init__(
self.tar_vocab_size,
self.hidden_size,
sparse=False,
weight_attr=base.ParamAttr(
weight_attr=paddle.ParamAttr(
name='target_embedding',
initializer=uniform_initializer(init_scale),
),
Expand All @@ -564,7 +555,7 @@ def __init__(
for i in range(num_layers):
self.enc_units.append(
self.add_sublayer(
"enc_units_%d" % i,
f"enc_units_{i}",
BasicLSTMUnit(
hidden_size=self.hidden_size,
input_size=self.hidden_size,
Expand All @@ -580,12 +571,12 @@ def __init__(
if i == 0:
self.dec_units.append(
self.add_sublayer(
"dec_units_%d" % i,
f"dec_units_{i}",
BasicLSTMUnit(
hidden_size=self.hidden_size,
input_size=self.hidden_size * 2,
param_attr=ParamAttr(
name="dec_units_%d" % i,
param_attr=paddle.ParamAttr(
name=f"dec_units_{i}",
initializer=uniform_initializer(
self.init_scale
),
Expand All @@ -598,12 +589,12 @@ def __init__(
else:
self.dec_units.append(
self.add_sublayer(
"dec_units_%d" % i,
f"dec_units_{i}",
BasicLSTMUnit(
hidden_size=self.hidden_size,
input_size=self.hidden_size,
param_attr=ParamAttr(
name="dec_units_%d" % i,
param_attr=paddle.ParamAttr(
name=f"dec_units_{i}",
initializer=uniform_initializer(
self.init_scale
),
Expand Down Expand Up @@ -726,12 +717,12 @@ def forward(self, inputs):

# NOTE: modify model code about `enc_hidden` and `enc_cell` to transform dygraph code successfully.
# Because nested list can't be transformed now.
enc_hidden_0 = to_variable(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_hidden_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
enc_hidden_0.stop_gradient = True
enc_cell_0 = to_variable(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
enc_cell_0 = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
enc_hidden_0.stop_gradient = True
zero = paddle.zeros(shape=[1], dtype="int64")
Expand Down Expand Up @@ -789,8 +780,8 @@ def forward(self, inputs):
enc_outputs = self._transpose_batch_time(enc_outputs)

# train
input_feed = to_variable(
np.zeros((self.batch_size, self.hidden_size), dtype='float32')
input_feed = paddle.zeros(
shape=[self.batch_size, self.hidden_size], dtype='float32'
)
# NOTE: set stop_gradient here, otherwise grad var is null
input_feed.stop_gradient = True
Expand Down Expand Up @@ -828,8 +819,8 @@ def forward(self, inputs):

dec_output = paddle.stack(dec_output)
dec_output = self.fc(self._transpose_batch_time(dec_output))
loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=dec_output, label=label, soft_label=False
loss = paddle.nn.functional.cross_entropy(
input=dec_output, label=label, soft_label=False, reduction="none"
)
loss = paddle.squeeze(loss, axis=[2])
max_tar_seq_len = paddle.shape(tar)[1]
Expand Down
21 changes: 11 additions & 10 deletions test/dygraph_to_static/simnet_dygraph_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -298,15 +298,16 @@ class FC(paddle.nn.Layer):

Examples:
.. code-block:: python
from paddle.base.dygraph.base import to_variable
import paddle.base as base
from paddle.base.dygraph import FC
import numpy as np
data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
with base.dygraph.guard():
fc = FC("fc", 64, num_flatten_dims=2)
data = to_variable(data)
conv = fc(data)

import paddle
import paddle.base as base
from paddle.base.dygraph import FC
import numpy as np
data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
with base.dygraph.guard():
fc = FC("fc", 64, num_flatten_dims=2)
data = paddle.to_tensor(data)
conv = fc(data)
"""

def __init__(
Expand Down
27 changes: 1 addition & 26 deletions test/dygraph_to_static/test_basic_api_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,13 @@
)

import paddle
from paddle import base, to_tensor
from paddle.base import dygraph
from paddle.base.dygraph import to_variable
from paddle import to_tensor
from paddle.jit.api import to_static

SEED = 2020
np.random.seed(SEED)


def dyfunc_to_variable(x):
res = base.dygraph.to_variable(x, name=None, zero_copy=None)
return res


def dyfunc_to_variable_2(x):
res = dygraph.to_variable(value=np.zeros(shape=(1), dtype=np.int32))
return res


def dyfunc_to_variable_3(x):
res = to_variable(x, name=None, zero_copy=None)
return res


def dyfunc_to_tensor(x):
res1 = paddle.to_tensor(x, dtype=None, place=None, stop_gradient=True)
res2 = paddle.tensor.to_tensor(data=res1)
Expand Down Expand Up @@ -73,15 +56,7 @@ def setUp(self):
dyfunc_bool_to_tensor,
dyfunc_int_to_tensor,
dyfunc_float_to_tensor,
dyfunc_to_variable,
dyfunc_to_variable_2,
dyfunc_to_variable_3,
]
self.place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)

def get_dygraph_output(self):
res = self.dygraph_func(self.input).numpy()
Expand Down
15 changes: 6 additions & 9 deletions test/dygraph_to_static/test_declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_ast_only,
test_legacy_and_pt_and_pir,
)
from test_basic_api_transformation import dyfunc_to_variable

import paddle
from paddle.framework import use_pir_api
Expand All @@ -35,6 +33,11 @@
from paddle.static import InputSpec


def call_to_tensor(x):
res = paddle.to_tensor(x)
return res


def create_simple_net():
class SimpleNet(Layer):
def __init__(self):
Expand Down Expand Up @@ -370,7 +373,7 @@ class TestDeclarativeAPI(Dy2StTestBase):
@test_ast_only
@test_legacy_and_pt_and_pir
def test_error(self):
func = paddle.jit.to_static(dyfunc_to_variable)
func = paddle.jit.to_static(call_to_tensor)

paddle.enable_static()

Expand All @@ -379,12 +382,6 @@ def test_error(self):
with self.assertRaises(RuntimeError):
func(np.ones(5).astype("int32"))

with enable_to_static_guard(False):
with self.assertRaises(AssertionError):
# AssertionError: We Only support to_variable in imperative mode,
# please use base.dygraph.guard() as context to run it in imperative Mode
func(np.ones(5).astype("int32"))

paddle.disable_static()


Expand Down
Loading