Skip to content

Commit

Permalink
[Cleanup][A-5] clean some VarType for test (#61549)
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc authored Feb 4, 2024
1 parent cd99902 commit ca68a91
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 112 deletions.
2 changes: 1 addition & 1 deletion test/collective/fleet/test_fleet_amp_meta_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_pure_fp16_optimizer(self):

params = train_prog.all_parameters()
for param in train_prog.all_parameters():
self.assertEqual(param.dtype, base.core.VarDesc.VarType.FP16)
self.assertEqual(param.dtype, paddle.float16)

ops = [op.type for op in avg_cost.block.ops]
self.assertIn('cast', ops)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_gm_pure_fp16_optimizer(self):

params = train_prog.all_parameters()
for param in train_prog.all_parameters():
self.assertEqual(param.dtype, paddle.base.core.VarDesc.VarType.FP16)
self.assertEqual(param.dtype, paddle.float16)

vars = [x.name for x in train_prog.list_vars()]
self.assertIn('@GradientMerge', ''.join(vars))
Expand Down
12 changes: 6 additions & 6 deletions test/legacy_test/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,28 @@
import unittest

import paddle
from paddle.base import Program, core, program_guard
from paddle.base import Program, program_guard


class TestApiStaticDataError(unittest.TestCase):
def test_dtype(self):
with program_guard(Program(), Program()):
x1 = paddle.static.data(name="x1", shape=[2, 25])
self.assertEqual(x1.dtype, core.VarDesc.VarType.FP32)
self.assertEqual(x1.dtype, paddle.float32)

x2 = paddle.static.data(name="x2", shape=[2, 25], dtype="bool")
self.assertEqual(x2.dtype, core.VarDesc.VarType.BOOL)
self.assertEqual(x2.dtype, paddle.bool)

paddle.set_default_dtype("float64")
x3 = paddle.static.data(name="x3", shape=[2, 25])
self.assertEqual(x3.dtype, core.VarDesc.VarType.FP64)
self.assertEqual(x3.dtype, paddle.float64)

def test_0D(self):
with program_guard(Program(), Program()):
x1 = paddle.static.data(name="x1_0D", shape=[])
self.assertEqual(x1.dtype, core.VarDesc.VarType.FP32)
self.assertEqual(x1.dtype, paddle.float32)
x2 = paddle.static.data(name="x2_0D", shape=(), dtype="bool")
self.assertEqual(x2.dtype, core.VarDesc.VarType.BOOL)
self.assertEqual(x2.dtype, paddle.bool)

def test_error(self):
with program_guard(Program(), Program()):
Expand Down
Loading

0 comments on commit ca68a91

Please sign in to comment.