Skip to content

Commit

Permalink
Fix VarType
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Feb 4, 2024
1 parent 062b99e commit 81b6461
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 20 deletions.
22 changes: 11 additions & 11 deletions python/paddle/tensor/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ def is_complex(x):
raise TypeError(f"Expected Tensor, but received type of x: {type(x)}")
dtype = x.dtype
is_complex_dtype = (
dtype == core.VarDesc.VarType.COMPLEX64
or dtype == core.VarDesc.VarType.COMPLEX128
dtype == paddle.complex64
or dtype == paddle.complex128
or dtype == core.DataType.COMPLEX64
or dtype == core.DataType.COMPLEX128
)
Expand Down Expand Up @@ -203,10 +203,10 @@ def is_floating_point(x):
raise TypeError(f"Expected Tensor, but received type of x: {type(x)}")
dtype = x.dtype
is_fp_dtype = (
dtype == core.VarDesc.VarType.FP32
or dtype == core.VarDesc.VarType.FP64
or dtype == core.VarDesc.VarType.FP16
or dtype == core.VarDesc.VarType.BF16
dtype == paddle.float32
or dtype == paddle.float64
or dtype == paddle.float16
or dtype == paddle.bfloat16
)
return is_fp_dtype

Expand Down Expand Up @@ -246,11 +246,11 @@ def is_integer(x):
is_int_dtype = False
if not in_pir_mode():
is_int_dtype = (
dtype == core.VarDesc.VarType.UINT8
or dtype == core.VarDesc.VarType.INT8
or dtype == core.VarDesc.VarType.INT16
or dtype == core.VarDesc.VarType.INT32
or dtype == core.VarDesc.VarType.INT64
dtype == paddle.uint8
or dtype == paddle.int8
or dtype == paddle.int16
or dtype == paddle.int32
or dtype == paddle.int64
)
else:
is_int_dtype = (
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@


def _complex_to_real_dtype(dtype):
if dtype == core.VarDesc.VarType.COMPLEX64:
if dtype == paddle.complex64:
return core.VarDesc.VarType.FP32
elif dtype == core.VarDesc.VarType.COMPLEX128:
elif dtype == paddle.complex128:
return core.VarDesc.VarType.FP64
else:
return dtype


def _real_to_complex_dtype(dtype):
if dtype == core.VarDesc.VarType.FP32:
if dtype == paddle.float32:
return core.VarDesc.VarType.COMPLEX64
elif dtype == core.VarDesc.VarType.FP64:
elif dtype == paddle.float64:
return core.VarDesc.VarType.COMPLEX128
else:
return dtype
Expand Down
11 changes: 6 additions & 5 deletions python/paddle/tensor/to_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np

import paddle
from paddle.base.data_feeder import check_type, convert_dtype

from ..framework import core
Expand Down Expand Up @@ -238,7 +239,7 @@ def to_string(var, prefix='Tensor'):
indent = len(prefix) + 1

dtype = convert_dtype(var.dtype)
if var.dtype == core.VarDesc.VarType.BF16:
if var.dtype == paddle.bfloat16:
dtype = 'bfloat16'

_template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient},\n{indent}{data})"
Expand All @@ -247,7 +248,7 @@ def to_string(var, prefix='Tensor'):
if not tensor._is_initialized():
return "Tensor(Not initialized)"

if var.dtype == core.VarDesc.VarType.BF16:
if var.dtype == paddle.bfloat16:
var = var.astype('float32')
np_var = var.numpy(False)

Expand Down Expand Up @@ -280,7 +281,7 @@ def to_string(var, prefix='Tensor'):


def _format_dense_tensor(tensor, indent):
if tensor.dtype == core.VarDesc.VarType.BF16:
if tensor.dtype == paddle.bfloat16:
tensor = tensor.astype('float32')

# TODO(zhouwei): will remove 0-D Tensor.numpy() hack
Expand Down Expand Up @@ -360,7 +361,7 @@ def dist_tensor_to_string(tensor, prefix='Tensor'):
# is ready.
indent = len(prefix) + 1
dtype = convert_dtype(tensor.dtype)
if tensor.dtype == core.VarDesc.VarType.BF16:
if tensor.dtype == paddle.bfloat16:
dtype = 'bfloat16'

if not tensor._is_dense_tensor_hold_allocation():
Expand Down Expand Up @@ -395,7 +396,7 @@ def tensor_to_string(tensor, prefix='Tensor'):
indent = len(prefix) + 1

dtype = convert_dtype(tensor.dtype)
if tensor.dtype == core.VarDesc.VarType.BF16:
if tensor.dtype == paddle.bfloat16:
dtype = 'bfloat16'

_template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient},\n{indent}{data})"
Expand Down

0 comments on commit 81b6461

Please sign in to comment.